PyTorch¶
Die Snowflake ML Model Registry unterstützt Modelle, die mit PyTorch erstellt wurden (von torch.nn.Module abgeleitete Modelle).
Die folgenden zusätzlichen Optionen können im options-Dictionary verwendet werden, wenn Sie log_model abrufen:
Option |
Beschreibung |
|---|---|
|
Liste mit den Namen der für das Modellobjekt verfügbaren Methoden. PyTorch-Modelle haben standardmäßig die folgende Zielmethode: |
|
Die Version der CUDA-Laufzeitumgebung, die beim Bereitstellen auf einer Plattform mit GPU verwendet werden soll. Der Standardwert ist 11.8. Wird das Modell manuell auf |
|
Ob das Modell mehrere Tensor-Eingaben erwartet. Die Standardeinstellung ist |
Sie müssen entweder den Parameter sample_input_data oder signatures angeben, wenn Sie ein PyTorch-Modell protokollieren, damit die Registry die Signaturen der Zielmethoden kennt.
Bemerkung
Bei Verwendung von pandas DataFrames (die standardmäßig float64 verwenden) stellen Sie sicher, dass Ihre PyTorch-Modellebenen mit dtype=torch.float64 erstellt werden, um dtype-Konfliktfehler zu vermeiden.
Beispiel¶
In diesem Beispiel wird davon ausgegangen, dass reg eine Instanz von snowflake.ml.registry.Registry ist.