Snowpark ML Ops: モデルレジストリのプレビュー API からの移行

Snowflakeは以前、モデル登録のプレビューを一部の顧客に非公開で提供していました。このトピックで説明するレジストリ機能には、プレビューバージョンと比較して、機能と APIs に大きな変更があります。最も注目すべき点は、レジストリのコア機能が、新しいスキーマレベルのモデルオブジェクトを使用して、Snowflake内部でネイティブにホストされるようになったことです。

注釈

パブリックプレビューバージョンでは、 Snowpark Container Services (SPCS)へのモデル展開はまだサポートされていません。この機能に依存している場合は、当面の間プライベートプレビューのレジストリを使い続けてください。

このテーブルは、2つのレジストリ実装の主な相違点をまとめたものです。非公開プレビューバージョンの API は「プレビュー API」と呼ばれ、現在一般公開されている API は「パブリック API」と呼ばれます。

プレビュー API

パブリック API

メタデータはテーブルに格納されます。モデルはステージに格納されます。レジストリ API はPythonライブラリであり、レジストリに格納されたモデルに対してこれらのオブジェクトを作成し、維持します。

  • レジストリを作成するには、ユーザーはスキーマ、テーブル、ステージを作成する権限を持っている必要があります。

  • Python API の外部でモデルのメタデータを更新すると、レジストリの一貫性が失われる可能性があります。

  • モデルを使用するには、明示的に展開する必要があります。

  • 個々のモデルがロールベースのアクセス制御を持つことはできません(ユーザー定義機能であるモデル展開は可能)。

  • レジストリ API を移植するには、レジストリ機能全体を新しい言語で実装する必要があります。

モデルは、テーブルやステージのような、スキーマレベルのネイティブオブジェクトです。Pythonレジストリ API は、Python内部で SQL を使用してモデルオブジェクトとのインタラクションを容易にするクラスです。

  • モデルは既存のスキーマに格納されます。スキーマはレジストリとして使用するために特別な準備をする必要はなく、ユーザーが所有していないスキーマにモデルを作成するための必要な権限はただ1つです。

  • メタデータはモデルオブジェクトの外部に格納されないため、レジストリの一貫性が失われることはありません。

  • モデルには SQL やPythonから呼び出せるメソッドが含まれており、明示的に展開する必要はありません。

  • 特定のモデルで個別に使用権限を付与することができます。

  • Pythonライブラリは SQL の上に薄いレイヤーを重ねたものであり、レジストリ API を容易に他の言語に移植できます。

以下のセクションでは、2つの APIs の違いについて詳しく説明します。

レジストリ API のインポートおよびアクセス

レジストリ APIs は、どちらもメインのSnowpark ML パッケージ、 snowflake.ml にあります。

プレビュー API

from snowflake.ml.registry import model_registry
Copy

レジストリの機能にアクセスするには、 model_registry.ModelRegistry を使用します。

パブリック API

from snowflake.ml.registry import Registry
Copy

レジストリの機能にアクセスするには、 Registry クラスを使用します。

レジストリの作成

プレビュー API には、Pythonライブラリで作成したレジストリが必要です。

プレビュー API

model_registry.create_model_registry(...)
Copy

レジストリを初めて使用する前に必要です。

パブリック API

該当なし。どの既存のスキーマもレジストリとして使用できます。

レジストリの開始

レジストリを開いて新しいモデルを追加したり、すでに登録されているモデルを操作したりします。

プレビュー API

reg = model_registry.ModelRegistry(
          session=session,
          database_name="MODEL_REGISTRY")
Copy

パブリック API

reg = Registry(
          session=session,
          database_name="ML",
          schema_name="REGISTRY")
Copy

モデルのログ

レジストリにモデルを追加することを ログする といいます。APIs はどちらもこの目的のために log_model というレジストリメソッドを使用します。このメソッドには、パブリック API に2つの小さな相違があります。

  • 以前は model_version と呼ばれていたモデルのバージョンを指定するパラメーターは、そのセマンティクスをよりよく反映するために、 version_name と呼ばれるようになりました。

  • モデルのログ時にタグを設定することはできません。代わりに、モデルの set_tag メソッドを使用してログした後にタグを追加します。

モデルへの参照の取得

モデルへの参照を取得すると、そのモデルのメタデータを更新したり、他の操作を実行したりすることができます。

プレビュー API

レジストリからモデルを取得すると、常に特定のバージョンのモデルが返されます。したがって、モデルを取得する際には、希望のバージョンを指定する必要があります。

model = model_registry.ModelReference(
            registry=registry,
            model_name="my_model",
            model_version="101")
Copy

パブリック API

モデルのバージョンは、モデル本体とは別のものです。モデルへの参照を取得するには、

m = reg.get_model("my_model")
Copy

特定のバージョンへの参照を取得するには、まず上記のようにモデルへの参照を取得し、次に目的のバージョンを取得します。モデルオブジェクトには default という属性があり、デフォルトとして指定したモデルのバージョンが含まれていることに注意してください。これは文字列ではなく、実際の ModelVersion オブジェクトです。

mv = m.version('v1')
mv = m.default
Copy

モデルの展開

プレビュー API

モデルは、ウェアハウス(ユーザー定義関数として)かSnowpark Container Services(サービスとして)に、明示的に展開する必要があります。この例では、ウェアハウスへの展開を示します。

model.deploy(
    deployment_name="my_warehouse_predict",
    target_method="predict",
    permanent=True)
Copy

パブリック API

モデルを明示的に展開する必要はありません。

推論に対するモデルの使用

推論 とは、テストデータに基づいて予測を行うためにモデルを使用することです。

