"""
HADACA3 - Local Scoring
========================
Evaluates predictions against ground truth using the same metrics as the challenge.

Metrics:
- RMSE (Root Mean Squared Error)
- MAE (Mean Absolute Error)
- Aitchison distance (compositional)
- Pearson correlations (total, column/sample, row/cell_type)
- Spearman correlations (total, column/sample, row/cell_type)

Final Score:
- Weighted geometric mean of normalized metrics
- Weights: RMSE+MAE (1/3), Aitchison (1/3), Correlations (1/3)
"""

import numpy as np
import pandas as pd
import h5py
from pathlib import Path
from scipy import stats


def weighted_geometric_mean(values, weights):
    """Compute weighted geometric mean."""
    values = np.array(values)
    weights = np.array(weights)
    
    # Handle NaN values
    mask = ~np.isnan(values)
    values = values[mask]
    weights = weights[mask]
    
    if len(values) == 0:
        return np.nan
    
    # Clip to avoid log(0)
    values = np.clip(values, 1e-10, None)
    
    return np.exp(np.sum(weights * np.log(values)) / np.sum(weights))


def aitchison_distance(x, y, min_val=1e-9):
    """Compute Aitchison distance between two compositions."""
    x = np.clip(x, min_val, None)
    y = np.clip(y, min_val, None)
    
    # Close the compositions
    x = x / x.sum()
    y = y / y.sum()
    
    # CLR transformation
    clr_x = np.log(x) - np.mean(np.log(x))
    clr_y = np.log(y) - np.mean(np.log(y))
    
    return np.sqrt(np.sum((clr_x - clr_y) ** 2))


def eval_aitchison(A_real, A_pred, min_val=1e-9):
    """Mean Aitchison distance across samples."""
    distances = []
    for i in range(A_real.shape[1]):
        d = aitchison_distance(A_real[:, i], A_pred[:, i], min_val)
        distances.append(d)
    return np.mean(distances)


def eval_rmse(A_real, A_pred):
    """Root Mean Squared Error."""
    return np.sqrt(np.mean((A_real - A_pred) ** 2))


def eval_mae(A_real, A_pred):
    """Mean Absolute Error."""
    return np.mean(np.abs(A_real - A_pred))


def correlation_total(A_real, A_pred, method='pearson'):
    """Global correlation on flattened matrices."""
    if np.var(A_pred.flatten()) == 0:
        return -1
    
    if method == 'pearson':
        corr, _ = stats.pearsonr(A_real.flatten(), A_pred.flatten())
    else:
        corr, _ = stats.spearmanr(A_real.flatten(), A_pred.flatten())
    
    return corr


def correlation_column(A_real, A_pred, method='pearson'):
    """Mean correlation across columns (samples)."""
    correlations = []
    for i in range(A_real.shape[1]):
        if np.std(A_pred[:, i]) > 0 and np.std(A_real[:, i]) > 0:
            if method == 'pearson':
                corr, _ = stats.pearsonr(A_real[:, i], A_pred[:, i])
            else:
                corr, _ = stats.spearmanr(A_real[:, i], A_pred[:, i])
            correlations.append(corr)
    
    if len(correlations) == 0:
        return -1
    
    return np.mean(correlations)


def correlation_row(A_real, A_pred, method='pearson'):
    """Mean correlation across rows (cell types)."""
    correlations = []
    for i in range(A_real.shape[0]):
        if np.std(A_pred[i, :]) > 0 and np.std(A_real[i, :]) > 0:
            if method == 'pearson':
                corr, _ = stats.pearsonr(A_real[i, :], A_pred[i, :])
            else:
                corr, _ = stats.spearmanr(A_real[i, :], A_pred[i, :])
            correlations.append(corr)
    
    if len(correlations) == 0:
        return -1
    
    return np.mean(correlations)


def normalize_scores(scores, best, worst):
    """Normalize scores between 0 and 1 (1 is best)."""
    normalized = {}
    
    for name, value in scores.items():
        if np.isnan(value):
            normalized[name] = np.nan
            continue
        
        b = best[name]
        w = worst[name]
        
        if b == w:
            normalized[name] = 1.0 if value == b else 0.0
        else:
            # Linear normalization
            norm = (value - w) / (b - w)
            normalized[name] = np.clip(norm, 0, 1)
    
    return normalized


