PyTorch¶
Le registre de modèles Snowflake ML prend en charge les modèles créés à l’aide de PyTorch (modèles dérivés de torch.nn.Module).
Les options supplémentaires suivantes peuvent être utilisées dans le dictionnaire options lors de l’appel à log_model :
Option |
Description |
|---|---|
|
Une liste des noms des méthodes disponibles sur l’objet modèle. Les modèles PyTorch ont la méthode cible suivante par défaut : |
|
La version de l’environnement d’exécution CUDA à utiliser lors du déploiement sur une plateforme avec GPU ; la valeur par défaut est 11.8. S’il est défini manuellement sur |
|
Si le modèle attend plusieurs entrées de tenseur. La valeur par défaut est |
Vous devez spécifier le paramètre sample_input_data ou signatures lors de la journalisation d’un modèle PyTorch afin que le registre connaisse les signatures des méthodes cibles.
Note
Lorsque vous utilisez des pandas DataFrames (qui utilisent float64 par défaut), assurez-vous que les couches de votre modèle PyTorch sont créées avec dtype=torch.float64 pour éviter les erreurs de non correspondance dtype.
Exemple¶
Cet exemple suppose que reg est une instance de snowflake.ml.registry.Registry.