Storing Custom Models in the Snowflake Model Registry¶
The Snowflake model registry allows you to register (log) models and use them for inference inside Snowflake. The registry supports several types of models:
scikit-learn
XGBoost
LightGBM
CatBoost
PyTorch
TensorFlow
MLFlow PyFunc
Sentence Transformer
Hugging Face pipeline
The model registry API also allows you to log other types of models, including those trained using external tools or
obtained from open source repositories, as long as they are serializable and derived from the
snowflake.ml.model.custom_model.CustomModel
class.
Note
This guide provides an example of logging a custom model.
See also this blog post and accompanying Quickstarts for information about importing models from AWS SageMaker or from Azure ML.
This topic explains how to create models, log them to the Snowflake Model Registry, and deploy them. A common use case for this feature is to define pipelines of multiple classes, such as a few transformers or imputers followed by a predictor or classifier. In such cases, the custom model class itself is relatively simple, calling these classes in sequence and passing the result of one as the input to the next.
Defining Model Context¶
Models often require one or more static files, such as configuration files or model weights, in addition to the code.
Custom models must provide information on all such artifacts so the registry knows to pack them along with the model.
Artifacts can be declared using the ModelContext
class as follows.
from snowflake.ml.model import custom_model
mc = custom_model.ModelContext(
artifacts={
'config': 'local_model_dir/config.json'
},
)
The paths to artifact files in model context are relative to the current directory in the environment from which the
model is logged. The Snowflake Model Registry uses this information to ensure that all necessary code and data is
deployed to the warehouse where your model will run. At runtime, the model can find those artifacts by calling
self.context.path('config')
(the value 'config'
is the same as the key in the dictionary passed to
ModelContext
).
Besides static files, a model can compose other models or pipelines of a supported type (for example, a
snowflake.ml.modeling.pipeline
or a scikit-learn model). The registry already knows how to log these types of
objects, so you can pass the Python objects directly in the model context using model_refs
, as shown here. You don’t
need to package these objects yourself. This can be useful for bootstrap aggregation (bagging) or for preprocessing or
postprocessing.
Note
model1
and model2
are objects of any type of model natively supported by the registry. feature_preproc
is a snowflake.ml.modeling.pipeline
object.
mc = custom_model.ModelContext(
artifacts={
'config': 'local_model_dir/config.json'
},
models={
'm1': model1,
'm2': model2,
'feature_preproc': preproc
}
)
The model registry serializes these model references when logging the model, and rehydrates them at runtime. Your model
can retrieve references to these subordinate models using, for example, self.context.model_ref('m1')
.
If the model exposes a predict
method, your code can call it directly from the retrieved model reference,
for example with self.context.model_ref('m1').predict()
.
In summary, then, a custom model’s context declares the Python objects to be serialized along with the model’s artifacts, which are local files that are used by the model and which must be stored in Snowflake along with the code. Your model class uses the context to locate the code and the artifacts; this works whether your model is running locally or in Snowflake.
Writing the Custom Model Class¶
To tell the model registry how to log and deploy your custom model, inherit from snowflake.ml.model.custom_model.CustomModel
.
Models can expose multiple inference methods (for example, scikit-learn models expose predict
and predict_proba
methods). To declare inference functions in your custom model, define them as public methods of your subclass and
decorate them with @custom_model.inference_api
. This decorator indicates that a method is part of the model’s
public API, allowing it to be called from Python or SQL via the model registry. Methods decorated with inference_api
must accept and return a pandas DataFrame. Any number of methods can be decorated with inference_api
.
Note
The requirement for the public API to accept and return a pandas DataFrame is the same as for vectorized UDFs. As with vectorized UDFs, these inference APIs can be called from Python passing a Snowpark DataFrame as well
A skeleton custom model class with a public API is shown below.
Note the use of context.path
to open the bias file
and self.context.model_ref
to obtain references to the subordinate model classes so that their predict
methods
can be called.
from snowflake.ml.model import custom_model
import pandas as pd
class ExamplePipelineModel(custom_model.CustomModel):
@custom_model.inference_api
def predict(self, input: pd.DataFrame) -> pd.DataFrame:
...
return pd.DataFrame(...)
Putting all the pieces together, the following is a fully-functional custom model.
class ExamplePipelineModel(custom_model.CustomModel):
def __init__(self, context: ModelContext) -> None:
super().__init__(context)
v = open(context.path('config')).read()
self.bias = json.loads(v)['bias']
@custom_model.inference_api
def predict(self, input: pd.DataFrame) -> pd.DataFrame:
features = self.context.model_ref('feature_preproc').transform(input)
model_output = self.context.model_ref('m2').predict(
self.context.model_ref('m1').predict(features)
)
return pd.DataFrame({
'output': model_output + self.bias})
Using the Custom Model¶
You can test your new custom model (pipeline) by running it locally as follows:
my_model_pipeline = ExamplePipelineModel(mc)
output_df = my_model_pipeline.predict(input_df)
Or log it in the registry and deploy it to Snowflake. As shown in the next code example, provide conda_dependencies
(or
pip_requirements
) to specify the libraries that the model class needs.
Provide sample_input_data
(a pandas DataFrame) to infer the input signature for the model. Alternatively,
provide a model signature.
reg = Registry(session=sp_session, database_name="ML", schema_name="REGISTRY")
mv = reg.log_model(my_model_pipeline,
model_name="my_custom_model_pipeline",
version_name="v1",
conda_dependencies=["scikit-learn"],
comment="My Custom ML Model Pipeline",
sample_input_data=train_features)
output_df = mv.run(input_df)
Table Function Inference¶
As of Snowpark ML 1.5.4, you can log models with inference methods that return
multiple columns. To do so, log your model with option {"function_type": "TABLE_FUNCTION"}
and use the
@inference_api
decorator as above. In the following example, the decorated method returns a pandas DataFrame that includes
two output columns.
class ExampleTableFunctionModel(custom_model.CustomModel):
@custom_model.inference_api
def predict(self, input: pd.DataFrame) -> pd.DataFrame:
output_df = pandas.DataFrame([{"OUTPUT1": input["INPUT1"] + 1, input["INPUT2"] + 1}])
return output_df
my_model = ExampleTableFunctionModel()
reg = Registry(session=sp_session, database_name="ML", schema_name="REGISTRY")
mv = reg.log_model(my_model,
model_name="my_custom_table_function_model",
version_name="v1",
options={"function_type": "TABLE_FUNCTION"},
sample_input_data=train_features
)
output_df = mv.run(input_df)
If the model includes multiple inference methods, use the method_options
option to log the model,
indicating which are FUNCTION
and which are TABLE_FUNCTION
:
reg = Registry(session=sp_session, database_name="ML", schema_name="REGISTRY")
mv = reg.log_model(my_model,
model_name="my_custom_table_function_model",
version_name="v1",
options={
"method_options": { ###
"METHOD1": {"function_type": "TABLE_FUNCTION"}, ###
"METHOD2": {"function_type": "FUNCTION"} ###
}
},
sample_input_data=train_features
)
The logged model table function inference method can also be invoked via SQL as follows:
SELECT OUTPUT1, OUTPUT2
FROM input_table,
table(
MY_MODEL!PREDICT(input_table.INPUT1, input_table.INPUT2)
)
To learn more about partitioned model inference methods where the inference method takes a data partition as input and outputs multiple rows and columns, see Partitioned Custom Models.