LightGBM¶
Le registre de modèles Snowflake ML prend en charge les modèles créés à l’aide de LightGBM (modèles dérivés du wrapper API scikit-learn, par exemple lightgbm.LGBMClassifier ou de l’API native, par ex. lightgbm.Booster).
Les options supplémentaires suivantes peuvent être utilisées dans le dictionnaire options lors de l’appel à log_model :
Option |
Description |
|---|---|
|
Une liste des noms des méthodes disponibles sur l’objet modèle. Les modèles dérivés de l’API scikit-learn (par ex. |
|
S’il faut activer l’explicabilité pour le modèle à l’aide de SHAP. La valeur par défaut est |
|
La version de l’environnement d’exécution CUDA à utiliser lors du déploiement sur une plateforme avec GPU ; la valeur par défaut est 11.8. S’il est défini manuellement sur |
Vous devez spécifier le paramètre sample_input_data ou signatures lors de la journalisation d’un modèle LightGBM afin que le registre connaisse les signatures des méthodes cibles.
Exemples¶
Ces exemples supposent que reg est une instance de snowflake.ml.registry.Registry.
API scikit-learn (LGBMClassifier)¶
L’exemple suivant illustre les étapes clés pour entraîner un classificateur LightGBM utilisant l’API scikit-learn, le connecter au registre de modèle Snowflake ML et utiliser le modèle enregistré pour l’inférence et l’explicabilité. Le flux de travail comprend :
Entraîne un classificateur LightGBM sur un ensemble de données d’échantillon.
Connecte le modèle au registre de modèles Snowflake ML.
Effectue des prédictions et extrait les probabilités de prédiction.
Obtient les valeurs SHAP pour les prédictions du modèle.
import lightgbm as lgb
from sklearn import datasets, model_selection
# Load dataset
cal_data = datasets.load_breast_cancer(as_frame=True)
cal_X = cal_data.data
cal_y = cal_data.target
# Normalize column names (replace spaces with underscores)
cal_X.columns = [col.replace(' ', '_') for col in cal_X.columns]
cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(
cal_X, cal_y, test_size=0.2
)
# Train LightGBM Classifier
classifier = lgb.LGBMClassifier(
n_estimators=100,
learning_rate=0.05,
num_leaves=31
)
classifier.fit(cal_X_train, cal_y_train)
# Log the model
model_ref = reg.log_model(
model=classifier,
model_name="my_lightgbm_classifier",
version_name="v1",
sample_input_data=cal_X_test,
)
# Make predictions
result_df = model_ref.run(cal_X_test[-10:], function_name="predict")
# Get prediction probabilities
proba_df = model_ref.run(cal_X_test[-10:], function_name="predict_proba")
# Get explanations (SHAP values)
explanations_df = model_ref.run(cal_X_test[-10:], function_name="explain")
API native (Booster)¶
L’exemple suivant illustre les étapes clés pour entraîner un modèle LightGBM utilisant l’API Snowflake ML native, le connecter au registre de modèles Snowflake ML et utiliser le modèle enregistré pour l’inférence. Le workflow effectue les opérations suivantes :
Entraîne un modèle LightGBM sur un ensemble de données d’échantillon.
Connecte le modèle au registre de modèles Snowflake ML.
Effectue des prédictions.
import lightgbm as lgb
import pandas as pd
from sklearn import datasets, model_selection
# Load dataset
cal_data = datasets.load_breast_cancer()
cal_X = pd.DataFrame(cal_data.data, columns=cal_data.feature_names)
cal_y = cal_data.target
# Normalize column names (replace spaces with underscores)
cal_X.columns = [col.replace(' ', '_') for col in cal_X.columns]
cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(
cal_X, cal_y, test_size=0.2
)
# Prepare LightGBM Data Structure
lgb_train = lgb.Dataset(cal_X_train, cal_y_train)
# Define parameters and train the model
params = {
'objective': 'binary',
'metric': 'binary_logloss',
'boosting_type': 'gbdt',
'num_leaves': 31,
'learning_rate': 0.05,
'feature_fraction': 0.9,
}
num_round = 100
booster = lgb.train(
params,
lgb_train,
num_round
)
# Log the model
model_ref = reg.log_model(
model=booster,
model_name="my_lightgbm_booster",
version_name="v1",
sample_input_data=cal_X_test,
)
# Make predictions
result_df = model_ref.run(cal_X_test[-10:], function_name="predict")