Benutzerdefinierte Python-Aggregatfunktionen

Benutzerdefinierte Aggregatfunktionen (UDAFs) nehmen eine oder mehrere Zeilen als Eingabe entgegen und erzeugen eine einzelne Zeile als Ausgabe. Sie arbeiten mit Werten über mehrere Zeilen hinweg, um mathematische Berechnungen wie Summe, Durchschnitt, Zählung, Ermitteln von Minimum- oder Maximumwert, Standardabweichung und Schätzung sowie andere nicht mathematische Operationen auszuführen.

Python UDAFs bietet Ihnen die Möglichkeit, eigene Aggregatfunktionen zu schreiben, die den systemdefinierten Funktionen von Snowflake SQL Aggregatfunktionen ähnlich sind.

Sie können auch Ihre eigene UDAFs mit Snowpark-APIs erstellen, wie in Erstellen von benutzerdefinierten Aggregatfunktionen (UDAFs) für DataFrames in Python beschrieben.

Einschränkungen

  • aggregate_state hat eine maximale Größe von 8 MB in einer serialisierten Version, also versuchen Sie, die Größe des Aggregatzustands zu kontrollieren.

  • Sie können eine UDAF nicht als Fensterfunktion (d. h. mit einer OVER-Klausel) aufrufen.

  • IMMUTABLE wird bei einer Aggregatfunktion nicht unterstützt (wenn Sie den Parameter AGGREGATE verwenden). Daher sind alle Aggregatfunktionen standardmäßig VOLATILE.

  • Benutzerdefinierte Aggregatfunktionen können nicht in Verbindung mit der WITHIN GROUP-Klausel verwendet werden. Abfragen werden nicht ausgeführt.

Schnittstelle für Aggregatfunktionshandler

Eine Aggregatfunktion aggregiert die Zustände in untergeordneten Knoten. Anschließend werden diese aggregierten Zustände serialisiert und an den übergeordneten Knoten gesendet, wo sie zusammengeführt werden und das Endergebnis berechnet wird.

Um eine Aggregatfunktion zu definieren, müssen Sie eine Python-Klasse definieren (die der Handler der Funktion ist), die Methoden enthält, die Snowflake zur Laufzeit aufruft. Diese Methoden werden in der folgenden Tabelle beschrieben. Beispiele finden Sie an anderer Stelle unter diesem Thema.

Methode

Anforderung

Beschreibung

__init__

Erforderlich

Initialisiert den internen Status eines Aggregats.

aggregate_state

Erforderlich

Gibt den aktuellen Status eines Aggregats zurück.

  • Die Methode muss einen @property-Decorator haben.

  • Ein Aggregatstatusobjekt kann ein beliebiger Python-Datentyp sein, der von der Python Pickle-Bibliothek serialisiert werden kann.

  • Für einfache Aggregatstatus verwenden Sie einen primitiven Python-Datentyp. Für komplexere Aggregatstatus verwenden Sie Python-Datenklassen.

accumulate

Erforderlich

Akkumuliert den Status des Aggregats auf der Grundlage der neuen Eingabezeile.

merge

Erforderlich

Kombiniert zwei Aggregatzwischenstatus.

finish

Erforderlich

Erzeugt das Endergebnis auf der Grundlage des aggregierten Status.

Die Abbildung zeigt, wie Eingabewerte in untergeordneten Knoten akkumuliert werden und dann an einen übergeordneten Knoten gesendet und dort zusammengeführt werden, um ein Endergebnis zu erhalten.

Beispiel: Summe berechnen

