Snowflake Model Registry: Partitioned Custom Models

Many datasets can be easily partitioned into multiple independent subsets. For example, a dataset containing sales data for a chain of stores can be partitioned by store number. A separate model can then be trained for each partition. Training and inference operations on the partitions can be parallelized, reducing the wall-clock time for these operations. Furthermore, since individual stores likely differ somewhat in how their features affect their sales, this approach can actually lead to more accurate inference at the store level.

The Snowflake Model Registry supports distributed processing of training and inference of partitioned data when:

  • The dataset contains a column that reliably identifies partitions in the data.

  • The data in each individual partition is uncorrelated with the data in the other partitions and contains enough rows to train the model.

  • The model is stateless: it performs both fitting (training) and inference (prediction) each time it is called and does not persist weights or other model state between calls.

With the Snowflake Model Registry, you implement partitioned training and inference using custom models. When using the model, the registry partitions the dataset, fits and predicts the partitions in parallel using all the nodes and cores in your warehouse, and combines the results into a single dataset afterward.

Note

Partitioned training and inference requires Snowpark ML (snowflake-ml-python package) version 1.5.0 or later.

Defining and logging the custom model

As explained in Writing the Custom Model Class, you declare custom model inference methods with the @custom_model.partitioned_inference_api decorator (Snowpark ML version 1.5.4 or later) or @custom_model.inference_api decorator (Snowpark ML version 1.5.0 to 1.5.3).

class ExampleForecastingModel(custom_model.CustomModel):

  @custom_model.partitioned_inference_api
  def predict(self, input: pd.DataFrame) -> pd.DataFrame:
      # All data in the partition will be loaded in the input dataframe.
      #… implement model logic here …
      return output_df

my_model = ExampleForecastingModel()
Copy

When logging the model, provide a function_type of TABLE_FUNCTION in the options dictionary along with any other options your model requires.

reg = Registry(session=sp_session, database_name="ML", schema_name="REGISTRY")
mv = reg.log_model(my_model,
  model_name="my_model",
  version_name="v1",
  options={"function_type": "TABLE_FUNCTION"},    ###
  conda_dependencies=["scikit-learn"],
  sample_input_data=train_features
)
Copy

If your custom model also has regular (non-table) functions as methods, you can instead use the method_options dictionary to specify the type of each method.

model_version = reg.log_model(my_model,
    model_name="my_model",
    version_name="v1",
    options={
      "method_options": {                                 ###
        "METHOD1": {"function_type": "TABLE_FUNCTION"},   ###
        "METHOD2": {"function_type": "FUNCTION"}          ###
      }
    }
    conda_dependencies=["scikit-learn"],
    sample_input_data=train_features
)
Copy

Performing training and inference

Use the run method of a Python ModelVersion object to invoke the table function methods in a partitioned fashion, passing partition_column to specify the name of the column that contains a numeric or string value that identifies the partition of each record. As usual, you may pass a Snowpark or pandas DataFrame (the latter is useful for local testing). You will receive the same type of DataFrame as the result. In these examples, we partition on a store number.

mv.run(
  input_df,
  function_name="PREDICT",
  partition_column="STORE_NUMBER"
)
Copy

You can also call these methods using partitioned data from SQL, as shown here.

SELECT OUTPUT1, OUTPUT2, PARTITION_COLUMN
  FROM input_table,
      table(
          MY_MODEL!PREDICT(input_table.INPUT1, input_table.INPUT2)
          OVER (PARTITION BY input_table.STORE_NUMBER)
      )
  ORDER BY input_table.STORE_NUMBER;
Copy

The input data is automatically split among the nodes and cores in your warehouse and the partitions are processed in parallel.

Tip

Many datasets can be partitioned in more than one way. Since the partition column is specified when you call the model, not when you log it, you can easily try out different partitioning schemes without changing the model. For example, in the hypothetical store sales dataset, you could partition by store number or by state or province to see which predicts more effectively.

This also means you don’t need a separate model for unpartitioned processing. If you don’t specify a partition column, no partitioning is done, and all the data is processed together as usual.

Example

See the notebook on our Google Drive for an example, including sample data.