ローカルテストフレームワーク

このトピックでは、Snowparkライブラリを使用する際にローカルでコードをテストする方法について説明します。

このトピックの内容:

Snowpark Pythonローカルテストフレームワークを使用すると、Snowflakeアカウントに接続することなく、ローカルでSnowpark Python DataFrames を作成および操作できます。ローカルテストフレームワークを使用すると、コードの変更をアカウントに展開する前に、開発マシンまたは CI (継続的統合)パイプラインで DataFrame 操作をローカルにテストできます。API は同じであるため、コードを変更せずに、テストをローカルで実行することも、Snowflakeアカウントに対して実行することもできます。

前提条件

ローカルテストフレームワークを使用するには

  • オプションで依存関係 pandas があるバージョン1.11.1以上のSnowpark Pythonライブラリを使用する必要があります。 pip install "snowflake-snowpark-python[pandas]" を実行してインストールします。

  • サポートされているPythonのバージョンは次のとおりです。

    • 3.8

    • 3.9

    • 3.10

    • 3.11

セッションの作成とローカルテストの有効化

開始するには、Snowpark Session を作成し、ローカルテスト構成を True に設定します。

from snowflake.snowpark import Session

session = Session.builder.config('local_testing', True).create()
Copy

セッションが作成されたら、それを使って DataFrames の作成と操作ができます。

df = session.create_dataframe([[1,2],[3,4]],['a','b'])
df.with_column('c', df['a']+df['b']).show()
Copy

データのロード

Pythonのプリミティブ、ファイル、Pandas DataFrames からSnowpark DataFrames を作成できます。これは、テストケースの入力と予想される出力を指定するのに役立ちます。このようにすることで、データがソース管理内に置かれ、テストデータとテストケースの同期の維持が簡単になります。

CSV データの読み込み

まず Session.file.put() を呼び出してメモリ内ステージにファイルを読み込み、次に Session.read() を使用してコンテンツを読み取ることで、 CSV ファイルをSnowpark DataFrame に読み込むことができます。ファイル data.csv があり、そのファイルには次のようなコンテンツがあるとします。

col1,col2,col3,col4
1,a,true,1.23
2,b,false,4.56
Copy

次のコードを使って data.csv をSnowpark DataFrame に読み込むことができます。まずファイルをステージに配置する必要があります。そうしないと、ファイルが見つからないというエラーが表示されます。

from snowflake.snowpark.types import StructType, StructField, IntegerType, BooleanType, StringType, DoubleType


# Put file onto stage
session.file.put("data.csv", "@mystage", auto_compress=False)
schema = StructType(
    [
        StructField("col1", IntegerType()),
        StructField("col2", StringType()),
        StructField("col3", BooleanType()),
        StructField("col4", DoubleType()),
    ]
)

# with option SKIP_HEADER set to 1, the header will be skipped when the csv file is loaded
dataframe = session.read.schema(schema).option("SKIP_HEADER", 1).csv("@mystage/data.csv")
dataframe.show()
Copy

dataframe.show() の出力は次のとおりです。

-------------------------------------
|"COL1"  |"COL2"  |"COL3"  |"COL4"  |
-------------------------------------
|1       |a       |True    |1.23    |
|2       |b       |False   |4.56    |
-------------------------------------

Pandasデータの読み込み

Pandas DataFrame からSnowpark Python DataFrame を作成するには、 create_dataframe メソッドを呼び出し、データをPandas DataFrame として渡します。

import pandas as pd

pandas_df = pd.DataFrame(
    data={
        "col1": pd.Series(["value1", "value2"]),
        "col2": pd.Series([1.23, 4.56]),
        "col3": pd.Series([123, 456]),
        "col4": pd.Series([True, False]),
    }
)

dataframe = session.create_dataframe(data=pandas_df)
dataframe.show()
Copy

dataframe.show() の出力は次のとおりです。

-------------------------------------
|"col1"  |"col2"  |"col3"  |"col4"  |
-------------------------------------
|value1  |1.23    |123     |True    |
|value2  |4.56    |456     |False   |
-------------------------------------

Snowpark Python DataFrame は、 DataFrame の to_pandas メソッドを呼び出すことで、Pandas DataFrame に変換することもできます。

from snowflake.snowpark.types import StructType, StructField, StringType, DoubleType, LongType, BooleanType