def scoring_function(A_real, A_pred, partial_gt=False):
    """
    Compute all metrics and aggregate score.
    
    Args:
        A_real: Ground truth (cell_types x samples)
        A_pred: Predictions (cell_types x samples)
        partial_gt: If True, only use row correlations (for invivo)
    
    Returns:
        Dictionary with all metrics and aggregate score
    """
    
    if partial_gt:
        # For partial ground truth (invivo), only row correlations matter
        pearson_row = correlation_row(A_real, A_pred, 'pearson')
        spearman_row = correlation_row(A_real, A_pred, 'spearman')
        
        # Score is geometric mean of row correlations
        # Normalize to 0-1 scale (correlation is already -1 to 1)
        p_norm = (pearson_row + 1) / 2
        s_norm = (spearman_row + 1) / 2
        
        score_aggreg = np.sqrt(p_norm * s_norm)  # Geometric mean
        
        return {
            'rmse': np.nan,
            'mae': np.nan,
            'aitchison': np.nan,
            'pearson_tot': np.nan,
            'pearson_col': np.nan,
            'pearson_row': pearson_row,
            'spearman_tot': np.nan,
            'spearman_col': np.nan,
            'spearman_row': spearman_row,
            'score_aggreg': score_aggreg,
        }
    
    # Standard scoring
    metrics = {
        'rmse': eval_rmse(A_real, A_pred),
        'mae': eval_mae(A_real, A_pred),
        'aitchison': eval_aitchison(A_real, A_pred),
        'pearson_tot': correlation_total(A_real, A_pred, 'pearson'),
        'pearson_col': correlation_column(A_real, A_pred, 'pearson'),
        'pearson_row': correlation_row(A_real, A_pred, 'pearson'),
        'spearman_tot': correlation_total(A_real, A_pred, 'spearman'),
        'spearman_col': correlation_column(A_real, A_pred, 'spearman'),
        'spearman_row': correlation_row(A_real, A_pred, 'spearman'),
    }
    
    # Best possible scores
    best = {
        'rmse': 0.0,
        'mae': 0.0,
        'aitchison': 0.0,
        'pearson_tot': 1.0,
        'pearson_col': 1.0,
        'pearson_row': 1.0,
        'spearman_tot': 1.0,
        'spearman_col': 1.0,
        'spearman_row': 1.0,
    }
    
    # Worst case: predict minimum for maximum
    fake_worst = np.zeros_like(A_real) + 1e-9
    for j in range(A_real.shape[1]):
        fake_worst[np.argmin(A_real[:, j]), j] = 1.0
    
    worst = {
        'rmse': eval_rmse(A_real, fake_worst),
        'mae': eval_mae(A_real, fake_worst),
        'aitchison': eval_aitchison(A_real, fake_worst),
        'pearson_tot': -1.0,
        'pearson_col': -1.0,
        'pearson_row': -1.0,
        'spearman_tot': -1.0,
        'spearman_col': -1.0,
        'spearman_row': -1.0,
    }
    
    # Normalize scores
    normalized = normalize_scores(metrics, best, worst)
    
    # Weights (as per challenge scoring.R)
    weights = {
        'rmse': 1/6,
        'mae': 1/6,
        'aitchison': 1/3,
        'pearson_tot': 1/18,
        'pearson_col': 1/18,
        'pearson_row': 1/18,
        'spearman_tot': 1/18,
        'spearman_col': 1/18,
        'spearman_row': 1/18,
    }
    
    # Compute aggregate score
    values = [normalized[k] for k in weights.keys()]
    w = [weights[k] for k in weights.keys()]
    
    score_aggreg = weighted_geometric_mean(values, w)
    
    metrics['score_aggreg'] = score_aggreg
    
    return metrics


