Python User-Defined Aggregate Functions¶
Introduction¶
User-defined aggregate functions (UDAFs) take one or more rows as input and produce a single row of output. They operate on values across multiple rows to perform mathematical calculations such as sum, average, counting, finding minimum or maximum values, standard deviation, and estimation, as well as some non-mathematical operations. Python UDAFs provide a way for you to write your own aggregate functions that are similar to the Snowflake system-defined SQL Aggregate Functions.
Limitations¶
The aggregate_state has a maximum size of 8 MB in a serialized version, so try to control the size of the aggregate state.
Tuples cannot be used in the intermediate aggregate state.
UDAFs cannot currently be called as a window function (i.e. with an OVER clause).
The Interface for Aggregate Functions¶
To define an aggregate function, you must define a Python class (which is the function’s handler) that includes methods with the following names:
__init__
Initializes the internal state of an aggregate.
aggregate_state
The method must have a @property decorator.
Returns the internal state of an aggregate.
An aggregate state object can be any Python data type serializable by the Python pickle library.
For simple aggregate states, use a primitive Python data type. For more complex aggregate states, use Python data classes.
accumulate
Accumulates the state of the aggregate based on the new input row.
merge
Combines two intermediate aggregated states.
finish
Produces the final result based on the aggregated state.
An aggregate function aggregates state in child nodes and then, eventually, those aggregate states are serialized and sent to the parent node where they get merged and the final result is calculated.

