snowflake.snowpark.DataFrameAIFunctions.classify¶
- DataFrameAIFunctions.classify(input_column: Union[snowflake.snowpark.column.Column, str], categories: Union[List[str], Column], *, output_column: Optional[str] = None, **kwargs) snowflake.snowpark.DataFrame [source]¶
Classify text or images into specified categories using AI.
This method applies AI-based classification to each row, assigning one or more categories from the provided list based on the input content.
- Parameters:
input_column – The column (Column object or column name as string) containing the text or image data to classify.
categories – List of category strings or a Column containing an array of categories. Must contain at least 2 and no more than 100 categories.
output_column – The name of the output column to be appended. If not provided, a column named
AI_CLASSIFY_OUTPUT
is appended.**kwargs –
Configuration settings specified as key/value pairs. Supported keys:
task_description: A explanation of the classification task that is 50 words or fewer. This can help the model understand the context of the classification task and improve accuracy.
output_mode: Set to
multi
for multi-label classification. Defaults tosingle
for single-label classification.examples: A list of example objects for few-shot learning. Each example must include:
input: Example text to classify.
labels: List of correct categories for the input.
explanation: Explanation of why the input maps to those categories.
- Returns:
A new DataFrame with an appended output column containing classification results. The output is a JSON object with a
labels
field containing the assigned categories.
Examples:
>>> # Simple text classification with list of categories >>> from snowflake.snowpark.functions import col >>> import json >>> df = session.create_dataframe( ... [ ... ["I love hiking in the mountains"], ... ["My favorite dish is pasta carbonara"], ... ["Just finished reading a great book"], ... ], ... schema=["text"] ... ) >>> result_df = df.ai.classify( ... input_column="text", ... categories=["hiking", "cooking", "reading"], ... output_column="category" ... ) >>> result_df.columns ['TEXT', 'CATEGORY'] >>> results = result_df.collect() >>> json.loads(results[0]["CATEGORY"])["labels"][0] 'hiking' >>> # Image classification with Column containing categories >>> from snowflake.snowpark.functions import to_file >>> # Upload images to a stage first >>> _ = session.sql("CREATE OR REPLACE TEMP STAGE mystage ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')").collect() >>> _ = session.file.put("tests/resources/dog.jpg", "@mystage", auto_compress=False) >>> _ = session.file.put("tests/resources/cat.jpeg", "@mystage", auto_compress=False) >>> _ = session.file.put("tests/resources/kitchen.png", "@mystage", auto_compress=False) >>> # Create DataFrame with image paths and possible categories for each image >>> df = session.create_dataframe( ... [ ... ["@mystage/dog.jpg", ["cat", "dog", "bird", "fish"]], ... ["@mystage/cat.jpeg", ["cat", "dog", "rabbit", "hamster"]], ... ["@mystage/kitchen.png", ["kitchen", "bedroom", "bathroom", "living room"]], ... ], ... schema=["image_path", "categories"] ... ) >>> # Classify images using their respective category options >>> result_df = df.ai.classify( ... input_column=to_file(col("image_path")), ... categories=col("categories"), ... output_column="classification" ... ) >>> result_df.columns ['IMAGE_PATH', 'CATEGORIES', 'CLASSIFICATION'] >>> results = result_df.collect() >>> # Verify the dog image is classified as 'dog' >>> dog_result = [r for r in results if 'dog.jpg' in r["IMAGE_PATH"]][0] >>> json.loads(dog_result["CLASSIFICATION"])["labels"][0] 'dog' >>> # Multi-label classification with advanced configuration >>> df = session.create_dataframe( ... [ ... ["I enjoy traveling and trying local cuisines"], ... ["Reading books while on a flight"], ... ["Cooking recipes from different countries"], ... ], ... schema=["text"] ... ) >>> result_df = df.ai.classify( ... input_column="text", ... categories=["travel", "cooking", "reading", "sports"], ... output_column="topics", ... task_description="Identify all topics mentioned in the text", ... output_mode="multi", ... examples=[{ ... "input": "I love reading cookbooks during my travels", ... "labels": ["travel", "cooking", "reading"], ... "explanation": "The text mentions traveling, cookbooks (cooking), and reading" ... }] ... ) >>> result_df.columns ['TEXT', 'TOPICS'] >>> results = result_df.collect() >>> len(json.loads(results[0]["TOPICS"])["labels"]) >= 1 # Multi-label can have multiple labels True
This function or method is experimental since 1.39.0.