Der Code im folgenden Beispiel definiert eine benutzerdefinierte Aggregatfunktion (UDAF) python_sum, die die Summe der numerischen Werte zurückgibt.

  1. Erstellen Sie den UDAF.

    CREATE OR REPLACE AGGREGATE FUNCTION PYTHON_SUM(a INT)
    RETURNS INT
    LANGUAGE PYTHON
    RUNTIME_VERSION = 3.9
    handler = 'PythonSum'
    AS $$
    class PythonSum:
      def __init__(self):
        # This aggregate state is a primitive Python data type.
        self._partial_sum = 0
    
      @property
      def aggregate_state(self):
        return self._partial_sum
    
      def accumulate(self, input_value):
        self._partial_sum += input_value
    
      def merge(self, other_partial_sum):
        self._partial_sum += other_partial_sum
    
      def finish(self):
        return self._partial_sum
    $$;
    
    Copy
  2. Erstellen Sie eine Tabelle mit Testdaten.

    CREATE OR REPLACE TABLE sales(item STRING, price INT);
    
    INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
    
    SELECT * FROM sales;
    
    Copy
  3. Rufen Sie die python_sum-UDAF auf.

    SELECT python_sum(price) FROM sales;
    
    Copy
  4. Vergleichen Sie die Ergebnisse mit der Ausgabe der systemdefinierten Snowflake SQL-Funktion SUM, und überzeugen Sie sich, dass das Ergebnis dasselbe ist.

    SELECT sum(col) FROM sales;
    
    Copy
  5. Gruppieren Sie die Summenwerte in der Verkaufstabelle nach dem Artikeltyp.

    SELECT item, python_sum(price) FROM sales GROUP BY item;
    
    Copy

Beispiel: Durchschnitt berechnen

Der Code im folgenden Beispiel definiert eine benutzerdefinierte Aggregatfunktion python_avg, die den Durchschnitt der numerischen Werte zurückgibt.

  1. Erstellen Sie die Funktion.

    CREATE OR REPLACE AGGREGATE FUNCTION python_avg(a INT)
    RETURNS FLOAT
    LANGUAGE PYTHON
    RUNTIME_VERSION = 3.9
    HANDLER = 'PythonAvg'
    AS $$
    from dataclasses import dataclass
    
    @dataclass
    class AvgAggState:
        count: int
        sum: int
    
    class PythonAvg:
        def __init__(self):
            # This aggregate state is an object data type.
            self._agg_state = AvgAggState(0, 0)
    
        @property
        def aggregate_state(self):
            return self._agg_state
    
        def accumulate(self, input_value):
            sum = self._agg_state.sum
            count = self._agg_state.count
    
            self._agg_state.sum = sum + input_value
            self._agg_state.count = count + 1
    
        def merge(self, other_agg_state):
            sum = self._agg_state.sum
            count = self._agg_state.count
    
            other_sum = other_agg_state.sum
            other_count = other_agg_state.count
    
            self._agg_state.sum = sum + other_sum
            self._agg_state.count = count + other_count
    
        def finish(self):
            sum = self._agg_state.sum
            count = self._agg_state.count
            return sum / count
    $$;
    
    Copy
  2. Erstellen Sie eine Tabelle mit Testdaten.

    CREATE OR REPLACE TABLE sales(item STRING, price INT);
    INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
    
    Copy
  3. Rufen Sie die benutzerdefinierte Funktion python_avg auf.

    SELECT python_avg(price) FROM sales;
    
    Copy
  4. Vergleichen Sie die Ergebnisse mit der Ausgabe der systemdefinierten Snowflake SQL-Funktion AVG, und überzeugen Sie sich, dass das Ergebnis dasselbe ist.

    SELECT avg(price) FROM sales;
    
    Copy
  5. Gruppieren Sie die Durchschnittswerte in der Verkaufstabelle nach dem Artikeltyp.

    SELECT item, python_avg(price) FROM sales GROUP BY item;
    
    Copy

Beispiel: Nur eindeutige Werte zurückgeben

Der Code im folgenden Beispiel nimmt ein Array und gibt ein Array zurück, das nur die eindeutigen (diskreten) Werte enthält.

CREATE OR REPLACE AGGREGATE FUNCTION pythonGetUniqueValues(input ARRAY)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonGetUniqueValues'
AS $$
class PythonGetUniqueValues:
    def __init__(self):
        self._agg_state = set()

    @property
    def aggregate_state(self):
        return self._agg_state

    def accumulate(self, input):
        self._agg_state.update(input)

    def merge(self, other_agg_state):
        self._agg_state.update(other_agg_state)

    def finish(self):
        return list(self._agg_state)
$$;


