from snowflake.ml.registry import registry
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
from snowflake.snowpark import Session
from sentence_transformers import SentenceTransformer
session = Session.builder.configs(SnowflakeLoginOptions("connection_name")).create()
reg = registry.Registry(session=session, database_name='my_registry_db', schema_name='my_registry_schema')
# Take an example sentence transformer from HF
embed_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# Have some sample input data
input_data = [
"This is the first sentence.",
"Here's another sentence for testing.",
"The quick brown fox jumps over the lazy dog.",
"I love coding and programming.",
"Machine learning is an exciting field.",
"Python is a popular programming language.",
"I enjoy working with data.",
"Deep learning models are powerful.",
"Natural language processing is fascinating.",
"I want to improve my NLP skills.",
]
# Log the model with pip dependencies
pip_model = reg.log_model(
embed_model,
model_name="sentence_transformer_minilm",
version_name="pip",
sample_input_data=input_data, # Needed for determining signature of the model
pip_requirements=["sentence-transformers", "torch", "transformers"], # If you want to run this model in the Warehouse, you can use conda_dependencies instead
)
# Force Snowflake to not try to check warehouse
conda_forge_model = reg.log_model(
embed_model,
model_name="sentence_transformer_minilm",
version_name="conda_forge_force",
sample_input_data=input_data,
# setting any package from conda-forge is sufficient to know that it can't be run in warehouse
conda_dependencies=["sentence-transformers", "conda-forge::pytorch", "transformers"]
)
# Deploy the model to SPCS
pip_model.create_service(
service_name="my_minilm_service",
service_compute_pool="my_gpu_pool", # Using GPU_NV_S - smallest GPU node that can run the model
ingress_enabled=True,
gpu_requests="1", # Model fits in GPU memory; only needed for GPU pool
max_instances=4, # 4 instances were able to run 10M inferences from an XS warehouse
)
# See all services running a model
pip_model.list_services()
# Run on SPCS
pip_model.run(input_data, function_name="encode", service_name="my_minilm_service")
# Delete the service
pip_model.delete_service("my_minilm_service")