dataframe = session.create_dataframe(
    data=[
        ["value1", 1.23, 123, True],
        ["value2", 4.56, 456, False],
    ],
    schema=StructType([
        StructField("col1", StringType()),
        StructField("col2", DoubleType()),
        StructField("col3", LongType()),
        StructField("col4", BooleanType()),
    ])
)

pandas_dataframe = dataframe.to_pandas()
print(pandas_dataframe.to_string())
Copy

print(pandas_dataframe.to_string()) の呼び出しの出力は次のとおりです。

    COL1  COL2  COL3   COL4
0  value1  1.23   123   True
1  value2  4.56   456  False

セッションの PyTest フィクスチャの作成

PyTest フィクスチャ は、テスト(またはテストのモジュール)の前に実行される関数で、通常はテストにデータや接続を提供します。このケースでは、Snowpark Session オブジェクトを返すフィクスチャを作成します。まず、 test ディレクトリがない場合は、それを作成します。次に、 test ディレクトリに次のコンテンツのファイル conftest.py を作成します。 connection_parameters は、Snowflakeアカウントの認証情報があるディクショナリです。ディクショナリフォーマットの詳細については、 セッションの作成 をご参照ください。

# test/conftest.py
import pytest
from snowflake.snowpark.session import Session

def pytest_addoption(parser):
    parser.addoption("--snowflake-session", action="store", default="live")

@pytest.fixture(scope='module')
def session(request) -> Session:
    if request.config.getoption('--snowflake-session') == 'local':
        return Session.builder.config('local_testing', True).create()
    else:
        return Session.builder.configs(CONNECTION_PARAMETERS).create()
Copy

pytest_addoption の呼び出しにより、 pytest コマンドに snowflake-session というコマンドラインオプションが追加されます。 Session フィクスチャはこのコマンドラインオプションをチェックし、その値に応じてローカルかライブの Session を作成します。これにより、ローカルモードとライブモードを簡単に切り替えてテストすることができます。

# Using local mode:
pytest --snowflake-session local

# Using live mode
pytest
Copy

SQL 操作

Session.sql(...) はローカルテストフレームワークではサポートされていません。可能な限りSnowparkの DataFrame APIs を使用します。 Session.sql(...) を使用しなければならない場合は、Pythonの unittest.mock.patch を使用して表形式の戻り値をモックし、指定された Session.sql() 呼び出しから予想される応答をパッチすることができます。

以下の例では、 mock_sql() は SQL クエリテキストを DataFrame 応答にマッピングします。次の条件ステートメントは、現在のセッションがローカルテストを使っているかどうかをチェックし、そうであれば Session.sql() メソッドにパッチを適用します。

from unittest import mock
from functools import partial

def test_something(pytestconfig, session):

    def mock_sql(session, sql_string):  # patch for SQL operations
        if sql_string == "select 1,2,3":
            return session.create_dataframe([[1,2,3]])
        else:
            raise RuntimeError(f"Unexpected query execution: {sql_string}")

    if pytestconfig.getoption('--snowflake-session') == 'local':
        with mock.patch.object(session, 'sql', wraps=partial(mock_sql, session)): # apply patch for SQL operations
            assert session.sql("select 1,2,3").collect() == [Row(1,2,3)]
    else:
        assert session.sql("select 1,2,3").collect() == [Row(1,2,3)]
Copy

ローカルテストが有効な場合、 DataFrame.save_as_table() で作成されたすべてのテーブルは仮テーブルとしてメモリに保存され、 Session.table() を使用して取得することができます。サポートされている DataFrame 操作は、通常通りテーブル上で使用できます。

組み込み関数のパッチ

snowflake.snowpark.functions の組み込み関数のすべてが、ローカルテストフレームワークでサポートされているわけではありません。サポートされていない関数を使用する場合は、 snowflake.snowpark.mock@patch デコレーターを使用してパッチを作成する必要があります。

パッチされた関数を定義し実装するには、署名(パラメーターリスト)が組み込み関数のパラメーターと一致している必要があります。ローカルテストフレームワークは、次のルールを使ってパッチされた関数にパラメーターを渡します。

  • 組み込み関数の署名の ColumnOrName 型のパラメーターでは、 ColumnEmulator がパッチされた関数のパラメーターとして渡されます。 ColumnEmulator は列データを含む pandas.Series オブジェクトに似ています。

  • 組み込み関数の署名の LiteralType 型のパラメーターについては、リテラル値がパッチされた関数のパラメーターとして渡されます。

  • そうでない場合は、生の値がパッチされた関数のパラメーターとして渡されます。

