import streamlit as st
from snowflake.core import Root
from snowflake.cortex import Complete
from snowflake.snowpark.context import get_active_session
MODELS = [
"llama3.1-8b",
"llama3.1-70b",
"llama3.1-405b"
]
def init_messages():
"""Initialize session state messages if not present or if we need to clear."""
if st.session_state.get("clear_conversation") or "messages" not in st.session_state:
st.session_state.messages = []
st.session_state.clear_conversation = False
def init_service_metadata():
"""Load or refresh cortex search services from Snowflake."""
services = session.sql("SHOW CORTEX SEARCH SERVICES IN ACCOUNT;").collect()
service_metadata = []
if services:
for s in services:
svc_name = s["name"]
svc_schema = s["schema_name"]
svc_db = s["database_name"]
svc_search_col = session.sql(
f"DESC CORTEX SEARCH SERVICE {svc_db}.{svc_schema}.{svc_name};"
).collect()[0]["search_column"]
service_metadata.append(
{
"name": svc_name,
"search_column": svc_search_col,
"db": svc_db,
"schema": svc_schema,
}
)
st.session_state.service_metadata = service_metadata
if "selected_cortex_search_service" not in st.session_state and service_metadata:
st.session_state.selected_cortex_search_service = service_metadata[0]["name"]
selected_entry = st.session_state.get("selected_cortex_search_service")
if selected_entry:
selected_service_metadata = next(
(svc for svc in st.session_state.service_metadata if svc["name"] == selected_entry),
None
)
if selected_service_metadata:
st.session_state.selected_schema = selected_service_metadata["schema"]
st.session_state.selected_db = selected_service_metadata["db"]
elif st.session_state.get("debug", False):
st.write("No matching service found for:", selected_entry)
def init_config_options():
if "service_metadata" not in st.session_state or not st.session_state.service_metadata:
st.sidebar.warning("No Cortex Knowledge Extensions available")
return
st.sidebar.selectbox(
"Select Cortex Knowledge Extension",
[s["name"] for s in st.session_state.service_metadata],
key="selected_cortex_search_service",
)
if st.sidebar.button("Clear conversation"):
st.session_state.clear_conversation = True
st.sidebar.checkbox("Debug", key="debug", value=False)
st.sidebar.checkbox("Use chat history", key="use_chat_history", value=True)
with st.sidebar.expander("Advanced options"):
st.selectbox("Select model:", MODELS, key="model_name")
st.number_input(
"Select number of context chunks",
value=5,
key="num_retrieved_chunks",
min_value=1,
max_value=10,
)
st.number_input(
"Select number of messages to use in chat history",
value=5,
key="num_chat_messages",
min_value=1,
max_value=10,
)
st.sidebar.expander("Session State").write(st.session_state)
def get_chat_history():
"""Get the last N messages from session state."""
start_index = max(
0, len(st.session_state.messages) - st.session_state.num_chat_messages
)
return st.session_state.messages[start_index : len(st.session_state.messages) - 1]
def complete(model, prompt):
"""Use the chosen Snowflake cortex model to complete a prompt."""
return Complete(model=model, prompt=prompt).replace("$", "\\$")
def make_chat_history_summary(chat_history, question):
"""
Summarize the chat history plus the question using your LLM,
to refine the final search query.
"""
prompt = f"""
[INST]
Based on the chat history below and the question, generate a query that extend the question
with the chat history provided. The query should be in natural language.
Answer with only the query. Do not add any explanation.
<chat_history>
{chat_history}
</chat_history>
<question>
{question}
</question>
[/INST]
"""
summary = complete(st.session_state.model_name, prompt)
if st.session_state.debug:
st.sidebar.text_area("Chat history summary", summary.replace("$", "\\$"), height=150)
return summary
def query_cortex_search_service(query, columns=[], filter={}):
"""
Query the selected cortex search service with the given query and retrieve context documents.
"""
db = st.session_state.get("selected_db")
schema = st.session_state.get("selected_schema")
if st.session_state.get("debug", False):
st.sidebar.write("Query:", query)
st.sidebar.write("DB:", db)
st.sidebar.write("Schema:", schema)
st.sidebar.write("Service:", st.session_state.selected_cortex_search_service)
cortex_search_service = (
root.databases[db]
.schemas[schema]
.cortex_search_services[st.session_state.selected_cortex_search_service]
)
context_documents = cortex_search_service.search(
query,
columns=columns,
filter=filter,
limit=st.session_state.num_retrieved_chunks
)
results = context_documents.results
if st.session_state.get("debug", False):
st.sidebar.write("Search Results:", results)
service_metadata = st.session_state.service_metadata
search_col = [
s["search_column"] for s in service_metadata
if s["name"] == st.session_state.selected_cortex_search_service
][0].lower()
context_str = ""
context_str_template = (
"Source: {source_url}\n"
"Source ID: {id}\n"
"Excerpt: {chunk}\n\n\n"
)
for i, r in enumerate(results):
context_str += context_str_template.format(
id=i+1,
chunk=r[search_col],
source_url=r["source_url"],
title=r["document_title"],
)
if st.session_state.debug:
st.sidebar.text_area("Context documents", context_str, height=500)
return context_str, results
def create_prompt(user_question):
"""
Combine user question, context from the search service, and chat history
to create a final prompt for the LLM.
"""
if st.session_state.use_chat_history:
chat_history = get_chat_history()
if chat_history != []:
question_summary = make_chat_history_summary(chat_history, user_question)
prompt_context, results = query_cortex_search_service(
question_summary, columns=["chunk", "source_url", "document_title"]
)
else:
prompt_context, results = query_cortex_search_service(
user_question, columns=["chunk", "source_url", "document_title"]
)
else:
prompt_context, results = query_cortex_search_service(
user_question, columns=["chunk", "source_url", "document_title"]
)
chat_history = ""
prompt = f"""
You are a helpful AI assistant with RAG capabilities. When a user asks you a question, you will also be given excerpts from relevant documentation to help answer the question accurately. Please use the context provided and cite your sources using the citation format provided.
Context from documentation:
{prompt_context}
User question:
{user_question}
OUTPUT:
"""
if st.session_state.get("debug", False):
st.sidebar.text_area("Complete Prompt", prompt, height=300)
return prompt, results
def post_process_citations(generated_response, results):
"""
Replace {{.StartCitation}}X{{.EndCitation}} with bracketed references to actual product links.
NOTE: If the model references chunks out of range (like 4 if only 2 exist),
consider adding logic to remap or drop invalid references.
"""
used_results = set()
for i, ref in enumerate(results):
old_str = f"{{.StartCitation}}{i+1}{{.EndCitation}}"
replacement = f"[{i+1}]{ref['source_url']})"
new_resp = generated_response.replace(old_str, replacement)
if new_resp != generated_response:
used_results.add(i)
generated_response = new_resp
return generated_response, used_results
def main():
st.set_page_config(
page_title="Cortex Knowledge Extension Chat Tester",
layout="wide",
)
custom_css = """
<style>
[data-testid="stChatMessage"] {
border-radius: 8px;
margin-bottom: 1rem;
padding: 10px;
}
</style>
"""
st.markdown(custom_css, unsafe_allow_html=True)
st.subheader("Cortex Knowledge Extension Chat Tester")
init_service_metadata()
init_config_options()
init_messages()
icons = {"assistant": "❄️", "user": "👤"}
for message in st.session_state.messages:
with st.chat_message(message["role"], avatar=icons[message["role"]]):
st.markdown(message["content"])
disable_chat = (
"service_metadata" not in st.session_state
or len(st.session_state.service_metadata) == 0
)
if question := st.chat_input("Ask a question...", disabled=disable_chat):
st.session_state.messages.append({"role": "user", "content": question})
with st.chat_message("user", avatar=icons["user"]):
st.markdown(question.replace("$", "\\$"))
with st.chat_message("assistant", avatar=icons["assistant"]):
message_placeholder = st.empty()
question_safe = question.replace("'", "")
prompt, results = create_prompt(question_safe)
with st.spinner("Thinking..."):
generated_response = complete(st.session_state.model_name, prompt)
post_processed_response, used_results = post_process_citations(generated_response, results)
if results:
markdown_table = "\n\n###### References \n\n| Index | Title | Source |\n%------%-------%--------%\n"
for i, ref in enumerate(results):
markdown_table += (
f"| {i+1} | {ref.get('document_title', 'N/A')} | "
f"{ref.get('source_url', 'N/A')} |\n"
)
else:
markdown_table = "\n\n*No references found*"
message_placeholder.markdown(post_processed_response + markdown_table)
st.session_state.messages.append(
{"role": "assistant", "content": post_processed_response + markdown_table}
)
if __name__ == "__main__":
session = get_active_session()
root = Root(session)
main()