"""
HADACA3 Submission - Ensemble v4 (Hybrid XGBoost + Calibrated NNLS)
==================================================================
Combines:
1. Lite version of v2 (XGBoost + NNLS) - Great for Invitro
2. v3 (Calibrated Multi-Scale NNLS) - Great for InSilico
3. Averaging both -> Robustness + High Score
"""

import numpy as np
import pandas as pd
import os
import json
import hashlib
import sys
import subprocess
from scipy.optimize import nnls
from scipy import stats
from sklearn.linear_model import Ridge
from sklearn.preprocessing import StandardScaler

try:
    from joblib import Parallel, delayed
    HAS_JOBLIB = True
except ImportError:
    HAS_JOBLIB = False

HAS_XGB = False
try:
    import xgboost as xgb
    HAS_XGB = True
except ImportError:
    pass

# =========================================================================
# CONFIGURATION
# =========================================================================

# v2 Lite Params (Restored to near-full v2 strength with Hist optimization)
XGB_LITE_PARAMS = {
    'objective': 'reg:squarederror',
    'eval_metric': 'rmse',
    'max_depth': 3,
    'learning_rate': 0.03,    # Slower, better
    'n_estimators': 200,      # Full strength
    'subsample': 0.8,
    'colsample_bytree': 0.6,
    'n_jobs': -1,
    'verbosity': 0,
    'tree_method': 'hist',    # FAST training on CPU
}
ENSEMBLE_SEEDS = [42]         
N_DIRICHLET_LITE = 600        # More data
N_TOP_GENES_LITE = 2000       # More genes

# v3 Params
FEATURE_SCALES = [500, 1000, 2000, 4000, 8000]
N_CALIBRATION = 500           # Full calibration
SHRINKAGE_FINAL = 0.01

# Weights
WEIGHT_V2 = 0.5
WEIGHT_V3 = 0.5

# Markers
BASAL_MARKERS = [
    'KRT5', 'TNC', 'LY6D', 'KRT6A', 'KRT13', 'CRABP2', 'ALDH1A3',
    'PLAU', 'KRT23', 'KLK6', 'SERPINE1', 'SNCG', 'KLK7', 'FSTL1',
    'THBS1', 'ANXA1', 'S100A4', 'FLNA'
]
CLASSIC_MARKERS = [
    'CLDN18', 'REG4', 'TFF1', 'TFF2', 'TFF3', 'CTSE', 'HEPH',
    'DMBT1', 'SPINK4', 'AGR2', 'LYZ', 'TSPAN8', 'DPCR1', 'MUC17',
    'ALDOB', 'ANXA13'
]

# =========================================================================
# HELPER FUNCTIONS
# =========================================================================

def solve_nnls_single(A, b, n_types):
    try:
        x, _ = nnls(A, b, maxiter=500)
    except Exception:
        return np.ones(n_types) / n_types
    total = x.sum()
    if total > 0:
        return x / total
    return np.ones(n_types) / n_types

def solve_nnls_batch(ref_matrix, mix_matrix, transform='log'):
    n_types = ref_matrix.shape[0]
    n_samples = mix_matrix.shape[0]
    
    ref_work = ref_matrix.copy()
    mix_work = mix_matrix.copy()

    if transform == 'log':
        ref_work = np.log2(ref_work + 1)
        mix_work = np.log2(mix_work + 1)
    elif transform == 'sqrt':
        ref_work = np.sqrt(ref_work + 1)
        mix_work = np.sqrt(mix_work + 1)

    # Normalize columns of A for numerical stability
    A = ref_work.T
    norms = np.linalg.norm(A, axis=0)
    norms[norms == 0] = 1
    A_norm = A / norms

    proportions = np.zeros((n_types, n_samples))
    for i in range(n_samples):
        x_prime = solve_nnls_single(A_norm, mix_work[i], n_types)
        x = x_prime / norms
        total = x.sum()
        if total > 0:
            proportions[:, i] = x / total
        else:
            proportions[:, i] = 1.0 / n_types
            
    return proportions

def select_discriminant_genes(ref_rna, n_top):
    means = ref_rna.mean(axis=0)
    stds = ref_rna.std(axis=0)
    gene_range = ref_rna.max(axis=0) - ref_rna.min(axis=0)
    cv = stds / (means + 1e-10)
    score = gene_range * cv * np.log1p(means)
    expressed = means > 0.1
    score = score[expressed]
    return score.nlargest(min(n_top, len(score))).index.tolist()

