# Generates sample data for the inventory forecasting collaboration template.
# Creates two tables, one used by the publisher, one used by the advertiser.
# Run this in both accounts to try out the sample SQL worksheets.
# Upload this Python notebook into your publisher and advertiser accounts to generate sample data.
# Set 'Handler = Main' and 'Return type = String' in the worksheet settings.
# For details about the features and use cases of this template, see the Snowflake documentation
# topic: https://docs.snowflake.com/user-guide/cleanrooms/collab-inventory-forecasting
# Replace values in angle brackets with your own values.

import pandas as pd
import numpy as np
from datetime import datetime, timedelta

import snowflake.snowpark

# Set DATABASE_NAME and SCHEMA_NAME to a database and schema where you have write privileges.
DATABASE_NAME = "<database_name>"
SCHEMA_NAME = "<schema_name>"
PUBLISHER_TABLE = "INVENTORY_PUBLISHER_SALES_HISTORY"
ADVERTISER_TABLE = "INVENTORY_ADVERTISER_STOCK_LEVELS"

NUM_PRODUCTS = 500
NUM_STORES = 10
HISTORY_DAYS = 365 * 2
BASE_SALES_MEAN = 50
PROMO_LIFT_MULTIPLIER = 3.5
WEEKEND_LIFT_MULTIPLIER = 1.5

def generate_data():
    """Generates related publisher sales history and advertiser stock levels."""

    print(f"Generating for {NUM_PRODUCTS} products across {NUM_STORES} stores...")

    product_ids = [f'PROD_{i:04d}' for i in range(1, NUM_PRODUCTS + 1)]
    store_ids = [f'STORE_{i:03d}' for i in range(1, NUM_STORES + 1)]

    master_list = pd.DataFrame(
        [(p, s) for p in product_ids for s in store_ids],
        columns=['PRODUCT_ID', 'STORE_ID']
    )

    print(f"Generating {HISTORY_DAYS} days of sales history...")
    publisher_dfs = []
    base_date = datetime.now().date() - timedelta(days=HISTORY_DAYS)
    date_range = [base_date + timedelta(days=x) for x in range(HISTORY_DAYS)]

    for _, row in master_list.iterrows():
        daily_data = []
        for d in date_range:
            base_sales = max(0, int(np.random.normal(BASE_SALES_MEAN, 10)))

            was_on_promo = (d.day % 20 == 0)
            is_weekend = d.weekday() >= 5

            if was_on_promo:
                base_sales *= PROMO_LIFT_MULTIPLIER
            elif is_weekend:
                base_sales *= WEEKEND_LIFT_MULTIPLIER

            daily_data.append({
                'PRODUCT_ID': row['PRODUCT_ID'],
                'STORE_ID': row['STORE_ID'],
                'SALES_DATE': d.isoformat(),
                'UNITS_SOLD': int(base_sales),
                'WAS_ON_PROMOTION': was_on_promo
            })
        publisher_dfs.append(pd.DataFrame(daily_data))

    publisher_df = pd.concat(publisher_dfs, ignore_index=True)
    print(f"Generated {len(publisher_df)} historical sales records.")

    print("Generating current advertiser inventory and promo plan...")
    advertiser_data = []
    promo_date = datetime.now().date() + timedelta(days=14)

    for _, row in master_list.iterrows():
        advertiser_data.append({
            'PRODUCT_ID': row['PRODUCT_ID'],
            'STORE_ID': row['STORE_ID'],
            'CURRENT_INVENTORY': int(np.random.normal(BASE_SALES_MEAN * 7, 20)),
            'UPCOMING_PROMOTION_DATE': promo_date.isoformat()
        })

    advertiser_df = pd.DataFrame(advertiser_data)
    print(f"Generated {len(advertiser_df)} advertiser inventory records.")

    return publisher_df, advertiser_df

def main(session: snowflake.snowpark.Session):

    try:
        publisher_df, advertiser_df = generate_data()

        print(f"Writing publisher data to {DATABASE_NAME}.{SCHEMA_NAME}.{PUBLISHER_TABLE}...")
        session.write_pandas(
            publisher_df,
            PUBLISHER_TABLE,
            database=DATABASE_NAME,
            schema=SCHEMA_NAME,
            auto_create_table=True,
            overwrite=True
        )
        print(f"Success! Created {PUBLISHER_TABLE}.")

        print(f"Writing advertiser data to {DATABASE_NAME}.{SCHEMA_NAME}.{ADVERTISER_TABLE}...")
        session.write_pandas(
            advertiser_df,
            ADVERTISER_TABLE,
            database=DATABASE_NAME,
            schema=SCHEMA_NAME,
            auto_create_table=True,
            overwrite=True
        )
        print(f"Success! Created {ADVERTISER_TABLE}.")

        return f"Successfully created tables: {DATABASE_NAME}.{SCHEMA_NAME}.{PUBLISHER_TABLE} and {DATABASE_NAME}.{SCHEMA_NAME}.{ADVERTISER_TABLE}"

    except Exception as e:
        print(f"\n--- ERROR during Snowpark execution ---")
        print(e)
        return f"Failed with error: {e}"
