scikit-learn¶
Die Registry unterstützt Modelle, die mit Scikit-learn erstellt wurden (von sklearn.base.BaseEstimator
oder sklearn.pipeline.Pipeline
abgeleitete Modelle).
Die folgenden zusätzlichen Optionen können im options
-Dictionary verwendet werden, wenn Sie log_model
abrufen:
Option |
Beschreibung |
---|---|
|
Liste der Namen der Methoden, die für das Modellobjekt verfügbar sind. Scikit-learn-Modelle haben standardmäßig die folgenden Zielmethoden, vorausgesetzt, die Methode existiert: |
Sie müssen entweder den Parameter sample_input_data
oder signatures
angeben, wenn Sie ein Scikit-learn-Modell protokollieren, damit die Registry die Signaturen der Zielmethoden kennt.
Beispiel¶
In diesem Beispiel werden ein RandomForestClassifier
und eine Pipeline
trainiert und in der Modellregistrierung erfasst.
from snowflake.ml.registry import Registry
from sklearn import datasets, ensemble
# create a session and set DATABASE and SCHEMA
# session = ...
registry = Registry(session=session, database_name=DATABASE, schema_name=SCHEMA)
iris_X, iris_y = datasets.load_iris(return_X_y=True, as_frame=True)
# Rename columns so they are valid Snowflake identifiers
column_name_map = {
'sepal length (cm)': 'sepal_length',
'sepal width (cm)': 'sepal_width',
'petal length (cm)': 'petal_length',
'petal width (cm)': 'petal_width'
}
iris_X = iris_X.rename(columns=column_name_map)
# Train the model
clf = ensemble.RandomForestClassifier(random_state=42)
clf.fit(iris_X, iris_y)
# Log the model in the registry
model_ref = registry.log_model(
clf,
model_name="RandomForestClassifier",
version_name="v1",
sample_input_data=iris_X,
options={
"method_options": {
"predict": {"case_sensitive": True},
"predict_proba": {"case_sensitive": True},
"predict_log_proba": {"case_sensitive": True},
}
},
)
# Generate predictions
model_ref.run(iris_X[-10:], function_name='"predict_proba"')
# Pipelines can also be logged in the registry
from sklearn import pipeline, preprocessing
pipe = pipeline.Pipeline([
('scaler', preprocessing.StandardScaler()),
('classifier', ensemble.RandomForestClassifier(random_state=42)),
])
pipe.fit(iris_X, iris_y)
model_ref = registry.log_model(
pipe,
model_name="Pipeline",
version_name="v1",
sample_input_data=iris_X,
options={
"method_options": {
"predict": {"case_sensitive": True},
"predict_proba": {"case_sensitive": True},
"predict_log_proba": {"case_sensitive": True},
}
},
)
# Generate predictions
model_ref.run(iris_X[-10:], function_name='"predict_proba"')
Bemerkung
Sie können die scikit-learn-Vorverarbeitung mit einem XGBoost-Modell in einer scikit-learn-Pipeline kombinieren.