snowflake.ml.registry.Registry

class snowflake.ml.registry.Registry(session: Session, *, database_name: Optional[str] = None, schema_name: Optional[str] = None)

Bases: object

Opens a registry within a pre-created Snowflake schema.

Parameters:
  • session – The Snowpark Session to connect with Snowflake.

  • database_name – The name of the database. If None, the current database of the session will be used. Defaults to None.

  • schema_name – The name of the schema. If None, the current schema of the session will be used. If there is no active schema, the PUBLIC schema will be used. Defaults to None.

Raises:

ValueError – When there is no specified or active database in the session.

Methods

delete_model(model_name: str) None

Delete the model by its name.

Parameters:

model_name – The name of the model to be deleted.

get_model(model_name: str) Model

Get the model object by its name.

Parameters:

model_name – The name of the model.

Returns:

The corresponding model object.

log_model(model: Union[catboost.CatBoost, lightgbm.LGBMModel, lightgbm.Booster, CustomModel, sklearn.base.BaseEstimator, sklearn.pipeline.Pipeline, xgboost.XGBModel, xgboost.Booster, torch.nn.Module, torch.jit.ScriptModule, tensorflow.Module, base.BaseEstimator, mlflow.pyfunc.PyFuncModel, transformers.Pipeline, sentence_transformers.SentenceTransformer, HuggingFacePipelineModel, snowflake.ml.model.models.llm.LLM], *, model_name: str, version_name: Optional[str] = None, comment: Optional[str] = None, metrics: Optional[Dict[str, Any]] = None, conda_dependencies: Optional[List[str]] = None, pip_requirements: Optional[List[str]] = None, python_version: Optional[str] = None, signatures: Optional[Dict[str, ModelSignature]] = None, sample_input_data: Optional[Union[pd.DataFrame, ndarray[Any, dtype[Union[np.int8, np.int16, np.int32, np.int64, np.float32, np.float64, np.uint8, np.uint16, np.uint32, np.uint64, np.bool_, np.str_, np.bytes_, np.datetime64]]], Sequence[Union[ndarray[Any, dtype[Union[np.int8, np.int16, np.int32, np.int64, np.float32, np.float64, np.uint8, np.uint16, np.uint32, np.uint64, np.bool_, np.str_, np.bytes_, np.datetime64]]], torch.Tensor, tensorflow.Tensor, tensorflow.Variable]], Sequence[Union[int, float, bool, str, bytes, _SupportedBuiltinsList]], snowflake.snowpark.DataFrame]] = None, code_paths: Optional[List[str]] = None, ext_modules: Optional[List[module]] = None, options: Optional[Union[BaseModelSaveOption, CatBoostModelSaveOptions, CustomModelSaveOption, LGBMModelSaveOptions, SKLModelSaveOptions, XGBModelSaveOptions, SNOWModelSaveOptions, PyTorchSaveOptions, TorchScriptSaveOptions, TensorflowSaveOptions, MLFlowSaveOptions, HuggingFaceSaveOptions, SentenceTransformersSaveOptions, LLMSaveOptions]] = None) ModelVersion

Log a model with various parameters and metadata.

Parameters:
  • model – Model object of supported types such as Scikit-learn, XGBoost, LightGBM, Snowpark ML, PyTorch, TorchScript, Tensorflow, Tensorflow Keras, MLFlow, HuggingFace Pipeline, Sentence Transformers, Peft-finetuned LLM, or Custom Model.

  • model_name – Name to identify the model.

  • version_name – Version identifier for the model. Combination of model_name and version_name must be unique. If not specified, a random name will be generated.

  • comment – Comment associated with the model version. Defaults to None.

  • metrics – A JSON serializable dictionary containing metrics linked to the model version. Defaults to None.

  • signatures – Model data signatures for inputs and outputs for various target methods. If it is None, sample_input_data would be used to infer the signatures for those models that cannot automatically infer the signature. If not None, sample_input_data should not be specified. Defaults to None.

  • sample_input_data – Sample input data to infer model signatures from. Defaults to None.

  • conda_dependencies – List of Conda package specifications. Use “[channel::]package [operator version]” syntax to specify a dependency. It is a recommended way to specify your dependencies using conda. When channel is not specified, Snowflake Anaconda Channel will be used. Defaults to None.

  • pip_requirements – List of Pip package specifications. Defaults to None. Currently it is not supported since Model can only executed in Snowflake Warehouse where all dependencies are required to be retrieved from Snowflake Anaconda Channel.

  • python_version – Python version in which the model is run. Defaults to None.

  • code_paths – List of directories containing code to import. Defaults to None.

  • ext_modules – List of external modules to pickle with the model object. Only supported when logging the following types of model: Scikit-learn, Snowpark ML, PyTorch, TorchScript and Custom Model. Defaults to None.

  • options (Dict[str, Any], optional) –

    Additional model saving options.

    Model Saving Options include:

    • embed_local_ml_library: Embed local Snowpark ML into the code directory or folder.

      Override to True if the local Snowpark ML version is not available in the Snowflake Anaconda Channel. Otherwise, defaults to False

    • relax_version: Whether or not relax the version constraints of the dependencies.

      It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.

    • function_type: Set the method function type globally. To set method function types individually see function_type in model_options.

    • method_options: Per-method saving options including:
      • case_sensitive: Indicates whether the method and its signature should be case sensitive.

        This means when you refer the method in the SQL, you need to double quote it. This will be helpful if you need case to tell apart your methods or features, or you have non-alphabetic characters in your method or feature name. Defaults to False.

      • max_batch_size: Maximum batch size that the method could accept in the Snowflake Warehouse.

        Defaults to None, determined automatically by Snowflake.

      • function_type: One of supported model method function types (FUNCTION or TABLE_FUNCTION).

Returns:

ModelVersion object corresponding to the model just logged.

Return type:

ModelVersion

models() List[Model]

Get all models in the schema where the registry is opened.

Returns:

A list of Model objects representing all models in the opened registry.

show_models() DataFrame

Show information of all models in the schema where the registry is opened.

Returns:

A Pandas DataFrame containing information of all models in the schema.

Attributes

location

Get the location (database.schema) of the registry.