"""
Diagnostic analysis: identify per-cell-type errors and prediction biases.
"""
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[1] / "src"))
sys.path.append(str(Path(__file__).resolve().parents[1]))

import numpy as np
import pandas as pd
import h5py
from pathlib import Path
from local_scoring import read_hdf5, scoring_function

data_dir = Path("data")
predictions_file = Path("submissions/prediction.h5")

print("Loading predictions...")
predictions = read_hdf5(predictions_file)

# Load all ground truths
gt_files = sorted(data_dir.glob("groundtruth*.h5"))

for gt_file in gt_files:
    name = gt_file.stem.replace("groundtruth1_", "").replace("groundtruth_", "")
    name = name.replace("_pdac", "")
    
    print(f"\n{'='*70}")
    print(f"DATASET: {name}")
    print(f"{'='*70}")
    
    gt_data = read_hdf5(gt_file)
    if 'groundtruth' in gt_data:
        gt_df = gt_data['groundtruth']
    else:
        gt_df = gt_data[list(gt_data.keys())[0]]
    
    if name not in predictions:
        print(f"  No prediction for {name}")
        continue
    
    pred_df = predictions[name]
    
    # Find common cell types
    common_ct = [ct for ct in gt_df.index if ct in pred_df.index]
    
    A_real = gt_df.loc[common_ct].values
    A_pred = pred_df.loc[common_ct].values
    
    # Renormalize after alignment
    col_sums = A_pred.sum(axis=0, keepdims=True)
    col_sums[col_sums == 0] = 1
    A_pred = A_pred / col_sums
    
    print(f"\nCell types: {common_ct}")
    print(f"Shape: {A_real.shape}")
    
    # Per-cell type analysis
    print(f"\n{'Cell Type':<15} {'GT Mean':>10} {'Pred Mean':>10} {'GT Std':>10} {'Pred Std':>10} {'MAE':>10} {'Pearson':>10}")
    print("-" * 75)
    
    from scipy import stats
    for i, ct in enumerate(common_ct):
        gt_row = A_real[i, :]
        pred_row = A_pred[i, :]
        mae = np.mean(np.abs(gt_row - pred_row))
        
        if np.std(pred_row) > 0 and np.std(gt_row) > 0:
            corr, _ = stats.pearsonr(gt_row, pred_row)
        else:
            corr = float('nan')
        
        print(f"{ct:<15} {gt_row.mean():10.4f} {pred_row.mean():10.4f} {gt_row.std():10.4f} {pred_row.std():10.4f} {mae:10.4f} {corr:10.4f}")
    
    # Overall bias
    print(f"\nOverall prediction sum per sample (should be ~1.0):")
    sample_sums = A_pred.sum(axis=0)
    print(f"  Range: [{sample_sums.min():.4f}, {sample_sums.max():.4f}], Mean: {sample_sums.mean():.4f}")
    
    # Show a few sample comparisons
    print(f"\nSample-by-sample comparison (first 5):")
    for j in range(min(5, A_real.shape[1])):
        gt_col = A_real[:, j]
        pred_col = A_pred[:, j]
        print(f"  Sample {j}:")
        for i, ct in enumerate(common_ct):
            diff = pred_col[i] - gt_col[i]
            print(f"    {ct:<12}: GT={gt_col[i]:.4f}, Pred={pred_col[i]:.4f}, Diff={diff:+.4f}")
    
    # Aitchison distance breakdown
    if len(common_ct) > 2:
        from local_scoring import aitchison_distance
        print(f"\nPer-sample Aitchison distances:")
        dists = []
        for j in range(A_real.shape[1]):
            d = aitchison_distance(A_real[:, j], A_pred[:, j])
            dists.append(d)
        dists = np.array(dists)
        print(f"  Mean: {dists.mean():.4f}, Std: {dists.std():.4f}")
        print(f"  Range: [{dists.min():.4f}, {dists.max():.4f}]")
        worst_idx = np.argsort(dists)[-5:]
        print(f"  Worst 5 samples: {worst_idx.tolist()} with distances {dists[worst_idx].tolist()}")
