Distributed training¶
The Snowflake Container Runtime provides a flexible training environment that you can use to train models on Snowflake’s infrastructure. You can use open source packages, or use Snowflake ML distributed trainers for multi-node and multi-device training.
Distributed trainers automatically scale your machine learning workloads across multiple nodes and GPUs. Snowflake distributors intelligently manage cluster resources without requiring complex configuration, making distributed training accessible and efficient.
Use standard open source libraries when you
Work with small datasets on single-node environments
Rapidly prototype and experiment with models
Lift and shift workflows without distributed requirements
Use Snowflake Distributed Trainers To:
Train models on datasets that are larger than the memory of a single compute node
Utilize multiple GPUs efficiently
Automatically leverage all compute multi-node MLJobs or scaled notebook clusters
Snowflake ML distributed training¶
Snowflake ML provides distributed trainers for popular machine learning frameworks, including XGBoost, LightGBM, and PyTorch. These trainers are optimized to run on Snowflake’s infrastructure and can automatically scale across multiple nodes and GPUs.
Automatic Resource Management - Snowflake automatically discovers and uses all available cluster resources
Simplified Setup - The Container Runtime environment is backed by a Ray cluster provided by Snowflake, with no user configuration required
Seamless Snowflake integration - Direct compatibility with Snowflake data connectors and stages
Optional scaling configs - Advanced users can fine-tune when needed
Data loading¶
For both open source and Snowflake distributed trainers, the most performant way to ingest data is with the Snowflake Data Connector:
Training methods¶
Open source training¶
Use standard open source libraries when you need maximum flexibility and control over your training process. With open source training, you directly use popular ML frameworks like XGBoost, LightGBM, and PyTorch with minimal modifications, while still benefiting from Snowflake’s infrastructure and data connectivity.
The following examples train a model with XGBoost and LightGBM.
To train with open source XGBoost, after loading data with the data connector, convert it into a pandas dataframe and use the XGB library directly:
Distributed training¶
The distributed XGBEstimator class has a similar API with a few key differences:
The XGBoost training parameters are passed to the
XGBEstimatorduring class initialization through the “params” parameter.The DataConnector object can be passed directly into the estimator’s
fitfunction, along with the input columns defining the features and the label column defining the target.You can provide a scaling configuration when instantiating the
XGBEstimatorclass. However, Snowflake defaults to using all available resources.
Evaluating the model¶
Models can be evaluated by passing an eval_set and using verbose_eval to print the evaluation data to the console. Additionally, inference can be done as a second step. The distributed estimator offers a predict method for convenience, but it will not do inference in a distributed fashion. We recommend converting the fit model into an OSS xgboost estimator after training in order to do inference and to log to the model registry.
Registering the model¶
To register the model to the Snowflake model registry, use the open source booster provided by estimator.get_booster and returned from estimator.fit. For more information, see XGBoost.
PyTorch¶
The Snowflake PyTorch Distributor natively supports Distributed Data Parallel models on the Snowflake backend. To use DDP on Snowflake, leverage open source PyTorch modules with a few Snowflake specific modifications:
Load data using the
ShardedDataConnectorto automatically shard data into the number of partitions that matches theworld_sizeof the distributed trainer. Callget_shardwithin a Snowflake training context to retrieve the shard associated with that worker process.Inside the training function, use the
contextobject to get process specific information like rank, local rank, and the data required for training.Save the model using the context’s
get_model_dirto find the location to store the model to. This will store the model locally for single node training, and sync the model to a Snowflake stage for distributed training. If no stage location is provided, your user stage will be used by default.
Load data¶
Train model¶
Retrieving the model¶
If you are using multi-node DDP, the model is automatically synchronized to a Snowflake stage as the shared persistent storage.
The following code gets the model from a stage. It uses the artifact_stage_location parameter to specify the location of the stage that stores the model artifact.
The function saved in the stage_location variable gets the location of the model in the stage after training completes. The model artifact is saved under "DB_NAME.SCHEMA_NAME.STAGE_NAME/model/{request_id}".