Funções de agregação definidas pelo usuário em Python

Funções agregadas definidas pelo usuário (UDAFs) usam uma ou mais linhas como entrada e produzem uma única linha de saída. Elas operam em valores de várias linhas para realizar cálculos matemáticos como soma, média, contagem, valores mínimos/máximos, desvio padrão e estimativa, assim como algumas operações não matemáticas.

As UDAFs Python fornecem uma maneira para você escrever suas próprias funções de agregação de são semelhantes às funções de agregação SQL definidas pelo sistema Snowflake.

Também é possível criar suas próprias UDAFs usando as APIs Snowpark, conforme descrito em Criação de funções agregadas definidas pelo usuário (UDAFs) para DataFrames em Python.

Limitações

  • O aggregate_state tem um tamanho máximo de 8 MB em uma versão serializada, então tente controlar o tamanho do estado agregado.

  • Você não pode chamar uma UDAF como uma função de janela (em outras palavras, com uma cláusula OVER).

  • IMMUTABLE não é suportado em uma função agregada (quando você usa o parâmetro AGGREGATE). Portanto, todas as funções agregadas são VOLATILE por padrão.

  • Funções de agregação definidas pelo usuário não podem ser usadas em conjunto com a cláusula WITHIN GROUP. As consultas não serão executadas.

Interface para manipulador de função agregada

Uma função de agregação agrega estados em nós filhos e, então, eventualmente, esses estados agregados são serializados e enviados ao nó pai, onde são mesclados e o resultado final é calculado.

Para definir uma função de agregação, você deve definir uma classe Python (que é o manipulador da função) que inclua métodos que o Snowflake invoca em tempo de execução. Esses métodos são descritos na tabela abaixo. Veja exemplos em outras partes deste tópico.

Método

Requisito

Descrição

__init__

Obrigatório

Inicializa o estado interno de um agregado.

aggregate_state

Obrigatório

Retorna o estado atual de um agregado.

  • O método deve ter um decorador @propriedade.

  • Um objeto de estado agregado pode ser qualquer tipo de dados Python serializável pela biblioteca Python Pickle.

  • Para estados agregados simples, use um tipo de dados primitivo do Python. Para estados agregados mais complexos, use classes de dados Python.

accumulate

Obrigatório

Acumula o estado do agregado com base na nova linha de entrada.

merge

Obrigatório

Combina dois estados agregados intermediários.

finish

Obrigatório

Produz o resultado final com base no estado agregado.

Diagrama mostrando valores de entrada sendo acumulados em nós filhos e depois enviados para um nó pai e mesclados para produzir um resultado final.

Exemplo: cálculo de uma soma

O código no exemplo a seguir define uma função de agregação python_sum definida pelo usuário (UDAF) para retornar a soma dos valores numéricos.

  1. Crie a UDAF.

    CREATE OR REPLACE AGGREGATE FUNCTION PYTHON_SUM(a INT)
    RETURNS INT
    LANGUAGE PYTHON
    RUNTIME_VERSION = 3.9
    handler = 'PythonSum'
    AS $$
    class PythonSum:
      def __init__(self):
        # This aggregate state is a primitive Python data type.
        self._partial_sum = 0
    
      @property
      def aggregate_state(self):
        return self._partial_sum
    
      def accumulate(self, input_value):
        self._partial_sum += input_value
    
      def merge(self, other_partial_sum):
        self._partial_sum += other_partial_sum
    
      def finish(self):
        return self._partial_sum
    $$;
    
    Copy
  2. Crie uma tabela de dados de teste.

    CREATE OR REPLACE TABLE sales(item STRING, price INT);
    
    INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
    
    SELECT * FROM sales;
    
    Copy
  3. Chame a python_sum UDAF

    SELECT python_sum(price) FROM sales;
    
    Copy
  4. Compare os resultados com a saída da função SQL definida pelo usuário do Snowflake, SUM e veja que o resultado é o mesmo.

    SELECT sum(col) FROM sales;
    
    Copy
  5. Agrupe por soma os valores por tipo de item na tabela de vendas.

    SELECT item, python_sum(price) FROM sales GROUP BY item;
    
    Copy

Exemplo: cálculo de uma média

O código no exemplo a seguir define uma função de agregação python_avg definida pelo usuário para retornar a média dos valores numéricos.

  1. Crie a função.

    CREATE OR REPLACE AGGREGATE FUNCTION python_avg(a INT)
    RETURNS FLOAT
    LANGUAGE PYTHON
    RUNTIME_VERSION = 3.9
    HANDLER = 'PythonAvg'
    AS $$
    from dataclasses import dataclass
    
    @dataclass
    class AvgAggState:
        count: int
        sum: int
    
    class PythonAvg:
        def __init__(self):
            # This aggregate state is an object data type.
            self._agg_state = AvgAggState(0, 0)
    
        @property
        def aggregate_state(self):
            return self._agg_state
    
        def accumulate(self, input_value):
            sum = self._agg_state.sum
            count = self._agg_state.count
    
            self._agg_state.sum = sum + input_value
            self._agg_state.count = count + 1
    
        def merge(self, other_agg_state):
            sum = self._agg_state.sum
            count = self._agg_state.count
    
            other_sum = other_agg_state.sum
            other_count = other_agg_state.count
    
            self._agg_state.sum = sum + other_sum
            self._agg_state.count = count + other_count
    
        def finish(self):
            sum = self._agg_state.sum
            count = self._agg_state.count
            return sum / count
    $$;
    
    Copy
  2. Crie uma tabela de dados de teste.

    CREATE OR REPLACE TABLE sales(item STRING, price INT);
    INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
    
    Copy
  3. Chame a função definida pelo usuário python_avg.

    SELECT python_avg(price) FROM sales;
    
    Copy
  4. Compare os resultados com a saída da função SQL definida pelo usuário do Snowflake, AVG e veja que o resultado é o mesmo.

    SELECT avg(price) FROM sales;
    
    Copy
  5. Agrupe valores médios por tipo de item na tabela de vendas.

    SELECT item, python_avg(price) FROM sales GROUP BY item;
    
    Copy

