"""
Full test of program_v2 on all datasets. 
"""
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[1] / "src"))
sys.path.append(str(Path(__file__).resolve().parents[1]))

import numpy as np
import pandas as pd
import h5py
import sys
import time
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent))
data_dir = Path("data")

# Load reference
print("Loading reference...", flush=True)
with h5py.File(data_dir / "ref.h5", 'r') as f:
    rna_data = f['ref_bulkRNA']['data'][:]
    genes = [x.decode('utf-8') if isinstance(x, bytes) else str(x) for x in f['ref_bulkRNA']['genes'][:]]
    cell_types_ref = [x.decode('utf-8') if isinstance(x, bytes) else str(x) for x in f['ref_bulkRNA']['cell_types'][:]]
    ref_bulkRNA = pd.DataFrame(rna_data, index=cell_types_ref, columns=genes)
    met_data = f['ref_met']['data'][:]
    cpg_sites = [x.decode('utf-8') if isinstance(x, bytes) else str(x) for x in f['ref_met']['CpG_sites'][:]]
    ref_met = pd.DataFrame(met_data, index=cell_types_ref, columns=cpg_sites)

from program_v2 import program
from local_scoring import read_hdf5, scoring_function

def load_mix(mix_file):
    with h5py.File(mix_file, 'r') as f:
        rna_grp = f['mix_rna']
        rna_vals = rna_grp['data'][:]
        mix_genes = [x.decode('utf-8') if isinstance(x, bytes) else str(x) for x in rna_grp['genes'][:]]
        if rna_vals.dtype.names is not None:
            samples = list(rna_vals.dtype.names)
            rna_matrix = np.array([rna_vals[name] for name in samples])
        else:
            samples = [f"s{i}" for i in range(rna_vals.shape[0])]
            rna_matrix = rna_vals
        mix_rna = pd.DataFrame(rna_matrix.T, index=mix_genes, columns=samples)
        
        met_grp = f['mix_met']
        met_vals = met_grp['data'][:]
        mix_cpg = [x.decode('utf-8') if isinstance(x, bytes) else str(x) for x in met_grp['CpG_sites'][:]]
        if met_vals.dtype.names is not None:
            met_samples = list(met_vals.dtype.names)
            met_matrix = np.array([met_vals[name] for name in met_samples])
        else:
            met_samples = samples
            met_matrix = met_vals
        mix_met_data = pd.DataFrame(met_matrix.T, index=mix_cpg, columns=met_samples)
    return mix_rna, mix_met_data

v1_scores = {
    'insilicodirichletNoDep4CTsource': 0.7757,
    'insilicodirichletNoDep': 0.6915,
    'insilicopseudobulk': 0.7125,
    'invitro': 0.7971,
    'invivo': 0.7890,
}

all_scores = {}
mix_files = sorted(data_dir.glob("mixes*.h5"))

for mix_file in mix_files:
    name = mix_file.stem.replace("mixes1_", "").replace("_pdac", "")
    gt_file = data_dir / mix_file.name.replace("mixes1_", "groundtruth1_")
    if not gt_file.exists():
        continue
    
    print(f"\n{'='*60}", flush=True)
    print(f"Dataset: {name}", flush=True)
    
    mix_rna, mix_met_data = load_mix(mix_file)
    
    start = time.time()
    pred = program(mix_rna=mix_rna, ref_bulkRNA=ref_bulkRNA.copy(), 
                   mix_met=mix_met_data, ref_met=ref_met.copy())
    elapsed = time.time() - start
    print(f"  Time: {elapsed:.1f}s | Shape: {pred.shape}", flush=True)
    
    gt_data = read_hdf5(gt_file)
    gt_df = gt_data.get('groundtruth', gt_data[list(gt_data.keys())[0]])
    
    common_ct = [ct for ct in gt_df.index if ct in pred.index]
    A_real = gt_df.loc[common_ct].values
    A_pred = pred.loc[common_ct].values
    col_sums = A_pred.sum(axis=0, keepdims=True)
    col_sums[col_sums == 0] = 1
    A_pred = A_pred / col_sums
    
    is_partial = 'invivo' in name.lower() or set(common_ct) == {'basal', 'classic'}
    scores = scoring_function(A_real, A_pred, partial_gt=is_partial)
    all_scores[name] = scores
    
    v1 = v1_scores.get(name, float('nan'))
    v2 = scores['score_aggreg']
    diff = v2 - v1
    print(f"  Score: {v2:.4f} (v1={v1:.4f}, diff={diff:+.4f})", flush=True)

print(f"\n{'='*60}")
print("COMPARISON: v1 -> v2")
print(f"{'='*60}")
for name, scores in all_scores.items():
    v1 = v1_scores.get(name, float('nan'))
    v2 = scores['score_aggreg']
    arrow = "↑" if v2 > v1 else "↓"
    print(f"  {name:<40}: {v1:.4f} -> {v2:.4f}  ({v2-v1:+.4f} {arrow})")

agg = [s['score_aggreg'] for s in all_scores.values() if not np.isnan(s['score_aggreg'])]
print(f"\n  Median: 0.7757 -> {np.median(agg):.4f}")
print(f"  Mean:   0.7532 -> {np.mean(agg):.4f}")
