User-defined aggregate functions (UDAFs) take one or more rows as input and produce a single row of output.
They operate on values across multiple rows to perform mathematical calculations such as sum, average, counting,
finding minimum or maximum values, standard deviation, and estimation, as well as some non-mathematical operations.
Python UDAFs provide a way for you to write your own aggregate functions that are similar to the Snowflake system-defined
SQL aggregate functions.
The aggregate_state has a maximum size of 64 MB in a serialized version, so try to control the size of the aggregate state.
You can’t call a UDAF as a window function (in other words, with an OVER clause).
IMMUTABLE is not supported on an aggregate function (when you use the AGGREGATE parameter). Therefore, all aggregate functions are
VOLATILE by default.
User-defined aggregate functions cannot be used in conjunction with the WITHIN GROUP clause. Queries will fail to execute.
An aggregate function aggregates state in child nodes and then, eventually, those aggregate states are serialized and sent to the parent
node where they get merged and the final result is calculated.
To define an aggregate function, you must define a Python class (which is the function’s handler) that includes methods that Snowflake
invokes at run time. Those methods are described in the table below. See examples elsewhere in this topic.
Code in the following example defines a python_sum user-defined aggregate function (UDAF) to return the sum of the numeric values.
Create the UDAF.
CREATEORREPLACEAGGREGATEFUNCTIONPYTHON_SUM(aINT)RETURNSINTLANGUAGEPYTHONRUNTIME_VERSION=3.12HANDLER='PythonSum'AS$$classPythonSum:def__init__(self):# This aggregate state is a primitive Python data type.self._partial_sum=0@propertydefaggregate_state(self):returnself._partial_sumdefaccumulate(self,input_value):self._partial_sum+=input_valuedefmerge(self,other_partial_sum):self._partial_sum+=other_partial_sumdeffinish(self):returnself._partial_sum$$;
Code in the following example defines a python_avg user-defined aggregate function to return the average of the numeric values.
Create the function.
CREATEORREPLACEAGGREGATEFUNCTIONpython_avg(aINT)RETURNSFLOATLANGUAGEPYTHONRUNTIME_VERSION=3.12HANDLER='PythonAvg'AS$$fromdataclassesimportdataclass@dataclassclassAvgAggState:count:intsum:intclassPythonAvg:def__init__(self):# This aggregate state is an object data type.self._agg_state=AvgAggState(0,0)@propertydefaggregate_state(self):returnself._agg_statedefaccumulate(self,input_value):sum=self._agg_state.sumcount=self._agg_state.countself._agg_state.sum=sum+input_valueself._agg_state.count=count+1defmerge(self,other_agg_state):sum=self._agg_state.sumcount=self._agg_state.countother_sum=other_agg_state.sumother_count=other_agg_state.countself._agg_state.sum=sum+other_sumself._agg_state.count=count+other_countdeffinish(self):sum=self._agg_state.sumcount=self._agg_state.countreturnsum/count$$;
Code in the following example returns a list of the top largest values for k. The code accumulates negated input values on a min
heap, then returns the top k largest values.
CREATEORREPLACEAGGREGATEFUNCTIONpythonTopK(inputINT,kINT)RETURNSARRAYLANGUAGEPYTHONRUNTIME_VERSION=3.12HANDLER='PythonTopK'AS$$importheapqfromdataclassesimportdataclassimportitertoolsfromtypingimportList@dataclassclassAggState:minheap:List[int]k:intclassPythonTopK:def__init__(self):self._agg_state=AggState([],0)@propertydefaggregate_state(self):returnself._agg_state@staticmethoddefget_top_k_items(minheap,k):# Return k smallest elements if there are more than k elements on the min heap.if(len(minheap)>k):return[heapq.heappop(minheap)foriinrange(k)]returnminheapdefaccumulate(self,input,k):self._agg_state.k=k# Store the input as negative value, as heapq is a min heap.heapq.heappush(self._agg_state.minheap,-input)# Store only top k items on the min heap.self._agg_state.minheap=self.get_top_k_items(self._agg_state.minheap,k)defmerge(self,other_agg_state):k=self._agg_state.kifself._agg_state.k>0elseother_agg_state.k# Merge two min heaps by popping off elements from one and pushing them onto another.while(len(other_agg_state.minheap)>0):heapq.heappush(self._agg_state.minheap,heapq.heappop(other_agg_state.minheap))# Store only k elements on the min heap.self._agg_state.minheap=self.get_top_k_items(self._agg_state.minheap,k)deffinish(self):return[-xforxinself._agg_state.minheap]$$;
CREATEORREPLACETABLEnumbers_table(num_columnINT);INSERTINTOnumbers_tableSELECT5FROMTABLE(GENERATOR(ROWCOUNT=>10));INSERTINTOnumbers_tableSELECT1FROMTABLE(GENERATOR(ROWCOUNT=>10));INSERTINTOnumbers_tableSELECT9FROMTABLE(GENERATOR(ROWCOUNT=>10));INSERTINTOnumbers_tableSELECT7FROMTABLE(GENERATOR(ROWCOUNT=>10));INSERTINTOnumbers_tableSELECT10FROMTABLE(GENERATOR(ROWCOUNT=>10));INSERTINTOnumbers_tableSELECT3FROMTABLE(GENERATOR(ROWCOUNT=>10));-- Return top 15 largest values from numbers_table.SELECTpythonTopK(num_column,15)FROMnumbers_table;