Python User-Defined Aggregate Functions

Introduction

User-defined aggregate functions (UDAFs) take one or more rows as input and produce a single row of output. They operate on values across multiple rows to perform mathematical calculations such as sum, average, counting, finding minimum or maximum values, standard deviation, and estimation, as well as some non-mathematical operations. Python UDAFs provide a way for you to write your own aggregate functions that are similar to the Snowflake system-defined SQL Aggregate Functions.

Limitations

  • The aggregate_state has a maximum size of 8 MB in a serialized version, so try to control the size of the aggregate state.

  • Tuples cannot be used in the intermediate aggregate state.

  • UDAFs cannot currently be called as a window function (i.e. with an OVER clause).

The Interface for Aggregate Functions

To define an aggregate function, you must define a Python class (which is the function’s handler) that includes methods with the following names:

  • __init__

    • Initializes the internal state of an aggregate.

  • aggregate_state

    • The method must have a @property decorator.

    • Returns the internal state of an aggregate.

    • An aggregate state object can be any Python data type serializable by the Python pickle library.

    • For simple aggregate states, use a primitive Python data type. For more complex aggregate states, use Python data classes.

  • accumulate

    • Accumulates the state of the aggregate based on the new input row.

  • merge

    • Combines two intermediate aggregated states.

  • finish

    • Produces the final result based on the aggregated state.

An aggregate function aggregates state in child nodes and then, eventually, those aggregate states are serialized and sent to the parent node where they get merged and the final result is calculated.

Diagram showing input values being accumulated in child nodes and then being sent to a parent node and merged to produce a final result.

Example: Calculate a Sum

In this example, we define a user-defined aggregate function to return the sum of the numeric values.

create or replace aggregate function python_sum(a int)
returns int
language python
runtime_version=3.8
handler='PythonSum'
as $$
class PythonSum:
  def __init__(self):
    # This aggregate state is a primitive Python data type.
    self._partial_sum = 0

  @property
  def aggregate_state(self):
    return self._partial_sum

  def accumulate(self, input_value):
    self._partial_sum += input_value

  def merge(self, other_partial_sum):
    self._partial_sum += other_partial_sum

  def finish(self):
    return self._partial_sum
$$;
Copy

Next, we’ll create a table of test data.

create or replace table sales(item string, price int);

insert into sales values ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);

select * from sales;
Copy

Now, we will call our user-defined function, which is named python_sum.

select python_sum(price) from sales;
Copy

We can compare our results with the output of the Snowflake system-defined SQL function, SUM, and see that the result is the same.

select sum(col) from sales;
Copy

We can also group by sum values by the item type in the sales table.

select item, python_sum(price) from sales group by item;
Copy

Example: Calculate an Average

In this example, we define a user-defined aggregate function to return the average of the numeric values.

create or replace aggregate function python_avg(a int)
returns float
language python
runtime_version=3.8
handler='PythonAvg'
as $$
from dataclasses import dataclass

@dataclass
class AvgAggState:
    count: int
    sum: int

class PythonAvg:
    def __init__(self):
        # This aggregate state is an object data type.
        self._agg_state = AvgAggState(0, 0)

    @property
    def aggregate_state(self):
        return self._agg_state

    def accumulate(self, input_value):
        sum = self._agg_state.sum
        count = self._agg_state.count

        self._agg_state.sum = sum + input_value
        self._agg_state.count = count + 1

    def merge(self, other_agg_state):
        sum = self._agg_state.sum
        count = self._agg_state.count

        other_sum = other_agg_state.sum
        other_count = other_agg_state.count

        self._agg_state.sum = sum + other_sum
        self._agg_state.count = count + other_count

    def finish(self):
        sum = self._agg_state.sum
        count = self._agg_state.count
        return sum / count
$$;
Copy

Next, we’ll create a table of test data.

create or replace table sales(item string, price int);
insert into sales values ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
Copy

Now, we will call our user-defined function, which is named python_avg.

select python_avg(price) from sales;
Copy

We can compare our results with the output of the Snowflake system-defined SQL function, AVG, and see that the result is the same.

select avg(price) from sales;
Copy

We can also group avg values by the item type in the sales table.

select item, python_avg(price) from sales group by item;
Copy

Example: Return Only Unique Values

