Many Model Training and Inference

Classes

class snowflake.ml.modeling.distributors.many_model.many_model_training.ManyModelTraining(train_func: Callable[[DataConnector, PartitionContext], Any], stage_name: str, serde: ModelSerde | None = None)

Bases: DPF

Specialized distributed model training across data partitions.

ManyModelTraining extends DPF specifically for machine learning scenarios where you need to train separate models on different data partitions. It automatically handles model serialization, artifact management, and provides framework-specific serialization options.

Key features:
  • Automatic model serialization and saving to Snowflake stages

  • Built-in serializers for popular frameworks (XGBoost, scikit-learn, PyTorch, TensorFlow) or implement ModelSerde for custom formats

  • Consistent artifact naming and organization

  • Progress monitoring

  • Seamless integration with model inference workflows

The user-provided training function should return a trained model object. The class automatically handles saving the model using the specified serialization method.

Typical workflow:
  1. Train models across data partitions using ManyModelTraining

  2. Monitor training progress and wait for completion

  3. Run inference on new data using ManyModelInference with the training run ID

  4. Access results from both training artifacts and inference outputs

  5. Optional: Restore previous DPF run (useful for retrieving previous results)

Example

Complete end-to-end many model workflow:

1. Define Training Function

>>> def train_xgboost_model(data_connector, context):
...     df = data_connector.to_pandas()
...     X = df[['feature1', 'feature2', 'feature3']]
...     y = df['target']
...
...     model = XGBRegressor()
...     model.fit(X, y)
...     return model  # Automatically saved as model.pkl
Copy

2. Train Models Across Partitions

>>> trainer = ManyModelTraining(train_xgboost_model, "model_stage")
>>> training_run = trainer.run(
...     partition_by="region",
...     snowpark_dataframe=sales_data,
...     run_id="regional_models_v1"
... )
>>> training_run.wait()
Copy

3. Run Inference with Trained Models

>>> def predict_sales(data_connector, model, context):
...     df = data_connector.to_pandas()
...     X = df[['feature1', 'feature2', 'feature3']]
...     predictions = model.predict(X)
...
...     results = df.copy()
...     results['predictions'] = predictions
...     results['region'] = context.partition_id
...
...     # Two persistence strategies (choose one or both based on your needs):
...
...     # Strategy 1: Stage artifacts - for framework management and debugging
...     context.upload_to_stage(results, "predictions.csv",
...         write_function=lambda df, path: df.to_csv(path, index=False))
...
...     # Strategy 2: Snowflake table - for immediate downstream consumption
...     predictions_df = context.session.create_dataframe(results)
...     predictions_df.write.mode("append").save_as_table("sales_predictions")
...
...     return predictions
Copy
>>> from snowflake.ml.modeling.distributors.many_model import ManyModelInference
>>> inference = ManyModelInference(predict_sales, "model_stage", "regional_models_v1")
>>> inference_run = inference.run(
...     partition_by="region",
...     snowpark_dataframe=new_sales_data,
...     run_id="predictions_2024"
... )
>>> inference_run.wait()
Copy

4. Access Results (Two Different Approaches)

>>> # Approach 1: Access stage artifacts (for detailed analysis and debugging)
>>> for partition_id in training_run.partition_details:
...     trained_model = training_run.get_model(partition_id)
...     predictions_csv = inference_run.partition_details[partition_id].stage_artifacts_manager.get(
...         "predictions.csv")
Copy
>>> # Approach 2: Query Snowflake table (for analytics and aggregations)
>>> from snowflake.runtime.utils.session_utils import get_session
>>> session = get_session()
>>> all_predictions = session.table("sales_predictions")
>>> summary_stats = all_predictions.group_by("region").agg(
...     {"predictions": "avg", "predictions": "count"}).collect()
Copy

See also

ManyModelInference: For running inference with trained models DPF: Base class for general distributed processing

Initialize ManyModelTraining.

Parameters:
  • train_func (Callable[[DataConnector, PartitionContext], Any]) – User function that takes (DataConnector, PartitionContext) and returns a trained model. Both parameters are automatically provided by the framework. Follows the same signature pattern as regular DPF functions for consistency.

  • stage_name (str) – Stage name for storing artifacts

  • serde (Optional[ModelSerde]) – ModelSerde instance for handling model serialization and deserialization. Built-in options: PickleSerde (XGBoost, scikit-learn), TorchSerde, TensorFlowSerde, or implement ModelSerde for custom formats. Defaults to PickleSerde.

Examples

Default pickle serialization:

