Python 사용자 정의 집계 함수

사용자 정의 집계 함수(UDAFs)는 하나 이상의 행을 입력으로 받아 단일 행의 출력을 생성합니다. UDAF는 여러 행의 값을 연산하여 합계, 평균, 계산, 최소값 또는 최대값 찾기, 표준 편차, 추정과 같은 수학적 계산과 일부 비수학적 연산을 수행합니다.

Python UDAFs는 Snowflake 시스템 정의 SQL 집계 함수 와 유사한 자체 집계 함수를 작성할 수 있는 방법을 제공합니다.

Python에서 DataFrames용 사용자 정의 함수(UDAFs) 만들기 에 설명된 대로 Snowpark APIs를 사용하여 자체 UDAFs를 생성할 수도 있습니다.

제한 사항

  • aggregate_state 는 직렬화된 버전에서 최대 크기가 8MB이므로 집계 상태의 크기를 조정합니다.

  • UDAF를 윈도우 함수 로 호출할 수 없습니다(즉, OVER 절이 있는 경우).

  • IMMUTABLE은 집계 함수에서 지원되지 않습니다(AGGREGATE 매개 변수를 사용하는 경우). 따라서 모든 집계 함수는 기본적으로 VOLATILE입니다.

  • 사용자 정의 집계 함수는 WITHIN GROUP 절과 함께 사용할 수 없습니다. 쿼리가 실행되지 않습니다.

집계 함수 핸들러를 위한 인터페이스

집계 함수는 하위 노드의 상태를 집계한 다음, 이 집계 상태를 직렬화하여 상위 노드로 전송하고, 상위 노드에서 병합하여 최종 결과를 계산합니다.

집계 함수를 정의하려면 Snowflake가 런타임에 호출하는 메서드를 포함하는 Python 클래스(함수의 핸들러)를 정의해야 합니다. 이러한 메서드는 아래 테이블에 설명되어 있습니다. 이 항목의 다른 곳에 있는 예제를 참조하십시오.

메서드

요구 사항

설명

__init__

필수

집계의 내부 상태를 초기화합니다.

aggregate_state

필수

집계의 현재 상태를 반환합니다.

  • 메서드에는 @property decorator 가 있어야 합니다.

  • 집계 상태 오브젝트는 Python 피클 라이브러리 에서 직렬화할 수 있는 모든 Python 데이터 타입이 될 수 있습니다.

  • 단순 집계 상태의 경우 기본 Python 데이터 타입을 사용합니다. 더 복잡한 집계 상태의 경우 Python 데이터 클래스 를 사용합니다.

accumulate

필수

새로운 입력 행을 기준으로 집계 상태를 누적합니다.

merge

필수

두 개의 중간 집계 상태를 결합합니다.

finish

필수

집계된 상태를 기반으로 최종 결과를 생성합니다.

입력 값이 하위 노드에 누적된 다음 상위 노드로 전송되어 병합되어 최종 결과를 생성하는 과정을 보여 주는 다이어그램입니다.

예: 합계 계산

다음 예제의 코드는 숫자 값의 합계를 반환하는 python_sum 사용자 정의 집계 함수(UDAF)를 정의합니다.

  1. UDAF를 만듭니다.

    CREATE OR REPLACE AGGREGATE FUNCTION PYTHON_SUM(a INT)
    RETURNS INT
    LANGUAGE PYTHON
    RUNTIME_VERSION = 3.9
    handler = 'PythonSum'
    AS $$
    class PythonSum:
      def __init__(self):
        # This aggregate state is a primitive Python data type.
        self._partial_sum = 0
    
      @property
      def aggregate_state(self):
        return self._partial_sum
    
      def accumulate(self, input_value):
        self._partial_sum += input_value
    
      def merge(self, other_partial_sum):
        self._partial_sum += other_partial_sum
    
      def finish(self):
        return self._partial_sum
    $$;
    
    Copy
  2. 테스트 데이터 테이블을 만듭니다.

    CREATE OR REPLACE TABLE sales(item STRING, price INT);
    
    INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
    
    SELECT * FROM sales;
    
    Copy
  3. python_sum UDAF를 호출합니다.

    SELECT python_sum(price) FROM sales;
    
    Copy
  4. 결과를 Snowflake 시스템 정의 SQL 함수인 SUM 의 출력과 비교하여 결과가 동일한지 확인합니다.

    SELECT sum(col) FROM sales;
    
    Copy
  5. 판매 테이블에서 품목 유형별 합계 값으로 그룹화합니다.

    SELECT item, python_sum(price) FROM sales GROUP BY item;
    
    Copy

예: 평균 계산

다음 예제의 코드는 숫자 값의 평균을 반환하는 python_avg 사용자 정의 집계 함수를 정의합니다.

  1. 함수를 만듭니다.

    CREATE OR REPLACE AGGREGATE FUNCTION python_avg(a INT)
    RETURNS FLOAT
    LANGUAGE PYTHON
    RUNTIME_VERSION = 3.9
    HANDLER = 'PythonAvg'
    AS $$
    from dataclasses import dataclass
    
    @dataclass
    class AvgAggState:
        count: int
        sum: int
    
    class PythonAvg:
        def __init__(self):
            # This aggregate state is an object data type.
            self._agg_state = AvgAggState(0, 0)
    
        @property
        def aggregate_state(self):
            return self._agg_state
    
        def accumulate(self, input_value):
            sum = self._agg_state.sum
            count = self._agg_state.count
    
            self._agg_state.sum = sum + input_value
            self._agg_state.count = count + 1
    
        def merge(self, other_agg_state):
            sum = self._agg_state.sum
            count = self._agg_state.count
    
            other_sum = other_agg_state.sum
            other_count = other_agg_state.count
    
            self._agg_state.sum = sum + other_sum
            self._agg_state.count = count + other_count
    
        def finish(self):
            sum = self._agg_state.sum
            count = self._agg_state.count
            return sum / count
    $$;
    
    Copy
  2. 테스트 데이터 테이블을 만듭니다.

    CREATE OR REPLACE TABLE sales(item STRING, price INT);
    INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
    
    Copy
  3. python_avg 사용자 정의 함수를 호출합니다.

    SELECT python_avg(price) FROM sales;
    
    Copy
  4. 결과를 Snowflake 시스템 정의 SQL 함수인 AVG 의 출력과 비교하여 결과가 동일한지 확인합니다.

    SELECT avg(price) FROM sales;
    
    Copy
  5. 판매 테이블에서 품목 유형별로 평균값을 그룹화합니다.

    SELECT item, python_avg(price) FROM sales GROUP BY item;
    
    Copy