def read_hdf5(filepath):
    """Read HDF5 file and return dictionary of DataFrames."""
    data = {}
    
    def decode_strings(arr):
        if arr.dtype.kind == 'S' or arr.dtype.kind == 'O':
            return [x.decode('utf-8') if isinstance(x, bytes) else str(x) for x in arr]
        return [str(x) for x in arr]
    
    with h5py.File(filepath, 'r') as f:
        for group_name in f.keys():
            grp = f[group_name]
            
            if 'data' not in grp:
                continue
            
            values = grp['data'][:]
            
            # Handle structured arrays
            if values.dtype.names is not None:
                samples = list(values.dtype.names)
                values = np.array([values[name] for name in samples])
            else:
                samples = None
            
            # Case 1: Ground Truth (samples x cell_types) -> want (cell_types x samples)
            if 'groundtruth' in str(filepath).lower() or 'groundtruth' in group_name.lower():
                if 'genes' in grp:
                    # 'genes' dataset actually contains cell types in GT files
                    row_names = decode_strings(grp['genes'][:])
                elif 'cell_types' in grp:
                    row_names = decode_strings(grp['cell_types'][:])
                else:
                    # If (samples, cell_types), cell_types is shape[1]
                    row_names = [f"cell_{i}" for i in range(values.shape[1])]
                
                # Samples are shape[0]
                if 'samples' in grp:
                    col_names = decode_strings(grp['samples'][:])
                elif samples:
                    col_names = samples
                else:
                    col_names = [f"sample_{i}" for i in range(values.shape[0])]
                
                # Start with (samples, cell_types)
                # We want (cell_types, samples)
                # DataFrame constructor takes (data, index=rows, columns=cols)
                # If we pass transposed data: (cell_types, samples)
                if values.shape[0] == len(col_names) and values.shape[1] == len(row_names):
                     df = pd.DataFrame(values.T, index=row_names, columns=col_names)
                elif values.shape[0] == len(row_names) and values.shape[1] == len(col_names):
                     df = pd.DataFrame(values, index=row_names, columns=col_names)
                else:
                     # Fallback
                     df = pd.DataFrame(values.T)

            # Case 2: Mix/Ref Data
            else:
                # Row names usually features (genes/cpg)
                if 'genes' in grp:
                    row_names = decode_strings(grp['genes'][:])
                elif 'CpG_sites' in grp:
                    row_names = decode_strings(grp['CpG_sites'][:])
                elif 'cell_types' in grp:
                    row_names = decode_strings(grp['cell_types'][:])
                else:
                    row_names = None

                # Col names usually samples or cell_types
                if 'samples' in grp:
                    col_names = decode_strings(grp['samples'][:])
                elif samples:
                    col_names = samples
                elif 'cell_types' in grp:
                     col_names = decode_strings(grp['cell_types'][:])
                else:
                    col_names = None

                # Determine orientation
                # If values is (samples, features) -> transpose to (features, samples)
                # Check dimensions against names if available
                
                if row_names and col_names:
                     # Perfect case
                     if values.shape == (len(col_names), len(row_names)):
                         # (samples, features)
                         df = pd.DataFrame(values.T, index=row_names, columns=col_names)
                     elif values.shape == (len(row_names), len(col_names)):
                         # (features, samples)
                         df = pd.DataFrame(values, index=row_names, columns=col_names)
                     else:
                         # Mismatch?
                         df = pd.DataFrame(values) # Fallback
                
                elif row_names:
                     # Only rows (features) known. Assume other dim is samples.
                     if values.shape[1] == len(row_names):
                         # (samples, features)
                         col_names = [f"sample_{i}" for i in range(values.shape[0])]
                         df = pd.DataFrame(values.T, index=row_names, columns=col_names)
                     elif values.shape[0] == len(row_names):
                         # (features, samples)
                         col_names = [f"sample_{i}" for i in range(values.shape[1])]
                         df = pd.DataFrame(values, index=row_names, columns=col_names)
                     else:
                         df = pd.DataFrame(values)
                
                else:
                     # No names? Use defaults
                     df = pd.DataFrame(values)

            data[group_name] = df
    
    return data


# =============================================================================
# MAIN
# =============================================================================