>>> def train_xgb_model(data_connector, context):
...     df = data_connector.to_pandas()
...     X = df[['feature1', 'feature2']]
...     y = df['target']
...
...     from xgboost import XGBRegressor
...     model = XGBRegressor()
...     model.fit(X, y)
...     return model  # Automatically saved as model.pkl
...
>>> trainer = ManyModelTraining(train_xgb_model, "models_stage")
Copy

PyTorch model training:

>>> def train_pytorch_model(data_connector, context):
...     import torch
...     import torch.nn as nn
...
...     df = data_connector.to_pandas()
...     # ... prepare data for PyTorch ...
...
...     model = nn.Sequential(nn.Linear(10, 1))
...     # ... training logic ...
...     return model  # Automatically saved as model.pth
...
>>> from snowflake.ml.modeling.distributors.many_model import TorchSerde
>>> trainer = ManyModelTraining(train_pytorch_model, "models_stage", serde=TorchSerde())
Copy

TensorFlow model training:

>>> def train_tf_model(data_connector, context):
...     import tensorflow as tf
...
...     df = data_connector.to_pandas()
...     # ... prepare data for TensorFlow ...
...
...     model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
...     # ... training logic ...
...     return model  # Automatically saved as model.h5
...
>>> from snowflake.ml.modeling.distributors.many_model import TensorFlowSerde
>>> trainer = ManyModelTraining(train_tf_model, "models_stage", serde=TensorFlowSerde())
Copy

Custom serialization for specialized model formats:

>>> from snowflake.ml.modeling.distributors.many_model import ModelSerde
>>> import json
>>>
>>> class ScikitLearnSerde(ModelSerde):
...     '''Custom serializer for scikit-learn models with metadata'''
...
...     @property
...     def filename(self) -> str:
...         return "sklearn_model.joblib"
...
...     def write(self, model, file_path: str) -> None:
...         import joblib
...         # Save model with metadata
...         model_data = {
...             'model': model,
...             'feature_names': getattr(model, 'feature_names_in_', None),
...             'model_type': type(model).__name__
...         }
...         joblib.dump(model_data, file_path)
...
...     def read(self, file_path: str):
...         import joblib
...         return joblib.load(file_path)
>>>
>>> def train_sklearn_model(data_connector, context):
...     from sklearn.ensemble import RandomForestRegressor
...     df = data_connector.to_pandas()
...     X, y = df[['feature1', 'feature2']], df['target']
...
...     model = RandomForestRegressor()
...     model.fit(X, y)
...     return model  # Automatically saved with metadata
...
>>> trainer = ManyModelTraining(train_sklearn_model, "models_stage", serde=ScikitLearnSerde())
Copy
run(*, partition_by: str, snowpark_dataframe: DataFrame, run_id: str, on_existing_artifacts: str = 'error', execution_options: ExecutionOptions | None = None) ManyModelRun

Execute distributed model training across data partitions.

Trains separate models on each partition of the data and automatically saves them to the specified Snowflake stage using the configured serialization method.

Parameters:
  • partition_by (str) – Column name to partition the DataFrame by. Each unique value creates a separate partition with its own trained model.

  • snowpark_dataframe (snowpark.DataFrame) – DataFrame containing training data.

  • run_id (str) – Unique identifier for this training run. Used to organize model artifacts and enable future inference runs.

  • on_existing_artifacts (str, optional) –

    Action for existing artifacts. Defaults to “error”.

    • "error": Raises error if training artifacts already exist

    • "overwrite": Replaces existing training artifacts

  • execution_options (ExecutionOptions, optional) – Configuration for distributed execution.

Returns:

Enhanced run handle with model-specific capabilities for accessing trained models and monitoring training progress.

Return type:

ManyModelRun

Example

>>> trainer = ManyModelTraining(my_train_func, "models_stage")
>>> run = trainer.run(
...     partition_by="store_id",
...     snowpark_dataframe=sales_data,
...     run_id="store_models_2024"
... )
>>> run.wait()
>>> # Models are now saved and can be used for inference
Copy
class snowflake.ml.modeling.distributors.many_model.many_model_inference.ManyModelInference(inference_func: Callable[[DataConnector, Any, PartitionContext], Any], stage_name: str, training_run_id: str, serde: ModelSerde | None = None)

Bases: DPF

Specialized distributed model inference across data partitions.

ManyModelInference extends DPF for machine learning scenarios where you need to run inference using previously trained models on different data partitions. It automatically handles model loading, artifact retrieval, and provides framework-specific deserialization.

Key features:
  • Automatic model loading from previous training runs

  • Built-in deserializers for popular frameworks (XGBoost, scikit-learn, PyTorch, TensorFlow) or use the same ModelSerde from training

  • Seamless integration with ManyModelTraining workflows

  • Built-in error handling for model loading

  • Consistent artifact naming and retrieval

