# Generates sample data for the last touch attribution custom template.
# Creates two tables, one used by the provider, one used by the consumer.
# Run this in both accounts to try out the sample SQL worksheets.
# Upload this python worksheet into your provider and consumer accounts to generate sample data to use
# for the inventory forecasting provider and consumer code.
# Set 'Handler = Main' and 'Return type = String' in the worksheet settings.
# For details about the features and use cases of this template, see the [Snowflake documentation
# topic for this template](https://docs.snowflake.com/user-guide/cleanrooms/last-touch-template).
# Add values for placeholders where indicated.

import random
import hashlib
import base64
import pandas as pd
from datetime import datetime, timedelta

import snowflake.snowpark

# Set DATABASE_NAME and SCHEMA_NAME to a database and schema where you have write privileges.
DATABASE_NAME = "<database_name>"
SCHEMA_NAME = "<schema_name>"
CONSUMER_TABLE = "LTA_CONSUMER_TOUCHPOINTS"
PROVIDER_TABLE = "LTA_PROVIDER_CONVERSIONS"

NUM_USERS = 10000
MAX_TOUCHPOINTS_PER_USER = 8
CONVERSION_RATE = 0.4
BASE_TIME = datetime.now()
CHANNELS = ['Email', 'Social', 'Search', 'Display', 'Affiliate']

def generate_hashed_phone(unique_id):
    """
    Generates a unique, consistent, and realistic-looking base64 hash
    to simulate a hashed phone number.
    """
    salt = "my_super_secret_salt_for_samooha_lta_gen"
    phone_string = f"{salt}_{unique_id}"
    sha_hash = hashlib.sha256(phone_string.encode('utf-8')).digest()
    base64_hash = base64.b64encode(sha_hash).decode('utf-8')
    return base64_hash

def generate_data():
    """Generates related touchpoint and conversion data."""
    
    print(f"Generating data for {NUM_USERS} users...")
    touchpoints_list = []
    
    for i in range(NUM_USERS):
        user_id = generate_hashed_phone(i)
        num_touchpoints = random.randint(1, MAX_TOUCHPOINTS_PER_USER)
        
        last_click_time = BASE_TIME - timedelta(days=random.randint(30, 90))
        
        for _ in range(num_touchpoints):
            last_click_time = last_click_time + timedelta(hours=random.randint(1, 48))
            touchpoints_list.append({
                'HASHED_PHONE': user_id,
                'CLICK_TIME': last_click_time,
                'CHANNEL': random.choice(CHANNELS)
            })

    consumer_df = pd.DataFrame(touchpoints_list)
    print(f"Generated {len(consumer_df)} total touchpoints.")

    conversions_list = []
    latest_touchpoints = consumer_df.sort_values('CLICK_TIME').groupby('HASHED_PHONE').last()
    
    for user_id, row in latest_touchpoints.iterrows():
        if random.random() < CONVERSION_RATE:
            last_click_time = row['CLICK_TIME']
            
            transaction_time = last_click_time + timedelta(
                hours=random.randint(1, 24),
                minutes=random.randint(1, 60)
            )
            
            conversions_list.append({
                'HASHED_PHONE': user_id,
                'TRANSACTION_TIME': transaction_time,
                'AMOUNT_SPENT': random.randint(50, 1500)
            })

    provider_df = pd.DataFrame(conversions_list)
    print(f"Generated {len(provider_df)} total conversions.")
    
    print("Converting datetime columns to ISO string format for Snowflake ingestion...")
    if not consumer_df.empty:
        consumer_df['CLICK_TIME'] = consumer_df['CLICK_TIME'].apply(lambda x: x.isoformat())
    
    if not provider_df.empty:
        provider_df['TRANSACTION_TIME'] = provider_df['TRANSACTION_TIME'].apply(lambda x: x.isoformat())
    
    return consumer_df, provider_df

def main(session: snowflake.snowpark.Session):

    try:
        consumer_df, provider_df = generate_data()
        
        print(f"Writing consumer data to {DATABASE_NAME}.{SCHEMA_NAME}.{CONSUMER_TABLE}...")
        session.write_pandas(
            consumer_df, 
            CONSUMER_TABLE, 
            database=DATABASE_NAME,
            schema=SCHEMA_NAME,
            auto_create_table=True, 
            overwrite=True
        )
        print(f"Success! Created {CONSUMER_TABLE}.")
        
        # 4. Write Provider Table (using the passed-in 'session')
        print(f"Writing provider data to {DATABASE_NAME}.{SCHEMA_NAME}.{PROVIDER_TABLE}...")
        session.write_pandas(
            provider_df, 
            PROVIDER_TABLE, 
            database=DATABASE_NAME,
            schema=SCHEMA_NAME,
            auto_create_table=True, 
            overwrite=True
        )
        print(f"Success! Created {PROVIDER_TABLE}.")
        
        # 5. Return a success message
        return f"Successfully created tables: {DATABASE_NAME}.{SCHEMA_NAME}.{CONSUMER_TABLE} and {DATABASE_NAME}.{SCHEMA_NAME}.{PROVIDER_TABLE}"

    except Exception as e:
        print(f"\n--- ERROR during execution ---")
        print(e)
        return f"Failed with error: {e}"

