Inference in Snowflake Warehouses

The Snowflake Model Registry runs model methods in a warehouse by default. Running models in a warehouse is appropriate for small-to-medium size CPU-only models whose dependencies can be satisfied by packages available in the Snowflake conda channel.

Note

You can also run models on a Snowpark Container Services (SPCS) compute pool. This approach is more appropriate for large models that can benefit from distributed inference. See Model Serving in Snowpark Container Services for more information.

To get started, log the model in the model registry. Only models that have been logged in the registry are available for inference. Logging a model, or loading an existing model from the registry by means of registry.get_model(...).version(...), returns a ModelVersion object on which you can call the run method.

To call a method of a model version, use mv.run, where mv is a ModelVersion object. Specify the name of the function to be called and pass a Snowpark or pandas DataFrame that contains the inference data, along with any required parameters. The method is executed in a Snowflake warehouse.

The return value of the method is a Snowpark or pandas DataFrame, matching the type of DataFrame passed in. Snowpark DataFrames are evaluated lazily, so the method is run only when the DataFrame’s collect, show, or to_pandas method is called.

Note

Invoking a method runs it in the warehouse specified in the session you’re using to connect to the registry. See Specifying a Warehouse.

The following example illustrates running the predict method of a model. This model’s predict method does not require any parameters besides the inference data (test_features here). If it did, they would be passed as additional arguments after the inference data:

remote_prediction = mv.run(test_features, function_name="predict")
remote_prediction.show()   # assuming test_features is Snowpark DataFrame
Copy

To see what methods can be called on a given model, call mv.show_functions. The return value of this method is a list of ModelFunctionInfo objects. Each of these objects includes the following attributes:

  • name: The name of the function that can be called from Python or SQL.

  • target_method: The name of the Python method in the original logged model.

Tip

You can also call model methods in SQL. See Calling model methods.