"""Retail media campaign scoring for ML Jobs in Data Clean Rooms.

This script loads the trained measurement model from the cleanroom table
and scores the full shopper population, producing per-user campaign impact
scores (the difference in predicted purchase probability with and without
ad exposure).

Usage:
  Staged to a Snowflake internal stage and referenced in the ML Jobs code spec.
  Not intended to be run directly.
"""

import argparse
import json
import codecs
import pickle

import numpy as np
import pandas as pd

from snowflake.snowpark.context import get_active_session


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--args", type=str, default="{}")
    parsed = parser.parse_args()

    args = json.loads(parsed.args) if parsed.args else {}
    session = get_active_session()

    source_tables = args.get("source_table", [])

    print("Loading trained model from cleanroom.rmn_measurement_model...")
    model_row = session.table("cleanroom.rmn_measurement_model").to_pandas()
    model_data = model_row.iloc[0]["MODEL_DATA"]
    model_package = pickle.loads(codecs.decode(model_data.encode(), "base64"))

    model_exposed = model_package["model_exposed"]
    model_control = model_package["model_control"]
    feature_cols = model_package["feature_cols"]

    print(f"Model loaded. Best AUUC: {model_package['best_auuc']:.4f}")
    print(f"Features: {feature_cols}")

    print(f"Loading campaign exposures from: {source_tables[0]}")
    exposure_df = session.table(source_tables[0]).to_pandas()
    print(f"Scoring {len(exposure_df)} shoppers...")

    X = exposure_df[feature_cols].apply(pd.to_numeric, errors="coerce").fillna(0).values

    p_exposed = model_exposed.predict_proba(X)[:, 1]
    p_control = model_control.predict_proba(X)[:, 1]
    campaign_impact = p_exposed - p_control

    results = pd.DataFrame({
        "USER_ID": exposure_df["USER_ID"].values,
        "CAMPAIGN_IMPACT": np.round(campaign_impact, 4),
        "P_PURCHASE_EXPOSED": np.round(p_exposed, 4),
        "P_PURCHASE_CONTROL": np.round(p_control, 4),
    })

    result_df = session.create_dataframe(results)
    result_df.write.save_as_table("cleanroom.rmn_scored_results", mode="overwrite")

    positive_impact = (campaign_impact > 0).sum()
    avg_impact = float(np.mean(campaign_impact))
    top_decile_impact = float(np.mean(np.sort(campaign_impact)[-len(campaign_impact)//10:]))

    print(f"Scoring complete:")
    print(f"  Total shoppers scored: {len(results)}")
    print(f"  Shoppers with positive campaign impact: {positive_impact} ({100*positive_impact/len(results):.1f}%)")
    print(f"  Average campaign impact: {avg_impact:.4f}")
    print(f"  Top decile avg impact: {top_decile_impact:.4f}")

    output = {
        "status": "completed",
        "total_scored": len(results),
        "positive_impact_shoppers": int(positive_impact),
        "avg_campaign_impact": avg_impact,
        "top_decile_impact": top_decile_impact,
    }
    print(json.dumps(output))


if __name__ == "__main__":
    main()
