snowflake.snowpark.functions.ai_classify¶

snowflake.snowpark.functions.ai_classify(expr: Union[Column, str], list_of_categories: Union[Column, List[str]], **kwargs) → Column[source]¶

Classifies text or images into categories that you specify.

Parameters:
  • expr – The string, image, or a SQL object from prompt() that you’re classifying. If you’re classifying text, the input string is case sensitive. You might get different results if you use different capitalization.

  • list_of_categories – An array of strings that represents the different categories. Categories are case-sensitive. The array must contain at least 2 and no more than 100 categories. If the requirements aren’t met, the function returns an error.

  • **kwargs –

    Configuration settings specified as key/value pairs. Supported keys:

    • task_description: A explanation of the classification task that is 50 words or fewer.

      This can help the model understand the context of the classification task and improve accuracy.

    • output_mode: Set to multi for multi-label classification. Defaults to single for single-label classification.

    • examples: A list of example objects for few-shot learning. Each example must include:

      • input: Example text to classify.

      • labels: List of correct categories for the input.

      • explanation: Explanation of why the input maps to those categories.

Returns:

A serialized object. The object’s label field is a string that specifies the category to which the input belongs. If you specify invalid values for the arguments, an error is returned.

Examples:

>>> # for text
>>> session.range(1).select(ai_classify('One day I will see the world', ['travel', 'cooking']).alias("answer")).show()
-----------------
|"ANSWER"       |
-----------------
|{              |
|  "labels": [  |
|    "travel"   |
|  ]            |
|}              |
-----------------

>>> df = session.create_dataframe([
...     ['France', ['North America', 'Europe', 'Asia']],
...     ['Singapore', ['North America', 'Europe', 'Asia']],
...     ['one day I will see the world', ['travel', 'cooking', 'dancing']],
...     ['my lobster bisque is second to none', ['travel', 'cooking', 'dancing']]
... ], schema=["data", "category"])
>>> df.select("data", ai_classify(col("data"), col("category"))["labels"][0].alias("class")).sort("data").show()
---------------------------------------------------
|"DATA"                               |"CLASS"    |
---------------------------------------------------
|France                               |"Europe"   |
|Singapore                            |"Asia"     |
|my lobster bisque is second to none  |"cooking"  |
|one day I will see the world         |"travel"   |
---------------------------------------------------


>>> # using kwargs for advanced configuration
>>> session.range(1).select(
...     ai_classify(
...         'One day I will see the world and learn to cook my favorite dishes',
...         ['travel', 'cooking', 'reading', 'driving'],
...         task_description='Determine topics related to the given text',
...         output_mode='multi',
...         examples=[{
...             'input': 'i love traveling with a good book',
...             'labels': ['travel', 'reading'],
...             'explanation': 'the text mentions traveling and a good book which relates to reading'
...         }]
...     ).alias("answer")
... ).show()
------------------
|"ANSWER"        |
------------------
|{               |
|  "labels": [   |
|    "cooking",  |
|    "travel"    |
|  ]             |
|}               |
------------------


>>> # for image
>>> _ = session.sql("CREATE OR REPLACE TEMP STAGE mystage ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')").collect()
>>> _ = session.file.put("tests/resources/dog.jpg", "@mystage", auto_compress=False)
>>> df = session.range(1).select(
...     ai_classify(
...         prompt("Please help me classify the dog within this image {0}", to_file("@mystage/dog.jpg")),
...         ["French Bulldog", "Golden Retriever", "Bichon", "Cavapoo", "Beagle"]
...     ).alias("classes")
... )
>>> df.show()
-----------------
|"CLASSES"      |
-----------------
|{              |
|  "labels": [  |
|    "Cavapoo"  |
|  ]            |
|}              |
-----------------
Copy