예: 고유 값만 반환

다음 예제의 코드는 배열을 받아서 고유 값만 포함된 배열을 반환합니다.

CREATE OR REPLACE AGGREGATE FUNCTION pythonGetUniqueValues(input ARRAY)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonGetUniqueValues'
AS $$
class PythonGetUniqueValues:
    def __init__(self):
        self._agg_state = set()

    @property
    def aggregate_state(self):
        return self._agg_state

    def accumulate(self, input):
        self._agg_state.update(input)

    def merge(self, other_agg_state):
        self._agg_state.update(other_agg_state)

    def finish(self):
        return list(self._agg_state)
$$;


CREATE OR REPLACE TABLE array_table(x array) AS
SELECT ARRAY_CONSTRUCT(0, 1, 2, 3, 4, 'foo', 'bar', 'snowflake') UNION ALL
SELECT ARRAY_CONSTRUCT(1, 3, 5, 7, 9, 'foo', 'barbar', 'snowpark') UNION ALL
SELECT ARRAY_CONSTRUCT(0, 2, 4, 6, 8, 'snow');

SELECT * FROM array_table;

SELECT pythonGetUniqueValues(x) FROM array_table;
Copy

예: 문자열 개수 반환

다음 예제의 코드는 오브젝트에서 문자열의 모든 인스턴스 수를 반환합니다.

CREATE OR REPLACE AGGREGATE FUNCTION pythonMapCount(input STRING)
RETURNS OBJECT
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonMapCount'
AS $$
from collections import defaultdict

class PythonMapCount:
    def __init__(self):
        self._agg_state = defaultdict(int)

    @property
    def aggregate_state(self):
        return self._agg_state

    def accumulate(self, input):
        # Increment count of lowercase input
        self._agg_state[input.lower()] += 1

    def merge(self, other_agg_state):
        for item, count in other_agg_state.items():
            self._agg_state[item] += count

    def finish(self):
        return dict(self._agg_state)
$$;

CREATE OR REPLACE TABLE string_table(x STRING);
INSERT INTO string_table SELECT 'foo' FROM TABLE(GENERATOR(ROWCOUNT => 1000));
INSERT INTO string_table SELECT 'bar' FROM TABLE(GENERATOR(ROWCOUNT => 2000));
INSERT INTO string_table SELECT 'snowflake' FROM TABLE(GENERATOR(ROWCOUNT => 50));
INSERT INTO string_table SELECT 'snowpark' FROM TABLE(GENERATOR(ROWCOUNT => 123));
INSERT INTO string_table SELECT 'SnOw' FROM TABLE(GENERATOR(ROWCOUNT => 1));
INSERT INTO string_table SELECT 'snow' FROM TABLE(GENERATOR(ROWCOUNT => 4));

SELECT pythonMapCount(x) FROM string_table;
Copy

예: 상위 k개의 가장 큰 값 반환

다음 예제의 코드는 k 에 대해 가장 큰 값의 목록을 반환합니다. 이 코드는 최소 힙에 음수 입력값을 누적한 다음 가장 큰 값의 상위 k 를 반환합니다.

CREATE OR REPLACE AGGREGATE FUNCTION pythonTopK(input INT, k INT)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.9
HANDLER = 'PythonTopK'
AS $$
import heapq
from dataclasses import dataclass
import itertools
from typing import List

@dataclass
class AggState:
    minheap: List[int]
    k: int

class PythonTopK:
    def __init__(self):
        self._agg_state = AggState([], 0)

    @property
    def aggregate_state(self):
        return self._agg_state

    @staticmethod
    def get_top_k_items(minheap, k):
      # Return k smallest elements if there are more than k elements on the min heap.
      if (len(minheap) > k):
        return [heapq.heappop(minheap) for i in range(k)]
      return minheap

    def accumulate(self, input, k):
        self._agg_state.k = k

        # Store the input as negative value, as heapq is a min heap.
        heapq.heappush(self._agg_state.minheap, -input)

        # Store only top k items on the min heap.
        self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)

    def merge(self, other_agg_state):
        k = self._agg_state.k if self._agg_state.k > 0 else other_agg_state.k

        # Merge two min heaps by popping off elements from one and pushing them onto another.
        while(len(other_agg_state.minheap) > 0):
            heapq.heappush(self._agg_state.minheap, heapq.heappop(other_agg_state.minheap))

        # Store only k elements on the min heap.
        self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)

    def finish(self):
        return [-x for x in self._agg_state.minheap]
$$;


CREATE OR REPLACE TABLE numbers_table(num_column INT);
INSERT INTO numbers_table SELECT 5 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 1 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 9 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 7 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 10 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 3 FROM TABLE(GENERATOR(ROWCOUNT => 10));

-- Return top 15 largest values from numbers_table.
SELECT pythonTopK(num_column, 15) FROM numbers_table;
Copy