scikit-learn¶
레지스트리는 scikit-learn을 사용하여 생성된 모델(sklearn.base.BaseEstimator
또는 sklearn.pipeline.Pipeline
에서 파생된 모델)을 지원합니다.
log_model
을 호출할 때 options
사전에서 다음 추가 옵션을 사용할 수 있습니다.
옵션 |
설명 |
---|---|
|
모델 오브젝트에서 사용할 수 있는 메서드 이름 목록입니다. scikit-learn 모델에는 대상 메서드가 존재한다고 가정하면 기본적으로 |
레지스트리가 대상 메서드의 서명을 알도록 scikit-learn 모델을 로깅할 때 sample_input_data
또는 signatures
매개 변수를 지정해야 합니다.
예¶
이 예에서는 RandomForestClassifier
및 Pipeline
이 학습되어 Model Registry에 로그됩니다.
from snowflake.ml.registry import Registry
from sklearn import datasets, ensemble
# create a session and set DATABASE and SCHEMA
# session = ...
registry = Registry(session=session, database_name=DATABASE, schema_name=SCHEMA)
iris_X, iris_y = datasets.load_iris(return_X_y=True, as_frame=True)
# Rename columns so they are valid Snowflake identifiers
column_name_map = {
'sepal length (cm)': 'sepal_length',
'sepal width (cm)': 'sepal_width',
'petal length (cm)': 'petal_length',
'petal width (cm)': 'petal_width'
}
iris_X = iris_X.rename(columns=column_name_map)
# Train the model
clf = ensemble.RandomForestClassifier(random_state=42)
clf.fit(iris_X, iris_y)
# Log the model in the registry
model_ref = registry.log_model(
clf,
model_name="RandomForestClassifier",
version_name="v1",
sample_input_data=iris_X,
options={
"method_options": {
"predict": {"case_sensitive": True},
"predict_proba": {"case_sensitive": True},
"predict_log_proba": {"case_sensitive": True},
}
},
)
# Generate predictions
model_ref.run(iris_X[-10:], function_name='"predict_proba"')
# Pipelines can also be logged in the registry
from sklearn import pipeline, preprocessing
pipe = pipeline.Pipeline([
('scaler', preprocessing.StandardScaler()),
('classifier', ensemble.RandomForestClassifier(random_state=42)),
])
pipe.fit(iris_X, iris_y)
model_ref = registry.log_model(
pipe,
model_name="Pipeline",
version_name="v1",
sample_input_data=iris_X,
options={
"method_options": {
"predict": {"case_sensitive": True},
"predict_proba": {"case_sensitive": True},
"predict_log_proba": {"case_sensitive": True},
}
},
)
# Generate predictions
model_ref.run(iris_X[-10:], function_name='"predict_proba"')
참고
scikit-learn 전처리를 XGBoost 모델과 결합하여 scikit-learn 파이프라인으로 사용할 수 있습니다.