Pythonユーザー定義集計関数

ユーザー定義集計関数(UDAFs)は1つ以上の行を入力として受け取り、1行の出力を生成します。複数行全体の値を操作して、合計、平均、カウント、最小値/最大値の探索、標準偏差、推定などの数学的計算に加え、一部の非数学的な演算も実行します。

Python UDAFs は、Snowflakeのシステム定義の SQL 独自の集計関数 に似たものを作成する方法を提供します。

また、 Pythonにおける DataFrames 用ユーザー定義関数(UDAFs)の作成 で説明されているように、Snowpark APIs を使用して独自の UDAFs を作成することもできます。

制限事項

  • aggregate_state のシリアル化されたバージョンの最大サイズは8 MB であるため、集計状態のサイズを制御してみてください。

  • UDAF を ウィンドウ関数 として呼び出すことはできません(言い換えれば、OVER 句を使用して)。

  • IMMUTABLE は集計関数ではサポートされていません(AGGREGATE パラメーターを使用する場合)。したがって、すべての集計関数はデフォルトで VOLATILE です。

  • ユーザー定義集約関数は、 WITHIN GROUP 句と共に使用することはできません。クエリの実行に失敗します。

集計関数ハンドラーのインターフェイス

集計関数は子ノードの状態を集計し、最終的にそれらの集計された状態はシリアル化され、マージされて最終的な結果が計算される親ノードに送信されます。

集計関数を定義するには、実行時にSnowflakeが呼び出すメソッドを含むPythonクラス(関数のハンドラー)を定義する必要があります。これらのメソッドは、以下の表に説明されています。このトピックの他の例をご参照ください。

メソッド

要件

説明

__init__

必須

集計の内部状態を初期化します。

aggregate_state

必須

集計の現在の状態を返します。

  • メソッドには、 @property デコレーター が必要です。

  • 集計状態オブジェクトは、 Python pickleライブラリ でシリアル化が可能な任意のPythonデータ型にすることができます。

  • 単純な集計状態の場合は、Pythonのプリミティブデータ型を使用します。より複雑な集計状態の場合は、 Python データクラス を使用します。

accumulate

必須

新しい入力行に基づいて集計状態を累積します。

merge

必須

2つの中間集計状態を組み合わせます。

finish

必須

集計状態に基づいて最終結果を生成します。

入力値が子ノードに蓄積され、最終結果を生成するために親ノードに送信されマージされる様子を示す図。

例: 合計の計算

次の例のコードでは、 python_sum ユーザー定義集計関数(UDAF)を定義し、数値の合計を返します。

  1. UDAFを作成します。

    CREATE OR REPLACE AGGREGATE FUNCTION PYTHON_SUM(a INT)
    RETURNS INT
    LANGUAGE PYTHON
    RUNTIME_VERSION = 3.9
    handler = 'PythonSum'
    AS $$
    class PythonSum:
      def __init__(self):
        # This aggregate state is a primitive Python data type.
        self._partial_sum = 0
    
      @property
      def aggregate_state(self):
        return self._partial_sum
    
      def accumulate(self, input_value):
        self._partial_sum += input_value
    
      def merge(self, other_partial_sum):
        self._partial_sum += other_partial_sum
    
      def finish(self):
        return self._partial_sum
    $$;
    
    Copy
  2. テストデータのテーブルを作成します。

    CREATE OR REPLACE TABLE sales(item STRING, price INT);
    
    INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
    
    SELECT * FROM sales;
    
    Copy
  3. python_sum UDAF を呼び出します。

    SELECT python_sum(price) FROM sales;
    
    Copy
  4. Snowflakeシステム定義の SQL 関数、 SUM の出力と結果を比較し、結果が同じであることを確認します。

    SELECT sum(col) FROM sales;
    
    Copy
  5. 売上テーブルのアイテム型別の合計値でグループ化します。

    SELECT item, python_sum(price) FROM sales GROUP BY item;
    
    Copy

例: 平均値の計算

次の例のコードは、数値の平均を返す python_avg ユーザー定義集計関数を定義します。

  1. 関数を作成します。

    CREATE OR REPLACE AGGREGATE FUNCTION python_avg(a INT)
    RETURNS FLOAT
    LANGUAGE PYTHON
    RUNTIME_VERSION = 3.9
    HANDLER = 'PythonAvg'
    AS $$
    from dataclasses import dataclass
    
    @dataclass
    class AvgAggState:
        count: int
        sum: int
    
    class PythonAvg:
        def __init__(self):
            # This aggregate state is an object data type.
            self._agg_state = AvgAggState(0, 0)
    
        @property
        def aggregate_state(self):
            return self._agg_state
    
        def accumulate(self, input_value):
            sum = self._agg_state.sum
            count = self._agg_state.count
    
            self._agg_state.sum = sum + input_value
            self._agg_state.count = count + 1
    
        def merge(self, other_agg_state):
            sum = self._agg_state.sum
            count = self._agg_state.count
    
            other_sum = other_agg_state.sum
            other_count = other_agg_state.count
    
            self._agg_state.sum = sum + other_sum
            self._agg_state.count = count + other_count
    
        def finish(self):
            sum = self._agg_state.sum
            count = self._agg_state.count
            return sum / count
    $$;
    
    Copy
  2. テストデータのテーブルを作成します。

    CREATE OR REPLACE TABLE sales(item STRING, price INT);
    INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
    
    Copy
  3. python_avg ユーザー定義関数を呼び出します。

    SELECT python_avg(price) FROM sales;
    
    Copy
  4. Snowflakeシステム定義の SQL 関数、 AVG の出力と結果を比較し、結果が同じであることを確認します。

    SELECT avg(price) FROM sales;
    
    Copy
  5. 売上テーブルのアイテム型別の平均値でグループ化します。

    SELECT item, python_avg(price) FROM sales GROUP BY item;
    
    Copy

