Distributed Partition Function (DPF)

Classes

class snowflake.ml.modeling.distributors.distributed_partition_function.dpf.DPF(func: Callable[[DataConnector, PartitionContext], None], stage_name: str)

Bases: object

Distributed Partition Function (DPF) for executing user-defined functions across data partitions.

DPF enables distributed processing by partitioning a Snowpark DataFrame and executing a user-provided function on each partition concurrently using Ray’s distributed computing framework. Results and artifacts are automatically stored in Snowflake stages for easy retrieval.

Key features:
  • Automatic data partitioning and distribution

  • Concurrent execution across Ray cluster

  • Built-in artifact management and persistence

  • Progress monitoring

  • Error handling and partial failure recovery

Typical workflow:
  1. Define your data processing function

  2. Create DPF instance with function and stage

  3. Execute distributed processing on partitioned data

  4. Monitor progress until completion

  5. Retrieve results from each partition

  6. Optional: Restore completed runs later

Example

Complete end-to-end workflow for distributed analytics:

1. Data Structure

sales_data is a Snowpark DataFrame with columns:
  • region: ‘North’, ‘South’, ‘East’, ‘West’

  • customer_id: unique customer identifiers

  • amount: transaction amounts

  • order_date: transaction dates

2. Define Processing Function

The data will be partitioned by region, with each partition containing all sales records for that specific region. The processing function receives this regional data subset and performs analytics on it:

>>> def process_sales_data(data_connector, context):
...     df = data_connector.to_pandas()
...     print(f"Processing {len(df)} records for region: {context.partition_id}")
...
...     # Perform region-specific analytics
...     summary = {
...         'region': context.partition_id,
...         'total_sales': df['amount'].sum(),
...         'avg_order_value': df['amount'].mean(),
...         'customer_count': df['customer_id'].nunique(),
...         'record_count': len(df)
...     }
...
...     # Save results to stage for later retrieval
...     import json
...     context.upload_to_stage(summary, "sales_summary.json",
...         write_function=lambda obj, path: json.dump(obj, open(path, 'w')))
Copy

3. Execute Distributed Processing

>>> dpf = DPF(process_sales_data, "analytics_stage")
>>> run = dpf.run(
...     partition_by="region",  # Creates separate partitions for North, South, East, West
...     snowpark_dataframe=sales_data,
...     run_id="regional_analytics_2024"
... )
Copy

4. Monitor Progress and Wait for Completion

>>> final_status = run.wait()  # Shows progress bar by default
>>> print(f"Job completed with status: {final_status}")
Copy

5. Retrieve Results from Each Partition

>>> if final_status == RunStatus.SUCCESS:
...     # Get results from each region
...     import json
...     all_results = []
...     for partition_id, details in run.partition_details.items():
...         summary = details.stage_artifacts_manager.get("sales_summary.json",
...             read_function=lambda path: json.load(open(path, 'r')))
...         all_results.append(summary)
...
...     # Combine results across all regions
...     total_sales = sum(r['total_sales'] for r in all_results)
...     total_customers = sum(r['customer_count'] for r in all_results)
... else:
...     # Handle failures - check logs for failed partitions
...     for partition_id, details in run.partition_details.items():
...         if details.status != "DONE":
...             error_logs = details.logs
Copy

6. Optional: Restore Results Later

>>> # Restore completed run from stage and access same results as above without re-running.
>>> restored_run = DPFRun.restore_from("regional_analytics_2024", "analytics_stage")
Copy

See also

ManyModelTraining:

Specialized DPF for distributed model training with automatic model serialization

ManyModelInference:

Specialized DPF for distributed model inference with automatic model loading

Initializes Distributed Partition Function(DPF)

Parameters:
  • func (Callable[[DataConnector, PartitionContext], None]) – A user-provided function to be executed. It accepts a DataConnector object and PartitionContext for a given partition.

  • stage_name (str) –

    The stage name (either fully qualified or not) used to store or retrieve training data.

    Examples:

    my_stage  # Unqualified stage name (uses the current session's database and schema)
    my_db.my_schema.my_stage  # Fully qualified stage name (specifies database, schema, and stage)
    
    Copy

