Many Model Training and Inference¶
Classes¶
- class snowflake.ml.modeling.distributors.many_model.many_model_training.ManyModelTraining(train_func: Callable[[DataConnector, PartitionContext], Any], stage_name: str, serde: ModelSerde | None = None)¶
Bases:
DPF
Specialized distributed model training across data partitions.
ManyModelTraining
extends DPF specifically for machine learning scenarios where you need to train separate models on different data partitions. It automatically handles model serialization, artifact management, and provides framework-specific serialization options.- Key features:
Automatic model serialization and saving to Snowflake stages
Built-in serializers for popular frameworks (XGBoost, scikit-learn, PyTorch, TensorFlow) or implement
ModelSerde
for custom formatsConsistent artifact naming and organization
Progress monitoring
Seamless integration with model inference workflows
The user-provided training function should return a trained model object. The class automatically handles saving the model using the specified serialization method.
- Typical workflow:
Train models across data partitions using
ManyModelTraining
Monitor training progress and wait for completion
Run inference on new data using
ManyModelInference
with the training run IDAccess results from both training artifacts and inference outputs
Optional: Restore previous DPF run (useful for retrieving previous results)
Example
Complete end-to-end many model workflow:
1. Define Training Function
>>> def train_xgboost_model(data_connector, context): ... df = data_connector.to_pandas() ... X = df[['feature1', 'feature2', 'feature3']] ... y = df['target'] ... ... model = XGBRegressor() ... model.fit(X, y) ... return model # Automatically saved as model.pkl
2. Train Models Across Partitions
>>> trainer = ManyModelTraining(train_xgboost_model, "model_stage") >>> training_run = trainer.run( ... partition_by="region", ... snowpark_dataframe=sales_data, ... run_id="regional_models_v1" ... ) >>> training_run.wait()
3. Run Inference with Trained Models
>>> def predict_sales(data_connector, model, context): ... df = data_connector.to_pandas() ... X = df[['feature1', 'feature2', 'feature3']] ... predictions = model.predict(X) ... ... results = df.copy() ... results['predictions'] = predictions ... results['region'] = context.partition_id ... ... # Two persistence strategies (choose one or both based on your needs): ... ... # Strategy 1: Stage artifacts - for framework management and debugging ... context.upload_to_stage(results, "predictions.csv", ... write_function=lambda df, path: df.to_csv(path, index=False)) ... ... # Strategy 2: Snowflake table - for immediate downstream consumption ... predictions_df = context.session.create_dataframe(results) ... predictions_df.write.mode("append").save_as_table("sales_predictions") ... ... return predictions
>>> from snowflake.ml.modeling.distributors.many_model import ManyModelInference >>> inference = ManyModelInference(predict_sales, "model_stage", "regional_models_v1") >>> inference_run = inference.run( ... partition_by="region", ... snowpark_dataframe=new_sales_data, ... run_id="predictions_2024" ... ) >>> inference_run.wait()
4. Access Results (Two Different Approaches)
>>> # Approach 1: Access stage artifacts (for detailed analysis and debugging) >>> for partition_id in training_run.partition_details: ... trained_model = training_run.get_model(partition_id) ... predictions_csv = inference_run.partition_details[partition_id].stage_artifacts_manager.get( ... "predictions.csv")
>>> # Approach 2: Query Snowflake table (for analytics and aggregations) >>> from snowflake.runtime.utils.session_utils import get_session >>> session = get_session() >>> all_predictions = session.table("sales_predictions") >>> summary_stats = all_predictions.group_by("region").agg( ... {"predictions": "avg", "predictions": "count"}).collect()
See also
ManyModelInference
: For running inference with trained modelsDPF
: Base class for general distributed processingInitialize ManyModelTraining.
- Parameters:
train_func (Callable[[DataConnector, PartitionContext], Any]) – User function that takes (DataConnector, PartitionContext) and returns a trained model. Both parameters are automatically provided by the framework. Follows the same signature pattern as regular DPF functions for consistency.
stage_name (str) – Stage name for storing artifacts
serde (Optional[ModelSerde]) – ModelSerde instance for handling model serialization and deserialization. Built-in options: PickleSerde (XGBoost, scikit-learn), TorchSerde, TensorFlowSerde, or implement ModelSerde for custom formats. Defaults to PickleSerde.
Examples
Default pickle serialization:
>>> def train_xgb_model(data_connector, context): ... df = data_connector.to_pandas() ... X = df[['feature1', 'feature2']] ... y = df['target'] ... ... from xgboost import XGBRegressor ... model = XGBRegressor() ... model.fit(X, y) ... return model # Automatically saved as model.pkl ... >>> trainer = ManyModelTraining(train_xgb_model, "models_stage")
PyTorch model training:
>>> def train_pytorch_model(data_connector, context): ... import torch ... import torch.nn as nn ... ... df = data_connector.to_pandas() ... # ... prepare data for PyTorch ... ... ... model = nn.Sequential(nn.Linear(10, 1)) ... # ... training logic ... ... return model # Automatically saved as model.pth ... >>> from snowflake.ml.modeling.distributors.many_model import TorchSerde >>> trainer = ManyModelTraining(train_pytorch_model, "models_stage", serde=TorchSerde())
TensorFlow model training:
>>> def train_tf_model(data_connector, context): ... import tensorflow as tf ... ... df = data_connector.to_pandas() ... # ... prepare data for TensorFlow ... ... ... model = tf.keras.Sequential([tf.keras.layers.Dense(1)]) ... # ... training logic ... ... return model # Automatically saved as model.h5 ... >>> from snowflake.ml.modeling.distributors.many_model import TensorFlowSerde >>> trainer = ManyModelTraining(train_tf_model, "models_stage", serde=TensorFlowSerde())
Custom serialization for specialized model formats:
>>> from snowflake.ml.modeling.distributors.many_model import ModelSerde >>> import json >>> >>> class ScikitLearnSerde(ModelSerde): ... '''Custom serializer for scikit-learn models with metadata''' ... ... @property ... def filename(self) -> str: ... return "sklearn_model.joblib" ... ... def write(self, model, file_path: str) -> None: ... import joblib ... # Save model with metadata ... model_data = { ... 'model': model, ... 'feature_names': getattr(model, 'feature_names_in_', None), ... 'model_type': type(model).__name__ ... } ... joblib.dump(model_data, file_path) ... ... def read(self, file_path: str): ... import joblib ... return joblib.load(file_path) >>> >>> def train_sklearn_model(data_connector, context): ... from sklearn.ensemble import RandomForestRegressor ... df = data_connector.to_pandas() ... X, y = df[['feature1', 'feature2']], df['target'] ... ... model = RandomForestRegressor() ... model.fit(X, y) ... return model # Automatically saved with metadata ... >>> trainer = ManyModelTraining(train_sklearn_model, "models_stage", serde=ScikitLearnSerde())
- run(*, partition_by: str, snowpark_dataframe: DataFrame, run_id: str, on_existing_artifacts: str = 'error', execution_options: ExecutionOptions | None = None) ManyModelRun ¶
Execute distributed model training across data partitions.
Trains separate models on each partition of the data and automatically saves them to the specified Snowflake stage using the configured serialization method.
- Parameters:
partition_by (str) – Column name to partition the DataFrame by. Each unique value creates a separate partition with its own trained model.
snowpark_dataframe (snowpark.DataFrame) – DataFrame containing training data.
run_id (str) – Unique identifier for this training run. Used to organize model artifacts and enable future inference runs.
on_existing_artifacts (str, optional) –
Action for existing artifacts. Defaults to “error”.
"error"
: Raises error if training artifacts already exist"overwrite"
: Replaces existing training artifacts
execution_options (ExecutionOptions, optional) – Configuration for distributed execution.
- Returns:
Enhanced run handle with model-specific capabilities for accessing trained models and monitoring training progress.
- Return type:
ManyModelRun
Example
>>> trainer = ManyModelTraining(my_train_func, "models_stage") >>> run = trainer.run( ... partition_by="store_id", ... snowpark_dataframe=sales_data, ... run_id="store_models_2024" ... ) >>> run.wait() >>> # Models are now saved and can be used for inference
- class snowflake.ml.modeling.distributors.many_model.many_model_inference.ManyModelInference(inference_func: Callable[[DataConnector, Any, PartitionContext], Any], stage_name: str, training_run_id: str, serde: ModelSerde | None = None)¶
Bases:
DPF
Specialized distributed model inference across data partitions.
ManyModelInference
extends DPF for machine learning scenarios where you need to run inference using previously trained models on different data partitions. It automatically handles model loading, artifact retrieval, and provides framework-specific deserialization.- Key features:
Automatic model loading from previous training runs
Built-in deserializers for popular frameworks (XGBoost, scikit-learn, PyTorch, TensorFlow) or use the same
ModelSerde
from trainingSeamless integration with
ManyModelTraining
workflowsBuilt-in error handling for model loading
Consistent artifact naming and retrieval
The user-provided inference function receives a pre-loaded model object along with the data. The class automatically handles loading the correct model for each partition.
Note
Your inference function signature should be
(data_connector, model, context)
- all three parameters are automatically provided by the framework. You don’t pass these values yourself.Note
This class is typically used as step 3 in the many model workflow after training models with
ManyModelTraining
. SeeManyModelTraining
for the complete workflow.Example
Basic inference function where each partition gets only its own data subset and the system automatically loads the corresponding trained model for that partition:
>>> def predict_with_model(data_connector, model, context): ... df = data_connector.to_pandas() ... X = df[['feature1', 'feature2', 'feature3']] ... ... predictions = model.predict(X) ... results = df.copy() ... results['predictions'] = predictions ... ... context.upload_to_stage(results, "predictions.csv") ... return results ... >>> inference = ManyModelInference( ... predict_with_model, ... "model_stage", ... training_run_id="regional_models_v1" ... ) >>> run = inference.run( ... partition_by="region", ... snowpark_dataframe=new_data, ... run_id="predictions_2024" ... ) >>> run.wait()
Initialize ManyModelInference.
- Parameters:
inference_func (Callable[[DataConnector, Any, PartitionContext], Any]) – User function that takes (DataConnector, model, PartitionContext) and returns results. All three parameters are automatically provided by the framework. Follows the same signature pattern as regular DPF functions for consistency.
stage_name (str) – Stage name where training artifacts are stored
training_run_id (str) – Run ID from the training phase to load models from
serde (Optional[ModelSerde]) – ModelSerde instance for handling model deserialization. Built-in options: PickleSerde (XGBoost, scikit-learn), TorchSerde, TensorFlowSerde, or use the same ModelSerde implementation from training. Defaults to PickleSerde.
Examples
Default pickle deserialization for XGBoost models:
>>> def predict_with_xgb(data_connector, model, context): ... df = data_connector.to_pandas() ... X = df[['feature1', 'feature2']] ... ... predictions = model.predict(X) ... results = df.copy() ... results['predictions'] = predictions ... results['partition_id'] = context.partition_id ... ... # Two persistence strategies (choose based on your use case): ... ... # Strategy 1: Stage artifacts - for framework management and retrieval ... context.upload_to_stage(results, "predictions.csv", ... write_function=lambda df, path: df.to_csv(path, index=False)) ... ... # Strategy 2: Snowflake table - for immediate analytics and dashboards ... results_df = context.session.create_dataframe(results) ... results_df.write.mode("append").save_as_table("my_predictions_table") ... ... return predictions ... >>> inference = ManyModelInference(predict_with_xgb, "models_stage", "training_run_v1")
PyTorch model inference:
>>> def predict_with_pytorch(data_connector, model, context): ... import torch ... ... df = data_connector.to_pandas() ... # ... prepare data for PyTorch ... ... ... model.eval() ... with torch.no_grad(): ... predictions = model(torch.tensor(X_test, dtype=torch.float32)) ... ... context.upload_to_stage(predictions.numpy(), "predictions.pkl") ... return predictions ... >>> from snowflake.ml.modeling.distributors.many_model import TorchSerde >>> inference = ManyModelInference( ... predict_with_pytorch, "models_stage", "training_run_v1", serde=TorchSerde())
TensorFlow model inference:
>>> def predict_with_tf(data_connector, model, context): ... import tensorflow as tf ... ... df = data_connector.to_pandas() ... # ... prepare data for TensorFlow ... ... ... predictions = model.predict(X_test) ... context.upload_to_stage(predictions, "predictions.pkl") ... return predictions ... >>> from snowflake.ml.modeling.distributors.many_model import TensorFlowSerde >>> inference = ManyModelInference( ... predict_with_tf, "models_stage", "training_run_v1", serde=TensorFlowSerde())
Custom deserialization for specialized model formats:
>>> # Using the same ScikitLearnSerde from training >>> class ScikitLearnSerde(ModelSerde): ... @property ... def filename(self) -> str: ... return "sklearn_model.joblib" ... ... def write(self, model, file_path: str) -> None: ... import joblib ... model_data = { ... 'model': model, ... 'feature_names': getattr(model, 'feature_names_in_', None), ... 'model_type': type(model).__name__ ... } ... joblib.dump(model_data, file_path) ... ... def read(self, file_path: str): ... import joblib ... return joblib.load(file_path) >>> >>> def predict_with_sklearn(data_connector, model_data, context): ... df = data_connector.to_pandas() ... X = df[['feature1', 'feature2']] ... ... # Access model and metadata ... model = model_data['model'] ... model_type = model_data['model_type'] ... ... predictions = model.predict(X) ... ... # Create results with model metadata ... results_df = df.copy() ... results_df['predictions'] = predictions ... results_df['model_type'] = model_type ... results_df['partition_id'] = context.partition_id ... ... # Hybrid persistence strategy - different data types to different locations: ... ... # Detailed metadata to stage (for debugging and audit) ... metadata = {'model_type': model_type, 'feature_count': len(X.columns)} ... context.upload_to_stage(metadata, "model_metadata.json", ... write_function=lambda obj, path: json.dump(obj, open(path, 'w'))) ... ... # Predictions with basic metadata to table (for analytics) ... snowpark_df = context.session.create_dataframe(results_df) ... snowpark_df.write.mode("append").save_as_table("sklearn_predictions") ... ... return predictions ... >>> inference = ManyModelInference( ... predict_with_sklearn, "models_stage", "training_run_v1", serde=ScikitLearnSerde())
- run(*, partition_by: str, snowpark_dataframe: DataFrame, run_id: str, on_existing_artifacts: str = 'error', execution_options: ExecutionOptions | None = None) ManyModelRun ¶
Execute distributed model inference across data partitions.
Loads previously trained models from the specified training run and runs inference on each partition of the new data. Models are automatically loaded and passed to the user-provided inference function.
- Parameters:
partition_by (str) – Column name to partition the DataFrame by. Must match the partitioning used during training to ensure correct model loading.
snowpark_dataframe (snowpark.DataFrame) – DataFrame containing data for inference.
run_id (str) – Unique identifier for this inference run. Each inference run should have a unique ID to avoid confusion and enable proper tracking.
on_existing_artifacts (str, optional) –
Action for existing artifacts. Defaults to “error”.
"error"
: Raises error if inference artifacts already exist"overwrite"
: Replaces existing inference artifacts
execution_options (ExecutionOptions, optional) – Configuration for distributed execution.
- Returns:
Enhanced run handle with model-specific capabilities for accessing inference results and monitoring execution progress.
- Return type:
ManyModelRun
Note
While inference loads existing models (read-only operation), it still creates execution logs and result artifacts. Using unique
run_id
values helps prevent accidental overwrites and improves traceability.Example
>>> inference = ManyModelInference(my_predict_func, "models_stage", "training_run_v1") >>> run = inference.run( ... partition_by="store_id", ... snowpark_dataframe=new_sales_data, ... run_id="predictions_2024" ... ) >>> run.wait()
Model Serialization¶
- class snowflake.ml.modeling.distributors.many_model.ModelSerde¶
Bases:
ABC
Base class for model serialization and deserialization.
This class defines the interface for saving and loading models. Implement this class only when you need to support a custom model format not covered by the built-in serializers. For common frameworks, use one of the provided serializers:
PickleSerde
,TorchSerde
, orTensorFlowSerde
.- abstract property filename: str¶
The filename to use for the serialized model.
- abstractmethod read(file_path: str) Any ¶
Deserialize a model from the given file path.
- Parameters:
file_path – Full path to the serialized model
- Returns:
The deserialized model object
- abstractmethod write(model: Any, file_path: str) None ¶
Serialize a model to the given file path.
- Parameters:
model – The model object to serialize
file_path – Full path where the model should be saved
- class snowflake.ml.modeling.distributors.many_model.PickleSerde(filename: str = 'model.pkl')¶
Bases:
ModelSerde
Serializer for scikit-learn, XGBoost, and other pickle-compatible models.
- Parameters:
filename (str) – The filename for the serialized model. Defaults to “model.pkl”.
- __init__(filename: str = 'model.pkl')¶
- class snowflake.ml.modeling.distributors.many_model.TorchSerde(filename: str = 'model.pt', weights_only: bool = False)¶
Bases:
ModelSerde
Serializer for PyTorch models.
- Parameters:
filename (str) – The filename for the serialized model. Defaults to “model.pt”.
weights_only (bool) – If True, saves only the model’s weights. Defaults to False.
- __init__(filename: str = 'model.pt', weights_only: bool = False)¶
- class snowflake.ml.modeling.distributors.many_model.TensorFlowSerde(filename: str = 'model.keras')¶
Bases:
ModelSerde
Serializer for TensorFlow/Keras models.
- Parameters:
filename (str) – The filename for the serialized model. Defaults to “model.keras”.
- __init__(filename: str = 'model.keras')¶