CREATE OR REPLACE TABLE array_table(x array) AS
SELECT ARRAY_CONSTRUCT(0, 1, 2, 3, 4, 'foo', 'bar', 'snowflake') UNION ALL
SELECT ARRAY_CONSTRUCT(1, 3, 5, 7, 9, 'foo', 'barbar', 'snowpark') UNION ALL
SELECT ARRAY_CONSTRUCT(0, 2, 4, 6, 8, 'snow');

SELECT * FROM array_table;

SELECT pythonGetUniqueValues(x) FROM array_table;
Copy

Beispiel: Zählung von Zeichenfolgen zurückgeben

Der Code im folgenden Beispiel gibt die Anzahl aller Instanzen von Zeichenfolgen in einem Objekt zurück.

CREATE OR REPLACE AGGREGATE FUNCTION pythonMapCount(input STRING)
RETURNS OBJECT
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonMapCount'
AS $$
from collections import defaultdict

class PythonMapCount:
    def __init__(self):
        self._agg_state = defaultdict(int)

    @property
    def aggregate_state(self):
        return self._agg_state

    def accumulate(self, input):
        # Increment count of lowercase input
        self._agg_state[input.lower()] += 1

    def merge(self, other_agg_state):
        for item, count in other_agg_state.items():
            self._agg_state[item] += count

    def finish(self):
        return dict(self._agg_state)
$$;

CREATE OR REPLACE TABLE string_table(x STRING);
INSERT INTO string_table SELECT 'foo' FROM TABLE(GENERATOR(ROWCOUNT => 1000));
INSERT INTO string_table SELECT 'bar' FROM TABLE(GENERATOR(ROWCOUNT => 2000));
INSERT INTO string_table SELECT 'snowflake' FROM TABLE(GENERATOR(ROWCOUNT => 50));
INSERT INTO string_table SELECT 'snowpark' FROM TABLE(GENERATOR(ROWCOUNT => 123));
INSERT INTO string_table SELECT 'SnOw' FROM TABLE(GENERATOR(ROWCOUNT => 1));
INSERT INTO string_table SELECT 'snow' FROM TABLE(GENERATOR(ROWCOUNT => 4));

SELECT pythonMapCount(x) FROM string_table;
Copy

Beispiel: Top-k-Werte zurückgeben

Der Code im folgenden Beispiel gibt eine Liste der größten Werte für k zurück. Der Code akkumuliert negierte Eingabewerte auf einem Min-Heap und gibt dann die obersten k größten Werte zurück.

CREATE OR REPLACE AGGREGATE FUNCTION pythonTopK(input INT, k INT)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonTopK'
AS $$
import heapq
from dataclasses import dataclass
import itertools
from typing import List

@dataclass
class AggState:
    minheap: List[int]
    k: int

class PythonTopK:
    def __init__(self):
        self._agg_state = AggState([], 0)

    @property
    def aggregate_state(self):
        return self._agg_state

    @staticmethod
    def get_top_k_items(minheap, k):
      # Return k smallest elements if there are more than k elements on the min heap.
      if (len(minheap) > k):
        return [heapq.heappop(minheap) for i in range(k)]
      return minheap

    def accumulate(self, input, k):
        self._agg_state.k = k

        # Store the input as negative value, as heapq is a min heap.
        heapq.heappush(self._agg_state.minheap, -input)

        # Store only top k items on the min heap.
        self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)

    def merge(self, other_agg_state):
        k = self._agg_state.k if self._agg_state.k > 0 else other_agg_state.k

        # Merge two min heaps by popping off elements from one and pushing them onto another.
        while(len(other_agg_state.minheap) > 0):
            heapq.heappush(self._agg_state.minheap, heapq.heappop(other_agg_state.minheap))

        # Store only k elements on the min heap.
        self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)

    def finish(self):
        return [-x for x in self._agg_state.minheap]
$$;


CREATE OR REPLACE TABLE numbers_table(num_column INT);
INSERT INTO numbers_table SELECT 5 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 1 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 9 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 7 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 10 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 3 FROM TABLE(GENERATOR(ROWCOUNT => 10));

-- Return top 15 largest values from numbers_table.
SELECT pythonTopK(num_column, 15) FROM numbers_table;
Copy