snowflake.snowpark.functions.ai_classify¶
- snowflake.snowpark.functions.ai_classify(expr: Union[Column, str], list_of_categories: Union[Column, List[str]], **kwargs) Column [source]¶
Classifies text or images into categories that you specify.
- Parameters:
expr – The string, image, or a SQL object from
prompt()
that you’re classifying. If you’re classifying text, the input string is case sensitive. You might get different results if you use different capitalization.list_of_categories – An array of strings that represents the different categories. Categories are case-sensitive. The array must contain at least 2 and no more than 100 categories. If the requirements aren’t met, the function returns an error.
**kwargs –
Configuration settings specified as key/value pairs. Supported keys:
- task_description: A explanation of the classification task that is 50 words or fewer.
This can help the model understand the context of the classification task and improve accuracy.
output_mode: Set to
multi
for multi-label classification. Defaults to single for single-label classification.examples: A list of example objects for few-shot learning. Each example must include:
input: Example text to classify.
labels: List of correct categories for the input.
explanation: Explanation of why the input maps to those categories.
- Returns:
A serialized object. The object’s
label
field is a string that specifies the category to which the input belongs. If you specify invalid values for the arguments, an error is returned.
Examples:
>>> # for text >>> session.range(1).select(ai_classify('One day I will see the world', ['travel', 'cooking']).alias("answer")).show() ----------------- |"ANSWER" | ----------------- |{ | | "labels": [ | | "travel" | | ] | |} | ----------------- >>> df = session.create_dataframe([ ... ['France', ['North America', 'Europe', 'Asia']], ... ['Singapore', ['North America', 'Europe', 'Asia']], ... ['one day I will see the world', ['travel', 'cooking', 'dancing']], ... ['my lobster bisque is second to none', ['travel', 'cooking', 'dancing']] ... ], schema=["data", "category"]) >>> df.select("data", ai_classify(col("data"), col("category"))["labels"][0].alias("class")).sort("data").show() --------------------------------------------------- |"DATA" |"CLASS" | --------------------------------------------------- |France |"Europe" | |Singapore |"Asia" | |my lobster bisque is second to none |"cooking" | |one day I will see the world |"travel" | --------------------------------------------------- >>> # using kwargs for advanced configuration >>> session.range(1).select( ... ai_classify( ... 'One day I will see the world and learn to cook my favorite dishes', ... ['travel', 'cooking', 'reading', 'driving'], ... task_description='Determine topics related to the given text', ... output_mode='multi', ... examples=[{ ... 'input': 'i love traveling with a good book', ... 'labels': ['travel', 'reading'], ... 'explanation': 'the text mentions traveling and a good book which relates to reading' ... }] ... ).alias("answer") ... ).show() ------------------ |"ANSWER" | ------------------ |{ | | "labels": [ | | "cooking", | | "travel" | | ] | |} | ------------------ >>> # for image >>> _ = session.sql("CREATE OR REPLACE TEMP STAGE mystage ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')").collect() >>> _ = session.file.put("tests/resources/dog.jpg", "@mystage", auto_compress=False) >>> df = session.range(1).select( ... ai_classify( ... prompt("Please help me classify the dog within this image {0}", to_file("@mystage/dog.jpg")), ... ["French Bulldog", "Golden Retriever", "Bichon", "Cavapoo", "Beagle"] ... ).alias("classes") ... ) >>> df.show() ----------------- |"CLASSES" | ----------------- |{ | | "labels": [ | | "Cavapoo" | | ] | |} | -----------------