"""
Quick test of program_v2 on a single dataset.
"""
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[1] / "src"))
sys.path.append(str(Path(__file__).resolve().parents[1]))

import numpy as np
import pandas as pd
import h5py
import sys
import time
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent))

data_dir = Path("data")

# Load reference directly
print("Loading reference...", flush=True)
with h5py.File(data_dir / "ref.h5", 'r') as f:
    rna_data = f['ref_bulkRNA']['data'][:]
    genes = [x.decode('utf-8') if isinstance(x, bytes) else str(x) for x in f['ref_bulkRNA']['genes'][:]]
    cell_types_ref = [x.decode('utf-8') if isinstance(x, bytes) else str(x) for x in f['ref_bulkRNA']['cell_types'][:]]
    ref_bulkRNA = pd.DataFrame(rna_data, index=cell_types_ref, columns=genes)
    
    met_data = f['ref_met']['data'][:]
    cpg_sites = [x.decode('utf-8') if isinstance(x, bytes) else str(x) for x in f['ref_met']['CpG_sites'][:]]
    ref_met = pd.DataFrame(met_data, index=cell_types_ref, columns=cpg_sites)

print(f"Ref RNA: {ref_bulkRNA.shape}, Ref MET: {ref_met.shape}", flush=True)

# Load a single mix dataset (insilicodirichletNoDep - worst performing)
mix_file = data_dir / "mixes1_insilicodirichletNoDep_pdac.h5"
print(f"\nLoading mix: {mix_file.name}...", flush=True)

with h5py.File(mix_file, 'r') as f:
    # mix_rna
    rna_grp = f['mix_rna']
    rna_vals = rna_grp['data'][:]
    mix_genes = [x.decode('utf-8') if isinstance(x, bytes) else str(x) for x in rna_grp['genes'][:]]
    
    if rna_vals.dtype.names is not None:
        samples = list(rna_vals.dtype.names)
        rna_matrix = np.array([rna_vals[name] for name in samples])
    else:
        samples = [f"sample_{i}" for i in range(rna_vals.shape[0])]
        rna_matrix = rna_vals
    
    # The data comes as (genes, samples) compound or (samples, genes)
    # We need DataFrame with genes as index, samples as columns
    mix_rna = pd.DataFrame(rna_matrix.T, index=mix_genes, columns=samples)
    
    # mix_met
    met_grp = f['mix_met']
    met_vals = met_grp['data'][:]
    mix_cpg = [x.decode('utf-8') if isinstance(x, bytes) else str(x) for x in met_grp['CpG_sites'][:]]
    
    if met_vals.dtype.names is not None:
        met_samples = list(met_vals.dtype.names)
        met_matrix = np.array([met_vals[name] for name in met_samples])
    else:
        met_samples = samples
        met_matrix = met_vals
    
    mix_met_data = pd.DataFrame(met_matrix.T, index=mix_cpg, columns=met_samples)

print(f"Mix RNA: {mix_rna.shape}, Mix MET: {mix_met_data.shape}", flush=True)

# Import and run program_v2
print("\nRunning program_v2...", flush=True)
from program_v2 import program

start = time.time()
pred = program(mix_rna=mix_rna, ref_bulkRNA=ref_bulkRNA.copy(), 
               mix_met=mix_met_data, ref_met=ref_met.copy())
elapsed = time.time() - start
print(f"Done in {elapsed:.1f}s", flush=True)
print(f"Prediction shape: {pred.shape}", flush=True)
print(f"Prediction index: {list(pred.index)}", flush=True)

# Score it
from local_scoring import read_hdf5, scoring_function

gt_file = data_dir / "groundtruth1_insilicodirichletNoDep_pdac.h5"
gt_data = read_hdf5(gt_file)
gt_df = gt_data['groundtruth'] if 'groundtruth' in gt_data else gt_data[list(gt_data.keys())[0]]

common_ct = [ct for ct in gt_df.index if ct in pred.index]
A_real = gt_df.loc[common_ct].values
A_pred = pred.loc[common_ct].values

col_sums = A_pred.sum(axis=0, keepdims=True)
col_sums[col_sums == 0] = 1
A_pred = A_pred / col_sums

scores = scoring_function(A_real, A_pred)

print(f"\nResults for insilicodirichletNoDep:")
print(f"  RMSE:          {scores['rmse']:.4f}  (v1: 0.1561)")
print(f"  MAE:           {scores['mae']:.4f}  (v1: 0.1355)")
print(f"  Aitchison:     {scores['aitchison']:.4f}  (v1: 2.1893)")
print(f"  Pearson tot:   {scores['pearson_tot']:.4f}  (v1: -0.0106)")
print(f"  Pearson col:   {scores['pearson_col']:.4f}  (v1: -0.0393)")
print(f"  Pearson row:   {scores['pearson_row']:.4f}  (v1: 0.7212)")
print(f"  Spearman tot:  {scores['spearman_tot']:.4f}  (v1: -0.1071)")
print(f"  Spearman col:  {scores['spearman_col']:.4f}  (v1: -0.1833)")
print(f"  Spearman row:  {scores['spearman_row']:.4f}  (v1: 0.6839)")
print(f"\n  AGGREGATE:     {scores['score_aggreg']:.4f}  (v1: 0.6915)")
