Hyperparameter Tuner

Classes

class entities.tuner_config.TunerConfig(metric: str, mode: str, search_alg: ~entities.search_algorithm.SearchAlgorithm = <factory>, num_trials: int = 5, max_concurrent_trials: dict | None = None, resource_per_trial: dict | None = None)

Bases: object

Configuration class for the tuning process.

metric

The name of the metric to optimize. This should correspond to a key in the metrics dictionary reported by the training function.

Type:

str

mode

The optimization mode for the metric. Must be either “min” for minimization or “max” for maximization.

Type:

str

search_alg

The search algorithm to use for exploring the hyperparameter space. Defaults to random search.

Type:

SearchAlgorithm

num_trials

The maximum number of parameter configurations to try. Defaults to 5. Note: In a grid search, the num_trials parameter is set to 1, such that each unique parameter combination in the grid is evaluated with exactly one trial.

Type:

int

max_concurrent_trials

The maximum number of concurrently running trials per node.

Type:

Optional[int]

If not specified, it defaults to the total number of nodes in the cluster. This value must be a positive integer
if provided.
resource_per_trial

An optional dictionary specifying the resources allocated per trial. For example, {‘CPU’: 1} reserves 1 CPU and {‘GPU’: 1} reserves 1 GPU. When this parameter is not provided, the resource allocation per trial is inferred based on the max_concurrent_trials setting and total cluster resources.

Type:

Optional[dict]

Example

>>> from snowflake.ml.modeling.tune import  TunerConfig
>>> config = TunerConfig(
...     metric="accuracy",
...     mode="max",
...     num_trials=5,
... )
Copy
class entities.tuner_context.TunerContext(*, hyper_params: Dict[str, Any], progress_reporter: Callable[[Dict[str, Any], Any | None], None], dataset_map: Dict[str, Type[DataConnector]] | None = None)

Bases: object

A context class for managing configuration, reporting, and dataset information in Ray Tune trials.

This class provides a centralized way to access configuration parameters, progress reporting functions, and dataset mappings within a Ray Tune trial.

Initialize a TunerContext instance.

Parameters:
  • hyper_params (Dict[str, Any]) – Configuration dictionary for the trial.

  • progress_reporter (Callable) – Function for reporting progress and metrics.

  • dataset_map (Optional[Dict[str, Type[DataConnector]]]) – Mapping of dataset names to DataConnector types.

get_dataset_map() Dict[str, Type[DataConnector]] | None

Retrieve the dataset mapping.

Returns:

A mapping of dataset names to DataConnector types, if available.

Return type:

Optional[Dict[str, Type[DataConnector]]]

get_hyper_params() Dict[str, Any]

Retrieve the configuration dictionary.

Returns:

The configuration dictionary for the trial.

Return type:

Dict[str, Any]

report(metrics: Dict[str, Any], model: Any | None = None) None

Report metrics and optionally the model if provided.

This method is used to report the performance metrics of a model and, if provided, the model itself. The reported metrics will be used to guide the next set of hyperparameters selection in the optimization process.

Parameters:
  • metrics (Dict[str, Any]) – A dictionary containing the performance metrics of the model. The keys are metric names, and the values are the corresponding metric values.

  • model (Optional[Any], optional) – The trained model to be reported. Defaults to None.

Returns:

This method doesn’t return anything.

Return type:

None

class entities.tuner_results.TunerResults(results: pandas.core.frame.DataFrame, best_result: pandas.core.frame.DataFrame, best_model: Any)

Bases: object

best_model: Any
best_result: DataFrame
results: DataFrame
class entities.search_space.SearchSpace(*args, **kwargs)

Bases: Dict[str, SamplingFunction | float | int | str | bool | List[float | int | str | bool]]

class entities.sampling_function.SamplingFunction

Bases: object

class entities.sampling_function.Uniform(lower: float, upper: float)

Bases: SamplingFunction

lower: float
upper: float
class entities.sampling_function.LogUniform(lower: float, upper: float)

Bases: SamplingFunction

lower: float
upper: float
class entities.sampling_function.RandInt(lower: float, upper: float)

Bases: SamplingFunction

lower: float
upper: float
class entities.search_algorithm.SearchAlgorithm

Bases: object

class entities.search_algorithm.BayesOpt(utility_kwargs: Dict[str, Any] | None = None)

Bases: SearchAlgorithm

utility_kwargs: Dict[str, Any] | None = None
class entities.search_algorithm.RandomSearch(random_state: int | None = None)

Bases: SearchAlgorithm

random_state: int | None = None
class entities.search_algorithm.GridSearch

Bases: SearchAlgorithm