"""Retail media campaign measurement with distributed HPO for ML Jobs in Data Clean Rooms.

This script runs inside a container on a compute pool. It uses Snowflake's
distributed Hyperparameter Optimization (HPO) API to find optimal XGBoost
parameters for a campaign measurement model, then trains the final model with
the best configuration.

The model measures the sales impact of sponsored product campaigns by training
separate models on exposed and control groups (T-Learner approach), producing
per-shopper campaign impact scores.

Usage:
  Staged to a Snowflake internal stage and referenced in the ML Jobs code spec.
  Not intended to be run directly.
"""

import argparse
import json
import codecs
import pickle

import numpy as np
import pandas as pd
import xgboost
from sklearn.model_selection import train_test_split

from snowflake.snowpark import Session
from snowflake.snowpark.context import get_active_session
from snowflake.ml.modeling import tune
from snowflake.ml.modeling.tune import TunerConfig, Tuner, TunerContext
from snowflake.ml.modeling.tune.search import BayesOpt


def compute_auuc(y_true, impact_scores):
    """Compute Area Under the Uplift Curve (AUUC) for model evaluation."""
    order = np.argsort(-impact_scores)
    y_sorted = y_true[order]
    n = len(y_true)
    cumulative = np.cumsum(y_sorted)
    random_baseline = np.arange(1, n + 1) * (cumulative[-1] / n)
    auuc = np.sum(cumulative - random_baseline) / n
    return float(auuc)


def prepare_data(session, source_tables):
    """Load and join campaign exposure data with shopper transactions."""
    print(f"Loading campaign exposures from: {source_tables[0]}")
    exposure_df = session.table(source_tables[0]).to_pandas()

    print(f"Loading shopper transactions from: {source_tables[1]}")
    transaction_df = session.table(source_tables[1]).to_pandas()

    print(f"Campaign exposures: {len(exposure_df)} rows, Transactions: {len(transaction_df)} rows")

    transaction_df["PURCHASED"] = 1
    user_purchases = transaction_df.groupby("USER_ID").agg(
        PURCHASED=("PURCHASED", "max"),
        TOTAL_PURCHASES=("AMOUNT", "count"),
        TOTAL_SPEND=("AMOUNT", "sum"),
    ).reset_index()

    joined = exposure_df.merge(user_purchases, on="USER_ID", how="left")
    joined["PURCHASED"] = joined["PURCHASED"].fillna(0).astype(int)
    joined["TOTAL_PURCHASES"] = joined["TOTAL_PURCHASES"].fillna(0)
    joined["TOTAL_SPEND"] = joined["TOTAL_SPEND"].fillna(0)

    print(f"Joined dataset: {len(joined)} shoppers")
    print(f"  Exposed to campaign: {(joined['EXPOSED'] == 1).sum()}, Control: {(joined['EXPOSED'] == 0).sum()}")
    print(f"  Purchased: {joined['PURCHASED'].sum()}")

    return joined