run(*, partition_by: str, snowpark_dataframe: DataFrame, run_id: str, on_existing_artifacts: str = 'error', execution_options: ExecutionOptions | None = None) DPFRun

Executes the user-defined function on partitions of the Snowpark DataFrame in a distributed manner.

Partitions the DataFrame by the specified column and executes the function on each partition concurrently using Ray’s distributed computing framework.

Parameters:
  • partition_by (str) – Column name to partition the DataFrame by. Each unique value creates a separate partition processed independently.

  • snowpark_dataframe (snowpark.DataFrame) – DataFrame to be partitioned and processed. Must contain a single query without post-actions.

  • run_id (str) –

    Unique identifier for this execution run. This ID serves multiple purposes:

    • Creates a dedicated directory {stage_name}/{run_id}/ to organize all run artifacts

    • Enables tracking and management of execution state across the distributed system

    Best practices: Use descriptive names like experiment_2024_01_15 or model_v1_retrain

  • on_existing_artifacts (str, optional) –

    Action for existing artifacts. Defaults to “error”.

    • "error": Raises error if {stage_name}/{run_id}/ artifacts exist

    • "overwrite": Replaces existing {stage_name}/{run_id}/ artifacts

  • execution_options (ExecutionOptions, optional) –

    Configuration for distributed execution. Uses defaults if not provided. Available options:

    • use_head_node (bool): Whether the Ray head node participates in workload execution (True, default) or acts solely as coordinator (False)

    • loading_wh (str): Optional dedicated warehouse for loading data from Snowflake tables into the stage to accelerate data loading

Returns:

Run management object for monitoring status, canceling jobs, and accessing results.

Return type:

DPFRun

class snowflake.ml.modeling.distributors.distributed_partition_function.dpf_run.DPFRun(run_id: str, *, orchestrator_ref: ObjectRef, run_object_ref: ObjectRef)

Bases: object

A handle to a single execution of a Distributed Partition Function (DPF) run.

Note

Users do not create DPFRun instances directly. They are returned by DPF.run() or DPFRun.restore_from().

This object provides methods to monitor execution progress, manage the running job, and retrieve results from the distributed execution across partitions.

Key capabilities:
  • Monitor execution status and progress

  • Cancel running jobs

  • Access partition-level details and artifacts

  • Wait for completion with optional progress display

  • Restore completed runs from persistent storage

Example

>>> dpf = DPF(process_sales_data, "analytics_stage")
>>> run = dpf.run(partition_by="region", snowpark_dataframe=sales_data, run_id="analytics_2024")
>>> final_status = run.wait()  # Wait with progress bar
>>> if final_status == RunStatus.SUCCESS:
...     for partition_id, details in run.partition_details.items():
...         summary = details.stage_artifacts_manager.get("sales_summary.json")
Copy
cancel() None

Cancels the ongoing distributed execution.

Note

  • Partitions that have already completed may not be affected.

  • Partial results, logs, or other artifacts might remain after cancellation.

get_progress() Dict[str, List[SinglePartitionDetails]]

Returns the current progress of the distributed run.

Returns:

A dictionary grouping partition details by their current status.

property partition_details: Dict[str, SinglePartitionDetails]

Provides details for each partition processed.

Returns:

A dictionary mapping partition IDs to their respective execution details.

classmethod restore_from(run_id: str, stage_name: str) DPFRun

Restore a completed DPF run from its persisted state.

Parameters:
  • run_id (str) – The run ID of the execution to restore.

  • stage_name (str) – The stage name where the run artifacts are stored.

Returns:

A read-only DPFRun instance with access to the completed run’s results.

Return type:

DPFRun

Raises:

ValueError – If the run state cannot be restored from the specified location.

property run_id: str

The unique identifier for this execution run.

Returns:

The run ID specified when the DPF was executed.

Return type:

str

property status: RunStatus

Retrieves the overall status of the distributed execution.

This property aggregates the status across all partitions to provide a summary of the entire run.