Example: Calculate a Sum¶
In this example, we define a user-defined aggregate function to return the sum of the numeric values.
create or replace aggregate function python_sum(a int)
returns int
language python
runtime_version=3.8
handler='PythonSum'
as $$
class PythonSum:
def __init__(self):
# This aggregate state is a primitive Python data type.
self._partial_sum = 0
@property
def aggregate_state(self):
return self._partial_sum
def accumulate(self, input_value):
self._partial_sum += input_value
def merge(self, other_partial_sum):
self._partial_sum += other_partial_sum
def finish(self):
return self._partial_sum
$$;
Next, we’ll create a table of test data.
create or replace table sales(item string, price int);
insert into sales values ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
select * from sales;
Now, we will call our user-defined function, which is named python_sum
.
select python_sum(price) from sales;
We can compare our results with the output of the Snowflake system-defined SQL function, SUM, and see that the result is the same.
select sum(col) from sales;
We can also group by sum values by the item type in the sales table.
select item, python_sum(price) from sales group by item;
Example: Calculate an Average¶
In this example, we define a user-defined aggregate function to return the average of the numeric values.
create or replace aggregate function python_avg(a int)
returns float
language python
runtime_version=3.8
handler='PythonAvg'
as $$
from dataclasses import dataclass
@dataclass
class AvgAggState:
count: int
sum: int
class PythonAvg:
def __init__(self):
# This aggregate state is an object data type.
self._agg_state = AvgAggState(0, 0)
@property
def aggregate_state(self):
return self._agg_state
def accumulate(self, input_value):
sum = self._agg_state.sum
count = self._agg_state.count
self._agg_state.sum = sum + input_value
self._agg_state.count = count + 1
def merge(self, other_agg_state):
sum = self._agg_state.sum
count = self._agg_state.count
other_sum = other_agg_state.sum
other_count = other_agg_state.count
self._agg_state.sum = sum + other_sum
self._agg_state.count = count + other_count
def finish(self):
sum = self._agg_state.sum
count = self._agg_state.count
return sum / count
$$;
Next, we’ll create a table of test data.
create or replace table sales(item string, price int);
insert into sales values ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
Now, we will call our user-defined function, which is named python_avg
.
select python_avg(price) from sales;
We can compare our results with the output of the Snowflake system-defined SQL function, AVG, and see that the result is the same.
select avg(price) from sales;
We can also group avg values by the item type in the sales table.
select item, python_avg(price) from sales group by item;
Example: Return Only Unique Values¶
In this example, we take in an array and return an array containing only the unique values.
create or replace aggregate function pythonGetUniqueValues(input array)
returns array
language python
runtime_version = 3.8
handler = 'PythonGetUniqueValues'
AS $$
class PythonGetUniqueValues:
def __init__(self):
self._agg_state = set()
@property
def aggregate_state(self):
return self._agg_state
def accumulate(self, input):
self._agg_state.update(input)
def merge(self, other_agg_state):
self._agg_state.update(other_agg_state)
def finish(self):
return list(self._agg_state)
$$;
create or replace table array_table(x array) as
select array_construct(0, 1, 2, 3, 4, 'foo', 'bar', 'snowflake') union all
select array_construct(1, 3, 5, 7, 9, 'foo', 'barbar', 'snowpark') union all
select array_construct(0, 2, 4, 6, 8, 'snow');
select * from array_table;
select pythonGetUniqueValues(x) from array_table;
Example: Return a Count of Strings¶
In this example, we return counts of all instances of strings in an object
create or replace aggregate function pythonMapCount(input string)
returns object
language python
runtime_version = 3.8
handler = 'PythonMapCount'
AS $$
from collections import defaultdict
class PythonMapCount:
def __init__(self):
self._agg_state = defaultdict(int)
@property
def aggregate_state(self):
return self._agg_state
def accumulate(self, input):
# Increment count of lowercase input
self._agg_state[input.lower()] += 1
def merge(self, other_agg_state):
for item, count in other_agg_state.items():
self._agg_state[item] += count
def finish(self):
return dict(self._agg_state)
$$;
create or replace table string_table(x string);
insert into string_table select 'foo' from table(generator(rowcount => 1000));
insert into string_table select 'bar' from table(generator(rowcount => 2000));
insert into string_table select 'snowflake' from table(generator(rowcount => 50));
insert into string_table select 'snowpark' from table(generator(rowcount => 123));
insert into string_table select 'SnOw' from table(generator(rowcount => 1));
insert into string_table select 'snow' from table(generator(rowcount => 4));
select pythonMapCount(x) from string_table;
Example: Return Top K Largest Values¶
In this example, we return a list of the top k largest values. We accumulate negated input values on a min heap then return the top k largest values.
create or replace aggregate function pythonTopK(input int, k int)
returns array
language python
runtime_version = 3.8
handler = 'PythonTopK'
AS $$
import heapq
from dataclasses import dataclass
import itertools
from typing import List
@dataclass
class AggState:
minheap: List[int]
k: int
class PythonTopK:
def __init__(self):
self._agg_state = AggState([], 0)
@property
def aggregate_state(self):
return self._agg_state
@staticmethod
def get_top_k_items(minheap, k):
# Return k smallest elements if there are more than k elements on the min heap.
if (len(minheap) > k):
return [heapq.heappop(minheap) for i in range(k)]
return minheap
def accumulate(self, input, k):
self._agg_state.k = k
# Store the input as negative value, as heapq is a min heap.
heapq.heappush(self._agg_state.minheap, -input)
# Store only top k items on the min heap.
self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)
def merge(self, other_agg_state):
k = self._agg_state.k if self._agg_state.k > 0 else other_agg_state.k
# Merge two min heaps by popping off elements from one and pushing them onto another.
while(len(other_agg_state.minheap) > 0):
heapq.heappush(self._agg_state.minheap, heapq.heappop(other_agg_state.minheap))
# Store only k elements on the min heap.
self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)
def finish(self):
return [-x for x in self._agg_state.minheap]
$$;
create or replace table numbers_table(num_column int);
insert into numbers_table select 5 from table(generator(rowcount => 10));
insert into numbers_table select 1 from table(generator(rowcount => 10));
insert into numbers_table select 9 from table(generator(rowcount => 10));
insert into numbers_table select 7 from table(generator(rowcount => 10));
insert into numbers_table select 10 from table(generator(rowcount => 10));
insert into numbers_table select 3 from table(generator(rowcount => 10));
-- Return top 15 largest values from numbers_table.
select pythonTopK(num_column, 15) from numbers_table;