if __name__ == "__main__":
    print("=" * 60)
    print("HADACA3 - Local Scoring")
    print("=" * 60)
    
    data_dir = Path("data")
    
    # Load predictions
    predictions_file = Path("submissions/prediction.h5")
    
    if not predictions_file.exists():
        print("ERROR: No prediction file found. Run submission_script.py first.")
        exit(1)
    
    print(f"\nLoading predictions: {predictions_file}")
    predictions = read_hdf5(predictions_file)
    
    # Find ground truth files
    gt_files = list(data_dir.glob("groundtruth*.h5"))
    
    if not gt_files:
        print("\nWARNING: No ground truth files found in data directory.")
        print("Cannot compute local scores without ground truth.")
        print("\nPrediction shapes:")
        for name, pred in predictions.items():
            print(f"  {name}: {pred.shape}")
        exit(0)
    
    print(f"\nFound {len(gt_files)} ground truth files")
    
    # Score each dataset
    all_scores = {}
    
    for gt_file in sorted(gt_files):
        # Extract dataset name
        name = gt_file.stem.replace("groundtruth1_", "").replace("groundtruth_", "")
        name = name.replace("_pdac", "")
        
        print(f"\n{'='*60}")
        print(f"Dataset: {name}")
        print("=" * 60)
        
        # Load ground truth
        gt_data = read_hdf5(gt_file)
        
        if 'groundtruth' in gt_data:
            A_real = gt_data['groundtruth'].values
            cell_types = list(gt_data['groundtruth'].index)
        else:
            key = list(gt_data.keys())[0]
            A_real = gt_data[key].values
            cell_types = list(gt_data[key].index)
        
        # Get prediction
        if name not in predictions:
            print(f"  WARNING: No prediction for dataset {name}")
            continue
        
        A_pred = predictions[name].values
        
        # Align cell types
        if hasattr(predictions[name], 'index'):
            pred_cell_types = list(predictions[name].index)
            if set(pred_cell_types) != set(cell_types):
                print(f"  WARNING: Cell types mismatch")
                print(f"    Predicted: {pred_cell_types}")
                print(f"    Ground truth: {cell_types}")
        
        print(f"  Shape - GT: {A_real.shape}, Pred: {A_pred.shape}")
        
        # Align predictions with ground truth cell types
        pred_df = predictions[name]
        gt_df = gt_data['groundtruth'] if 'groundtruth' in gt_data else gt_data[list(gt_data.keys())[0]]
        
        # Find common cell types
        common_ct = [ct for ct in gt_df.index if ct in pred_df.index]
        
        if len(common_ct) < len(gt_df.index):
            print(f"  Note: Aligning {len(common_ct)}/{len(gt_df.index)} cell types")
        
        if len(common_ct) == 0:
            print(f"  ERROR: No common cell types found!")
            continue
        
        # Align both matrices
        A_real_aligned = gt_df.loc[common_ct].values
        A_pred_aligned = pred_df.loc[common_ct].values
        
        # Renormalize predictions to sum to 1 after alignment
        col_sums = A_pred_aligned.sum(axis=0, keepdims=True)
        col_sums[col_sums == 0] = 1
        A_pred_aligned = A_pred_aligned / col_sums
        
        print(f"  Aligned shapes - GT: {A_real_aligned.shape}, Pred: {A_pred_aligned.shape}")
        
        # Detect if this is invivo (partial ground truth)
        is_partial_gt = 'invivo' in name.lower() or set(common_ct) == {'basal', 'classic'}
        
        if is_partial_gt:
            print("  Note: Using partial ground truth scoring (row correlations only)")
        
        # Compute scores
        scores = scoring_function(A_real_aligned, A_pred_aligned, partial_gt=is_partial_gt)
        all_scores[name] = scores
        
        print(f"\n  Metrics:")
        if not is_partial_gt:
            print(f"    RMSE:          {scores['rmse']:.4f}")
            print(f"    MAE:           {scores['mae']:.4f}")
            print(f"    Aitchison:     {scores['aitchison']:.4f}")
            print(f"    Pearson tot:   {scores['pearson_tot']:.4f}")
            print(f"    Pearson col:   {scores['pearson_col']:.4f}")
        print(f"    Pearson row:   {scores['pearson_row']:.4f}")
        if not is_partial_gt:
            print(f"    Spearman tot:  {scores['spearman_tot']:.4f}")
            print(f"    Spearman col:  {scores['spearman_col']:.4f}")
        print(f"    Spearman row:  {scores['spearman_row']:.4f}")
        print(f"\n    AGGREGATE SCORE: {scores['score_aggreg']:.4f}")
    
    # Summary
    if all_scores:
        print("\n" + "=" * 60)
        print("SUMMARY")
        print("=" * 60)
        
        agg_scores = [s['score_aggreg'] for s in all_scores.values() if not np.isnan(s['score_aggreg'])]
        
        print(f"\nDataset scores:")
        for name, scores in all_scores.items():
            print(f"  {name}: {scores['score_aggreg']:.4f}")
        
        if agg_scores:
            print(f"\nMedian performance: {np.median(agg_scores):.4f}")
            print(f"Mean performance:   {np.mean(agg_scores):.4f}")
