Fonctions agrégées définies par l’utilisateur Python¶
Les fonctions agrégées définies par l’utilisateur (UDAFs) prennent une ou plusieurs lignes en entrée et produisent une seule ligne en sortie. Elles agissent sur les valeurs de lignes pour effectuer des calculs mathématiques tels que la somme, la moyenne, le comptage, les valeurs minimale/maximale, l’écart type et l’estimation, ainsi que d’autres opérations non mathématiques.
Les UDAFs Python vous permettent d’écrire vos propres fonctions d’agrégation qui sont similaires aux fonctions d’agrégation SQL définies par le système Snowflake.
Vous pouvez également créer vos propres UDAFs en utilisant des APIs Snowpark comme décrit dans Création de fonctions définies par l’utilisateur (UDAFs) pour DataFrames dans Python.
Limitations¶
aggregate_state
a une taille maximale de 8 MB dans une version sérialisée, essayez donc de contrôler la taille de l’état agrégé.Vous ne pouvez pas appeler une UDAF en tant que fonction de fenêtre (en d’autres termes, avec une clause OVER).
IMMUTABLE n’est pas pris en charge sur une fonction d’agrégation (lorsque vous utilisez le paramètre AGGREGATE). Par conséquent, toutes les fonctions d’agrégation sont VOLATILE par défaut.
Les fonctions d’agrégation définies par l’utilisateur ne peuvent pas être utilisées conjointement avec la clause WITHIN GROUP. Les requêtes ne pourront pas être exécutées.
Interface pour le gestionnaire (handler) de la fonction d’agrégation¶
Une fonction d’agrégation regroupe les états des nœuds enfants, puis ces états agrégés sont sérialisés et envoyés au nœud parent où ils sont fusionnés et où le résultat final est calculé.
Pour définir une fonction agrégée, vous devez définir une classe Python (qui est le gestionnaire (handler) de la fonction) qui comprend des méthodes que Snowflake appelle au moment de l’exécution. Ces méthodes sont décrites dans le tableau ci-dessous. Voir les exemples ailleurs dans cette rubrique.
Méthode |
Exigence |
Description |
---|---|---|
|
Obligatoire |
Initialise l’état interne d’un agrégat. |
|
Obligatoire |
Renvoie l’état actuel d’un agrégat.
|
|
Obligatoire |
Accumule l’état de l’agrégat sur la base de la nouvelle ligne d’entrée. |
|
Obligatoire |
Combine deux états agrégés intermédiaires. |
|
Obligatoire |
Produit le résultat final sur la base de l’état agrégé. |

