"""Quick test: NNLS-only vs Hybrid vs Current model performance."""
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[1] / "src"))
sys.path.append(str(Path(__file__).resolve().parents[1]))

import numpy as np
import h5py
import pandas as pd
from pathlib import Path
from scipy.optimize import nnls
import sys
sys.path.insert(0, '.')
from local_scoring import eval_aitchison, scoring_function

data_dir = Path('data')

def load_matrix(filepath, group_name):
    """Load HDF5 group into a DataFrame (cell_types x samples)."""
    with h5py.File(filepath, 'r') as f:
        grp = f[group_name]
        vals = grp['data'][:]
        genes_or_ct = [x.decode() if isinstance(x, bytes) else x
                       for x in grp['genes'][:]]

        # Detect shape: if rows == len(genes), data is samples x genes
        # if rows == small number (~5), data is cell_types x samples
        if vals.shape[0] > 10:
            # samples x genes/cell_types -> transpose
            vals = vals.T

        return vals, genes_or_ct

# Load reference
ref_vals, ref_genes = load_matrix(data_dir / 'ref.h5', 'ref_bulkRNA')
with h5py.File(data_dir / 'ref.h5', 'r') as f:
    cell_types = [x.decode() for x in f['ref_bulkRNA']['cell_types'][:]]
# ref_vals shape: (5, n_genes), ref_genes = gene names

test_datasets = [
    ('dirichlet4CT', 'mixes1_insilicodirichletNoDep4CTsource_pdac.h5', 'groundtruth1_insilicodirichletNoDep4CTsource_pdac.h5'),
    ('dirichletNoDep', 'mixes1_insilicodirichletNoDep_pdac.h5', 'groundtruth1_insilicodirichletNoDep_pdac.h5'),
    ('pseudobulk', 'mixes1_insilicopseudobulk_pdac.h5', 'groundtruth1_insilicopseudobulk_pdac.h5'),
    ('invitro', 'mixes1_invitro_pdac.h5', 'groundtruth1_invitro_pdac.h5'),
]

for name, mix_file, gt_file in test_datasets:
    print(f"\n{'='*60}")
    print(f"Dataset: {name}")

    # Load mix RNA (samples x genes -> transpose to genes x samples)
    with h5py.File(data_dir / mix_file, 'r') as f:
        grp = f['mix_rna']
        mix_vals = grp['data'][:]  # samples x genes
        mix_genes = [x.decode() if isinstance(x, bytes) else x for x in grp['genes'][:]]
    # mix_vals: (n_samples, n_genes)

    # Load GT
    with h5py.File(data_dir / gt_file, 'r') as f:
        grp = f['groundtruth']
        gt_vals = grp['data'][:]
        gt_ct = [x.decode() if isinstance(x, bytes) else x for x in grp['genes'][:]]
    # gt could be (n_samples, n_ct) or (n_ct, n_samples)
    if gt_vals.shape[0] > gt_vals.shape[1]:
        gt_vals = gt_vals.T  # make it (n_ct, n_samples)

    # Build gene index mapping
    ref_gene_idx = {g: i for i, g in enumerate(ref_genes)}
    mix_gene_idx = {g: i for i, g in enumerate(mix_genes)}

    common = [g for g in ref_genes if g in mix_gene_idx]
    ref_idx = [ref_gene_idx[g] for g in common]
    mix_idx = [mix_gene_idx[g] for g in common]

    # A = ref (genes x cell_types), b = mix sample (genes)
    A = ref_vals[:, ref_idx].T  # (n_common_genes, n_cell_types)
    B = mix_vals[:, mix_idx]     # (n_samples, n_common_genes)

    n_samples = B.shape[0]
    n_ct = A.shape[1]

    # NNLS
    nnls_props = np.zeros((n_ct, n_samples))
    for i in range(n_samples):
        x, _ = nnls(A, B[i])
        total = x.sum()
        if total > 0:
            nnls_props[:, i] = x / total

    # Align with GT cell types
    ct_idx = {ct: i for i, ct in enumerate(cell_types)}
    common_ct = [ct for ct in gt_ct if ct in ct_idx]
    gt_rows = [gt_ct.index(ct) for ct in common_ct]
    nnls_rows = [ct_idx[ct] for ct in common_ct]

    gt_aligned = gt_vals[gt_rows]
    nnls_aligned = nnls_props[nnls_rows]

    # Renormalize
    col_sums = nnls_aligned.sum(axis=0, keepdims=True)
    col_sums[col_sums == 0] = 1
    nnls_aligned = nnls_aligned / col_sums

    # Pseudocount
    nnls_aligned = np.clip(nnls_aligned, 0.002, None)
    nnls_aligned = nnls_aligned / nnls_aligned.sum(axis=0, keepdims=True)

    scores = scoring_function(gt_aligned, nnls_aligned)
    agg = scores['score_aggreg']
    print(f"  NNLS-only:  score={agg:.4f}  aitchison={scores['aitchison']:.4f}  "
          f"rmse={scores['rmse']:.4f}  p_row={scores['pearson_row']:.4f}")