パッチされた関数の戻り型に関しては、 ColumnEmulator のインスタンスを返すことが、組み込み関数の Column の戻り型に対して予想されます。

たとえば、組み込み関数 to_timestamp() は次のようにパッチすることができます。

import datetime
from snowflake.snowpark.mock import patch, ColumnEmulator, ColumnType
from snowflake.snowpark.functions import to_timestamp
from snowflake.snowpark.types import TimestampType

@patch(to_timestamp)
def mock_to_timestamp(column: ColumnEmulator, format = None) -> ColumnEmulator:
    ret_column = ColumnEmulator(data=[datetime.datetime.strptime(row, '%Y-%m-%dT%H:%M:%S%z') for row in column])
    ret_column.sf_type = ColumnType(TimestampType(), True)
    return ret_column
Copy

テストケースのスキップ

PyTest テストスイートにローカルテストでうまくサポートされないテストケースが含まれている場合は、 PyTest の mark.skipif デコレーターを使用してそれらのケースをスキップすることができます。次の例は、前述のようにセッションとパラメーターが構成されていると想定します。条件では local_testing_modelocal に設定されているかどうかをチェックし、設定されている場合はテストケースをスキップし、その理由を説明するメッセージを表示します。

import pytest

@pytest.mark.skipif(
    condition="config.getvalue('local_testing_mode') == 'local'",
reason="Test case disabled for local testing"
)
def test_case(session):
    ...
Copy

制限事項

次の機能はサポートされていません。

  • 生の SQL 文字列と、 SQL 文字列の解析を必要とする操作。たとえば、 session.sqlDataFrame.filter("col1 > 12") はサポートされていません。

  • UDFs、 UDTFs、およびストアドプロシージャ

  • テーブル関数。

  • AsyncJobs。

  • ウェアハウス、スキーマ、その他のセッションプロパティの変更などのセッション操作。

  • GeometryGeography のデータ型。

  • ウィンドウ関数の集約。

    # Selecting window function expressions is supported
    df.select("key", "value", sum_("value").over(), avg("value").over())
    
    # Aggregating window function expressions is NOT supported
    df.group_by("key").agg([sum_("value"), sum_(sum_("value")).over(window) - sum_("value")])
    
    Copy

その他の制限事項は次のとおりです。

  • VariantArrayObject のデータ型は、標準の JSON エンコードとデコードでのみサポートされています。{1,2,,3,} のような式は、Snowflakeでは有効な JSON とみなされますが、Pythonの組み込み JSON 関数が使用されるローカルテストでは無効です。モジュールレベルの変数 snowflake.snowpark.mock.CUSTOM_JSON_ENCODERsnowflake.snowpark.mock.CUSTOM_JSON_DECODER を指定して、デフォルト設定を上書きすることができます。

  • Snowflakeの関数のサブセット(ウィンドウ関数を含む)のみが実装されます。独自の関数定義を注入する方法については、 組み込み関数のパッチ をご参照ください。

  • ランク関連関数のパッチは現在サポートされていません。

  • 同じ名前の列を選択すると、1つの列のみが返されます。回避策として、 Column.alias を使用して列の名前を変更します。

    df.select(lit(1), lit(1)).show() # col("a"), col("a")
    #---------
    #|"'1'"  |
    #---------
    #|1      |
    #|...    |
    #---------
    
    # Workaround: Column.alias
    DataFrame.select(lit(1).alias("col1_1"), lit(1).alias("col1_2"))
    # "col1_1", "col1_2"
    
    Copy
  • Column.cast を使用した明示的な型キャストには、フォーマット文字列が以下に対してサポートされていないという制限があります。入力: to_decimalto_numberto_numericto_doubleto_dateto_timeto_timestamp および出力: to_charto_varcharto_binary

  • JSON VariantType に格納された文字列は Datetime 型に変換できません。

  • Table.mergeTable.update については、セッションパラメーター ERROR_ON_NONDETERMINISTIC_UPDATEERROR_ON_NONDETERMINISTIC_MERGEFalse に設定されているときの動作のみをサポートする実装です。これは、複数結合の場合、マッチした行の1つを更新することを意味します。

サポートされた APIs のリスト

Snowparkセッション

Session.createDataFrame

Session.create_dataframe

Session.flatten

Session.range

Session.table

入力/出力

DataFrameReader.csv

DataFrameReader.table

DataFrameWriter.saveAsTable

DataFrameWriter.save_as_table

DataFrame

DataFrame.agg

DataFrame.cache_result