Exemplo: retorno apenas de valores únicos

O código no exemplo a seguir pega uma matriz e retorna uma matriz contendo apenas os valores exclusivos.

CREATE OR REPLACE AGGREGATE FUNCTION pythonGetUniqueValues(input ARRAY)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonGetUniqueValues'
AS $$
class PythonGetUniqueValues:
    def __init__(self):
        self._agg_state = set()

    @property
    def aggregate_state(self):
        return self._agg_state

    def accumulate(self, input):
        self._agg_state.update(input)

    def merge(self, other_agg_state):
        self._agg_state.update(other_agg_state)

    def finish(self):
        return list(self._agg_state)
$$;


CREATE OR REPLACE TABLE array_table(x array) AS
SELECT ARRAY_CONSTRUCT(0, 1, 2, 3, 4, 'foo', 'bar', 'snowflake') UNION ALL
SELECT ARRAY_CONSTRUCT(1, 3, 5, 7, 9, 'foo', 'barbar', 'snowpark') UNION ALL
SELECT ARRAY_CONSTRUCT(0, 2, 4, 6, 8, 'snow');

SELECT * FROM array_table;

SELECT pythonGetUniqueValues(x) FROM array_table;
Copy

Exemplo: retorna uma contagem de cadeias de caracteres

O código no exemplo a seguir retorna contagens de todas as instâncias de cadeias de caracteres em um objeto.

CREATE OR REPLACE AGGREGATE FUNCTION pythonMapCount(input STRING)
RETURNS OBJECT
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonMapCount'
AS $$
from collections import defaultdict

class PythonMapCount:
    def __init__(self):
        self._agg_state = defaultdict(int)

    @property
    def aggregate_state(self):
        return self._agg_state

    def accumulate(self, input):
        # Increment count of lowercase input
        self._agg_state[input.lower()] += 1

    def merge(self, other_agg_state):
        for item, count in other_agg_state.items():
            self._agg_state[item] += count

    def finish(self):
        return dict(self._agg_state)
$$;

CREATE OR REPLACE TABLE string_table(x STRING);
INSERT INTO string_table SELECT 'foo' FROM TABLE(GENERATOR(ROWCOUNT => 1000));
INSERT INTO string_table SELECT 'bar' FROM TABLE(GENERATOR(ROWCOUNT => 2000));
INSERT INTO string_table SELECT 'snowflake' FROM TABLE(GENERATOR(ROWCOUNT => 50));
INSERT INTO string_table SELECT 'snowpark' FROM TABLE(GENERATOR(ROWCOUNT => 123));
INSERT INTO string_table SELECT 'SnOw' FROM TABLE(GENERATOR(ROWCOUNT => 1));
INSERT INTO string_table SELECT 'snow' FROM TABLE(GENERATOR(ROWCOUNT => 4));

SELECT pythonMapCount(x) FROM string_table;
Copy

Exemplo: retorno dos maiores valores k

O código no exemplo a seguir retorna uma lista dos maiores valores para k. O código acumula valores de entrada negados em um heap mínimo e, em seguida, retorna os primeiros valores maiores k.

CREATE OR REPLACE AGGREGATE FUNCTION pythonTopK(input INT, k INT)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonTopK'
AS $$
import heapq
from dataclasses import dataclass
import itertools
from typing import List

@dataclass
class AggState:
    minheap: List[int]
    k: int

class PythonTopK:
    def __init__(self):
        self._agg_state = AggState([], 0)

    @property
    def aggregate_state(self):
        return self._agg_state

    @staticmethod
    def get_top_k_items(minheap, k):
      # Return k smallest elements if there are more than k elements on the min heap.
      if (len(minheap) > k):
        return [heapq.heappop(minheap) for i in range(k)]
      return minheap

    def accumulate(self, input, k):
        self._agg_state.k = k

        # Store the input as negative value, as heapq is a min heap.
        heapq.heappush(self._agg_state.minheap, -input)

        # Store only top k items on the min heap.
        self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)

    def merge(self, other_agg_state):
        k = self._agg_state.k if self._agg_state.k > 0 else other_agg_state.k

        # Merge two min heaps by popping off elements from one and pushing them onto another.
        while(len(other_agg_state.minheap) > 0):
            heapq.heappush(self._agg_state.minheap, heapq.heappop(other_agg_state.minheap))

        # Store only k elements on the min heap.
        self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)

    def finish(self):
        return [-x for x in self._agg_state.minheap]
$$;


CREATE OR REPLACE TABLE numbers_table(num_column INT);
INSERT INTO numbers_table SELECT 5 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 1 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 9 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 7 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 10 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 3 FROM TABLE(GENERATOR(ROWCOUNT => 10));

-- Return top 15 largest values from numbers_table.
SELECT pythonTopK(num_column, 15) FROM numbers_table;
Copy