Returns:

The current status of the execution.

wait(show_progress: bool = True) RunStatus

Wait for the DPF run to complete.

Parameters:

show_progress (bool, optional) – Whether to display a progress bar during execution. Defaults to True. Uses Streamlit UI if available, otherwise falls back to tqdm.

Returns:

The final status of the run (SUCCESS, FAILURE, PARTIAL_FAILURE, or CANCELLED).

Return type:

RunStatus

class snowflake.ml.modeling.distributors.distributed_partition_function.entities.ExecutionOptions(use_head_node: bool = True, loading_wh: str | None = None)

Bases: object

Configuration options for workload execution.

use_head_node

This option controls whether the head node participates in the workload execution or solely acts as a coordinator. If True, the Ray head node will execute user-provided functions alongside worker nodes. If False, only worker nodes will execute user-provided functions.

Type:

bool

loading_wh

A dedicated warehouse used for loading data from Snowflake tables into the snowflake stage. This is typically a larger warehouse to accelerate the data loading process. Once data loading is complete, session will switch back to use the original warehouse.

Type:

str | None

class snowflake.ml.modeling.distributors.distributed_partition_function.entities.SinglePartitionDetails(*, partition_id: str, status: PartitionStatus = PartitionStatus.PENDING, stage_path: str | None = None)

Bases: object

A class that encapsulates the details of a single model training job, including its status, model, and associated logs. It manages the loading of the model and training logs from a specified stage path, and provides safeguards in case of errors during these operations.

partition_id

The unique identifier for the partition of the job.

Type:

str

stage_path

The path where the model and logs are stored in the system. status (PartitionStatus): The current status of the partition (default is PENDING).

Type:

str

Raises:

RuntimeError – If loading the model or logs from the stage path fails due to system errors.

property logs: str
property partition_id: str
property stage_artifacts_manager: StageArtifactsManager

Returns an StageArtifactsManager instance for interacting with this partition’s saved artifacts.

property status: PartitionStatus
class snowflake.ml.modeling.distributors.distributed_partition_function.entities.RunStatus(value)

Bases: Enum

Enum representing the status of an overall distributed run.

SUCCESS

The run completed successfully.

Type:

str

IN_PROGRESS

The run is actively executing.

Type:

str

PARTIAL_FAILURE

The run completed, but some partitions failed.

Type:

str

FAILURE

The run failed completely.

Type:

str

CANCELLED

The run was cancelled.

Type:

str

class snowflake.ml.modeling.distributors.distributed_partition_function.entities.PartitionStatus(value)

Bases: str, Enum

Enum representing the possible status of a partition execution.

PENDING

The partition is waiting to be processed.

Type:

str

RUNNING

The partition is currently being processed.

Type:

str

FAILED

The partition has failed.

Type:

str

DONE

The partition has successfully completed.

Type:

str

CANCELLED

The partition has been cancelled, this happens when user explicitly invokes the cancel call.

Type:

str

INTERNAL_ERROR

The partition encountered an internal error during processing.

Type:

str

class snowflake.ml.modeling.distributors.distributed_partition_function.partition_context.PartitionContext(*, session: Session, partition_id: str, stage_path_prefix: str)

Bases: object

Provides context and utilities for partition-specific operations within DPF functions.

This class is automatically passed as the second argument to user-defined DPF functions, providing access to partition metadata and artifact management capabilities.

Key capabilities:
  • Access to partition identifier and Snowflake session

  • Upload artifacts to the partition’s dedicated stage directory

  • Automatic handling of temporary files and serialization

Note

Users don’t create PartitionContext instances directly. They are provided automatically by the DPF execution framework.

Example

>>> def my_dpf_function(data_connector, context):
...     # Access partition information
...     print(f"Processing partition: {context.partition_id}")
...
...     # Process data
...     df = data_connector.to_pandas()
...     results = perform_analysis(df)
...
...     # Save results to stage
...     context.upload_to_stage(results, "analysis_results.pkl")
Copy
property partition_id: str
property session: Session
property stage_path_prefix: str
upload_to_stage(obj: Any, filename: str, *, write_function: Callable[[Any, str], None] | None = None)

