snowflake.snowpark.DataFrameAIFunctions.classify¶

DataFrameAIFunctions.classify(input_column: Union[snowflake.snowpark.column.Column, str], categories: Union[List[str], Column], *, output_column: Optional[str] = None, **kwargs) → snowflake.snowpark.DataFrame[source]¶

Classify text or images into specified categories using AI.

This method applies AI-based classification to each row, assigning one or more categories from the provided list based on the input content.

Parameters:
  • input_column – The column (Column object or column name as string) containing the text or image data to classify.

  • categories – List of category strings or a Column containing an array of categories. Must contain at least 2 and no more than 100 categories.

  • output_column – The name of the output column to be appended. If not provided, a column named AI_CLASSIFY_OUTPUT is appended.

  • **kwargs –

    Configuration settings specified as key/value pairs. Supported keys:

    • task_description: A explanation of the classification task that is 50 words or fewer. This can help the model understand the context of the classification task and improve accuracy.

    • output_mode: Set to multi for multi-label classification. Defaults to single for single-label classification.

    • examples: A list of example objects for few-shot learning. Each example must include:

      • input: Example text to classify.

      • labels: List of correct categories for the input.

      • explanation: Explanation of why the input maps to those categories.

Returns:

A new DataFrame with an appended output column containing classification results. The output is a JSON object with a labels field containing the assigned categories.

Examples:

>>> # Simple text classification with list of categories
>>> from snowflake.snowpark.functions import col
>>> import json
>>> df = session.create_dataframe(
...     [
...         ["I love hiking in the mountains"],
...         ["My favorite dish is pasta carbonara"],
...         ["Just finished reading a great book"],
...     ],
...     schema=["text"]
... )
>>> result_df = df.ai.classify(
...     input_column="text",
...     categories=["hiking", "cooking", "reading"],
...     output_column="category"
... )
>>> result_df.columns
['TEXT', 'CATEGORY']
>>> results = result_df.collect()
>>> json.loads(results[0]["CATEGORY"])["labels"][0]
'hiking'

>>> # Image classification with Column containing categories
>>> from snowflake.snowpark.functions import to_file
>>> # Upload images to a stage first
>>> _ = session.sql("CREATE OR REPLACE TEMP STAGE mystage ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')").collect()
>>> _ = session.file.put("tests/resources/dog.jpg", "@mystage", auto_compress=False)
>>> _ = session.file.put("tests/resources/cat.jpeg", "@mystage", auto_compress=False)
>>> _ = session.file.put("tests/resources/kitchen.png", "@mystage", auto_compress=False)
>>> # Create DataFrame with image paths and possible categories for each image
>>> df = session.create_dataframe(
...     [
...         ["@mystage/dog.jpg", ["cat", "dog", "bird", "fish"]],
...         ["@mystage/cat.jpeg", ["cat", "dog", "rabbit", "hamster"]],
...         ["@mystage/kitchen.png", ["kitchen", "bedroom", "bathroom", "living room"]],
...     ],
...     schema=["image_path", "categories"]
... )
>>> # Classify images using their respective category options
>>> result_df = df.ai.classify(
...     input_column=to_file(col("image_path")),
...     categories=col("categories"),
...     output_column="classification"
... )
>>> result_df.columns
['IMAGE_PATH', 'CATEGORIES', 'CLASSIFICATION']
>>> results = result_df.collect()
>>> # Verify the dog image is classified as 'dog'
>>> dog_result = [r for r in results if 'dog.jpg' in r["IMAGE_PATH"]][0]
>>> json.loads(dog_result["CLASSIFICATION"])["labels"][0]
'dog'

>>> # Multi-label classification with advanced configuration
>>> df = session.create_dataframe(
...     [
...         ["I enjoy traveling and trying local cuisines"],
...         ["Reading books while on a flight"],
...         ["Cooking recipes from different countries"],
...     ],
...     schema=["text"]
... )
>>> result_df = df.ai.classify(
...     input_column="text",
...     categories=["travel", "cooking", "reading", "sports"],
...     output_column="topics",
...     task_description="Identify all topics mentioned in the text",
...     output_mode="multi",
...     examples=[{
...         "input": "I love reading cookbooks during my travels",
...         "labels": ["travel", "cooking", "reading"],
...         "explanation": "The text mentions traveling, cookbooks (cooking), and reading"
...     }]
... )
>>> result_df.columns
['TEXT', 'TOPICS']
>>> results = result_df.collect()
>>> len(json.loads(results[0]["TOPICS"])["labels"]) >= 1  # Multi-label can have multiple labels
True
Copy

This function or method is experimental since 1.39.0.