scikit-learn

レジストリは、scikit-learnを使用して作成されたモデル(sklearn.base.BaseEstimator または sklearn.pipeline.Pipeline から派生したモデル)をサポートしています。

以下の追加オプションは、 options ディクショナリで log_model を呼び出すときに使用できます。

オプション

説明

target_methods

モデルオブジェクトで利用可能なメソッド名。メソッドが存在すると仮定して、scikit-learnモデルはデフォルトで以下のターゲットメソッドを持ちます。 predicttransformpredict_probapredict_log_probadecision_function

scikit-learnモデルをログする場合は、 sample_input_datasignatures のどちらかのパラメーターを指定する必要があります。

この例では、 RandomForestClassifierPipeline がトレーニングされ、モデルレジストリにログされます。

from snowflake.ml.registry import Registry
from sklearn import datasets, ensemble

# create a session and set DATABASE and SCHEMA
# session = ...

registry = Registry(session=session, database_name=DATABASE, schema_name=SCHEMA)

iris_X, iris_y = datasets.load_iris(return_X_y=True, as_frame=True)

# Rename columns so they are valid Snowflake identifiers
column_name_map = {
        'sepal length (cm)': 'sepal_length',
        'sepal width (cm)': 'sepal_width',
        'petal length (cm)': 'petal_length',
        'petal width (cm)': 'petal_width'
}
iris_X = iris_X.rename(columns=column_name_map)

# Train the model
clf = ensemble.RandomForestClassifier(random_state=42)
clf.fit(iris_X, iris_y)

# Log the model in the registry
model_ref = registry.log_model(
    clf,
    model_name="RandomForestClassifier",
    version_name="v1",
    sample_input_data=iris_X,
    options={
        "method_options": {
            "predict": {"case_sensitive": True},
            "predict_proba": {"case_sensitive": True},
            "predict_log_proba": {"case_sensitive": True},
        }
    },
)

# Generate predictions
model_ref.run(iris_X[-10:], function_name='"predict_proba"')

# Pipelines can also be logged in the registry
from sklearn import pipeline, preprocessing

pipe = pipeline.Pipeline([
    ('scaler', preprocessing.StandardScaler()),
    ('classifier', ensemble.RandomForestClassifier(random_state=42)),
])
pipe.fit(iris_X, iris_y)

model_ref = registry.log_model(
    pipe,
    model_name="Pipeline",
    version_name="v1",
    sample_input_data=iris_X,
    options={
        "method_options": {
            "predict": {"case_sensitive": True},
            "predict_proba": {"case_sensitive": True},
            "predict_log_proba": {"case_sensitive": True},
        }
    },
)

# Generate predictions
model_ref.run(iris_X[-10:], function_name='"predict_proba"')
Copy

注釈

scikit-learn パイプラインとして、scikit-learn の前処理と XGBoost モデルを組み合わせることができます。