Pythonユーザー定義集計関数¶
ユーザー定義集計関数(UDAFs)は1つ以上の行を入力として受け取り、1行の出力を生成します。複数行全体の値を操作して、合計、平均、カウント、最小値/最大値の探索、標準偏差、推定などの数学的計算に加え、一部の非数学的な演算も実行します。
Python UDAFs は、Snowflakeのシステム定義の SQL 独自の集計関数 に似たものを作成する方法を提供します。
また、 Pythonにおける DataFrames 用ユーザー定義関数(UDAFs)の作成 で説明されているように、Snowpark APIs を使用して独自の UDAFs を作成することもできます。
制限事項¶
aggregate_state
のシリアル化されたバージョンの最大サイズは8 MB であるため、集計状態のサイズを制御してみてください。UDAF を ウィンドウ関数 として呼び出すことはできません(言い換えれば、OVER 句を使用して)。
IMMUTABLE は集計関数ではサポートされていません(AGGREGATE パラメーターを使用する場合)。したがって、すべての集計関数はデフォルトで VOLATILE です。
ユーザー定義集約関数は、 WITHIN GROUP 句と共に使用することはできません。クエリの実行に失敗します。
集計関数ハンドラーのインターフェイス¶
集計関数は子ノードの状態を集計し、最終的にそれらの集計された状態はシリアル化され、マージされて最終的な結果が計算される親ノードに送信されます。
集計関数を定義するには、実行時にSnowflakeが呼び出すメソッドを含むPythonクラス(関数のハンドラー)を定義する必要があります。これらのメソッドは、以下の表に説明されています。このトピックの他の例をご参照ください。
メソッド |
要件 |
説明 |
---|---|---|
|
必須 |
集計の内部状態を初期化します。 |
|
必須 |
集計の現在の状態を返します。
|
|
必須 |
新しい入力行に基づいて集計状態を累積します。 |
|
必須 |
2つの中間集計状態を組み合わせます。 |
|
必須 |
集計状態に基づいて最終結果を生成します。 |

