PythonでのDataFramesのユーザー定義テーブル関数(UDTFs)の作成

Snowpark APIは、Pythonで作成されたハンドラーを使用してユーザー定義のテーブル関数を作成するために使用できるメソッドを提供します。このトピックでは、これらのタイプの関数を作成する方法について説明します。

このトピックの内容:

概要

Snowpark APIを使用して、ユーザー定義のテーブル関数(UDTF)を作成できます。

これは、 Pythonでの DataFrames 用ユーザー定義関数(UDFs)の作成 で説明されているように、APIを使用してスカラーユーザー定義関数(UDF)を作成するのと同様の方法で行います。主な違いには、UDTFを登録するときに必要なUDFハンドラー要件とパラメーター値が含まれます。

SnowparkでUDTFを作成して登録するには、次が必要です。

  • UDTFハンドラーを実装します。

    ハンドラーにはUDTFのロジックが含まれています。UDTFハンドラーは、UDTFが呼び出されたときにSnowflakeが実行時に呼び出す関数を実装する必要があります。詳細については、 UDTFハンドラーの実装 をご参照ください。

  • UDTFとそのハンドラーをSnowflakeデータベースに登録します。

    Snowpark APIを使用して、UDTFとそのハンドラーを登録できます。UDTFを登録すると、SQLから、またはSnowparkAPIを使用して呼び出すことができます。登録の詳細については、 UDTFの登録 をご参照ください。

UDTF の呼び出しについては、 ユーザー定義のテーブル関数(UDTFs)の呼び出し をご参照ください。

UDTFハンドラーの実装

Pythonでの UDTF の記述 で詳細に説明されているように、UDTFハンドラークラスは、UDTFが呼び出されたときにSnowflakeが呼び出すメソッドを実装する必要があります。UDTFをSnowparkAPIに登録する場合でも、CREATE FUNCTIONステートメントを使用してSQLで作成する場合でも、作成したクラスをハンドラーとして使用できます。

ハンドラークラスのメソッドは、UDTFが受け取った行とパーティションを処理するように設計されています。

UDTFハンドラークラスは、Snowflakeが実行時に呼び出す以下を実装します。

  • __init__ メソッド。オプションです。入力パーティションのステートフル処理を初期化するために呼び出されます。

  • process メソッド。必須です。入力行ごとに呼び出されます。このメソッドは、表形式の値をタプルとして返します。

  • end_partition メソッド。オプションです。入力パーティションの処理を完了するために呼び出されます。

    Snowflakeは、正常に処理するためにタイムアウトが調整された大型のパーティションをサポートしていますが、特に大型のパーティションでは、処理中にタイムアウトする可能性があります(end_partition の完了に時間がかかりすぎる場合など)。特定の使用シナリオに合わせてタイムアウトのしきい値を調整する必要がある場合は、 Snowflakeサポート にお問い合わせください。

ハンドラーの詳細と例については、 Pythonでの UDTF の記述 をご参照ください。

UDTFの登録

UDTFハンドラーを実装したら、Snowpark APIを使用して、SnowflakeデータベースにUDTFを登録できます。UDTFを登録すると、UDTFが作成されて呼び出せるようになります。

スカラーUDFの場合と同様に、UDTFを名前付き関数または匿名関数として登録できます。スカラーUDFの登録に関する関連情報については、 匿名 UDF の作成 および 名前付き UDF の作成と登録 をご参照ください。

UDTFを登録するときは、SnowflakeがUDTFを作成するために必要なパラメーター値を指定します。(これらのパラメーターの多くは、SQLのCREATE FUNCTIONステートメントの句に機能的に対応しています。詳細については、 CREATE FUNCTION をご参照ください。)

これらのパラメーターのほとんどは、スカラーUDFを作成するときに指定するパラメーターと同じです(詳細については、 Pythonでの DataFrames 用ユーザー定義関数(UDFs)の作成 をご参照ください)。主な違いは、UDTFが表形式の値を返すという事実と、そのハンドラーが関数ではなくクラスであるという事実によるものです。パラメーターの完全なリストについては、以下にリンクされているAPIsのドキュメントをご参照ください。

UDTFをSnowparkに登録するには、次のいずれかを使用して、データベースにUDTFを作成するために必要なパラメーター値を指定します。これらのオプションを区別する情報については、スカラーUDFを登録するための同様のオプションについて説明している UDFRegistration をご参照ください。

UDTFの入力タイプと出力スキーマの定義

UDTFを登録するときに、関数のパラメーターと出力値に関する詳細を指定します。これを行うのは、関数自体が、関数の基になるハンドラーの型に正確に対応する型を宣言するためです。

