CatBoost¶
Le registre de modèles Snowflake ML prend en charge les modèles créés à l’aide de CatBoost (modèles dérivés de catboost.CatBoost, tels que catboost.CatBoostClassifier, catboost.CatBoostRegressor et catboost.CatBoostRanker).
Les options supplémentaires suivantes peuvent être utilisées dans le dictionnaire options lors de l’appel à log_model :
Option |
Description |
|---|---|
|
Une liste des noms des méthodes disponibles sur l’objet modèle. Les modèles CatBoost disposent des méthodes cibles suivantes par défaut, en supposant que la méthode existe : |
|
S’il faut activer l’explicabilité pour le modèle à l’aide de SHAP. La valeur par défaut est |
|
La version de l’environnement d’exécution CUDA à utiliser lors du déploiement sur une plateforme avec GPU ; la valeur par défaut est 11.8. S’il est défini manuellement sur |
Vous devez spécifier le paramètre sample_input_data ou signatures lors de la journalisation d’un modèle CatBoost afin que le registre connaisse les signatures des méthodes cibles.
Exemples¶
Ces exemples supposent que reg est une instance de snowflake.ml.registry.Registry.
CatBoostClassifier¶
L’exemple suivant illustre les étapes clés pour entraîner un classificateur CatBoost, le connecter au registre de modèles Snowflake ML, et utiliser le modèle enregistré pour l’inférence et l’explicabilité. Le flux de travail comprend :
Entraîne un classificateur CatBoost sur un ensemble de données d’échantillon.
Connecte le modèle au registre de modèles Snowflake ML.
Effectue des prédictions et extrait les probabilités de prédiction.
Obtient les valeurs SHAP pour les prédictions du modèle.
import catboost
from sklearn import datasets, model_selection
# Load dataset
cal_data = datasets.load_breast_cancer(as_frame=True)
cal_X = cal_data.data
cal_y = cal_data.target
# Normalize column names (replace spaces with underscores)
cal_X.columns = [col.replace(' ', '_') for col in cal_X.columns]
cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(
cal_X, cal_y, test_size=0.2
)
# Train CatBoost Classifier
classifier = catboost.CatBoostClassifier(
iterations=100,
learning_rate=0.1,
depth=6,
verbose=False
)
classifier.fit(cal_X_train, cal_y_train)
# Log the model
model_ref = reg.log_model(
model=classifier,
model_name="my_catboost_classifier",
version_name="v1",
sample_input_data=cal_X_test,
)
# Make predictions
result_df = model_ref.run(cal_X_test[-10:], function_name="predict")
# Get prediction probabilities
proba_df = model_ref.run(cal_X_test[-10:], function_name="predict_proba")
# Get explanations (SHAP values)
explanations_df = model_ref.run(cal_X_test[-10:], function_name="explain")
CatBoostRegressor¶
L’exemple suivant illustre les étapes clés pour entraîner une régression CatBoost, la connecter au registre de modèles Snowflake ML, et utiliser le modèle enregistré pour l’inférence. Le flux de travail comprend :
Entraîne une régression CatBoost sur un ensemble de données d’échantillon.
Connecte le modèle au registre de modèles Snowflake ML.
Effectue des prédictions.
import catboost
from sklearn import datasets, model_selection
# Load dataset
cal_data = datasets.load_diabetes(as_frame=True)
cal_X = cal_data.data
cal_y = cal_data.target
cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(
cal_X, cal_y, test_size=0.2
)
# Train CatBoost Regressor
regressor = catboost.CatBoostRegressor(
iterations=100,
learning_rate=0.1,
depth=6,
verbose=False
)
regressor.fit(cal_X_train, cal_y_train)
# Log the model
model_ref = reg.log_model(
model=regressor,
model_name="my_catboost_regressor",
version_name="v1",
sample_input_data=cal_X_test,
)
# Make predictions
result_df = model_ref.run(cal_X_test[-10:], function_name="predict")
Désactivation de l’explicabilité¶
Si vous n’avez pas besoin de fonctionnalités d’explicabilité, vous pouvez les désactiver pendant la journalisation pour réduire la taille du modèle et les dépendances :
model_ref = reg.log_model(
model=classifier,
model_name="my_catboost_classifier_no_explain",
version_name="v1",
sample_input_data=cal_X_test,
options={"enable_explainability": False},
)