Distributed Partition Function (DPF)¶
Classes¶
- class snowflake.ml.modeling.distributors.distributed_partition_function.dpf.DPF(func: Callable[[DataConnector, PartitionContext], None], stage_name: str)¶
Bases:
object
Distributed Partition Function (DPF) for executing user-defined functions across data partitions.
DPF enables distributed processing by partitioning a Snowpark DataFrame and executing a user-provided function on each partition concurrently using Ray’s distributed computing framework. Results and artifacts are automatically stored in Snowflake stages for easy retrieval.
- Key features:
Automatic data partitioning and distribution
Concurrent execution across Ray cluster
Built-in artifact management and persistence
Progress monitoring
Error handling and partial failure recovery
- Typical workflow:
Define your data processing function
Create DPF instance with function and stage
Execute distributed processing on partitioned data
Monitor progress until completion
Retrieve results from each partition
Optional: Restore completed runs later
Example
Complete end-to-end workflow for distributed analytics:
1. Data Structure
- sales_data is a Snowpark DataFrame with columns:
region: ‘North’, ‘South’, ‘East’, ‘West’
customer_id: unique customer identifiers
amount: transaction amounts
order_date: transaction dates
2. Define Processing Function
The data will be partitioned by region, with each partition containing all sales records for that specific region. The processing function receives this regional data subset and performs analytics on it:
>>> def process_sales_data(data_connector, context): ... df = data_connector.to_pandas() ... print(f"Processing {len(df)} records for region: {context.partition_id}") ... ... # Perform region-specific analytics ... summary = { ... 'region': context.partition_id, ... 'total_sales': df['amount'].sum(), ... 'avg_order_value': df['amount'].mean(), ... 'customer_count': df['customer_id'].nunique(), ... 'record_count': len(df) ... } ... ... # Save results to stage for later retrieval ... import json ... context.upload_to_stage(summary, "sales_summary.json", ... write_function=lambda obj, path: json.dump(obj, open(path, 'w')))
3. Execute Distributed Processing
>>> dpf = DPF(process_sales_data, "analytics_stage") >>> run = dpf.run( ... partition_by="region", # Creates separate partitions for North, South, East, West ... snowpark_dataframe=sales_data, ... run_id="regional_analytics_2024" ... )
4. Monitor Progress and Wait for Completion
>>> final_status = run.wait() # Shows progress bar by default >>> print(f"Job completed with status: {final_status}")
5. Retrieve Results from Each Partition
>>> if final_status == RunStatus.SUCCESS: ... # Get results from each region ... import json ... all_results = [] ... for partition_id, details in run.partition_details.items(): ... summary = details.stage_artifacts_manager.get("sales_summary.json", ... read_function=lambda path: json.load(open(path, 'r'))) ... all_results.append(summary) ... ... # Combine results across all regions ... total_sales = sum(r['total_sales'] for r in all_results) ... total_customers = sum(r['customer_count'] for r in all_results) ... else: ... # Handle failures - check logs for failed partitions ... for partition_id, details in run.partition_details.items(): ... if details.status != "DONE": ... error_logs = details.logs
6. Optional: Restore Results Later
>>> # Restore completed run from stage and access same results as above without re-running. >>> restored_run = DPFRun.restore_from("regional_analytics_2024", "analytics_stage")
See also
ManyModelTraining
:Specialized DPF for distributed model training with automatic model serialization
ManyModelInference
:Specialized DPF for distributed model inference with automatic model loading
Initializes Distributed Partition Function(DPF)
- Parameters:
func (Callable[[DataConnector, PartitionContext], None]) – A user-provided function to be executed. It accepts a DataConnector object and PartitionContext for a given partition.
stage_name (str) –
The stage name (either fully qualified or not) used to store or retrieve training data.
Examples:
my_stage # Unqualified stage name (uses the current session's database and schema) my_db.my_schema.my_stage # Fully qualified stage name (specifies database, schema, and stage)
- run(*, partition_by: str, snowpark_dataframe: DataFrame, run_id: str, on_existing_artifacts: str = 'error', execution_options: ExecutionOptions | None = None) DPFRun ¶
Executes the user-defined function on partitions of the Snowpark DataFrame in a distributed manner.
Partitions the DataFrame by the specified column and executes the function on each partition concurrently using Ray’s distributed computing framework.
- Parameters:
partition_by (str) – Column name to partition the DataFrame by. Each unique value creates a separate partition processed independently.
snowpark_dataframe (snowpark.DataFrame) – DataFrame to be partitioned and processed. Must contain a single query without post-actions.
run_id (str) –
Unique identifier for this execution run. This ID serves multiple purposes:
Creates a dedicated directory
{stage_name}/{run_id}/
to organize all run artifactsEnables tracking and management of execution state across the distributed system
Best practices: Use descriptive names like
experiment_2024_01_15
ormodel_v1_retrain
on_existing_artifacts (str, optional) –
Action for existing artifacts. Defaults to “error”.
"error"
: Raises error if{stage_name}/{run_id}/
artifacts exist"overwrite"
: Replaces existing{stage_name}/{run_id}/
artifacts
execution_options (ExecutionOptions, optional) –
Configuration for distributed execution. Uses defaults if not provided. Available options:
use_head_node
(bool): Whether the Ray head node participates in workload execution (True, default) or acts solely as coordinator (False)loading_wh
(str): Optional dedicated warehouse for loading data from Snowflake tables into the stage to accelerate data loading
- Returns:
Run management object for monitoring status, canceling jobs, and accessing results.
- Return type:
- class snowflake.ml.modeling.distributors.distributed_partition_function.dpf_run.DPFRun(run_id: str, *, orchestrator_ref: ObjectRef, run_object_ref: ObjectRef)¶
Bases:
object
A handle to a single execution of a Distributed Partition Function (DPF) run.
Note
Users do not create
DPFRun
instances directly. They are returned byDPF.run()
orDPFRun.restore_from()
.This object provides methods to monitor execution progress, manage the running job, and retrieve results from the distributed execution across partitions.
- Key capabilities:
Monitor execution status and progress
Cancel running jobs
Access partition-level details and artifacts
Wait for completion with optional progress display
Restore completed runs from persistent storage
Example
>>> dpf = DPF(process_sales_data, "analytics_stage") >>> run = dpf.run(partition_by="region", snowpark_dataframe=sales_data, run_id="analytics_2024") >>> final_status = run.wait() # Wait with progress bar >>> if final_status == RunStatus.SUCCESS: ... for partition_id, details in run.partition_details.items(): ... summary = details.stage_artifacts_manager.get("sales_summary.json")
- cancel() None ¶
Cancels the ongoing distributed execution.
Note
Partitions that have already completed may not be affected.
Partial results, logs, or other artifacts might remain after cancellation.
- get_progress() Dict[str, List[SinglePartitionDetails]] ¶
Returns the current progress of the distributed run.
- Returns:
A dictionary grouping partition details by their current status.
- property partition_details: Dict[str, SinglePartitionDetails]¶
Provides details for each partition processed.
- Returns:
A dictionary mapping partition IDs to their respective execution details.
- classmethod restore_from(run_id: str, stage_name: str) DPFRun ¶
Restore a completed DPF run from its persisted state.
- Parameters:
run_id (str) – The run ID of the execution to restore.
stage_name (str) – The stage name where the run artifacts are stored.
- Returns:
A read-only DPFRun instance with access to the completed run’s results.
- Return type:
- Raises:
ValueError – If the run state cannot be restored from the specified location.
- property run_id: str¶
The unique identifier for this execution run.
- Returns:
The run ID specified when the DPF was executed.
- Return type:
str
- property status: RunStatus¶
Retrieves the overall status of the distributed execution.
This property aggregates the status across all partitions to provide a summary of the entire run.
- Returns:
The current status of the execution.
- wait(show_progress: bool = True) RunStatus ¶
Wait for the DPF run to complete.
- Parameters:
show_progress (bool, optional) – Whether to display a progress bar during execution. Defaults to True. Uses Streamlit UI if available, otherwise falls back to tqdm.
- Returns:
The final status of the run (SUCCESS, FAILURE, PARTIAL_FAILURE, or CANCELLED).
- Return type:
RunStatus
- class snowflake.ml.modeling.distributors.distributed_partition_function.entities.ExecutionOptions(use_head_node: bool = True, loading_wh: str | None = None)
Bases:
object
Configuration options for workload execution.
- use_head_node
This option controls whether the head node participates in the workload execution or solely acts as a coordinator. If True, the Ray head node will execute user-provided functions alongside worker nodes. If False, only worker nodes will execute user-provided functions.
- Type:
bool
- loading_wh
A dedicated warehouse used for loading data from Snowflake tables into the snowflake stage. This is typically a larger warehouse to accelerate the data loading process. Once data loading is complete, session will switch back to use the original warehouse.
- Type:
str | None
- class snowflake.ml.modeling.distributors.distributed_partition_function.entities.SinglePartitionDetails(*, partition_id: str, status: PartitionStatus = PartitionStatus.PENDING, stage_path: str | None = None)
Bases:
object
A class that encapsulates the details of a single model training job, including its status, model, and associated logs. It manages the loading of the model and training logs from a specified stage path, and provides safeguards in case of errors during these operations.
- partition_id
The unique identifier for the partition of the job.
- Type:
str
- stage_path
The path where the model and logs are stored in the system. status (PartitionStatus): The current status of the partition (default is PENDING).
- Type:
str
- Raises:
RuntimeError – If loading the model or logs from the stage path fails due to system errors.
- property logs: str
- property partition_id: str
- property stage_artifacts_manager: StageArtifactsManager
Returns an StageArtifactsManager instance for interacting with this partition’s saved artifacts.
- property status: PartitionStatus
- class snowflake.ml.modeling.distributors.distributed_partition_function.entities.RunStatus(value)
Bases:
Enum
Enum representing the status of an overall distributed run.
- SUCCESS
The run completed successfully.
- Type:
str
- IN_PROGRESS
The run is actively executing.
- Type:
str
- PARTIAL_FAILURE
The run completed, but some partitions failed.
- Type:
str
- FAILURE
The run failed completely.
- Type:
str
- CANCELLED
The run was cancelled.
- Type:
str
- class snowflake.ml.modeling.distributors.distributed_partition_function.entities.PartitionStatus(value)
Bases:
str
,Enum
Enum representing the possible status of a partition execution.
- PENDING
The partition is waiting to be processed.
- Type:
str
- RUNNING
The partition is currently being processed.
- Type:
str
- FAILED
The partition has failed.
- Type:
str
- DONE
The partition has successfully completed.
- Type:
str
- CANCELLED
The partition has been cancelled, this happens when user explicitly invokes the cancel call.
- Type:
str
- INTERNAL_ERROR
The partition encountered an internal error during processing.
- Type:
str
- class snowflake.ml.modeling.distributors.distributed_partition_function.partition_context.PartitionContext(*, session: Session, partition_id: str, stage_path_prefix: str)¶
Bases:
object
Provides context and utilities for partition-specific operations within DPF functions.
This class is automatically passed as the second argument to user-defined DPF functions, providing access to partition metadata and artifact management capabilities.
- Key capabilities:
Access to partition identifier and Snowflake session
Upload artifacts to the partition’s dedicated stage directory
Automatic handling of temporary files and serialization
Note
Users don’t create
PartitionContext
instances directly. They are provided automatically by the DPF execution framework.Example
>>> def my_dpf_function(data_connector, context): ... # Access partition information ... print(f"Processing partition: {context.partition_id}") ... ... # Process data ... df = data_connector.to_pandas() ... results = perform_analysis(df) ... ... # Save results to stage ... context.upload_to_stage(results, "analysis_results.pkl")
- property partition_id: str¶
- property session: Session¶
- property stage_path_prefix: str¶
- upload_to_stage(obj: Any, filename: str, *, write_function: Callable[[Any, str], None] | None = None)¶
Saves a Python object to a file and uploads it to the partition’s stage directory.
This method handles the complete workflow of serializing objects and uploading them to Snowflake stages, with automatic temporary file management.
- How it works:
If
write_function
is NOT provided (default): Uses Python’s pickle moduleIf
write_function
IS provided: Uses your custom serialization logic
- Parameters:
obj – The Python object to save.
filename (str) – The destination filename for the artifact (e.g., “model.pkl”).
write_function (Optional[Callable[[Any, str], None]]) – Custom serialization function that takes the object and local file path as arguments.
Example
Common usage patterns within a DPF function:
>>> def my_analytics_function(data_connector, context): ... df = data_connector.to_pandas() ... ... # Save pickle objects (default behavior) ... results = {'total': df['amount'].sum(), 'count': len(df)} ... context.upload_to_stage(results, "summary.pkl") ... ... # Save JSON data ... import json ... context.upload_to_stage( ... results, "summary.json", ... write_function=lambda obj, path: json.dump(obj, open(path, 'w')) ... ) ... ... # Save CSV file ... df_processed = df.groupby('category').sum() ... context.upload_to_stage( ... df_processed, "aggregated.csv", ... write_function=lambda df, path: df.to_csv(path, index=False) ... )
Note
The method automatically handles temporary file creation and cleanup. Files are uploaded with
auto_compress=False
andoverwrite=True
.
- class snowflake.ml.modeling.distributors.distributed_partition_function.stage_artifacts_manager.StageArtifactsManager(session, stage_path_prefix: str)¶
Bases:
object
Provides a user-friendly interface to list, download, and load artifacts that were saved to a specific stage path for a given partition.
This class handles the common workflow of retrieving artifacts from Snowflake stages, with automatic deserialization for Python objects and flexible handling of different file types.
Example
Basic usage after a DPF run completes:
>>> # Get the stage artifacts manager for a specific partition >>> manager = run.partition_details["North"].stage_artifacts_manager
>>> # List all available artifacts >>> files = manager.list() >>> print(files) # ['model.pkl', 'metrics.json', 'predictions.csv']
>>> # Download and deserialize a pickle file (default behavior) >>> model = manager.get("model.pkl")
>>> # Download and load JSON with custom deserializer >>> import json >>> metrics = manager.get("metrics.json", read_function=lambda path: json.load(open(path, 'r')))
>>> # Download raw file to specific local directory >>> import tempfile >>> local_dir = tempfile.mkdtemp() >>> csv_path = manager.download("predictions.csv", local_dir) >>> print(f"Downloaded to: {csv_path}")
>>> # Load text files with custom deserializers >>> with open('config.txt', 'w') as f: f.write("setting=value") >>> config = manager.get("config.txt", read_function=lambda path: open(path, 'r').read())
- download(remote_filename: str, local_destination_dir: str) str ¶
Downloads a single raw artifact file to a specified local directory.
- Parameters:
remote_filename – The name of the file to download (e.g., “model.pkl”).
local_destination_dir – The local directory to save the file in.
- Returns:
The full local path to the downloaded file.
- get(filename: str, *, read_function: Callable[[str], Any] | None = None) Any ¶
Downloads an artifact to a temporary location and deserializes it into a Python object.
How it works:
If ‘read_function’ is NOT provided (default): The file is assumed to be a pickle file and will be loaded with pickle.load.
If ‘read_function’ IS provided: That function will be called to load the object from the downloaded file path.
- Parameters:
filename – The name of the artifact to get (e.g., “model.pkl”).
read_function – Optional. A function that takes one argument (the local file path) and returns the deserialized Python object. Example: lambda path: json.load(open(path, ‘r’))
- Returns:
The deserialized Python object.
- list() List[str] ¶
Lists the filenames of all artifacts in the partition’s stage directory.