Pytorch Distributor

Classes

class snowflake.ml.modeling.distributors.pytorch.pytorch_trainer.PyTorchDistributor(train_func: Callable, scaling_config: PyTorchScalingConfig | None = None)

Bases: BaseRunnable

PyTorchDistributor enables users to run distributed training with PyTorch on ContainerRuntime cluster.

PyTorchDistributor is responsible for setting up the environment, scheduling the training processes, managing the communication between the processes, and collecting the results.

Initialize the PyTorchDistributor.

Parameters:
  • train_func (Callable) – A callable object that defines the training logic to be executed.

  • scaling_config (PyTorchScalingConfig) – Configuration for scaling and other settings related to the training job.

run(*, dataset_map: Dict[str, DataConnector] | None = None, hyper_params: Dict[str, str] | None = None, artifact_stage_location: str | None = None) PyTorchDistributorResult

Runs the distributed training job.

Parameters:
  • dataset_map (Optional[Dict[str, DataConnector]], optional) – A mapping of dataset names to their corresponding DataConnector instances. Defaults to None.

  • hyper_params (Optional[Dict[str, str]], optional) – A dictionary of hyperparameters to be used during training. Defaults to None.

Returns:

The result of the training process, which may include the trained model or

other outputs.

Return type:

PyTorchDistributorResult

Raises:

RuntimeError – If there is an error during the execution of the training process.

class snowflake.ml.modeling.distributors.pytorch.pytorch_trainer.PyTorchScalingConfig(num_nodes: int, num_workers_per_node: int, resource_requirements_per_worker: WorkerResourceConfig)

Bases: BaseScalingConfig

Scaling configuration for training PyTorch models.

This class defines the scaling configuration for a PyTorch training job, including the number of nodes, the number of workers per node, and the resource requirements for each worker.

num_nodes

The number of nodes to use for training.

Type:

int

num_workers_per_node

The number of workers to use per node.

Type:

int

resource_requirements_per_worker

The resource requirements for each worker, such as the number of CPUs and GPUs.

Type:

WorkerResourceConfig

Initialize the PyTorchScalingConfig.

Parameters:
  • num_nodes (int) – The number of nodes to use for training.

  • num_workers_per_node (int) – The number of workers to use per node.

  • resource_requirements_per_worker (WorkerResourceConfig) – The resource requirements for each worker, such as the number of CPUs and GPUs.

class snowflake.ml.modeling.distributors.pytorch.scaling_config.WorkerResourceConfig(num_cpus: int = 1, num_gpus: int = 0)

Bases: object

Resources requirements per worker.

This class defines the resource requirements for each worker in a distributed training job, specifying the number of CPU and GPU resources to allocate.

num_cpus

The number of CPU cores to reserve for each worker.

Type:

int

num_gpus

The number of GPUs to reserve for each worker. Default is 0, indicating no GPUs are reserved.

Type:

int

Initialize the WorkerResourceConfig.

Parameters:
  • num_cpus (int) – The number of CPU cores to reserve for each worker.

  • num_gpus (int) – The number of GPUs to reserve for each worker. Default is 0, indicating no GPUs are reserved.

class snowflake.ml.modeling.distributors.pytorch.context.Context(*args, **kwargs)

Bases: Protocol

Context for setting up the PyTorch distributed environment for training scripts.

Context defines the necessary methods to manage and retrieve information about the distributed training environment, including worker and node ranks, world size, and backend configurations.

Definitions:

Node: A physical instance or a container. Worker: A worker process in the context of distributed training. WorkerGroup: The set of workers that execute the same function (e.g., trainers). LocalWorkerGroup: A subset of the workers in the worker group running on the same node. RANK: The rank of the worker within a worker group. WORLD_SIZE: The total number of workers in a worker group. LOCAL_RANK: The rank of the worker within a local worker group. LOCAL_WORLD_SIZE: The size of the local worker group. rdzv_id: An ID that uniquely identifies the worker group for a job. This ID is used by each node to join as

a member of a particular worker group.

