Bring your own model types via serialized files

The model registry supports several built-in model types. You can also log other types of models, including those trained using external tools or obtained from open source repositories, as long as they are serializable and extend the snowflake.ml.model.custom_model.CustomModel class.

This guide explains how to:

  • Create custom models.

  • Log them to the Snowflake Model Registry.

  • Deploy them for inference.

Note

This quickstart provides an example of logging a custom PyCaret model.

Defining model context by keyword arguments

Snowflake ML allows an arbitrary number of keyword arguments when instantiating the ModelContext class, allowing you to easily include parameters, configuration files, or instances of your own model classes when defining and initializing a custom model.

Attributes of the model context can be supported model types, such as built-in model types or a path, such as a path to a directory containing a model, parameter, or configuration file.

Below is an example demonstrating how to provide keyword arguments by way of the model context, and how to use them in a custom model class:

import json
import pandas as pd
from snowflake.ml.model import custom_model

# Initialize ModelContext with keyword arguments
# my_model can be any kind of model
mc = custom_model.ModelContext(
    my_model=my_model,
)

# Define a custom model class that utilizes the context
class ExampleBringYourOwnModel(custom_model.CustomModel):
    def __init__(self, context: custom_model.ModelContext) -> None:
        super().__init__(context)

    @custom_model.inference_api
    def predict(self, input: pd.DataFrame) -> pd.DataFrame:
        # Use the model 'my_model' from the context to make predictions
        model_output = self.context['my_model'].predict(input)
        return pd.DataFrame({'output': model_output})
Copy

Testing and logging a custom model

You can test a custom model by running it locally.

my_model = ExampleBringYourOwnModel(mc)
output_df = my_model.predict(input_df)
Copy

When the model works as intended, log it to the Snowflake Model Registry. As shown in the next code example, provide conda_dependencies (or pip_requirements) to specify the libraries that the model class needs. Provide sample_input_data (a pandas or Snowpark DataFrame) to infer the input signature for the model. Alternatively, provide a model signature.

reg = Registry(session=sp_session, database_name="ML", schema_name="REGISTRY")
mv = reg.log_model(my_model_pipeline,
            model_name="my_custom_model_pipeline",
            version_name="v1",
            conda_dependencies=["scikit-learn"],
            comment="My Custom ML Model Pipeline",
            sample_input_data=train_features)
output_df = mv.run(input_df)
Copy

Example: Logging a PyCaret model

PyCaret is a low-code, high-efficiency third-party package that Snowflake doesn’t support natively. You can bring your own model types by similar methods.

Step 1: Define the model context

Before logging the model, define a ModelContext that refers to your own model type that is not natively supported by Snowflake ML. In this case, we specify the path to the serialized (pickled) model using the context’s model_file attribute.

pycaret_mc = custom_model.ModelContext(
  model_file = 'pycaret_best_model.pkl',
)
Copy

Step 2: Create a custom model class

Define a custom model class to log a model type without native support. In this example, a PyCaretModel class, derived from CustomModel, is defined so the model can be logged in the registry.

from pycaret.classification import load_model, predict_model

class PyCaretModel(custom_model.CustomModel):
    def __init__(self, context: custom_model.ModelContext) -> None:
        super().__init__(context)
        model_dir = self.context["model_file"][:-4]  # Remove '.pkl' suffix
        self.model = load_model(model_dir, verbose=False)
        self.model.memory = '/tmp/'  # Update memory directory

    @custom_model.inference_api
    def predict(self, X: pd.DataFrame) -> pd.DataFrame:
        model_output = predict_model(self.model, data=X)
        return pd.DataFrame({
            "prediction_label": model_output['prediction_label'],
            "prediction_score": model_output['prediction_score']
        })
Copy

Note

As shown, set the model’s memory directory to /tmp/. Snowflake’s warehouse nodes have restricted directory access. /tmp is always writeable and is a safe choice when the model needs a place to write files. This might not be necessary for other types of models.

Step 3: Test the custom model

Test the PyCaret model locally using code like the following.

test_data = [
    [1, 237, 1, 1.75, 1.99, 0.00, 0.00, 0, 0, 0.5, 1.99, 1.75, 0.24, 'No', 0.0, 0.0, 0.24, 1],
    # Additional test rows...
]
col_names = ['Id', 'WeekofPurchase', 'StoreID', 'PriceCH', 'PriceMM', 'DiscCH', 'DiscMM',
            'SpecialCH', 'SpecialMM', 'LoyalCH', 'SalePriceMM', 'SalePriceCH',
            'PriceDiff', 'Store7', 'PctDiscMM', 'PctDiscCH', 'ListPriceDiff', 'STORE']

test_df = pd.DataFrame(test_data, columns=col_names)

my_pycaret_model = PyCaretModel(pycaret_mc)
output_df = my_pycaret_model.predict(test_df)
Copy

Step 4: Define a model signature

In this example, use the sample data to infer a model signature for input validation:

predict_signature = model_signature.infer_signature(input_data=test_df, output_data=output_df)
Copy

Step 5: Log the model

The following code logs (registers) the model in the Snowflake Model Regsitry.

snowml_registry = Registry(session)

custom_mv = snowml_registry.log_model(
    my_pycaret_model,
    model_name="'my_pycaret_best_model",
    version_name="version_1",
    conda_dependencies=["pycaret==3.0.2", "scipy==1.11.4", "joblib==1.2.0"],
    options={"relax_version": False},
    signatures={"predict": predict_signature},
    comment = 'My PyCaret classification experiment using the CustomModel API'
)
Copy

Step 6: Verify the model in the registry

To verify that the model is available in the Model Registry, use show_models function.

snowml_registry.show_models()
Copy

Step 7: Make predictions with the registered model

Use the run function to call the model for prediction.

snowpark_df = session.create_dataframe(test_data, schema=col_nms)

custom_mv.run(snowpark_df).show()
Copy

Next Steps

After deploying a PyCaret model by way of the Snowflake Model Registry, you can view the model in Snowsight. Navigate to the Models page under AI & ML. If you do not see it there, make sure you are using the ACCOUNTADMIN role or the role you used to log the model.

To use the model from SQL, use SQL like the following:

SELECT
    my_pycaret_model!predict(*) AS predict_dict,
    predict_dict['prediction_label']::text AS prediction_label,
    predict_dict['prediction_score']::double AS prediction_score
from pycaret_input_data;
Copy