Saves a Python object to a file and uploads it to the partition’s stage directory.

This method handles the complete workflow of serializing objects and uploading them to Snowflake stages, with automatic temporary file management.

How it works:
  • If write_function is NOT provided (default): Uses Python’s pickle module

  • If write_function IS provided: Uses your custom serialization logic

Parameters:
  • obj – The Python object to save.

  • filename (str) – The destination filename for the artifact (e.g., “model.pkl”).

  • write_function (Optional[Callable[[Any, str], None]]) – Custom serialization function that takes the object and local file path as arguments.

Example

Common usage patterns within a DPF function:

>>> def my_analytics_function(data_connector, context):
...     df = data_connector.to_pandas()
...
...     # Save pickle objects (default behavior)
...     results = {'total': df['amount'].sum(), 'count': len(df)}
...     context.upload_to_stage(results, "summary.pkl")
...
...     # Save JSON data
...     import json
...     context.upload_to_stage(
...         results, "summary.json",
...         write_function=lambda obj, path: json.dump(obj, open(path, 'w'))
...     )
...
...     # Save CSV file
...     df_processed = df.groupby('category').sum()
...     context.upload_to_stage(
...         df_processed, "aggregated.csv",
...         write_function=lambda df, path: df.to_csv(path, index=False)
...     )
Copy

Note

The method automatically handles temporary file creation and cleanup. Files are uploaded with auto_compress=False and overwrite=True.

class snowflake.ml.modeling.distributors.distributed_partition_function.stage_artifacts_manager.StageArtifactsManager(session, stage_path_prefix: str)

Bases: object

Provides a user-friendly interface to list, download, and load artifacts that were saved to a specific stage path for a given partition.

This class handles the common workflow of retrieving artifacts from Snowflake stages, with automatic deserialization for Python objects and flexible handling of different file types.

Example

Basic usage after a DPF run completes:

>>> # Get the stage artifacts manager for a specific partition
>>> manager = run.partition_details["North"].stage_artifacts_manager
Copy
>>> # List all available artifacts
>>> files = manager.list()
>>> print(files)  # ['model.pkl', 'metrics.json', 'predictions.csv']
Copy
>>> # Download and deserialize a pickle file (default behavior)
>>> model = manager.get("model.pkl")
Copy
>>> # Download and load JSON with custom deserializer
>>> import json
>>> metrics = manager.get("metrics.json", read_function=lambda path: json.load(open(path, 'r')))
Copy
>>> # Download raw file to specific local directory
>>> import tempfile
>>> local_dir = tempfile.mkdtemp()
>>> csv_path = manager.download("predictions.csv", local_dir)
>>> print(f"Downloaded to: {csv_path}")
Copy
>>> # Load text files with custom deserializers
>>> with open('config.txt', 'w') as f: f.write("setting=value")
>>> config = manager.get("config.txt", read_function=lambda path: open(path, 'r').read())
Copy
download(remote_filename: str, local_destination_dir: str) str

Downloads a single raw artifact file to a specified local directory.

Parameters:
  • remote_filename – The name of the file to download (e.g., “model.pkl”).

  • local_destination_dir – The local directory to save the file in.

Returns:

The full local path to the downloaded file.

get(filename: str, *, read_function: Callable[[str], Any] | None = None) Any

Downloads an artifact to a temporary location and deserializes it into a Python object.

How it works:

  • If ‘read_function’ is NOT provided (default): The file is assumed to be a pickle file and will be loaded with pickle.load.

  • If ‘read_function’ IS provided: That function will be called to load the object from the downloaded file path.

Parameters:
  • filename – The name of the artifact to get (e.g., “model.pkl”).

  • read_function – Optional. A function that takes one argument (the local file path) and returns the deserialized Python object. Example: lambda path: json.load(open(path, ‘r’))

Returns:

The deserialized Python object.

list() List[str]

Lists the filenames of all artifacts in the partition’s stage directory.