rdzv_backend: The backend of the rendezvous (e.g., c10d). This is typically a strongly consistent

key-value store.

rdzv_endpoint: The rendezvous backend endpoint; usually in the form <host>:<port>.

get_artifact_manager() ArtifactManager | None

Return artifact manager object to help sync checkpoint files to shared persistence storage. A stage needs to be specified as input argument to PyTorchDistributor.run() method.

Returns:

An instance of the ArtifactManager if artifact_stage_location argument

is provided to PyTorchDistributor.run(). Otherwise, returns None.

Return type:

ArtifactManager

get_dataset_map() Dict[str, Type[DataConnector]] | None

Return dataset map provided to trainer.run(…) method.

Returns:

A dictionary mapping dataset names to their DataConnector types.

Return type:

Optional[Dict[str, Type[DataConnector]]]

get_default_backend() str

Return default backend selected by MCE.

Returns:

The default backend being used for distributed training.

Return type:

str

get_hyper_params() Dict[str, str] | None

Return hyperparameter map provided to trainer.run(…) method.

Returns:

A dictionary mapping hyperparameter names to their values.

Return type:

Optional[Dict[str, str]]

get_local_rank() int

Return the local rank for the current worker.

Local rank is a unique local ID for a worker (or process) running on the current node.

For example, if training is running on 2 nodes (servers) each with 4 GPUs, then local rank for workers(or processes) running on node 0 will be [0, 1, 2, 3] and similarly four workers(or processes) running on node 1 will have local_rank [0, 1, 2, 3].

Returns:

The local rank of the current process.

Return type:

int

get_local_world_size() int

Return the number of workers running in the current node.

For example, if training is running on 2 nodes (servers) each with 4 GPUs, then local_world_size will be 4 for all processes on both nodes.

Returns:

The number of workers in the current node.

Return type:

int

get_master_addr() str

Return IP address of the master node.

This is typically the address of the node with node_rank 0.

Returns:

The IP address of the master node.

Return type:

str

get_master_port() int

Return port on master_addr that hosts the rendezvous server.

Returns:

The port number for rendezvous communication.

Return type:

int

get_metrics_reporter() MetricsReporter

Return the metric reporter object for logging training metrics. Only metrics reported by rank 0 are returned by the trainer.run() method.

Returns:

An instance of the MetricReporter interface.

Return type:

MetricReporter

get_model_dir() str

Return the path to a directory where the model should be saved. All the model artifacts written to this directory will be preserved after the training job.

Returns:

The path to the model directory.

Return type:

str

get_node_rank() int

Return the rank of the current node across all nodes.

Node rank is a unique ID given to each node to identify it uniquely across all nodes in the world.

For example, if training is running on 2 nodes (servers) each with 4 GPUs, then node ranks will be [0, 1] respectively.

Returns:

The rank of the current node.

Return type:

int

get_rank() int

Return the rank of the current process across all processes.

Rank is the unique ID given to a process to identify it uniquely across the world. It should be a number between 0 and world_size - 1. Some frameworks aslo call it world_rank, to distinguish it from local_rank. For example, if training is running on 2 nodes (servers) each with 4 GPUs, then the ranks will be [0, 1, 2, 3, 4, 5, 6, 7], i.e., from 0 to world_size - 1.

Returns:

The rank of the current process.

Return type:

int

get_supported_backends() List[str]

Return list of supported backends by MCE.

Returns:

A list containing names of supported backends.

Return type:

List[str]

get_world_size() int

Return the number of workers(or processes) participating in the job.

For example, if training is running on 2 nodes (servers) each with 4 GPUs, then the world size is 8 (2 nodes * 4 GPUs per node). Usually, each GPU corresponds to a training process.

Returns:

The total number of workers in the job.

Return type:

int

Functions

snowflake.ml.modeling.distributors.pytorch.context.get_context() Context

Fetches the context object that contains the worker specific runtime information.

Returns:

An instance of the Context interface that provides methods for managing the distributed training environment.

Return type:

Context

Raises:

RuntimeError – If the PyTorch context is not available.