"""Test program_v3 against all datasets, compare with v2 scores."""
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[1] / "src"))
sys.path.append(str(Path(__file__).resolve().parents[1]))

import numpy as np
import time
from local_scoring import read_hdf5, scoring_function

# v2 baseline scores (from previous test run)
V2_SCORES = {
    'insilicodirichletNoDep4CTsource': 0.8284,
    'insilicodirichletNoDep': 0.7422,
    'insilicopseudobulk': 0.7276,
    'invitro': 0.8282,
    'invivo': 0.7957,
}

data_dir = Path(__file__).resolve().parent.parent / "data"

# Load reference
print("Loading reference...")
ref_data = read_hdf5(data_dir / "ref.h5")

from program_v4 import program

mix_files = sorted(data_dir.glob("mixes*.h5"))
results = {}

for mix_file in mix_files:
    name = mix_file.stem.replace("mixes1_", "").replace("mixes_", "").replace("_pdac", "")
    gt_candidates = list(data_dir.glob(f"groundtruth*{name}*.h5"))
    if not gt_candidates:
        continue

    print(f"\n{'='*60}")
    print(f"Dataset: {name}")

    t0 = time.time()
    mix_data = read_hdf5(mix_file)

    pred = program(
        mix_rna=mix_data.get("mix_rna"),
        ref_bulkRNA=ref_data.get("ref_bulkRNA"),
        mix_met=mix_data.get("mix_met"),
        ref_met=ref_data.get("ref_met"),
    )
    elapsed = time.time() - t0

    gt_data = read_hdf5(gt_candidates[0])
    gt_key = 'groundtruth' if 'groundtruth' in gt_data else list(gt_data.keys())[0]
    gt = gt_data[gt_key]

    print(f"  GT index: {list(gt.index)}")
    print(f"  Pred index: {list(pred.index)}")

    common_ct = [ct for ct in gt.index if ct in pred.index]
    if not common_ct:
        print(f"  ERROR: No common cell types found! Skipping...")
        continue

    A_real = gt.loc[common_ct].values
    A_pred = pred.loc[common_ct].values

    col_sums = A_pred.sum(axis=0, keepdims=True)
    col_sums[col_sums == 0] = 1
    A_pred = A_pred / col_sums

    is_partial = 'invivo' in name.lower()
    scores = scoring_function(A_real, A_pred, partial_gt=is_partial)

    v3_score = scores['score_aggreg']
    v2_score = V2_SCORES.get(name, 0)
    diff = v3_score - v2_score

    print(f"  Time: {elapsed:.1f}s | Shape: {pred.shape}")
    print(f"  Score: {v3_score:.4f} (v2={v2_score:.4f}, diff={diff:+.4f})")

    if not is_partial:
        print(f"  RMSE={scores['rmse']:.4f}  MAE={scores['mae']:.4f}  "
              f"Aitchison={scores['aitchison']:.4f}")
        print(f"  Pearson: tot={scores['pearson_tot']:.4f}  "
              f"col={scores['pearson_col']:.4f}  row={scores['pearson_row']:.4f}")
        print(f"  Spearman: tot={scores['spearman_tot']:.4f}  "
              f"col={scores['spearman_col']:.4f}  row={scores['spearman_row']:.4f}")
    else:
        print(f"  Pearson row={scores['pearson_row']:.4f}  "
              f"Spearman row={scores['spearman_row']:.4f}")

    results[name] = {'v3': v3_score, 'v2': v2_score}

# Summary
print(f"\n{'='*60}")
print("COMPARISON: v2 -> v3")
print(f"{'='*60}")
for name, r in sorted(results.items()):
    arrow = "+" if r['v3'] > r['v2'] else "-" if r['v3'] < r['v2'] else "="
    print(f"  {name:<40}: {r['v2']:.4f} -> {r['v3']:.4f}  ({r['v3']-r['v2']:+.4f} {arrow})")

v2_all = [r['v2'] for r in results.values()]
v3_all = [r['v3'] for r in results.values()]
print(f"\n  v2 median: {np.median(v2_all):.4f}  mean: {np.mean(v2_all):.4f}")
print(f"  v3 median: {np.median(v3_all):.4f}  mean: {np.mean(v3_all):.4f}")
print(f"  Median change: {np.median(v3_all) - np.median(v2_all):+.4f}")
