Bring your own model types via serialized files¶
The model registry supports logging built-in model types directly in the registry.
We also provide a method of logging other model types with snowflake.ml.model.custom_model.CustomModel
. Serializable models trained using external tools or obtained from open source repositories can be used with CustomModel
.
This guide explains how to:
Create a custom model.
Create model context with files and model objects.
Log the custom model to the Snowflake Model Registry.
Deploy the model for inference.
Note
This quickstart provides an example of logging a custom PyCaret model.
Defining model context by keyword arguments¶
The snowflake.ml.model.custom_model.ModelContext
can be instantiated with user-defined keyword arguments. The values can either be string file paths or instances of supported model types</developer-guide/snowflake-ml/model-registry/built-in-models/overview>
. The files and serialized models will be packaged with the model for use in the model inference logic.
Files can be serialized models, configuration files, or files containing parameters. A common use of this is to load a pickle or json file in the custom model __init__
method or inference method.
Below is an example demonstrating how to provide models and files with the model context, and use them in a custom model class:
import pickle
import pandas as pd
from snowflake.ml.model import custom_model
# Initialize ModelContext with keyword arguments
# my_model can be any supported model type
# my_file_path is a local pickle file path
model_context = custom_model.ModelContext(
my_model=my_model,
my_file_path='/path/to/file.pkl',
)
# Define a custom model class that utilizes the context
class ExampleBringYourOwnModel(custom_model.CustomModel):
def __init__(self, context: custom_model.ModelContext) -> None:
super().__init__(context)
# Use 'my_file_path' key from the context to load the pickled object
with open(self.context['my_file_path'], 'rb') as f:
self.obj = pickle.load(f)
@custom_model.inference_api
def predict(self, input: pd.DataFrame) -> pd.DataFrame:
# Use the model with key 'my_model' from the context to make predictions
model_output = self.context['my_model'].predict(input)
return pd.DataFrame({'output': model_output})
# Instantiate the custom model with the model context. This instance can be logged in the model registry.
my_model = ExampleBringYourOwnModel(model_context)
Testing and logging a custom model¶
You can test a custom model by running it locally.
my_model = ExampleBringYourOwnModel(model_context)
output_df = my_model.predict(input_df)
When the model works as intended, log it to the Snowflake Model Registry. As shown in the next code
example, provide conda_dependencies
(or pip_requirements
) to specify the libraries that the model class needs.
Provide sample_input_data
(a pandas or Snowpark DataFrame) to infer the input signature for the model. Alternatively,
provide a model signature.
reg = Registry(session=sp_session, database_name="ML", schema_name="REGISTRY")
mv = reg.log_model(my_model,
model_name="my_custom_model",
version_name="v1",
conda_dependencies=["scikit-learn"],
comment="My Custom ML Model",
sample_input_data=train_features)
output_df = mv.run(input_df)
Example: Logging a PyCaret model¶
PyCaret is a low-code, high-efficiency third-party package that Snowflake doesn’t support natively. You can bring your own model types by similar methods.
Step 1: Define the model context¶
Before logging the model, define a ModelContext
that refers to your own model type that is not natively supported by Snowflake ML.
In this case, we specify the path to the serialized (pickled) model using the context’s model_file
attribute.
pycaret_model_context = custom_model.ModelContext(
model_file = 'pycaret_best_model.pkl',
)
Step 2: Create a custom model class¶
Define a custom model class to log a model type without native support. In this example, a PyCaretModel
class,
derived from CustomModel
, is defined so the model can be logged in the registry.
from pycaret.classification import load_model, predict_model
class PyCaretModel(custom_model.CustomModel):
def __init__(self, context: custom_model.ModelContext) -> None:
super().__init__(context)
model_dir = self.context["model_file"][:-4] # Remove '.pkl' suffix
self.model = load_model(model_dir, verbose=False)
self.model.memory = '/tmp/' # Update memory directory
@custom_model.inference_api
def predict(self, X: pd.DataFrame) -> pd.DataFrame:
model_output = predict_model(self.model, data=X)
return pd.DataFrame({
"prediction_label": model_output['prediction_label'],
"prediction_score": model_output['prediction_score']
})
Note
As shown, set the model’s memory directory to /tmp/
. Snowflake’s warehouse nodes have restricted directory
access. /tmp
is always writeable and is a safe choice when the model needs a place to write files. This might
not be necessary for other types of models.
Step 3: Test the custom model¶
Test the PyCaret model locally using code like the following.
test_data = [
[1, 237, 1, 1.75, 1.99, 0.00, 0.00, 0, 0, 0.5, 1.99, 1.75, 0.24, 'No', 0.0, 0.0, 0.24, 1],
# Additional test rows...
]
col_names = ['Id', 'WeekofPurchase', 'StoreID', 'PriceCH', 'PriceMM', 'DiscCH', 'DiscMM',
'SpecialCH', 'SpecialMM', 'LoyalCH', 'SalePriceMM', 'SalePriceCH',
'PriceDiff', 'Store7', 'PctDiscMM', 'PctDiscCH', 'ListPriceDiff', 'STORE']
test_df = pd.DataFrame(test_data, columns=col_names)
my_pycaret_model = PyCaretModel(pycaret_model_context)
output_df = my_pycaret_model.predict(test_df)
Step 4: Define a model signature¶
In this example, use the sample data to infer a model signature for input validation:
predict_signature = model_signature.infer_signature(input_data=test_df, output_data=output_df)
Step 5: Log the model¶
The following code logs (registers) the model in the Snowflake Model Registry.
snowml_registry = Registry(session)
custom_mv = snowml_registry.log_model(
my_pycaret_model,
model_name="'my_pycaret_best_model",
version_name="version_1",
conda_dependencies=["pycaret==3.0.2", "scipy==1.11.4", "joblib==1.2.0"],
options={"relax_version": False},
signatures={"predict": predict_signature},
comment = 'My PyCaret classification experiment using the CustomModel API'
)
Step 6: Verify the model in the registry¶
To verify that the model is available in the Model Registry, use show_models
function.
snowml_registry.show_models()
Step 7: Make predictions with the registered model¶
Use the run
function to call the model for prediction.
snowpark_df = session.create_dataframe(test_data, schema=col_nms)
custom_mv.run(snowpark_df).show()
Next Steps¶
After deploying a PyCaret model by way of the Snowflake Model Registry, you can view the model in Snowsight. Navigate to the Models page under AI & ML. If you do not see it there, make sure you are using the ACCOUNTADMIN role or the role you used to log the model.
To use the model from SQL, use SQL like the following:
SELECT
my_pycaret_model!predict(*) AS predict_dict,
predict_dict['prediction_label']::text AS prediction_label,
predict_dict['prediction_score']::double AS prediction_score
from pycaret_input_data;