CatBoost¶
Die Snowflake ML Model Registry unterstützt Modelle, die mit CatBoost erstellt wurden (von catboost.CatBoost abgeleitete Modelle wie catboost.CatBoostClassifier, catboost.CatBoostRegressor und catboost.CatBoostRanker).
Die folgenden zusätzlichen Optionen können im options-Dictionary verwendet werden, wenn Sie log_model abrufen:
Option |
Beschreibung |
|---|---|
|
Liste der Namen der Methoden, die für das Modellobjekt verfügbar sind. CatBoost-Modelle haben standardmäßig die folgenden Zielmethoden, vorausgesetzt, die Methode existiert: |
|
Gibt an, ob die Erklärbarkeit für das Modell mit SHAP aktiviert werden soll. Die Standardeinstellung ist |
|
Die Version der CUDA-Laufzeitumgebung, die beim Bereitstellen auf einer Plattform mit GPU verwendet werden soll. Der Standardwert ist 11.8. Wird das Modell manuell auf |
Sie müssen entweder den Parameter sample_input_data oder signatures angeben, wenn Sie ein CatBoost-Modell protokollieren, damit die Registry die Signaturen der Zielmethoden kennt.
Beispiele¶
Diese Beispiele gehen davon aus, dass reg eine Instanz von snowflake.ml.registry.Registry ist.
CatBoostClassifier¶
Das folgende Beispiel zeigt die wichtigsten Schritte zum Trainieren eines CatBoost-Klassifikators, zum Protokollieren dieses Klassifikators in der Snowflake MLModel Registry und zum Verwenden des registrierten Modells zu Ableitungs- und Erklärbarkeitszwecken. Der Workflow umfasst:
Trainieren eines CatBoost-Klassifikators für ein Beispiel-Datenset
Protokollieren des Modells in der Snowflake ML-Modell-Registry
Erstellen von Vorhersagen und Abrufen der Vorhersagewahrscheinlichkeiten
Abrufen der SHAP-Werte für die Vorhersagen des Modells
import catboost
from sklearn import datasets, model_selection
# Load dataset
cal_data = datasets.load_breast_cancer(as_frame=True)
cal_X = cal_data.data
cal_y = cal_data.target
# Normalize column names (replace spaces with underscores)
cal_X.columns = [col.replace(' ', '_') for col in cal_X.columns]
cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(
cal_X, cal_y, test_size=0.2
)
# Train CatBoost Classifier
classifier = catboost.CatBoostClassifier(
iterations=100,
learning_rate=0.1,
depth=6,
verbose=False
)
classifier.fit(cal_X_train, cal_y_train)
# Log the model
model_ref = reg.log_model(
model=classifier,
model_name="my_catboost_classifier",
version_name="v1",
sample_input_data=cal_X_test,
)
# Make predictions
result_df = model_ref.run(cal_X_test[-10:], function_name="predict")
# Get prediction probabilities
proba_df = model_ref.run(cal_X_test[-10:], function_name="predict_proba")
# Get explanations (SHAP values)
explanations_df = model_ref.run(cal_X_test[-10:], function_name="explain")
CatBoostRegressor¶
Das folgende Beispiel zeigt die wichtigsten Schritte zum Trainieren eines CatBoost-Regressors, zum Protokollieren in der Snowflake ML Model Registry und zum Verwenden des registrierten Modells zu Ableitungszwecken. Der Workflow umfasst:
Trainieren eines CatBoost-Regressors für ein Beispiel-Datenset
Protokollieren des Modells in der Snowflake ML-Modell-Registry
Erstellen von Vorhersagen
import catboost
from sklearn import datasets, model_selection
# Load dataset
cal_data = datasets.load_diabetes(as_frame=True)
cal_X = cal_data.data
cal_y = cal_data.target
cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(
cal_X, cal_y, test_size=0.2
)
# Train CatBoost Regressor
regressor = catboost.CatBoostRegressor(
iterations=100,
learning_rate=0.1,
depth=6,
verbose=False
)
regressor.fit(cal_X_train, cal_y_train)
# Log the model
model_ref = reg.log_model(
model=regressor,
model_name="my_catboost_regressor",
version_name="v1",
sample_input_data=cal_X_test,
)
# Make predictions
result_df = model_ref.run(cal_X_test[-10:], function_name="predict")
Deaktivieren der Erklärbarkeit¶
Wenn Sie keine Erklärbarkeits-Features benötigen, können Sie diese während der Protokollierung deaktivieren, um die Modellgröße und die Abhängigkeiten zu reduzieren:
model_ref = reg.log_model(
model=classifier,
model_name="my_catboost_classifier_no_explain",
version_name="v1",
sample_input_data=cal_X_test,
options={"enable_explainability": False},
)