In this example, we take in an array and return an array containing only the unique values.

create or replace aggregate function pythonGetUniqueValues(input array)
returns array
language python
runtime_version = 3.8
handler = 'PythonGetUniqueValues'
AS $$
class PythonGetUniqueValues:
    def __init__(self):
        self._agg_state = set()

    @property
    def aggregate_state(self):
        return self._agg_state

    def accumulate(self, input):
        self._agg_state.update(input)

    def merge(self, other_agg_state):
        self._agg_state.update(other_agg_state)

    def finish(self):
        return list(self._agg_state)
$$;


create or replace table array_table(x array) as
select array_construct(0, 1, 2, 3, 4, 'foo', 'bar', 'snowflake') union all
select array_construct(1, 3, 5, 7, 9, 'foo', 'barbar', 'snowpark') union all
select array_construct(0, 2, 4, 6, 8, 'snow');

select * from array_table;

select pythonGetUniqueValues(x) from array_table;
Copy

Example: Return a Count of Strings

In this example, we return counts of all instances of strings in an object

create or replace aggregate function pythonMapCount(input string)
returns object
language python
runtime_version = 3.8
handler = 'PythonMapCount'
AS $$
from collections import defaultdict

class PythonMapCount:
    def __init__(self):
        self._agg_state = defaultdict(int)

    @property
    def aggregate_state(self):
        return self._agg_state

    def accumulate(self, input):
        # Increment count of lowercase input
        self._agg_state[input.lower()] += 1

    def merge(self, other_agg_state):
        for item, count in other_agg_state.items():
            self._agg_state[item] += count

    def finish(self):
        return dict(self._agg_state)
$$;


create or replace table string_table(x string);
insert into string_table select 'foo' from table(generator(rowcount => 1000));
insert into string_table select 'bar' from table(generator(rowcount => 2000));
insert into string_table select 'snowflake' from table(generator(rowcount => 50));
insert into string_table select 'snowpark' from table(generator(rowcount => 123));
insert into string_table select 'SnOw' from table(generator(rowcount => 1));
insert into string_table select 'snow' from table(generator(rowcount => 4));

select pythonMapCount(x) from string_table;
Copy

Example: Return Top K Largest Values

In this example, we return a list of the top k largest values. We accumulate negated input values on a min heap then return the top k largest values.

create or replace aggregate function pythonTopK(input int, k int)
returns array
language python
runtime_version = 3.8
handler = 'PythonTopK'
AS $$
import heapq
from dataclasses import dataclass
import itertools
from typing import List

@dataclass
class AggState:
    minheap: List[int]
    k: int

class PythonTopK:
    def __init__(self):
        self._agg_state = AggState([], 0)

    @property
    def aggregate_state(self):
        return self._agg_state

    @staticmethod
    def get_top_k_items(minheap, k):
      # Return k smallest elements if there are more than k elements on the min heap.
      if (len(minheap) > k):
        return [heapq.heappop(minheap) for i in range(k)]
      return minheap

    def accumulate(self, input, k):
        self._agg_state.k = k

        # Store the input as negative value, as heapq is a min heap.
        heapq.heappush(self._agg_state.minheap, -input)

        # Store only top k items on the min heap.
        self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)

    def merge(self, other_agg_state):
        k = self._agg_state.k if self._agg_state.k > 0 else other_agg_state.k

        # Merge two min heaps by popping off elements from one and pushing them onto another.
        while(len(other_agg_state.minheap) > 0):
            heapq.heappush(self._agg_state.minheap, heapq.heappop(other_agg_state.minheap))

        # Store only k elements on the min heap.
        self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)

    def finish(self):
        return [-x for x in self._agg_state.minheap]
$$;


create or replace table numbers_table(num_column int);
insert into numbers_table select 5 from table(generator(rowcount => 10));
insert into numbers_table select 1 from table(generator(rowcount => 10));
insert into numbers_table select 9 from table(generator(rowcount => 10));
insert into numbers_table select 7 from table(generator(rowcount => 10));
insert into numbers_table select 10 from table(generator(rowcount => 10));
insert into numbers_table select 3 from table(generator(rowcount => 10));

-- Return top 15 largest values from numbers_table.
select pythonTopK(num_column, 15) from numbers_table;
Copy