scikit-learn

Le registre prend en charge les modèles créés à l’aide de scikit-learn (modèles dérivés de sklearn.base.BaseEstimator ou sklearn.pipeline.Pipeline).

Les options supplémentaires suivantes peuvent être utilisées dans le dictionnaire options lors de l’appel à log_model :

Option

Description

target_methods

Une liste des noms des méthodes disponibles sur l’objet modèle. Les modèles scikit-learn ont les méthodes cibles suivantes par défaut, en supposant que la méthode existe : predict, transform, predict_proba, predict_log_proba, decision_function.

Vous devez spécifier le paramètre sample_input_data ou signatures lorsque vous enregistrez un modèle scikit-learn afin que le registre connaisse les signatures des méthodes cibles.

Exemple

from sklearn import datasets, ensemble

iris_X, iris_y = datasets.load_iris(return_X_y=True, as_frame=True)
clf = ensemble.RandomForestClassifier(random_state=42)
clf.fit(iris_X, iris_y)
model_ref = registry.log_model(
    clf,
    model_name="RandomForestClassifier",
    version_name="v1",
    sample_input_data=iris_X,
    options={
        "method_options": {
            "predict": {"case_sensitive": True},
            "predict_proba": {"case_sensitive": True},
            "predict_log_proba": {"case_sensitive": True},
        }
    },
)
model_ref.run(iris_X[-10:], function_name='"predict_proba"')
Copy

Pipeline :

from sklearn import datasets, ensemble, pipeline, preprocessing

iris_X, iris_y = datasets.load_iris(return_X_y=True, as_frame=True)
pipe = pipeline.Pipeline([
    ('scaler', preprocessing.StandardScaler()),
    ('classifier', ensemble.RandomForestClassifier(random_state=42)),
])
pipe.fit(iris_X, iris_y)
model_ref = registry.log_model(
    pipe,
    model_name="Pipeline",
    version_name="v1",
    sample_input_data=iris_X,
    options={
        "method_options": {
            "predict": {"case_sensitive": True},
            "predict_proba": {"case_sensitive": True},
            "predict_log_proba": {"case_sensitive": True},
        }
    },
)
model_ref.run(iris_X[-10:], function_name='"predict_proba"')
Copy

Note

Vous pouvez combiner le prétraitement scikit-learn avec un modèle XGBoost sous la forme d’un pipeline scikit-learn.