Many Model Training and Inference¶
Classes¶
- class snowflake.ml.modeling.distributors.many_model.many_model_training.ManyModelTraining(train_func: Callable[[DataConnector, PartitionContext], Any], stage_name: str, serde: ModelSerde | None = None)¶
Bases:
DPFSpecialized distributed model training across data partitions.
ManyModelTrainingextends DPF specifically for machine learning scenarios where you need to train separate models on different data partitions. It automatically handles model serialization, artifact management, and provides framework-specific serialization options.- Key features:
Automatic model serialization and saving to Snowflake stages
Built-in serializers for popular frameworks (XGBoost, scikit-learn, PyTorch, TensorFlow) or implement
ModelSerdefor custom formatsConsistent artifact naming and organization
Progress monitoring
Seamless integration with model inference workflows
The user-provided training function should return a trained model object. The class automatically handles saving the model using the specified serialization method.
- Typical workflow:
Train models across data partitions using
ManyModelTrainingMonitor training progress and wait for completion
Run inference on new data using
ManyModelInferencewith the training run IDAccess results from both training artifacts and inference outputs
Optional: Restore previous DPF run (useful for retrieving previous results)
Example
Complete end-to-end many model workflow:
1. Define Training Function
2. Train Models Across Partitions
3. Run Inference with Trained Models
4. Access Results (Two Different Approaches)
See also
ManyModelInference: For running inference with trained modelsDPF: Base class for general distributed processingInitialize ManyModelTraining.
- Parameters:
train_func (Callable[[DataConnector, PartitionContext], Any]) – User function that takes (DataConnector, PartitionContext) and returns a trained model. Both parameters are automatically provided by the framework. Follows the same signature pattern as regular DPF functions for consistency.
stage_name (str) – Stage name for storing artifacts
serde (Optional[ModelSerde]) – ModelSerde instance for handling model serialization and deserialization. Built-in options: PickleSerde (XGBoost, scikit-learn), TorchSerde, TensorFlowSerde, or implement ModelSerde for custom formats. Defaults to PickleSerde.
Examples
Default pickle serialization:
PyTorch model training:
TensorFlow model training:
Custom serialization for specialized model formats:
- run(*, partition_by: str, snowpark_dataframe: DataFrame, run_id: str, on_existing_artifacts: str = 'error', execution_options: ExecutionOptions | None = None) ManyModelRun¶
Execute distributed model training across data partitions.
Trains separate models on each partition of the data and automatically saves them to the specified Snowflake stage using the configured serialization method.
- Parameters:
partition_by (str) – Column name to partition the DataFrame by. Each unique value creates a separate partition with its own trained model.
snowpark_dataframe (snowpark.DataFrame) – DataFrame containing training data.
run_id (str) – Unique identifier for this training run. Used to organize model artifacts and enable future inference runs.
on_existing_artifacts (str, optional) –
Action for existing artifacts. Defaults to “error”.
"error": Raises error if training artifacts already exist"overwrite": Replaces existing training artifacts
execution_options (ExecutionOptions, optional) – Configuration for distributed execution.
- Returns:
Enhanced run handle with model-specific capabilities for accessing trained models and monitoring training progress.
- Return type:
ManyModelRun
Example
- class snowflake.ml.modeling.distributors.many_model.many_model_inference.ManyModelInference(inference_func: Callable[[DataConnector, Any, PartitionContext], Any], stage_name: str, training_run_id: str, serde: ModelSerde | None = None)¶
Bases:
DPFSpecialized distributed model inference across data partitions.
ManyModelInferenceextends DPF for machine learning scenarios where you need to run inference using previously trained models on different data partitions. It automatically handles model loading, artifact retrieval, and provides framework-specific deserialization.- Key features:
Automatic model loading from previous training runs
Built-in deserializers for popular frameworks (XGBoost, scikit-learn, PyTorch, TensorFlow) or use the same
ModelSerdefrom trainingSeamless integration with
ManyModelTrainingworkflowsBuilt-in error handling for model loading
Consistent artifact naming and retrieval
The user-provided inference function receives a pre-loaded model object along with the data. The class automatically handles loading the correct model for each partition.
Note
Your inference function signature should be
(data_connector, model, context)- all three parameters are automatically provided by the framework. You don’t pass these values yourself.Note
This class is typically used as step 3 in the many model workflow after training models with
ManyModelTraining. SeeManyModelTrainingfor the complete workflow.Example
Basic inference function where each partition gets only its own data subset and the system automatically loads the corresponding trained model for that partition:
Initialize ManyModelInference.
- Parameters:
inference_func (Callable[[DataConnector, Any, PartitionContext], Any]) – User function that takes (DataConnector, model, PartitionContext) and returns results. All three parameters are automatically provided by the framework. Follows the same signature pattern as regular DPF functions for consistency.
stage_name (str) – Stage name where training artifacts are stored
training_run_id (str) – Run ID from the training phase to load models from
serde (Optional[ModelSerde]) – ModelSerde instance for handling model deserialization. Built-in options: PickleSerde (XGBoost, scikit-learn), TorchSerde, TensorFlowSerde, or use the same ModelSerde implementation from training. Defaults to PickleSerde.
Examples
Default pickle deserialization for XGBoost models:
PyTorch model inference:
TensorFlow model inference:
Custom deserialization for specialized model formats:
- run(*, partition_by: str, snowpark_dataframe: DataFrame, run_id: str, on_existing_artifacts: str = 'error', execution_options: ExecutionOptions | None = None) ManyModelRun¶
Execute distributed model inference across data partitions.
Loads previously trained models from the specified training run and runs inference on each partition of the new data. Models are automatically loaded and passed to the user-provided inference function.
- Parameters:
partition_by (str) – Column name to partition the DataFrame by. Must match the partitioning used during training to ensure correct model loading.
snowpark_dataframe (snowpark.DataFrame) – DataFrame containing data for inference.
run_id (str) – Unique identifier for this inference run. Each inference run should have a unique ID to avoid confusion and enable proper tracking.
on_existing_artifacts (str, optional) –
Action for existing artifacts. Defaults to “error”.
"error": Raises error if inference artifacts already exist"overwrite": Replaces existing inference artifacts
execution_options (ExecutionOptions, optional) – Configuration for distributed execution.
- Returns:
Enhanced run handle with model-specific capabilities for accessing inference results and monitoring execution progress.
- Return type:
ManyModelRun
Note
While inference loads existing models (read-only operation), it still creates execution logs and result artifacts. Using unique
run_idvalues helps prevent accidental overwrites and improves traceability.Example
Model Serialization¶
- class snowflake.ml.modeling.distributors.many_model.ModelSerde¶
Bases:
ABCBase class for model serialization and deserialization.
This class defines the interface for saving and loading models. Implement this class only when you need to support a custom model format not covered by the built-in serializers. For common frameworks, use one of the provided serializers:
PickleSerde,TorchSerde, orTensorFlowSerde.- abstract property filename: str¶
The filename to use for the serialized model.
- abstractmethod read(file_path: str) Any¶
Deserialize a model from the given file path.
- Parameters:
file_path – Full path to the serialized model
- Returns:
The deserialized model object
- abstractmethod write(model: Any, file_path: str) None¶
Serialize a model to the given file path.
- Parameters:
model – The model object to serialize
file_path – Full path where the model should be saved
- class snowflake.ml.modeling.distributors.many_model.PickleSerde(filename: str = 'model.pkl')¶
Bases:
ModelSerdeSerializer for scikit-learn, XGBoost, and other pickle-compatible models.
- Parameters:
filename (str) – The filename for the serialized model. Defaults to “model.pkl”.
- __init__(filename: str = 'model.pkl')¶
- class snowflake.ml.modeling.distributors.many_model.TorchSerde(filename: str = 'model.pt', weights_only: bool = False)¶
Bases:
ModelSerdeSerializer for PyTorch models.
- Parameters:
filename (str) – The filename for the serialized model. Defaults to “model.pt”.
weights_only (bool) – If True, saves only the model’s weights. Defaults to False.
- __init__(filename: str = 'model.pt', weights_only: bool = False)¶
- class snowflake.ml.modeling.distributors.many_model.TensorFlowSerde(filename: str = 'model.keras')¶
Bases:
ModelSerdeSerializer for TensorFlow/Keras models.
- Parameters:
filename (str) – The filename for the serialized model. Defaults to “model.keras”.
- __init__(filename: str = 'model.keras')¶