例: 合計の計算¶
次の例のコードでは、 python_sum
ユーザー定義集計関数(UDAF)を定義し、数値の合計を返します。
UDAFを作成します。
CREATE OR REPLACE AGGREGATE FUNCTION PYTHON_SUM(a INT) RETURNS INT LANGUAGE PYTHON RUNTIME_VERSION = 3.9 handler = 'PythonSum' AS $$ class PythonSum: def __init__(self): # This aggregate state is a primitive Python data type. self._partial_sum = 0 @property def aggregate_state(self): return self._partial_sum def accumulate(self, input_value): self._partial_sum += input_value def merge(self, other_partial_sum): self._partial_sum += other_partial_sum def finish(self): return self._partial_sum $$;
テストデータのテーブルを作成します。
CREATE OR REPLACE TABLE sales(item STRING, price INT); INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000); SELECT * FROM sales;
python_sum
UDAF を呼び出します。SELECT python_sum(price) FROM sales;
Snowflakeシステム定義の SQL 関数、 SUM の出力と結果を比較し、結果が同じであることを確認します。
SELECT sum(col) FROM sales;
売上テーブルのアイテム型別の合計値でグループ化します。
SELECT item, python_sum(price) FROM sales GROUP BY item;
例: 平均値の計算¶
次の例のコードは、数値の平均を返す python_avg
ユーザー定義集計関数を定義します。
関数を作成します。
CREATE OR REPLACE AGGREGATE FUNCTION python_avg(a INT) RETURNS FLOAT LANGUAGE PYTHON RUNTIME_VERSION = 3.9 HANDLER = 'PythonAvg' AS $$ from dataclasses import dataclass @dataclass class AvgAggState: count: int sum: int class PythonAvg: def __init__(self): # This aggregate state is an object data type. self._agg_state = AvgAggState(0, 0) @property def aggregate_state(self): return self._agg_state def accumulate(self, input_value): sum = self._agg_state.sum count = self._agg_state.count self._agg_state.sum = sum + input_value self._agg_state.count = count + 1 def merge(self, other_agg_state): sum = self._agg_state.sum count = self._agg_state.count other_sum = other_agg_state.sum other_count = other_agg_state.count self._agg_state.sum = sum + other_sum self._agg_state.count = count + other_count def finish(self): sum = self._agg_state.sum count = self._agg_state.count return sum / count $$;
テストデータのテーブルを作成します。
CREATE OR REPLACE TABLE sales(item STRING, price INT); INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
python_avg
ユーザー定義関数を呼び出します。SELECT python_avg(price) FROM sales;
Snowflakeシステム定義の SQL 関数、 AVG の出力と結果を比較し、結果が同じであることを確認します。
SELECT avg(price) FROM sales;
売上テーブルのアイテム型別の平均値でグループ化します。
SELECT item, python_avg(price) FROM sales GROUP BY item;
例: 一意な値のみを返します。¶
次の例のコードは配列を受け取り、一意な値のみを含む配列を返します。
CREATE OR REPLACE AGGREGATE FUNCTION pythonGetUniqueValues(input ARRAY)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonGetUniqueValues'
AS $$
class PythonGetUniqueValues:
def __init__(self):
self._agg_state = set()
@property
def aggregate_state(self):
return self._agg_state
def accumulate(self, input):
self._agg_state.update(input)
def merge(self, other_agg_state):
self._agg_state.update(other_agg_state)
def finish(self):
return list(self._agg_state)
$$;
CREATE OR REPLACE TABLE array_table(x array) AS
SELECT ARRAY_CONSTRUCT(0, 1, 2, 3, 4, 'foo', 'bar', 'snowflake') UNION ALL
SELECT ARRAY_CONSTRUCT(1, 3, 5, 7, 9, 'foo', 'barbar', 'snowpark') UNION ALL
SELECT ARRAY_CONSTRUCT(0, 2, 4, 6, 8, 'snow');
SELECT * FROM array_table;
SELECT pythonGetUniqueValues(x) FROM array_table;
例: 文字列の数を返します。¶
以下の例のコードは、オブジェクト内のすべての文字列インスタンスの数を返します。
CREATE OR REPLACE AGGREGATE FUNCTION pythonMapCount(input STRING)
RETURNS OBJECT
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonMapCount'
AS $$
from collections import defaultdict
class PythonMapCount:
def __init__(self):
self._agg_state = defaultdict(int)
@property
def aggregate_state(self):
return self._agg_state
def accumulate(self, input):
# Increment count of lowercase input
self._agg_state[input.lower()] += 1
def merge(self, other_agg_state):
for item, count in other_agg_state.items():
self._agg_state[item] += count
def finish(self):
return dict(self._agg_state)
$$;
CREATE OR REPLACE TABLE string_table(x STRING);
INSERT INTO string_table SELECT 'foo' FROM TABLE(GENERATOR(ROWCOUNT => 1000));
INSERT INTO string_table SELECT 'bar' FROM TABLE(GENERATOR(ROWCOUNT => 2000));
INSERT INTO string_table SELECT 'snowflake' FROM TABLE(GENERATOR(ROWCOUNT => 50));
INSERT INTO string_table SELECT 'snowpark' FROM TABLE(GENERATOR(ROWCOUNT => 123));
INSERT INTO string_table SELECT 'SnOw' FROM TABLE(GENERATOR(ROWCOUNT => 1));
INSERT INTO string_table SELECT 'snow' FROM TABLE(GENERATOR(ROWCOUNT => 4));
SELECT pythonMapCount(x) FROM string_table;
例: 最大値の上位k件を返します。¶
以下の例のコードは、 k
の最大値の上位リストを返します。このコードは、最小ヒープに反転させた入力値を蓄積し、上位 k
件の最大値を返します。
CREATE OR REPLACE AGGREGATE FUNCTION pythonTopK(input INT, k INT)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonTopK'
AS $$
import heapq
from dataclasses import dataclass
import itertools
from typing import List
@dataclass
class AggState:
minheap: List[int]
k: int
class PythonTopK:
def __init__(self):
self._agg_state = AggState([], 0)
@property
def aggregate_state(self):
return self._agg_state
@staticmethod
def get_top_k_items(minheap, k):
# Return k smallest elements if there are more than k elements on the min heap.
if (len(minheap) > k):
return [heapq.heappop(minheap) for i in range(k)]
return minheap
def accumulate(self, input, k):
self._agg_state.k = k
# Store the input as negative value, as heapq is a min heap.
heapq.heappush(self._agg_state.minheap, -input)
# Store only top k items on the min heap.
self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)
def merge(self, other_agg_state):
k = self._agg_state.k if self._agg_state.k > 0 else other_agg_state.k
# Merge two min heaps by popping off elements from one and pushing them onto another.
while(len(other_agg_state.minheap) > 0):
heapq.heappush(self._agg_state.minheap, heapq.heappop(other_agg_state.minheap))
# Store only k elements on the min heap.
self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)
def finish(self):
return [-x for x in self._agg_state.minheap]
$$;
CREATE OR REPLACE TABLE numbers_table(num_column INT);
INSERT INTO numbers_table SELECT 5 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 1 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 9 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 7 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 10 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 3 FROM TABLE(GENERATOR(ROWCOUNT => 10));
-- Return top 15 largest values from numbers_table.
SELECT pythonTopK(num_column, 15) FROM numbers_table;