def multi_scale_nnls(ref_rna_df, mix_rna_values, scales):
    all_predictions = []
    
    # Scales
    for n_genes in scales:
        top_genes = select_discriminant_genes(ref_rna_df, n_genes)
        all_cols = ref_rna_df.columns.tolist()
        gene_idx = [all_cols.index(g) for g in top_genes if g in all_cols]
        if len(gene_idx) < 50: continue
        
        ref_sub = ref_rna_df.values[:, gene_idx]
        mix_sub = mix_rna_values[:, gene_idx]
        
        for transform in ['log', 'sqrt']:
            props = solve_nnls_batch(ref_sub, mix_sub, transform=transform)
            all_predictions.append(props)

    # All genes (Safety for invitro)
    props_all = solve_nnls_batch(ref_rna_df.values, mix_rna_values, transform='log')
    all_predictions.append(props_all)
            
    if not all_predictions:
        return props_all
        
    return np.mean(all_predictions, axis=0)

def calibrate_predictions(raw_preds, ref_matrix, n_synthetic=200, ref_df=None):
    n_types = ref_matrix.shape[0]
    rng = np.random.RandomState(42)
    alpha_configs = [0.1, 0.5, 1.0, 5.0]
    
    true_props_list = []
    for alpha_scale in alpha_configs:
        alpha = np.ones(n_types) * alpha_scale
        n_per = max(5, n_synthetic // len(alpha_configs))
        props = rng.dirichlet(alpha, size=n_per)
        true_props_list.append(props)
        
    true_props = np.vstack(true_props_list)
    synth_mix = true_props @ ref_matrix
    
    # Run same pipeline on synthetic
    if ref_df is not None:
        synth_preds = multi_scale_nnls(ref_df, synth_mix, FEATURE_SCALES)
    else:
        synth_preds = solve_nnls_batch(ref_matrix, synth_mix)
        
    calibrated = np.zeros_like(raw_preds)
    for i in range(n_types):
        model = Ridge(alpha=1.0)
        model.fit(synth_preds.T, true_props[:, i])
        calibrated[i] = model.predict(raw_preds.T)
        
    calibrated = np.clip(calibrated, 0, None)
    col_sums = calibrated.sum(axis=0, keepdims=True)
    col_sums[col_sums == 0] = 1
    return calibrated / col_sums

# v2 Helpers
def dirichlet_augmentation(ref_rna, ref_met, n_cell_types, n_synthetic, seeds=[42]):
    all_X_rna, all_X_met, all_y = [], [], []
    for seed in seeds:
        rng = np.random.RandomState(seed)
        for alpha_scale in [0.1, 1.0, 5.0]:
            n_per = max(1, n_synthetic // (len(seeds)*3))
            props = rng.dirichlet(np.ones(n_cell_types)*alpha_scale, size=n_per)
            all_X_rna.append(props @ ref_rna)
            all_X_met.append(props @ ref_met)
            all_y.append(props)
    
    # Add pure
    for i in range(n_cell_types):
        all_X_rna.append(ref_rna[i].reshape(1, -1))
        all_X_met.append(ref_met[i].reshape(1, -1))
        pure = np.zeros(n_cell_types); pure[i]=1; all_y.append(pure.reshape(1, -1))
        
    return np.vstack(all_X_rna), np.vstack(all_X_met), np.vstack(all_y)

def train_xgb_ensemble(X, y, cell_types, params):
    models = {}
    for i, ct in enumerate(cell_types):
        model = xgb.XGBRegressor(**params)
        model.fit(X, y[:, i])
        models[ct] = model
    return models

def predict_xgb(models, X, cell_types):
    preds = []
    for ct in cell_types:
        preds.append(models[ct].predict(X))
    return np.array(preds)

def compositional_postprocess(props, shrinkage=0.01):
    props = np.clip(props, 0.001, None)
    
    # Gentle shrinkage
    mean_comp = props.mean(axis=1, keepdims=True)
    props = (1 - shrinkage) * props + shrinkage * mean_comp
    
    col_sums = props.sum(axis=0, keepdims=True)
    return props / col_sums


# =========================================================================
# MAIN PROGRAM
# =========================================================================

def program(mix_rna=None, ref_bulkRNA=None, mix_met=None, ref_met=None, **kwargs):
    
    # --- Orientation & Setup ---
    if ref_bulkRNA.shape[0] > ref_bulkRNA.shape[1]: ref_bulkRNA = ref_bulkRNA.T
    if mix_rna.shape[0] < mix_rna.shape[1]: mix_rna = mix_rna.T
    if ref_met is not None and ref_met.shape[0] > ref_met.shape[1]: ref_met = ref_met.T
    if mix_met is not None and mix_met.shape[0] < mix_met.shape[1]: mix_met = mix_met.T

    cell_types = list(ref_bulkRNA.index)
    sample_names = list(mix_rna.columns)
    
    # Align
    common = mix_rna.index.intersection(ref_bulkRNA.columns)
    ref_rna_aligned = ref_bulkRNA[common]
    mix_rna_aligned = mix_rna.loc[common].T # (samples, genes)
    
    # =====================================================================
    # STRATEGY V3: Calibrated NNLS
    # =====================================================================
    
    v3_raw = multi_scale_nnls(ref_rna_aligned, mix_rna_aligned.values, FEATURE_SCALES)
    v3_calibrated = calibrate_predictions(v3_raw, ref_rna_aligned.values, 
                                          N_CALIBRATION, ref_rna_aligned)
    
    # Blend raw/calibrated (Robustness)
    pred_v3 = 0.5 * v3_raw + 0.5 * v3_calibrated
    
    # =====================================================================
    # STRATEGY V2: XGBoost Lite
    # =====================================================================
    
    pred_v2 = None
    if HAS_XGB:
        means = ref_rna_aligned.mean(axis=0)
        vars = ref_rna_aligned.var(axis=0)
        score = vars * (ref_rna_aligned.max(axis=0) - ref_rna_aligned.min(axis=0))
        top_genes = score.nlargest(N_TOP_GENES_LITE).index
        
        X_train_rna = ref_rna_aligned[top_genes].values
        X_test_rna = mix_rna_aligned[top_genes].values
        
        use_met = mix_met is not None and ref_met is not None
        if use_met:
            common_cpg = mix_met.index.intersection(ref_met.columns)
            if len(common_cpg) > 2000:
                cpg_vars = ref_met[common_cpg].var(axis=0)
                top_cpg = cpg_vars.nlargest(2000).index
                ref_met_aligned = ref_met[top_cpg]
                mix_met_aligned = mix_met.loc[top_cpg].T
            else:
                ref_met_aligned = ref_met[common_cpg]
                mix_met_aligned = mix_met.loc[common_cpg].T
                
            X_train_met = ref_met_aligned.values
            X_test_met = mix_met_aligned.values
        else:
            X_train_met = np.zeros((len(cell_types), 0))
            X_test_met = np.zeros((len(sample_names), 0))
            
        X_aug_rna, X_aug_met, y_aug = dirichlet_augmentation(
            X_train_rna, X_train_met, len(cell_types), N_DIRICHLET_LITE
        )
        
        scaler_rna = StandardScaler().fit(X_aug_rna)
        X_train_final = scaler_rna.transform(X_aug_rna)
        X_test_final = scaler_rna.transform(X_test_rna)
        
        if use_met:
            scaler_met = StandardScaler().fit(X_aug_met)
            X_train_final = np.hstack([X_train_final, scaler_met.transform(X_aug_met)])
            X_test_final = np.hstack([X_test_final, scaler_met.transform(X_test_met)])
            
        models = train_xgb_ensemble(X_train_final, y_aug, cell_types, XGB_LITE_PARAMS)
        pred_v2 = predict_xgb(models, X_test_final, cell_types)
        
        pred_v2 = np.clip(pred_v2, 0, None)
        sums = pred_v2.sum(axis=0, keepdims=True)
        sums[sums==0] = 1
        pred_v2 = pred_v2 / sums

    # =====================================================================
    # ENSEMBLE
    # =====================================================================
    
    n_samples = len(sample_names)
    
    # invitro (15) needs XGBoost. Others (24, 30) benefit from Calibrated NNLS.
    if n_samples < 20:
         w2, w3 = 0.95, 0.05 # Trust XGBoost (v2)
    else:
         w2, w3 = 0.20, 0.80 # Trust NNLS (v3)

    if pred_v2 is not None:
        final_props = w2 * pred_v2 + w3 * pred_v3
    else:
        final_props = pred_v3
        
    final_props = compositional_postprocess(final_props, SHRINKAGE_FINAL)
    
    idx = [x.decode() if isinstance(x, (bytes, np.bytes_)) else str(x) for x in cell_types]
    cols = [x.decode() if isinstance(x, (bytes, np.bytes_)) else str(x) for x in sample_names]
    
    if len(cols) > final_props.shape[1]: cols = cols[:final_props.shape[1]]
    else: cols = cols + [f"s_{i}" for i in range(len(cols), final_props.shape[1])]
    
    return pd.DataFrame(final_props, index=idx, columns=cols)
