Benutzerdefinierte Python-Aggregatfunktionen¶
Benutzerdefinierte Aggregatfunktionen (UDAFs) nehmen eine oder mehrere Zeilen als Eingabe entgegen und erzeugen eine einzelne Zeile als Ausgabe. Sie arbeiten mit Werten über mehrere Zeilen hinweg, um mathematische Berechnungen wie Summe, Durchschnitt, Zählung, Ermitteln von Minimum- oder Maximumwert, Standardabweichung und Schätzung sowie andere nicht mathematische Operationen auszuführen.
Python UDAFs bietet Ihnen die Möglichkeit, eigene Aggregatfunktionen zu schreiben, die den systemdefinierten Funktionen von Snowflake SQL Aggregatfunktionen ähnlich sind.
Sie können auch Ihre eigene UDAFs mit Snowpark-APIs erstellen, wie in Erstellen von benutzerdefinierten Aggregatfunktionen (UDAFs) für DataFrames in Python beschrieben.
Einschränkungen¶
aggregate_state
hat eine maximale Größe von 8 MB in einer serialisierten Version, also versuchen Sie, die Größe des Aggregatzustands zu kontrollieren.Sie können eine UDAF nicht als Fensterfunktion (d. h. mit einer OVER-Klausel) aufrufen.
IMMUTABLE wird bei einer Aggregatfunktion nicht unterstützt (wenn Sie den Parameter AGGREGATE verwenden). Daher sind alle Aggregatfunktionen standardmäßig VOLATILE.
Benutzerdefinierte Aggregatfunktionen können nicht in Verbindung mit der WITHIN GROUP-Klausel verwendet werden. Abfragen werden nicht ausgeführt.
Schnittstelle für Aggregatfunktionshandler¶
Eine Aggregatfunktion aggregiert die Zustände in untergeordneten Knoten. Anschließend werden diese aggregierten Zustände serialisiert und an den übergeordneten Knoten gesendet, wo sie zusammengeführt werden und das Endergebnis berechnet wird.
Um eine Aggregatfunktion zu definieren, müssen Sie eine Python-Klasse definieren (die der Handler der Funktion ist), die Methoden enthält, die Snowflake zur Laufzeit aufruft. Diese Methoden werden in der folgenden Tabelle beschrieben. Beispiele finden Sie an anderer Stelle unter diesem Thema.
Methode |
Anforderung |
Beschreibung |
---|---|---|
|
Erforderlich |
Initialisiert den internen Status eines Aggregats. |
|
Erforderlich |
Gibt den aktuellen Status eines Aggregats zurück.
|
|
Erforderlich |
Akkumuliert den Status des Aggregats auf der Grundlage der neuen Eingabezeile. |
|
Erforderlich |
Kombiniert zwei Aggregatzwischenstatus. |
|
Erforderlich |
Erzeugt das Endergebnis auf der Grundlage des aggregierten Status. |

