"""
Test program_v2 locally against all datasets with ground truth.
Uses the official data_processing module for correct HDF5 loading.
"""
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[1] / "src"))
sys.path.append(str(Path(__file__).resolve().parents[1]))

import numpy as np
import pandas as pd
import h5py
import sys
import time
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent))
sys.path.insert(0, str(Path(__file__).parent.parent / "hadaca3" / "Hadaca3_bootcamp" / "ingestion_program"))

import data_processing as dp
from local_scoring import read_hdf5, scoring_function

# Import program from program_v2
import importlib.util
spec = importlib.util.spec_from_file_location("program_v2", Path(__file__).parent / "program_v2.py")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
program = mod.program

data_dir = Path("data")

# Load reference using data_processing (same as ingestion does)
print("Loading reference data...")
reference_data = dp.read_hdf5(str(data_dir / "ref.h5"))
ref_bulkRNA = reference_data["ref_bulkRNA"]
ref_met = reference_data["ref_met"]

print(f"Reference RNA: {ref_bulkRNA.shape}, MET: {ref_met.shape}")
print(f"ref_bulkRNA index (genes?): {list(ref_bulkRNA.index[:5])}...")
print(f"ref_bulkRNA columns (cell types?): {list(ref_bulkRNA.columns[:5])}...")

# Find mix files
mix_files = sorted(data_dir.glob("mixes*.h5"))

all_scores = {}

for mix_file in mix_files:
    name_parts = mix_file.stem.replace("mixes1_", "").replace("_pdac", "")
    
    gt_file = data_dir / mix_file.name.replace("mixes1_", "groundtruth1_")
    if not gt_file.exists():
        print(f"\nNo ground truth for {name_parts}, skipping")
        continue
    
    print(f"\n{'='*60}")
    print(f"Dataset: {name_parts}")
    print(f"{'='*60}")
    
    # Load mix data using same method as ingestion
    mixes_data = dp.read_hdf5(str(mix_file))
    mix_rna = mixes_data["mix_rna"]
    mix_met_data = mixes_data.get("mix_met")
    
    print(f"  Mix RNA: {mix_rna.shape}")
    if mix_met_data is not None:
        print(f"  Mix MET: {mix_met_data.shape}")
    
    # Run program (same invocation as ingestion)
    start_time = time.time()
    pred = program(mix_rna, ref_bulkRNA.copy(), mix_met=mix_met_data, ref_met=ref_met.copy())
    elapsed = time.time() - start_time
    print(f"  Time: {elapsed:.2f}s")
    print(f"  Prediction shape: {pred.shape}")
    print(f"  Prediction index: {list(pred.index)}")
    
    # Load ground truth
    gt_data = read_hdf5(gt_file)
    if 'groundtruth' in gt_data:
        gt_df = gt_data['groundtruth']
    else:
        gt_df = gt_data[list(gt_data.keys())[0]]
    
    print(f"  GT index: {list(gt_df.index)}")
    
    # Align
    common_ct = [ct for ct in gt_df.index if ct in pred.index]
    A_real = gt_df.loc[common_ct].values
    A_pred = pred.loc[common_ct].values
    
    # Renormalize
    col_sums = A_pred.sum(axis=0, keepdims=True)
    col_sums[col_sums == 0] = 1
    A_pred = A_pred / col_sums
    
    is_partial_gt = 'invivo' in name_parts.lower() or set(common_ct) == {'basal', 'classic'}
    
    scores = scoring_function(A_real, A_pred, partial_gt=is_partial_gt)
    all_scores[name_parts] = scores
    
    print(f"\n  Metrics:")
    if not is_partial_gt:
        print(f"    RMSE:          {scores['rmse']:.4f}")
        print(f"    MAE:           {scores['mae']:.4f}")
        print(f"    Aitchison:     {scores['aitchison']:.4f}")
        print(f"    Pearson tot:   {scores['pearson_tot']:.4f}")
        print(f"    Pearson col:   {scores['pearson_col']:.4f}")
    print(f"    Pearson row:   {scores['pearson_row']:.4f}")
    if not is_partial_gt:
        print(f"    Spearman tot:  {scores['spearman_tot']:.4f}")
        print(f"    Spearman col:  {scores['spearman_col']:.4f}")
    print(f"    Spearman row:  {scores['spearman_row']:.4f}")
    print(f"\n    AGGREGATE SCORE: {scores['score_aggreg']:.4f}")

# Summary
if all_scores:
    print(f"\n{'='*60}")
    print(f"COMPARISON: v1 -> v2")
    print(f"{'='*60}")
    
    v1_scores = {
        'insilicodirichletNoDep4CTsource': 0.7757,
        'insilicodirichletNoDep': 0.6915,
        'insilicopseudobulk': 0.7125,
        'invitro': 0.7971,
        'invivo': 0.7890,
    }
    
    for name, scores in all_scores.items():
        v1 = v1_scores.get(name, float('nan'))
        v2 = scores['score_aggreg']
        diff = v2 - v1
        arrow = "↑" if diff > 0 else "↓" if diff < 0 else "="
        print(f"  {name:<40}: {v1:.4f} -> {v2:.4f}  ({diff:+.4f} {arrow})")
    
    agg_scores = [s['score_aggreg'] for s in all_scores.values() if not np.isnan(s['score_aggreg'])]
    v1_median = 0.7757
    v2_median = np.median(agg_scores)
    print(f"\n  Median: {v1_median:.4f} -> {v2_median:.4f}  ({v2_median - v1_median:+.4f})")
    print(f"  Mean:   {np.mean(agg_scores):.4f}")
