Storing Custom Models in the Snowflake Model Registry

The Snowflake model registry allows you to register (log) models and use them for inference inside Snowflake. The registry supports several types of models:

  • Snowpark ML Modeling

  • scikit-learn

  • XGBoost

  • LightGBM

  • CatBoost

  • PyTorch

  • TensorFlow

  • MLFlow PyFunc

  • Sentence Transformer

  • Hugging Face pipeline

The model registry API also allows you to log other types of models, including those trained using external tools or obtained from open source repositories, as long as they are serializable and derived from the snowflake.ml.model.custom_model.CustomModel class.

Note

See this blog post and accompanying Quickstarts for information about importing models from AWS SageMaker or from Azure ML.

This topic explains how to create models, log them to the Snowflake Model Registry, and deploy them. A common use case for this feature is to define pipelines of multiple classes, such as a few transformers or imputers followed by a predictor or classifier. In such cases, the custom model class itself is relatively simple, calling these classes in sequence and passing the result of one as the input to the next.

Defining Model Context

Models often require one or more static files, such as configuration files or model weights, in addition to the code. Custom models must provide information on all such artifacts so the registry knows to pack them along with the model. Artifacts can be declared using the ModelContext class as follows.

from snowflake.ml.model import custom_model

mc = custom_model.ModelContext(
    artifacts={
        'config': 'local_model_dir/config.json'
    },
)
Copy

The paths to artifact files in model context are relative to the current directory in the environment from which the model is logged. The Snowflake Model Registry uses this information to ensure that all necessary code and data is deployed to the warehouse where your model will run. At runtime, the model can find those artifacts by calling self.context.path('config') (the value 'config' is the same as the key in the dictionary passed to ModelContext).

Besides static files, a model can compose other models or pipelines of a supported type (for example, a snowflake.ml.modeling.pipeline or a scikit-learn model). The registry already knows how to log these types of objects, so you can pass the Python objects directly in the model context using model_refs, as shown here. You don’t need to package these objects yourself. This can be useful for bootstrap aggregation (bagging) or for preprocessing or postprocessing.

Note

model1 and model2 are objects of any type of model natively supported by the registry. feature_preproc is a snowflake.ml.modeling.pipeline object.

mc = custom_model.ModelContext(
    artifacts={
        'config': 'local_model_dir/config.json'
    },
    model_refs={
        'm1': model1,
        'm2': model2,
        'feature_preproc': preproc
    }
)
Copy

The model registry serializes these model references when logging the model, and rehydrates them at runtime. Your model can retrieve references to thes subordinate models using, for example, self.context.model_ref('m1'). If the model exposes a predict method, your code can call it directly from the retrieved model reference, for example with self.context.model_ref('m1').predict().

In summary, then, a custom model’s context declares the Python objects to be serialized along with the model’s artifacts, which are local files that are used by the model and which must be stored in Snowflake along with the code. Your model class uses the context to locate the code and the artifacts; this works whether your model is running locally or in Snowflake.

Writing the Custom Model Class

To tell the model registry how to log and deploy your custom model, inherit from snowflake.ml.model.custom_model.CustomModel.

Models can expose multiple inference methods (for example, scikit-learn models expose predict and predict_proba methods). To declare inference functions in your custom model, define them as public methods of your subclass and decorate them with @custom_model.inference_api. This decorator indicates that a method is part of the model’s public API, allowing it to be called from Python or SQL via the model registry. Methods decorated with inference_api must accept and return a pandas DataFrame. Any number of methods can be decorated with inference_api.

Note

The requirement for the public API to accept and return a pandas DataFrame is the same as for vectorized UDFs. As with vectorized UDFs, these inference APIs can be called from Python passing a Snowpark DataFrame as well

A skeleton custom model class with a public API is shown below.

Note the use of context.path to open the bias file and self.context.model_ref to obtain references to the subordinate model classes so that their predict methods can be called.

from snowflake.ml.model import custom_model
import pandas as pd

class ExamplePipelineModel(custom_model.CustomModel):

    @custom_model.inference_api
    def run(self, input: pd.DataFrame) -> pd.DataFrame:
        ...
        return pd.DataFrame(...)
Copy

Putting all the pieces together, the following is a fully-functional custom model.

class ExamplePipelineModel(custom_model.CustomModel):
    def __init__(self, context: ModelContext) -> None:
        super().__init__(context)
        v = int(open(context.path('config')).read())
        self.bias = json.loads(v)['bias']

    @custom_model.inference_api
    def run(self, input: pd.DataFrame) -> pd.DataFrame:
        features = self.context.model_ref('feature_preproc').transform(input)
        model_output = self.context.model_ref('m2').predict(
            self.context.model_ref('m1').predict(features)
        )
        return pd.DataFrame({
            'output': model_output + self.bias})
Copy

Using the Custom Model

Now, you can test your new custom model (pipeline) by running it locally as follows:

mymodel = ExamplePipelineModel(mc)
output_df = mymodelpipeline.run(input_df)
Copy

Or log it in the registry and deploy it to Snowflake. As shown here, provide conda_dependencies (or pip_requirements) to specify the libraries that the model class needs.

reg = Registry(session=sp_session, database_name="ML", schema_name="REGISTRY")
mv = reg.log_model(mymodelpipeline,
            model_name="my_custom_modelpipeline",
            version_name="v1",
            conda_dependencies=["scikit-learn"],
            comment="My Custom ML Modelpipeline",
            sample_input_data=train_features)
output_df = mymodelpipeline.run(input_df)
Copy