scikit-learn

Le registre prend en charge les modèles créés à l’aide de scikit-learn (modèles dérivés de sklearn.base.BaseEstimator ou sklearn.pipeline.Pipeline).

Les options supplémentaires suivantes peuvent être utilisées dans le dictionnaire options lors de l’appel à log_model :

Option

Description

target_methods

Une liste des noms des méthodes disponibles sur l’objet modèle. Les modèles scikit-learn ont les méthodes cibles suivantes par défaut, en supposant que la méthode existe : predict, transform, predict_proba, predict_log_proba, decision_function.

Vous devez spécifier le paramètre sample_input_data ou signatures lorsque vous enregistrez un modèle scikit-learn afin que le registre connaisse les signatures des méthodes cibles.

Exemple

Dans cet exemple, RandomForestClassifier et Pipeline sont entraînés et connectés au registre des modèles.

from snowflake.ml.registry import Registry
from sklearn import datasets, ensemble

# create a session and set DATABASE and SCHEMA
# session = ...

registry = Registry(session=session, database_name=DATABASE, schema_name=SCHEMA)

iris_X, iris_y = datasets.load_iris(return_X_y=True, as_frame=True)

# Rename columns so they are valid Snowflake identifiers
column_name_map = {
        'sepal length (cm)': 'sepal_length',
        'sepal width (cm)': 'sepal_width',
        'petal length (cm)': 'petal_length',
        'petal width (cm)': 'petal_width'
}
iris_X = iris_X.rename(columns=column_name_map)

# Train the model
clf = ensemble.RandomForestClassifier(random_state=42)
clf.fit(iris_X, iris_y)

# Log the model in the registry
model_ref = registry.log_model(
    clf,
    model_name="RandomForestClassifier",
    version_name="v1",
    sample_input_data=iris_X,
    options={
        "method_options": {
            "predict": {"case_sensitive": True},
            "predict_proba": {"case_sensitive": True},
            "predict_log_proba": {"case_sensitive": True},
        }
    },
)

# Generate predictions
model_ref.run(iris_X[-10:], function_name='"predict_proba"')

# Pipelines can also be logged in the registry
from sklearn import pipeline, preprocessing

pipe = pipeline.Pipeline([
    ('scaler', preprocessing.StandardScaler()),
    ('classifier', ensemble.RandomForestClassifier(random_state=42)),
])
pipe.fit(iris_X, iris_y)

model_ref = registry.log_model(
    pipe,
    model_name="Pipeline",
    version_name="v1",
    sample_input_data=iris_X,
    options={
        "method_options": {
            "predict": {"case_sensitive": True},
            "predict_proba": {"case_sensitive": True},
            "predict_log_proba": {"case_sensitive": True},
        }
    },
)

# Generate predictions
model_ref.run(iris_X[-10:], function_name='"predict_proba"')
Copy

Note

Vous pouvez combiner le prétraitement scikit-learn avec un modèle XGBoost sous la forme d’un pipeline scikit-learn.