Beispiel: Summe berechnen¶
Der Code im folgenden Beispiel definiert eine benutzerdefinierte Aggregatfunktion (UDAF) python_sum
, die die Summe der numerischen Werte zurückgibt.
Erstellen Sie den UDAF.
CREATE OR REPLACE AGGREGATE FUNCTION PYTHON_SUM(a INT) RETURNS INT LANGUAGE PYTHON RUNTIME_VERSION = 3.9 handler = 'PythonSum' AS $$ class PythonSum: def __init__(self): # This aggregate state is a primitive Python data type. self._partial_sum = 0 @property def aggregate_state(self): return self._partial_sum def accumulate(self, input_value): self._partial_sum += input_value def merge(self, other_partial_sum): self._partial_sum += other_partial_sum def finish(self): return self._partial_sum $$;
Erstellen Sie eine Tabelle mit Testdaten.
CREATE OR REPLACE TABLE sales(item STRING, price INT); INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000); SELECT * FROM sales;
Rufen Sie die
python_sum
-UDAF auf.SELECT python_sum(price) FROM sales;
Vergleichen Sie die Ergebnisse mit der Ausgabe der systemdefinierten Snowflake SQL-Funktion SUM, und überzeugen Sie sich, dass das Ergebnis dasselbe ist.
SELECT sum(col) FROM sales;
Gruppieren Sie die Summenwerte in der Verkaufstabelle nach dem Artikeltyp.
SELECT item, python_sum(price) FROM sales GROUP BY item;
Beispiel: Durchschnitt berechnen¶
Der Code im folgenden Beispiel definiert eine benutzerdefinierte Aggregatfunktion python_avg
, die den Durchschnitt der numerischen Werte zurückgibt.
Erstellen Sie die Funktion.
CREATE OR REPLACE AGGREGATE FUNCTION python_avg(a INT) RETURNS FLOAT LANGUAGE PYTHON RUNTIME_VERSION = 3.9 HANDLER = 'PythonAvg' AS $$ from dataclasses import dataclass @dataclass class AvgAggState: count: int sum: int class PythonAvg: def __init__(self): # This aggregate state is an object data type. self._agg_state = AvgAggState(0, 0) @property def aggregate_state(self): return self._agg_state def accumulate(self, input_value): sum = self._agg_state.sum count = self._agg_state.count self._agg_state.sum = sum + input_value self._agg_state.count = count + 1 def merge(self, other_agg_state): sum = self._agg_state.sum count = self._agg_state.count other_sum = other_agg_state.sum other_count = other_agg_state.count self._agg_state.sum = sum + other_sum self._agg_state.count = count + other_count def finish(self): sum = self._agg_state.sum count = self._agg_state.count return sum / count $$;
Erstellen Sie eine Tabelle mit Testdaten.
CREATE OR REPLACE TABLE sales(item STRING, price INT); INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
Rufen Sie die benutzerdefinierte Funktion
python_avg
auf.SELECT python_avg(price) FROM sales;
Vergleichen Sie die Ergebnisse mit der Ausgabe der systemdefinierten Snowflake SQL-Funktion AVG, und überzeugen Sie sich, dass das Ergebnis dasselbe ist.
SELECT avg(price) FROM sales;
Gruppieren Sie die Durchschnittswerte in der Verkaufstabelle nach dem Artikeltyp.
SELECT item, python_avg(price) FROM sales GROUP BY item;
Beispiel: Nur eindeutige Werte zurückgeben¶
Der Code im folgenden Beispiel nimmt ein Array und gibt ein Array zurück, das nur die eindeutigen (diskreten) Werte enthält.
CREATE OR REPLACE AGGREGATE FUNCTION pythonGetUniqueValues(input ARRAY)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonGetUniqueValues'
AS $$
class PythonGetUniqueValues:
def __init__(self):
self._agg_state = set()
@property
def aggregate_state(self):
return self._agg_state
def accumulate(self, input):
self._agg_state.update(input)
def merge(self, other_agg_state):
self._agg_state.update(other_agg_state)
def finish(self):
return list(self._agg_state)
$$;
CREATE OR REPLACE TABLE array_table(x array) AS
SELECT ARRAY_CONSTRUCT(0, 1, 2, 3, 4, 'foo', 'bar', 'snowflake') UNION ALL
SELECT ARRAY_CONSTRUCT(1, 3, 5, 7, 9, 'foo', 'barbar', 'snowpark') UNION ALL
SELECT ARRAY_CONSTRUCT(0, 2, 4, 6, 8, 'snow');
SELECT * FROM array_table;
SELECT pythonGetUniqueValues(x) FROM array_table;
Beispiel: Zählung von Zeichenfolgen zurückgeben¶
Der Code im folgenden Beispiel gibt die Anzahl aller Instanzen von Zeichenfolgen in einem Objekt zurück.
CREATE OR REPLACE AGGREGATE FUNCTION pythonMapCount(input STRING)
RETURNS OBJECT
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonMapCount'
AS $$
from collections import defaultdict
class PythonMapCount:
def __init__(self):
self._agg_state = defaultdict(int)
@property
def aggregate_state(self):
return self._agg_state
def accumulate(self, input):
# Increment count of lowercase input
self._agg_state[input.lower()] += 1
def merge(self, other_agg_state):
for item, count in other_agg_state.items():
self._agg_state[item] += count
def finish(self):
return dict(self._agg_state)
$$;
CREATE OR REPLACE TABLE string_table(x STRING);
INSERT INTO string_table SELECT 'foo' FROM TABLE(GENERATOR(ROWCOUNT => 1000));
INSERT INTO string_table SELECT 'bar' FROM TABLE(GENERATOR(ROWCOUNT => 2000));
INSERT INTO string_table SELECT 'snowflake' FROM TABLE(GENERATOR(ROWCOUNT => 50));
INSERT INTO string_table SELECT 'snowpark' FROM TABLE(GENERATOR(ROWCOUNT => 123));
INSERT INTO string_table SELECT 'SnOw' FROM TABLE(GENERATOR(ROWCOUNT => 1));
INSERT INTO string_table SELECT 'snow' FROM TABLE(GENERATOR(ROWCOUNT => 4));
SELECT pythonMapCount(x) FROM string_table;
Beispiel: Top-k-Werte zurückgeben¶
Der Code im folgenden Beispiel gibt eine Liste der größten Werte für k
zurück. Der Code akkumuliert negierte Eingabewerte auf einem Min-Heap und gibt dann die obersten k
größten Werte zurück.
CREATE OR REPLACE AGGREGATE FUNCTION pythonTopK(input INT, k INT)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonTopK'
AS $$
import heapq
from dataclasses import dataclass
import itertools
from typing import List
@dataclass
class AggState:
minheap: List[int]
k: int
class PythonTopK:
def __init__(self):
self._agg_state = AggState([], 0)
@property
def aggregate_state(self):
return self._agg_state
@staticmethod
def get_top_k_items(minheap, k):
# Return k smallest elements if there are more than k elements on the min heap.
if (len(minheap) > k):
return [heapq.heappop(minheap) for i in range(k)]
return minheap
def accumulate(self, input, k):
self._agg_state.k = k
# Store the input as negative value, as heapq is a min heap.
heapq.heappush(self._agg_state.minheap, -input)
# Store only top k items on the min heap.
self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)
def merge(self, other_agg_state):
k = self._agg_state.k if self._agg_state.k > 0 else other_agg_state.k
# Merge two min heaps by popping off elements from one and pushing them onto another.
while(len(other_agg_state.minheap) > 0):
heapq.heappush(self._agg_state.minheap, heapq.heappop(other_agg_state.minheap))
# Store only k elements on the min heap.
self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)
def finish(self):
return [-x for x in self._agg_state.minheap]
$$;
CREATE OR REPLACE TABLE numbers_table(num_column INT);
INSERT INTO numbers_table SELECT 5 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 1 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 9 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 7 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 10 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 3 FROM TABLE(GENERATOR(ROWCOUNT => 10));
-- Return top 15 largest values from numbers_table.
SELECT pythonTopK(num_column, 15) FROM numbers_table;