プレビュー API

推論を実行するためにモデルを展開したときに使用した展開名を指定します。

result_dataframe = model.predict(
    "my_warehouse_predict", test_dataframe)
Copy

パブリック API

モデルはウェアハウスで実行されます。モデルのメソッドは、Pythonから、または SQL から呼び出すことができます。

Python

モデルバージョンの run メソッドを使用してメソッドを呼び出します。

remote_prediction = mv.run(
    test_features, function_name="predict")
Copy

SQL

単純な SELECT クエリを使用してデフォルトバージョンのメソッドを呼び出すか、 WITH 句を使用してバージョンを指定することができます。

-- Use default version
SELECT my_model!predict() FROM test_table;

-- Use a specific version
WITH my_model_v1 AS MODEL my_model VERSION "1"
     SELECT my_model_v1!predict() FROM test_table;
Copy

説明のアクセスおよび更新

プレビュー API

モデルリファレンスは、説明のゲッターメソッドとセッターメソッドを提供します。(この API では、モデル参照は常にモデルの特定のバージョンに対するものです。)

print(model.get_model_description())

model.set_model_description("A better description")
Copy

パブリック API

モデルもモデルバージョンも、同等の comment 属性と description 属性を通して説明へのアクセスを提供します。

print(m.comment)
m.comment = "A better description"

print(m.description)
m.description = "A better description"

print(mv.comment)
mv.comment = "A better description"

print(mv.description)
mv.description = "A better description"
Copy

タグのアクセスおよび更新

プレビュー API

タグはモデルのバージョンレベルで設定され、アクセスされます(モデル参照は常に特定のバージョンを参照)。

すべてのタグを取得する

print(model.get_tags())
Copy

タグを追加するか、新しいタグ値を設定する

model.set_tag("minor_rev", "1")
Copy

タグを削除する

model.remove_tag("minor_rev")
Copy

パブリック API

タグはモデルレベルで設定され(モデルはバージョンのコレクションで構成される)、 SQL タグを使用して実装されます。タグを作成し、その許容値を定義する方法については、 オブジェクトのタグ付け をご参照ください。

すべてのタグを取得する

print(m.show_tags())
Copy

タグを追加するか、新しいタグ値を設定する

m.set_tag("minor_rev", "1")
Copy

タグを削除する

m.unset_tag("minor_rev")
Copy

メトリクスのアクセスおよび更新

どちらも APIs で、メトリクスはモデルバージョンレベルで設定されます。

プレビュー API

スカラーメトリクスを設定する

model.set_metric("test_accuracy", test_accuracy)
Copy

階層(ディクショナリ)メトリクスを設定する

model.set_metric("dataset_test", {"accuracy": test_accuracy})
Copy

多価(行列)のメトリクスを設定する

model.set_metric("confusion_matrix", test_confusion_matrix)
Copy

すべてのメトリクスを取得する

print(model.get_metrics())
Copy

メトリクスを削除する

model.remove_metric("test_accuracy")
Copy

パブリック API

スカラーメトリクスを設定する

m.set_metric("test_accuracy", test_accuracy)
Copy

階層(ディクショナリ)メトリクスを設定する

mo.set_metric("dataset_test", {"accuracy": test_accuracy})
Copy

多価(行列)のメトリクスを設定する

m.set_metric("confusion_matrix", test_confusion_matrix)
Copy

すべてのメトリクスを取得する

print(m.get_metrics())
Copy

メトリクスを削除する

m.remove_metric("test_accuracy")
Copy

モデルの削除

プレビュー API

モデルの特定のバージョンのみを削除することができます。モデルを完全に削除するには、すべてのバージョンを削除します。

registry.delete_model(
    model_name="my_model",
    model_version="100")
Copy

パブリック API

モデルを削除すると、そのすべてのバージョンが削除されます。現在、1つのバージョンのみを削除することはできません。

reg.delete_model("mymodel")
Copy

モデルのバージョンのリスト

プレビュー API

list_models メソッドは、すべてのモデルバージョンの DataFrame を返します。これをフィルターして、特定のモデルのバージョンのみを表示することができます。

model_list = registry.list_models()
model_list.filter(model_list["NAME"] == "mymodel").show()
Copy

パブリック API

モデル参照を取得すると、モデルのバージョンを ModelVersion インスタンスのリストとして、またはモデルのバージョンに関する情報を含む DataFrame として取得することができます。

ModelVersions インスタンスのリストを取得する

version_list = m.versions()
Copy

情報 DataFrame を取得する

version_df = m.show_versions()
Copy

モデルのライフサイクルの管理

タグを使用してモデルのライフサイクルを管理することが意図されています。たとえば、モデルの現在のステータスを記録するために stage というタグを作成し、「experimental」、「alpha」、「beta」、「production」、「deprecated」、「obsolete」といった値を使用します。

パブリック API では、タグは SQL タグオブジェクトを使用して実装されます。タグを作成し、その許容値を定義する方法については、 オブジェクトのタグ付け をご参照ください。

パブリック API にはモデルの デフォルトバージョン という概念もあり、これは特に SQL でバージョンが指定されていない場合に使用されるモデルです。モデルの新しいバージョンをトレーニングし、新しいモデルが幅広く使用できるようになったら、デフォルトバージョンを更新することができます。モデルの default 属性を使用してデフォルトのバージョンを設定することができます。

m.default = "2"
Copy

そして、次のように ModelVersion オブジェクトとしてモデルのデフォルトバージョンを取得することができます。

mv = m.default
Copy

あるいは、その predict メソッドをすぐに呼び出すこともできます。

m.default.run(test_features, function_name="predict"))
Copy