Python user-defined aggregate functions¶
User-defined aggregate functions (UDAFs) take one or more rows as input and produce a single row of output. They operate on values across multiple rows to perform mathematical calculations such as sum, average, counting, finding minimum or maximum values, standard deviation, and estimation, as well as some non-mathematical operations.
Python UDAFs provide a way for you to write your own aggregate functions that are similar to the Snowflake system-defined SQL aggregate functions.
You can also create your own UDAFs using Snowpark APIs as described in Creating User-Defined Aggregate Functions (UDAFs) for DataFrames in Python.
Limitations¶
The
aggregate_state
has a maximum size of 8 MB in a serialized version, so try to control the size of the aggregate state.You can’t call a UDAF as a window function (in other words, with an OVER clause).
Event table logging from Python UDAFs is not currently supported.
IMMUTABLE is not supported on an aggregate function (when you use the AGGREGATE parameter). Therefore, all aggregate functions are VOLATILE by default.
Interface for aggregate function handler¶
An aggregate function aggregates state in child nodes and then, eventually, those aggregate states are serialized and sent to the parent node where they get merged and the final result is calculated.
To define an aggregate function, you must define a Python class (which is the function’s handler) that includes methods that Snowflake invokes at run time. Those methods are described in the table below. See examples elsewhere in this topic.
Method |
Requirement |
Description |
---|---|---|
|
Required |
Initializes the internal state of an aggregate. |
|
Required |
Returns the current state of an aggregate.
|
|
Required |
Accumulates the state of the aggregate based on the new input row. |
|
Required |
Combines two intermediate aggregated states. |
|
Required |
Produces the final result based on the aggregated state. |
Example: Calculate a sum¶
Code in the following example defines a python_sum
user-defined aggregate function (UDAF) to return the sum of the numeric values.
Create the UDAF.
CREATE OR REPLACE AGGREGATE FUNCTION PYTHON_SUM(a INT) RETURNS INT LANGUAGE PYTHON RUNTIME_VERSION=3.8 handler = 'PythonSum' AS $$ class PythonSum: def __init__(self): # This aggregate state is a primitive Python data type. self._partial_sum = 0 @property def aggregate_state(self): return self._partial_sum def accumulate(self, input_value): self._partial_sum += input_value def merge(self, other_partial_sum): self._partial_sum += other_partial_sum def finish(self): return self._partial_sum $$;
Create a table of test data.
CREATE OR REPLACE TABLE sales(item STRING, price INT); INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000); SELECT * FROM sales;
Call the
python_sum
UDAF.SELECT python_sum(price) FROM sales;
Compare results with the output of the Snowflake system-defined SQL function, SUM, and see that the result is the same.
SELECT sum(col) FROM sales;
Group by sum values by the item type in the sales table.
SELECT item, python_sum(price) FROM sales GROUP BY item;
Example: Calculate an average¶
Code in the following example defines a python_avg
user-defined aggregate function to return the average of the numeric values.
Create the function.
CREATE OR REPLACE AGGREGATE FUNCTION python_avg(a INT) RETURNS FLOAT LANGUAGE PYTHON RUNTIME_VERSION = 3.8 HANDLER = 'PythonAvg' AS $$ from dataclasses import dataclass @dataclass class AvgAggState: count: int sum: int class PythonAvg: def __init__(self): # This aggregate state is an object data type. self._agg_state = AvgAggState(0, 0) @property def aggregate_state(self): return self._agg_state def accumulate(self, input_value): sum = self._agg_state.sum count = self._agg_state.count self._agg_state.sum = sum + input_value self._agg_state.count = count + 1 def merge(self, other_agg_state): sum = self._agg_state.sum count = self._agg_state.count other_sum = other_agg_state.sum other_count = other_agg_state.count self._agg_state.sum = sum + other_sum self._agg_state.count = count + other_count def finish(self): sum = self._agg_state.sum count = self._agg_state.count return sum / count $$;
Create a table of test data.
CREATE OR REPLACE TABLE sales(item STRING, price INT); INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
Call the
python_avg
user-defined function.SELECT python_avg(price) FROM sales;
Compare results with the output of the Snowflake system-defined SQL function, AVG, and see that the result is the same.
SELECT avg(price) FROM sales;
Group average values by the item type in the sales table.
SELECT item, python_avg(price) FROM sales GROUP BY item;
Example: Return only unique values¶
Code in the following example takes an array and returns an array containing only the unique values.
CREATE OR REPLACE AGGREGATE FUNCTION pythonGetUniqueValues(input ARRAY)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.8
HANDLER = 'PythonGetUniqueValues'
AS $$
class PythonGetUniqueValues:
def __init__(self):
self._agg_state = set()
@property
def aggregate_state(self):
return self._agg_state
def accumulate(self, input):
self._agg_state.update(input)
def merge(self, other_agg_state):
self._agg_state.update(other_agg_state)
def finish(self):
return list(self._agg_state)
$$;
CREATE OR REPLACE TABLE array_table(x array) AS
SELECT ARRAY_CONSTRUCT(0, 1, 2, 3, 4, 'foo', 'bar', 'snowflake') UNION ALL
SELECT ARRAY_CONSTRUCT(1, 3, 5, 7, 9, 'foo', 'barbar', 'snowpark') UNION ALL
SELECT ARRAY_CONSTRUCT(0, 2, 4, 6, 8, 'snow');
SELECT * FROM array_table;
SELECT pythonGetUniqueValues(x) FROM array_table;
Example: Return a count of strings¶
Code in the following example returns counts of all instances of strings in an object.
CREATE OR REPLACE AGGREGATE FUNCTION pythonMapCount(input STRING)
RETURNS OBJECT
LANGUAGE PYTHON
RUNTIME_VERSION = 3.8
HANDLER = 'PythonMapCount'
AS $$
from collections import defaultdict
class PythonMapCount:
def __init__(self):
self._agg_state = defaultdict(int)
@property
def aggregate_state(self):
return self._agg_state
def accumulate(self, input):
# Increment count of lowercase input
self._agg_state[input.lower()] += 1
def merge(self, other_agg_state):
for item, count in other_agg_state.items():
self._agg_state[item] += count
def finish(self):
return dict(self._agg_state)
$$;
CREATE OR REPLACE TABLE string_table(x STRING);
INSERT INTO string_table SELECT 'foo' FROM TABLE(GENERATOR(ROWCOUNT => 1000));
INSERT INTO string_table SELECT 'bar' FROM TABLE(GENERATOR(ROWCOUNT => 2000));
INSERT INTO string_table SELECT 'snowflake' FROM TABLE(GENERATOR(ROWCOUNT => 50));
INSERT INTO string_table SELECT 'snowpark' FROM TABLE(GENERATOR(ROWCOUNT => 123));
INSERT INTO string_table SELECT 'SnOw' FROM TABLE(GENERATOR(ROWCOUNT => 1));
INSERT INTO string_table SELECT 'snow' FROM TABLE(GENERATOR(ROWCOUNT => 4));
SELECT pythonMapCount(x) FROM string_table;
Example: Return top k largest values¶
Code in the following example returns a list of the top largest values for k
. The code accumulates negated input values on a min
heap, then returns the top k
largest values.
CREATE OR REPLACE AGGREGATE FUNCTION pythonTopK(input INT, k INT)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.8
HANDLER = 'PythonTopK'
AS $$
import heapq
from dataclasses import dataclass
import itertools
from typing import List
@dataclass
class AggState:
minheap: List[int]
k: int
class PythonTopK:
def __init__(self):
self._agg_state = AggState([], 0)
@property
def aggregate_state(self):
return self._agg_state
@staticmethod
def get_top_k_items(minheap, k):
# Return k smallest elements if there are more than k elements on the min heap.
if (len(minheap) > k):
return [heapq.heappop(minheap) for i in range(k)]
return minheap
def accumulate(self, input, k):
self._agg_state.k = k
# Store the input as negative value, as heapq is a min heap.
heapq.heappush(self._agg_state.minheap, -input)
# Store only top k items on the min heap.
self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)
def merge(self, other_agg_state):
k = self._agg_state.k if self._agg_state.k > 0 else other_agg_state.k
# Merge two min heaps by popping off elements from one and pushing them onto another.
while(len(other_agg_state.minheap) > 0):
heapq.heappush(self._agg_state.minheap, heapq.heappop(other_agg_state.minheap))
# Store only k elements on the min heap.
self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)
def finish(self):
return [-x for x in self._agg_state.minheap]
$$;
CREATE OR REPLACE TABLE numbers_table(num_column INT);
INSERT INTO numbers_table SELECT 5 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 1 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 9 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 7 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 10 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 3 FROM TABLE(GENERATOR(ROWCOUNT => 10));
-- Return top 15 largest values from numbers_table.
SELECT pythonTopK(num_column, 15) FROM numbers_table;