The user-provided inference function receives a pre-loaded model object along with the data. The class automatically handles loading the correct model for each partition.

Note

Your inference function signature should be (data_connector, model, context) - all three parameters are automatically provided by the framework. You don’t pass these values yourself.

Note

This class is typically used as step 3 in the many model workflow after training models with ManyModelTraining. See ManyModelTraining for the complete workflow.

Example

Basic inference function where each partition gets only its own data subset and the system automatically loads the corresponding trained model for that partition:

>>> def predict_with_model(data_connector, model, context):
...     df = data_connector.to_pandas()
...     X = df[['feature1', 'feature2', 'feature3']]
...
...     predictions = model.predict(X)
...     results = df.copy()
...     results['predictions'] = predictions
...
...     context.upload_to_stage(results, "predictions.csv")
...     return results
...
>>> inference = ManyModelInference(
...     predict_with_model,
...     "model_stage",
...     training_run_id="regional_models_v1"
... )
>>> run = inference.run(
...     partition_by="region",
...     snowpark_dataframe=new_data,
...     run_id="predictions_2024"
... )
>>> run.wait()
Copy

Initialize ManyModelInference.

Parameters:
  • inference_func (Callable[[DataConnector, Any, PartitionContext], Any]) – User function that takes (DataConnector, model, PartitionContext) and returns results. All three parameters are automatically provided by the framework. Follows the same signature pattern as regular DPF functions for consistency.

  • stage_name (str) – Stage name where training artifacts are stored

  • training_run_id (str) – Run ID from the training phase to load models from

  • serde (Optional[ModelSerde]) – ModelSerde instance for handling model deserialization. Built-in options: PickleSerde (XGBoost, scikit-learn), TorchSerde, TensorFlowSerde, or use the same ModelSerde implementation from training. Defaults to PickleSerde.

Examples

Default pickle deserialization for XGBoost models:

>>> def predict_with_xgb(data_connector, model, context):
...     df = data_connector.to_pandas()
...     X = df[['feature1', 'feature2']]
...
...     predictions = model.predict(X)
...     results = df.copy()
...     results['predictions'] = predictions
...     results['partition_id'] = context.partition_id
...
...     # Two persistence strategies (choose based on your use case):
...
...     # Strategy 1: Stage artifacts - for framework management and retrieval
...     context.upload_to_stage(results, "predictions.csv",
...         write_function=lambda df, path: df.to_csv(path, index=False))
...
...     # Strategy 2: Snowflake table - for immediate analytics and dashboards
...     results_df = context.session.create_dataframe(results)
...     results_df.write.mode("append").save_as_table("my_predictions_table")
...
...     return predictions
...
>>> inference = ManyModelInference(predict_with_xgb, "models_stage", "training_run_v1")
Copy

PyTorch model inference:

>>> def predict_with_pytorch(data_connector, model, context):
...     import torch
...
...     df = data_connector.to_pandas()
...     # ... prepare data for PyTorch ...
...
...     model.eval()
...     with torch.no_grad():
...         predictions = model(torch.tensor(X_test, dtype=torch.float32))
...
...     context.upload_to_stage(predictions.numpy(), "predictions.pkl")
...     return predictions
...
>>> from snowflake.ml.modeling.distributors.many_model import TorchSerde
>>> inference = ManyModelInference(
...     predict_with_pytorch, "models_stage", "training_run_v1", serde=TorchSerde())
Copy

TensorFlow model inference:

>>> def predict_with_tf(data_connector, model, context):
...     import tensorflow as tf
...
...     df = data_connector.to_pandas()
...     # ... prepare data for TensorFlow ...
...
...     predictions = model.predict(X_test)
...     context.upload_to_stage(predictions, "predictions.pkl")
...     return predictions
...
>>> from snowflake.ml.modeling.distributors.many_model import TensorFlowSerde
>>> inference = ManyModelInference(
...     predict_with_tf, "models_stage", "training_run_v1", serde=TensorFlowSerde())
Copy

Custom deserialization for specialized model formats:

>>> # Using the same ScikitLearnSerde from training
>>> class ScikitLearnSerde(ModelSerde):
...     @property
...     def filename(self) -> str:
...         return "sklearn_model.joblib"
...
...     def write(self, model, file_path: str) -> None:
...         import joblib
...         model_data = {
...             'model': model,
...             'feature_names': getattr(model, 'feature_names_in_', None),
...             'model_type': type(model).__name__
...         }
...         joblib.dump(model_data, file_path)
...
...     def read(self, file_path: str):
...         import joblib
...         return joblib.load(file_path)
>>>
>>> def predict_with_sklearn(data_connector, model_data, context):
...     df = data_connector.to_pandas()
...     X = df[['feature1', 'feature2']]
...
...     # Access model and metadata
...     model = model_data['model']
...     model_type = model_data['model_type']
...
...     predictions = model.predict(X)
...
...     # Create results with model metadata
...     results_df = df.copy()
...     results_df['predictions'] = predictions
...     results_df['model_type'] = model_type
...     results_df['partition_id'] = context.partition_id
...
...     # Hybrid persistence strategy - different data types to different locations:
...
...     # Detailed metadata to stage (for debugging and audit)
...     metadata = {'model_type': model_type, 'feature_count': len(X.columns)}
...     context.upload_to_stage(metadata, "model_metadata.json",
...         write_function=lambda obj, path: json.dump(obj, open(path, 'w')))
...
...     # Predictions with basic metadata to table (for analytics)
...     snowpark_df = context.session.create_dataframe(results_df)
...     snowpark_df.write.mode("append").save_as_table("sklearn_predictions")
...
...     return predictions
...
>>> inference = ManyModelInference(
...     predict_with_sklearn, "models_stage", "training_run_v1", serde=ScikitLearnSerde())
Copy
run(*, partition_by: str, snowpark_dataframe: DataFrame, run_id: str, on_existing_artifacts: str = 'error', execution_options: ExecutionOptions | None = None) ManyModelRun

Execute distributed model inference across data partitions.

Loads previously trained models from the specified training run and runs inference on each partition of the new data. Models are automatically loaded and passed to the user-provided inference function.

Parameters:
  • partition_by (str) – Column name to partition the DataFrame by. Must match the partitioning used during training to ensure correct model loading.

  • snowpark_dataframe (snowpark.DataFrame) – DataFrame containing data for inference.

  • run_id (str) – Unique identifier for this inference run. Each inference run should have a unique ID to avoid confusion and enable proper tracking.

  • on_existing_artifacts (str, optional) –

    Action for existing artifacts. Defaults to “error”.

    • "error": Raises error if inference artifacts already exist

    • "overwrite": Replaces existing inference artifacts

  • execution_options (ExecutionOptions, optional) – Configuration for distributed execution.

Returns:

Enhanced run handle with model-specific capabilities for accessing inference results and monitoring execution progress.

Return type:

ManyModelRun

Note

While inference loads existing models (read-only operation), it still creates execution logs and result artifacts. Using unique run_id values helps prevent accidental overwrites and improves traceability.

Example

>>> inference = ManyModelInference(my_predict_func, "models_stage", "training_run_v1")
>>> run = inference.run(
...     partition_by="store_id",
...     snowpark_dataframe=new_sales_data,
...     run_id="predictions_2024"
... )
>>> run.wait()
Copy

Model Serialization

class snowflake.ml.modeling.distributors.many_model.ModelSerde

Bases: ABC

Base class for model serialization and deserialization.

This class defines the interface for saving and loading models. Implement this class only when you need to support a custom model format not covered by the built-in serializers. For common frameworks, use one of the provided serializers: PickleSerde, TorchSerde, or TensorFlowSerde.

abstract property filename: str

The filename to use for the serialized model.

abstractmethod read(file_path: str) Any

Deserialize a model from the given file path.

Parameters:

file_path – Full path to the serialized model

Returns:

The deserialized model object

abstractmethod write(model: Any, file_path: str) None

Serialize a model to the given file path.

Parameters:
  • model – The model object to serialize

  • file_path – Full path where the model should be saved

class snowflake.ml.modeling.distributors.many_model.PickleSerde(filename: str = 'model.pkl')

Bases: ModelSerde

Serializer for scikit-learn, XGBoost, and other pickle-compatible models.

Parameters:

filename (str) – The filename for the serialized model. Defaults to “model.pkl”.

__init__(filename: str = 'model.pkl')
class snowflake.ml.modeling.distributors.many_model.TorchSerde(filename: str = 'model.pt', weights_only: bool = False)

Bases: ModelSerde

Serializer for PyTorch models.

Parameters:
  • filename (str) – The filename for the serialized model. Defaults to “model.pt”.

  • weights_only (bool) – If True, saves only the model’s weights. Defaults to False.

__init__(filename: str = 'model.pt', weights_only: bool = False)
class snowflake.ml.modeling.distributors.many_model.TensorFlowSerde(filename: str = 'model.keras')

Bases: ModelSerde

Serializer for TensorFlow/Keras models.

Parameters:

filename (str) – The filename for the serialized model. Defaults to “model.keras”.

__init__(filename: str = 'model.keras')