Exemple : calculer une somme¶
Le code de l’exemple suivant définit une fonction d’agrégation définie par l’utilisateur python_sum
(UDAF) pour renvoyer la somme des valeurs numériques.
Créez l’UDAF.
CREATE OR REPLACE AGGREGATE FUNCTION PYTHON_SUM(a INT) RETURNS INT LANGUAGE PYTHON RUNTIME_VERSION = 3.9 handler = 'PythonSum' AS $$ class PythonSum: def __init__(self): # This aggregate state is a primitive Python data type. self._partial_sum = 0 @property def aggregate_state(self): return self._partial_sum def accumulate(self, input_value): self._partial_sum += input_value def merge(self, other_partial_sum): self._partial_sum += other_partial_sum def finish(self): return self._partial_sum $$;
Créez une table de données de test.
CREATE OR REPLACE TABLE sales(item STRING, price INT); INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000); SELECT * FROM sales;
Appelez l”
python_sum
UDAF.SELECT python_sum(price) FROM sales;
Comparez les résultats avec la sortie de la fonction SQL définie par le système Snowflake, SUM, et constatez que le résultat est le même.
SELECT sum(col) FROM sales;
Regroupez par des valeurs de somme par type d’article dans le tableau des ventes.
SELECT item, python_sum(price) FROM sales GROUP BY item;
Exemple : calculer une moyenne¶
Le code de l’exemple suivant définit une fonction d’agrégation définie par l’utilisateur python_avg
pour renvoyer la moyenne des valeurs numériques.
Créez la fonction.
CREATE OR REPLACE AGGREGATE FUNCTION python_avg(a INT) RETURNS FLOAT LANGUAGE PYTHON RUNTIME_VERSION = 3.9 HANDLER = 'PythonAvg' AS $$ from dataclasses import dataclass @dataclass class AvgAggState: count: int sum: int class PythonAvg: def __init__(self): # This aggregate state is an object data type. self._agg_state = AvgAggState(0, 0) @property def aggregate_state(self): return self._agg_state def accumulate(self, input_value): sum = self._agg_state.sum count = self._agg_state.count self._agg_state.sum = sum + input_value self._agg_state.count = count + 1 def merge(self, other_agg_state): sum = self._agg_state.sum count = self._agg_state.count other_sum = other_agg_state.sum other_count = other_agg_state.count self._agg_state.sum = sum + other_sum self._agg_state.count = count + other_count def finish(self): sum = self._agg_state.sum count = self._agg_state.count return sum / count $$;
Créez une table de données de test.
CREATE OR REPLACE TABLE sales(item STRING, price INT); INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
Appelez la fonction définie par l’utilisateur
python_avg
.SELECT python_avg(price) FROM sales;
Comparez les résultats avec la sortie de la fonction SQL définie par le système Snowflake, AVG, et constatez que le résultat est le même.
SELECT avg(price) FROM sales;
Regroupez les valeurs moyennes par type d’article dans le tableau des ventes.
SELECT item, python_avg(price) FROM sales GROUP BY item;
Exemple : ne renvoyer que les valeurs uniques¶
Le code de l’exemple suivant prend un tableau et renvoie un tableau contenant uniquement les valeurs uniques.
CREATE OR REPLACE AGGREGATE FUNCTION pythonGetUniqueValues(input ARRAY)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonGetUniqueValues'
AS $$
class PythonGetUniqueValues:
def __init__(self):
self._agg_state = set()
@property
def aggregate_state(self):
return self._agg_state
def accumulate(self, input):
self._agg_state.update(input)
def merge(self, other_agg_state):
self._agg_state.update(other_agg_state)
def finish(self):
return list(self._agg_state)
$$;
CREATE OR REPLACE TABLE array_table(x array) AS
SELECT ARRAY_CONSTRUCT(0, 1, 2, 3, 4, 'foo', 'bar', 'snowflake') UNION ALL
SELECT ARRAY_CONSTRUCT(1, 3, 5, 7, 9, 'foo', 'barbar', 'snowpark') UNION ALL
SELECT ARRAY_CONSTRUCT(0, 2, 4, 6, 8, 'snow');
SELECT * FROM array_table;
SELECT pythonGetUniqueValues(x) FROM array_table;
Exemple : renvoyer un décompte de chaînes¶
Le code de l’exemple suivant renvoie le nombre de toutes les instances de chaînes dans un objet.
CREATE OR REPLACE AGGREGATE FUNCTION pythonMapCount(input STRING)
RETURNS OBJECT
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonMapCount'
AS $$
from collections import defaultdict
class PythonMapCount:
def __init__(self):
self._agg_state = defaultdict(int)
@property
def aggregate_state(self):
return self._agg_state
def accumulate(self, input):
# Increment count of lowercase input
self._agg_state[input.lower()] += 1
def merge(self, other_agg_state):
for item, count in other_agg_state.items():
self._agg_state[item] += count
def finish(self):
return dict(self._agg_state)
$$;
CREATE OR REPLACE TABLE string_table(x STRING);
INSERT INTO string_table SELECT 'foo' FROM TABLE(GENERATOR(ROWCOUNT => 1000));
INSERT INTO string_table SELECT 'bar' FROM TABLE(GENERATOR(ROWCOUNT => 2000));
INSERT INTO string_table SELECT 'snowflake' FROM TABLE(GENERATOR(ROWCOUNT => 50));
INSERT INTO string_table SELECT 'snowpark' FROM TABLE(GENERATOR(ROWCOUNT => 123));
INSERT INTO string_table SELECT 'SnOw' FROM TABLE(GENERATOR(ROWCOUNT => 1));
INSERT INTO string_table SELECT 'snow' FROM TABLE(GENERATOR(ROWCOUNT => 4));
SELECT pythonMapCount(x) FROM string_table;
Exemple : renvoyer les k premières valeurs les plus élevées¶
Le code de l’exemple suivant renvoie une liste des plus grandes valeurs pour k
. Le code accumule les valeurs d’entrée négatives sur un tas min, puis renvoie les k
valeurs les plus importantes.
CREATE OR REPLACE AGGREGATE FUNCTION pythonTopK(input INT, k INT)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonTopK'
AS $$
import heapq
from dataclasses import dataclass
import itertools
from typing import List
@dataclass
class AggState:
minheap: List[int]
k: int
class PythonTopK:
def __init__(self):
self._agg_state = AggState([], 0)
@property
def aggregate_state(self):
return self._agg_state
@staticmethod
def get_top_k_items(minheap, k):
# Return k smallest elements if there are more than k elements on the min heap.
if (len(minheap) > k):
return [heapq.heappop(minheap) for i in range(k)]
return minheap
def accumulate(self, input, k):
self._agg_state.k = k
# Store the input as negative value, as heapq is a min heap.
heapq.heappush(self._agg_state.minheap, -input)
# Store only top k items on the min heap.
self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)
def merge(self, other_agg_state):
k = self._agg_state.k if self._agg_state.k > 0 else other_agg_state.k
# Merge two min heaps by popping off elements from one and pushing them onto another.
while(len(other_agg_state.minheap) > 0):
heapq.heappush(self._agg_state.minheap, heapq.heappop(other_agg_state.minheap))
# Store only k elements on the min heap.
self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)
def finish(self):
return [-x for x in self._agg_state.minheap]
$$;
CREATE OR REPLACE TABLE numbers_table(num_column INT);
INSERT INTO numbers_table SELECT 5 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 1 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 9 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 7 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 10 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 3 FROM TABLE(GENERATOR(ROWCOUNT => 10));
-- Return top 15 largest values from numbers_table.
SELECT pythonTopK(num_column, 15) FROM numbers_table;