Hyperparameter Tuner¶
Classes¶
- class entities.tuner_config.TunerConfig(metric: str, mode: str, search_alg: ~entities.search_algorithm.SearchAlgorithm = <factory>, num_trials: int = 5, max_concurrent_trials: dict | None = None, resource_per_trial: dict | None = None)¶
Bases:
object
Configuration class for the tuning process.
- metric¶
The name of the metric to optimize. This should correspond to a key in the metrics dictionary reported by the training function.
- Type:
str
- mode¶
The optimization mode for the metric. Must be either “min” for minimization or “max” for maximization.
- Type:
str
- search_alg¶
The search algorithm to use for exploring the hyperparameter space. Defaults to random search.
- Type:
- num_trials¶
The maximum number of parameter configurations to try. Defaults to 5. Note: In a grid search, the num_trials parameter is set to 1, such that each unique parameter combination in the grid is evaluated with exactly one trial.
- Type:
int
- max_concurrent_trials¶
The maximum number of concurrently running trials per node.
- Type:
Optional[int]
- If not specified, it defaults to the total number of nodes in the cluster. This value must be a positive integer
- if provided.
- resource_per_trial¶
An optional dictionary specifying the resources allocated per trial. For example, {‘CPU’: 1} reserves 1 CPU and {‘GPU’: 1} reserves 1 GPU. When this parameter is not provided, the resource allocation per trial is inferred based on the max_concurrent_trials setting and total cluster resources.
- Type:
Optional[dict]
Example
>>> from snowflake.ml.modeling.tune import TunerConfig >>> config = TunerConfig( ... metric="accuracy", ... mode="max", ... num_trials=5, ... )
- class entities.tuner_context.TunerContext(*, hyper_params: Dict[str, Any], progress_reporter: Callable[[Dict[str, Any], Any | None], None], dataset_map: Dict[str, Type[DataConnector]] | None = None)¶
Bases:
object
A context class for managing configuration, reporting, and dataset information in Ray Tune trials.
This class provides a centralized way to access configuration parameters, progress reporting functions, and dataset mappings within a Ray Tune trial.
Initialize a TunerContext instance.
- Parameters:
hyper_params (Dict[str, Any]) – Configuration dictionary for the trial.
progress_reporter (Callable) – Function for reporting progress and metrics.
dataset_map (Optional[Dict[str, Type[DataConnector]]]) – Mapping of dataset names to DataConnector types.
- get_dataset_map() Dict[str, Type[DataConnector]] | None ¶
Retrieve the dataset mapping.
- Returns:
A mapping of dataset names to DataConnector types, if available.
- Return type:
Optional[Dict[str, Type[DataConnector]]]
- get_hyper_params() Dict[str, Any] ¶
Retrieve the configuration dictionary.
- Returns:
The configuration dictionary for the trial.
- Return type:
Dict[str, Any]
- report(metrics: Dict[str, Any], model: Any | None = None) None ¶
Report metrics and optionally the model if provided.
This method is used to report the performance metrics of a model and, if provided, the model itself. The reported metrics will be used to guide the next set of hyperparameters selection in the optimization process.
- Parameters:
metrics (Dict[str, Any]) – A dictionary containing the performance metrics of the model. The keys are metric names, and the values are the corresponding metric values.
model (Optional[Any], optional) – The trained model to be reported. Defaults to None.
- Returns:
This method doesn’t return anything.
- Return type:
None
- class entities.tuner_results.TunerResults(results: pandas.core.frame.DataFrame, best_result: pandas.core.frame.DataFrame, best_model: Any)¶
Bases:
object
- best_model: Any¶
- best_result: DataFrame¶
- results: DataFrame¶
- class entities.search_space.SearchSpace(*args, **kwargs)¶
Bases:
Dict
[str
,SamplingFunction
|float
|int
|str
|bool
|List
[float
|int
|str
|bool
]]
- class entities.sampling_function.SamplingFunction¶
Bases:
object
- class entities.sampling_function.Uniform(lower: float, upper: float)¶
Bases:
SamplingFunction
- lower: float¶
- upper: float¶
- class entities.sampling_function.LogUniform(lower: float, upper: float)¶
Bases:
SamplingFunction
- lower: float¶
- upper: float¶
- class entities.sampling_function.RandInt(lower: float, upper: float)¶
Bases:
SamplingFunction
- lower: float¶
- upper: float¶
- class entities.search_algorithm.SearchAlgorithm¶
Bases:
object
- class entities.search_algorithm.BayesOpt(utility_kwargs: Dict[str, Any] | None = None)¶
Bases:
SearchAlgorithm
- utility_kwargs: Dict[str, Any] | None = None¶
- class entities.search_algorithm.RandomSearch(random_state: int | None = None)¶
Bases:
SearchAlgorithm
- random_state: int | None = None¶
- class entities.search_algorithm.GridSearch¶
Bases:
SearchAlgorithm