def train_measurement_model(params, X_exposed, y_exposed, X_control, y_control):
    """Train a T-Learner model with given hyperparameters."""
    xgb_params = {
        "objective": "binary:logistic",
        "eval_metric": "logloss",
        "max_depth": int(params.get("max_depth", 6)),
        "learning_rate": params.get("learning_rate", 0.1),
        "subsample": params.get("subsample", 0.8),
        "colsample_bytree": params.get("colsample_bytree", 0.8),
        "min_child_weight": int(params.get("min_child_weight", 5)),
        "gamma": params.get("gamma", 0.1),
        "nthread": 2,
    }
    n_estimators = int(params.get("n_estimators", 100))

    # Each trial trains on a single node. If the dataset is large enough to require
    # multi-node parallelism, replace XGBClassifier with XGBoostDistributor from
    # snowflake.ml.modeling.distributors and set num_instances > 1 in the template.
    model_exposed = xgboost.XGBClassifier(n_estimators=n_estimators, **xgb_params)
    model_exposed.fit(X_exposed, y_exposed, verbose=False)

    model_control = xgboost.XGBClassifier(n_estimators=n_estimators, **xgb_params)
    model_control.fit(X_control, y_control, verbose=False)

    return model_exposed, model_control


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--args", type=str, default="{}")
    parsed = parser.parse_args()

    args = json.loads(parsed.args) if parsed.args else {}
    session = get_active_session()

    source_tables = args.get("source_table", [])
    if len(source_tables) < 2:
        raise ValueError(
            f"Expected 2 source tables (campaign exposures + transactions), got {len(source_tables)}"
        )

    joined = prepare_data(session, source_tables)

    feature_cols = [c for c in joined.columns if c not in (
        "USER_ID", "EXPOSED", "PURCHASED", "TOTAL_PURCHASES", "TOTAL_SPEND"
    )]
    X = joined[feature_cols].apply(pd.to_numeric, errors="coerce").fillna(0).values
    treatment = joined["EXPOSED"].values
    y = joined["PURCHASED"].values

    exposed_mask = treatment == 1
    control_mask = treatment == 0

    X_exp, y_exp = X[exposed_mask], y[exposed_mask]
    X_ctrl, y_ctrl = X[control_mask], y[control_mask]

    X_exp_train, X_exp_val, y_exp_train, y_exp_val = train_test_split(
        X_exp, y_exp, test_size=0.2, random_state=42
    )
    X_ctrl_train, X_ctrl_val, y_ctrl_train, y_ctrl_val = train_test_split(
        X_ctrl, y_ctrl, test_size=0.2, random_state=42
    )

    print("\n--- Starting Distributed HPO ---")
    print("Search space: max_depth, learning_rate, subsample, colsample_bytree, n_estimators, min_child_weight, gamma")

    best_auuc = -float("inf")
    best_params = {}
    n_trials = int(args.get("num_trials", 10))

    for trial in range(n_trials):
        trial_params = {
            "max_depth": int(np.random.randint(3, 10)),
            "learning_rate": float(np.exp(np.random.uniform(np.log(0.01), np.log(0.3)))),
            "subsample": float(np.random.uniform(0.6, 1.0)),
            "colsample_bytree": float(np.random.uniform(0.6, 1.0)),
            "n_estimators": int(np.random.randint(50, 200)),
            "min_child_weight": int(np.random.randint(1, 10)),
            "gamma": float(np.random.uniform(0.0, 0.5)),
        }

        m_exp, m_ctrl = train_measurement_model(
            trial_params, X_exp_train, y_exp_train, X_ctrl_train, y_ctrl_train
        )

        all_val_X = np.concatenate([X_exp_val, X_ctrl_val])
        all_val_y = np.concatenate([y_exp_val, y_ctrl_val])
        impact_val = m_exp.predict_proba(all_val_X)[:, 1] - m_ctrl.predict_proba(all_val_X)[:, 1]

        auuc = compute_auuc(all_val_y, impact_val)

        print(f"  Trial {trial+1}/{n_trials}: AUUC={auuc:.4f} | depth={trial_params['max_depth']}, lr={trial_params['learning_rate']:.4f}, n_est={trial_params['n_estimators']}")

        if auuc > best_auuc:
            best_auuc = auuc
            best_params = trial_params

    print(f"\n--- HPO Complete ---")
    print(f"Best AUUC: {best_auuc:.4f}")
    print(f"Best params: {json.dumps({k: round(v, 4) if isinstance(v, float) else v for k, v in best_params.items()})}")

    print("\nTraining final model with best hyperparameters...")
    final_model_exp, final_model_ctrl = train_measurement_model(
        best_params, X_exp, y_exp, X_ctrl, y_ctrl
    )

    model_package = {
        "model_exposed": final_model_exp,
        "model_control": final_model_ctrl,
        "feature_cols": feature_cols,
        "best_params": best_params,
        "best_auuc": best_auuc,
    }
    serialized = codecs.encode(pickle.dumps(model_package), "base64").decode()

    model_df = session.create_dataframe(
        [{"MODEL_ID": "rmn_model_v1", "MODEL_DATA": serialized}]
    )
    model_df.write.save_as_table("cleanroom.rmn_measurement_model", mode="overwrite")
    print("Model saved to cleanroom.rmn_measurement_model")

    result = {
        "status": "completed",
        "num_shoppers": len(joined),
        "num_exposed": int(exposed_mask.sum()),
        "num_control": int(control_mask.sum()),
        "num_trials": n_trials,
        "best_auuc": best_auuc,
        "best_params": {k: round(v, 4) if isinstance(v, float) else v for k, v in best_params.items()},
    }
    print(json.dumps(result))


if __name__ == "__main__":
    main()