例については、このトピックの および snowflake.snowpark.udtf.UDTFRegistration リファレンスをご参照ください。

UDTFを登録するときに、次を指定します。

  • 登録関数の input_types パラメーターの値としての入力パラメーターの型。 process メソッドの宣言で型ヒントを指定する場合、 input_types パラメーターはオプションです。

    この値を snowflake.snowpark.types DataType に基づく型のリストとして指定します。例えば、 input_types=[StringType(), IntegerType()] を指定できます。

  • 登録関数の output_schema パラメーターの値としての表形式の出力のスキーマ。

    output_schema の値は、次のいずれかになります。

    • UDTFの戻り値の列の名前のリスト。

      リストには列名のみが含まれるため、 process メソッドの宣言で型のヒントも指定する必要があります。

    • 出力テーブルの列名 および タイプを表す StructType

      次の例のコードは、スキーマを値として output 変数に割り当て、UDTFを登録するときにその変数を使用します。

      >>> from snowflake.snowpark.types import StructField, StructType, StringType, IntegerType, FloatType
      >>> from snowflake.snowpark.functions import udtf, table_function
      >>> schema = StructType([
      ...     StructField("symbol", StringType())
      ...     StructField("cost", IntegerType()),
      ... ])
      >>> @udtf(output_schema=schema,input_types=[StringType(), IntegerType(), FloatType()],stage_location="straut_udf",is_permanent=True,name="test_udtf",replace=True)
      ... class StockSale:
      ...     def process(self, symbol, quantity, price):
      ...         cost = quantity * price
      ...         yield (symbol, cost)
      
      Copy

以下は例の簡単なリストです。その他の例については、 snowflake.snowpark.udtf.UDTFRegistration をご参照ください。

udtf関数を使用したUDTFの登録

関数を登録します。

>>> from snowflake.snowpark.types import IntegerType, StructField, StructType
>>> from snowflake.snowpark.functions import udtf, lit
>>> class GeneratorUDTF:
...     def process(self, n):
...         for i in range(n):
...             yield (i, )
>>> generator_udtf = udtf(GeneratorUDTF, output_schema=StructType([StructField("number", IntegerType())]), input_types=[IntegerType()])
Copy

関数を呼び出します。

>>> session.table_function(generator_udtf(lit(3))).collect()  # Query it by calling it
[Row(NUMBER=0), Row(NUMBER=1), Row(NUMBER=2)]
>>> session.table_function(generator_udtf.name, lit(3)).collect()  # Query it by using the name
[Row(NUMBER=0), Row(NUMBER=1), Row(NUMBER=2)]
Copy

register関数を使用したUDTFの登録

関数を登録します。

>>> from collections import Counter
>>> from typing import Iterable, Tuple
>>> from snowflake.snowpark.functions import lit
>>> class MyWordCount:
...     def __init__(self):
...         self._total_per_partition = 0
...
...     def process(self, s1: str) -> Iterable[Tuple[str, int]]:
...         words = s1.split()
...         self._total_per_partition = len(words)
...         counter = Counter(words)
...         yield from counter.items()
...
...     def end_partition(self):
...         yield ("partition_total", self._total_per_partition)
>>> udtf_name = "word_count_udtf"
>>> word_count_udtf = session.udtf.register(
...     MyWordCount, ["word", "count"], name=udtf_name, is_permanent=False, replace=True
... )
Copy

関数を呼び出します。

>>> # Call it by its name
>>> df1 = session.table_function(udtf_name, lit("w1 w2 w2 w3 w3 w3"))
>>> df1.show()
-----------------------------
|"WORD"           |"COUNT"  |
-----------------------------
|w1               |1        |
|w2               |2        |
|w3               |3        |
|partition_total  |6        |
-----------------------------
Copy

register_from_file関数を使用したUDTFの登録

関数を登録します。

>>> from snowflake.snowpark.types import IntegerType, StructField, StructType
>>> from snowflake.snowpark.functions import udtf, lit
>>> _ = session.sql("create or replace temp stage mystage").collect()
>>> _ = session.file.put("tests/resources/test_udtf_dir/test_udtf_file.py", "@mystage", auto_compress=False)
>>> generator_udtf = session.udtf.register_from_file(
...     file_path="@mystage/test_udtf_file.py",
...     handler_name="GeneratorUDTF",
...     output_schema=StructType([StructField("number", IntegerType())]),
...     input_types=[IntegerType()]
... )
Copy

関数を呼び出します。

>>> session.table_function(generator_udtf(lit(3))).collect()
[Row(NUMBER=0), Row(NUMBER=1), Row(NUMBER=2)]
Copy