DataFrame.col

DataFrame.collect

DataFrame.collect_nowait

DataFrame.copy_into_table

DataFrame.count

DataFrame.createOrReplaceTempView

DataFrame.createOrReplaceView

DataFrame.create_or_replace_temp_view

DataFrame.create_or_replace_view

DataFrame.crossJoin

DataFrame.cross_join

DataFrame.distinct

DataFrame.drop

DataFrame.dropDuplicates

DataFrame.drop_duplicates

DataFrame.dropna

DataFrame.except_

DataFrame.explain

DataFrame.fillna

DataFrame.filter

DataFrame.first

DataFrame.groupBy

DataFrame.group_by

DataFrame.intersect

DataFrame.join

DataFrame.limit

DataFrame.minus

DataFrame.natural_join

DataFrame.orderBy

DataFrame.order_by

DataFrame.rename

DataFrame.replace

DataFrame.rollup

DataFrame.sample

DataFrame.select

DataFrame.show

DataFrame.sort

DataFrame.subtract

DataFrame.take

DataFrame.toDF

DataFrame.toLocalIterator

DataFrame.toPandas

DataFrame.to_df

DataFrame.to_local_iterator

DataFrame.to_pandas

DataFrame.to_pandas_batches

DataFrame.union

DataFrame.unionAll

DataFrame.unionAllByName

DataFrame.unionByName

DataFrame.union_all

DataFrame.union_all_by_name

DataFrame.union_by_name

DataFrame.unpivot

DataFrame.where

DataFrame.withColumn

DataFrame.withColumnRenamed

DataFrame.with_column

DataFrame.with_column_renamed

DataFrame.with_columns

DataFrameNaFunctions.drop

DataFrameNaFunctions.fill

DataFrameNaFunctions.replace

Column.alias

Column.as_

Column.asc

Column.asc_nulls_first

Column.asc_nulls_last

Column.astype

Column.between

Column.bitand

Column.bitor

Column.bitwiseAnd

Column.bitwiseOR

Column.bitwiseXOR

Column.bitxor

Column.cast

Column.collate

Column.desc

Column.desc_nulls_first

Column.desc_nulls_last

Column.endswith

Column.eqNullSafe

Column.equal_nan

Column.equal_null

Column.getItem

Column.getName

Column.get_name

Column.in_

Column.isNotNull

Column.isNull

Column.is_not_null

Column.is_null

Column.isin

Column.like

Column.name

Column.over

Column.regexp

Column.rlike

Column.startswith

Column.substr

Column.substring

Column.try_cast

Column.within_group

CaseExpr.when

CaseExpr.otherwise

データ型

ArrayType

BinaryType

BooleanType

ByteType

ColumnIdentifier

DataType

DateType

DecimalType

DoubleType

FloatType

IntegerType

LongType

MapType

NullType

ShortType

StringType

StructField

StructType

Timestamp

TimestampType

TimeType

Variant

VariantType

Row.asDict

Row.as_dict

Row.count

Row.index

関数

abs

avg

coalesce

contains

count

count_distinct

covar_pop

endswith

first_value

iff

lag

last_value

lead

list_agg

max

median

min

parse_json

row_number

startswith

substring

sum

to_array

to_binary

to_boolean

to_char

to_date

to_decimal

to_double

to_object

to_time

to_timestamp

to_variant

Window

Window.orderBy

Window.order_by

Window.partitionBy

Window.partition_by

Window.rangeBetween

Window.range_between

Window.rowsBetween

Window.rows_between

WindowSpec.orderBy

WindowSpec.order_by

WindowSpec.partitionBy

WindowSpec.partition_by

WindowSpec.rangeBetween

WindowSpec.range_between

WindowSpec.rowsBetween

WindowSpec.rows_between

グループ化

RelationalGroupedDataFrame.agg

RelationalGroupedDataFrame.apply_in_pandas

RelationalGroupedDataFrame.applyInPandas

RelationalGroupedDataFrame.avg

RelationalGroupedDataFrame.builtin

RelationalGroupedDataFrame.count

RelationalGroupedDataFrame.function

RelationalGroupedDataFrame.max

RelationalGroupedDataFrame.mean

RelationalGroupedDataFrame.median

RelationalGroupedDataFrame.min

RelationalGroupedDataFrame.sum

テーブル

Table.delete

Table.drop_table

Table.merge

Table.sample

Table.update

WhenMatchedClause.delete

WhenMatchedClause.update

WhenNotMatchedClause.insert