例: 一意な値のみを返します。

次の例のコードは配列を受け取り、一意な値のみを含む配列を返します。

CREATE OR REPLACE AGGREGATE FUNCTION pythonGetUniqueValues(input ARRAY)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonGetUniqueValues'
AS $$
class PythonGetUniqueValues:
    def __init__(self):
        self._agg_state = set()

    @property
    def aggregate_state(self):
        return self._agg_state

    def accumulate(self, input):
        self._agg_state.update(input)

    def merge(self, other_agg_state):
        self._agg_state.update(other_agg_state)

    def finish(self):
        return list(self._agg_state)
$$;


CREATE OR REPLACE TABLE array_table(x array) AS
SELECT ARRAY_CONSTRUCT(0, 1, 2, 3, 4, 'foo', 'bar', 'snowflake') UNION ALL
SELECT ARRAY_CONSTRUCT(1, 3, 5, 7, 9, 'foo', 'barbar', 'snowpark') UNION ALL
SELECT ARRAY_CONSTRUCT(0, 2, 4, 6, 8, 'snow');

SELECT * FROM array_table;

SELECT pythonGetUniqueValues(x) FROM array_table;
Copy

例: 文字列の数を返します。

以下の例のコードは、オブジェクト内のすべての文字列インスタンスの数を返します。

CREATE OR REPLACE AGGREGATE FUNCTION pythonMapCount(input STRING)
RETURNS OBJECT
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonMapCount'
AS $$
from collections import defaultdict

class PythonMapCount:
    def __init__(self):
        self._agg_state = defaultdict(int)

    @property
    def aggregate_state(self):
        return self._agg_state

    def accumulate(self, input):
        # Increment count of lowercase input
        self._agg_state[input.lower()] += 1

    def merge(self, other_agg_state):
        for item, count in other_agg_state.items():
            self._agg_state[item] += count

    def finish(self):
        return dict(self._agg_state)
$$;

CREATE OR REPLACE TABLE string_table(x STRING);
INSERT INTO string_table SELECT 'foo' FROM TABLE(GENERATOR(ROWCOUNT => 1000));
INSERT INTO string_table SELECT 'bar' FROM TABLE(GENERATOR(ROWCOUNT => 2000));
INSERT INTO string_table SELECT 'snowflake' FROM TABLE(GENERATOR(ROWCOUNT => 50));
INSERT INTO string_table SELECT 'snowpark' FROM TABLE(GENERATOR(ROWCOUNT => 123));
INSERT INTO string_table SELECT 'SnOw' FROM TABLE(GENERATOR(ROWCOUNT => 1));
INSERT INTO string_table SELECT 'snow' FROM TABLE(GENERATOR(ROWCOUNT => 4));

SELECT pythonMapCount(x) FROM string_table;
Copy

例: 最大値の上位k件を返します。

以下の例のコードは、 k の最大値の上位リストを返します。このコードは、最小ヒープに反転させた入力値を蓄積し、上位 k 件の最大値を返します。

CREATE OR REPLACE AGGREGATE FUNCTION pythonTopK(input INT, k INT)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonTopK'
AS $$
import heapq
from dataclasses import dataclass
import itertools
from typing import List

@dataclass
class AggState:
    minheap: List[int]
    k: int

class PythonTopK:
    def __init__(self):
        self._agg_state = AggState([], 0)

    @property
    def aggregate_state(self):
        return self._agg_state

    @staticmethod
    def get_top_k_items(minheap, k):
      # Return k smallest elements if there are more than k elements on the min heap.
      if (len(minheap) > k):
        return [heapq.heappop(minheap) for i in range(k)]
      return minheap

    def accumulate(self, input, k):
        self._agg_state.k = k

        # Store the input as negative value, as heapq is a min heap.
        heapq.heappush(self._agg_state.minheap, -input)

        # Store only top k items on the min heap.
        self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)

    def merge(self, other_agg_state):
        k = self._agg_state.k if self._agg_state.k > 0 else other_agg_state.k

        # Merge two min heaps by popping off elements from one and pushing them onto another.
        while(len(other_agg_state.minheap) > 0):
            heapq.heappush(self._agg_state.minheap, heapq.heappop(other_agg_state.minheap))

        # Store only k elements on the min heap.
        self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)

    def finish(self):
        return [-x for x in self._agg_state.minheap]
$$;


CREATE OR REPLACE TABLE numbers_table(num_column INT);
INSERT INTO numbers_table SELECT 5 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 1 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 9 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 7 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 10 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 3 FROM TABLE(GENERATOR(ROWCOUNT => 10));

-- Return top 15 largest values from numbers_table.
SELECT pythonTopK(num_column, 15) FROM numbers_table;
Copy