"""
HADACA3 Submission - Adaptive Hybrid Deconvolution
====================================================
Phase 1: Hybrid NNLS + XGBoost ensemble with 5 cell types
    - NNLS for compositional baseline (natural sum-to-1)
    - XGBoost ensemble for correction/refinement
    - Dirichlet augmentation for diverse training proportions
    - Feature selection (top variable genes)
    - Parallelized via joblib

Detection (fast cascade):
    1. Cache check: serialized decisions
    2. Marker variance + anti-correlation check

Phase 2 (conditional): Redistribute basal/classic using classic markers
"""


def program(mix_rna=None, ref_bulkRNA=None, mix_met=None, ref_met=None, **kwargs):
    """
    Adaptive deconvolution for HADACA3 challenge.
    Hybrid NNLS + XGBoost with conditional basal/classic correction.
    """

    # =========================================================================
    # IMPORTS
    # =========================================================================
    import subprocess
    import sys
    import numpy as np
    import pandas as pd
    import json
    import hashlib
    import os
    from sklearn.preprocessing import StandardScaler
    from sklearn.linear_model import Ridge
    from scipy.optimize import nnls
    from scipy import stats

    try:
        from joblib import Parallel, delayed
        HAS_JOBLIB = True
    except ImportError:
        HAS_JOBLIB = False

    HAS_XGB = False
    HAS_XGB = False
    try:
        import xgboost as xgb
        HAS_XGB = True
    except ImportError:
        try:
            subprocess.check_call(
                [sys.executable, "-m", "pip", "install", "xgboost"],
                stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
            )
            import xgboost as xgb
            HAS_XGB = True
        except Exception:
            pass  # Will use NNLS-only fallback

    # =========================================================================
    # CONFIGURATION
    # =========================================================================

    XGB_PARAMS = {
        'objective': 'reg:squarederror',
        'eval_metric': 'rmse',
        'max_depth': 4,
        'learning_rate': 0.05,
        'n_estimators': 150,
        'subsample': 0.8,
        'colsample_bytree': 0.7,
        'colsample_bylevel': 0.7,
        'reg_alpha': 0.5,
        'reg_lambda': 3.0,
        'min_child_weight': 3,
        'gamma': 0.05,
        'n_jobs': -1,
        'verbosity': 0,
    }

    ENSEMBLE_SEEDS = [42, 123, 456]
    N_DIRICHLET = 200        # reduced for speed
    N_TOP_GENES = 2000       
    NNLS_WEIGHT = 0.5        # weight of NNLS prediction in hybrid
    XGB_WEIGHT = 0.5         # weight of XGBoost prediction in hybrid
    MIN_PROPORTION = 0.005   # pseudo-count for compositional smoothing

    # Discriminant markers
    BASAL_MARKERS = [
        'KRT5', 'TNC', 'LY6D', 'KRT6A', 'KRT13', 'CRABP2', 'ALDH1A3',
        'PLAU', 'KRT23', 'KLK6', 'SERPINE1', 'SNCG', 'KLK7', 'FSTL1',
        'THBS1', 'ANXA1', 'S100A4', 'FLNA'
    ]
    CLASSIC_MARKERS = [
        'CLDN18', 'REG4', 'TFF1', 'TFF2', 'TFF3', 'CTSE', 'HEPH',
        'DMBT1', 'SPINK4', 'AGR2', 'LYZ', 'TSPAN8', 'DPCR1', 'MUC17',
        'ALDOB', 'ANXA13'
    ]

    # Detection thresholds
    MARKER_VARIANCE_THRESHOLD = 0.1
    ANTICORR_THRESHOLD = -0.5
    N_SYNTH_PHASE2 = 200

    try:
        _dir = os.path.dirname(os.path.abspath(__file__))
    except NameError:
        _dir = os.getcwd()
    CACHE_FILE = os.path.join(_dir, 'dataset_analysis_cache.json')

    # =========================================================================
    # HELPER FUNCTIONS
    # =========================================================================

    def select_top_variable_genes(ref_rna, n_top):
        """Select genes with highest variance across cell types."""
        gene_var = ref_rna.var(axis=0)
        # Also consider range (max - min) to capture discriminative genes
        gene_range = ref_rna.max(axis=0) - ref_rna.min(axis=0)
        # Combined score: variance * range
        combined_score = gene_var * gene_range
        top_genes = combined_score.nlargest(n_top).index.tolist()
        return top_genes

    def dirichlet_augmentation(ref_rna, ref_met, n_cell_types, n_synthetic,
                               seeds=None):
        """
        Generate training data using Dirichlet-distributed proportions.
        Creates synthetic mixtures with known cell type proportions.
        Much more diverse than simple pairwise mixup.
        """
        if seeds is None:
            seeds = [42]

        all_X_rna = []
        all_X_met = []
        all_y = []

        for seed in seeds:
            rng = np.random.RandomState(seed)

            # Multiple Dirichlet concentration levels for diversity
            for alpha_scale in [0.1, 0.3, 0.5, 1.0, 2.0, 5.0]:
                n_per = max(1, n_synthetic // (len(seeds) * 6))
                alpha = np.ones(n_cell_types) * alpha_scale
                proportions = rng.dirichlet(alpha, size=n_per)

                for props in proportions:
                    # Create mixture
                    mix_rna = props @ ref_rna
                    mix_met = props @ ref_met
                    all_X_rna.append(mix_rna)
                    all_X_met.append(mix_met)
                    all_y.append(props)

        # Include pure cell type references
        for i in range(n_cell_types):
            all_X_rna.append(ref_rna[i])
            all_X_met.append(ref_met[i])
            pure = np.zeros(n_cell_types)
            pure[i] = 1.0
            all_y.append(pure)

        # Include extreme mixtures (one type dominant)
        rng = np.random.RandomState(999)
        for i in range(n_cell_types):
            for _ in range(5):
                props = rng.dirichlet(np.ones(n_cell_types) * 0.05)
                # Force one type to dominate
                props[i] = 0.7 + rng.uniform(0, 0.25)
                props = props / props.sum()
                mix_rna = props @ ref_rna
                mix_met = props @ ref_met
                all_X_rna.append(mix_rna)
                all_X_met.append(mix_met)
                all_y.append(props)

        return (
            np.array(all_X_rna),
            np.array(all_X_met),
            np.array(all_y)
        )

    def solve_nnls(ref_matrix, mix_matrix, use_log=True):
        """
        Non-negative least squares deconvolution.
        Returns proportions (n_cell_types, n_samples).
        ref_matrix: (n_cell_types, n_genes)
        mix_matrix: (n_samples, n_genes)
        """
        n_samples = mix_matrix.shape[0]
        n_cell_types = ref_matrix.shape[0]
        proportions = np.zeros((n_cell_types, n_samples))

        ref_work = ref_matrix.copy()
        mix_work = mix_matrix.copy()

        if use_log:
            # Log-transform: log2(x + 1) — reduces impact of high-expression genes
            ref_work = np.log2(ref_work + 1)
            mix_work = np.log2(mix_work + 1)

        # A = ref_matrix.T (genes x cell_types), b = mix_i (genes)
        A = ref_work.T

        for i in range(n_samples):
            b = mix_work[i]
            x, _ = nnls(A, b)
            # Normalize to sum to 1
            total = x.sum()
            if total > 0:
                proportions[:, i] = x / total
            else:
                proportions[:, i] = 1.0 / n_cell_types

        return proportions

    def solve_nnls_multi(ref_matrix, mix_matrix):
        """
        Run NNLS with multiple preprocessing strategies and average.
        This gives more robust estimates.
        """
        # Raw NNLS
        props_raw = solve_nnls(ref_matrix, mix_matrix, use_log=False)
        # Log-NNLS
        props_log = solve_nnls(ref_matrix, mix_matrix, use_log=True)
        # Average
        return 0.5 * props_raw + 0.5 * props_log

    def prepare_features(X_rna, X_met, rna_scaler, met_scaler, fit=False):
        """Prepare and combine RNA and MET features."""
        if X_rna.shape[1] == 0:
            return X_met
        if X_met.shape[1] == 0:
            if fit:
                return rna_scaler.fit_transform(X_rna)
            return rna_scaler.transform(X_rna)

        if fit:
            rna_scaled = rna_scaler.fit_transform(X_rna)
            met_scaled = met_scaler.fit_transform(X_met)
        else:
            rna_scaled = rna_scaler.transform(X_rna)
            met_scaled = met_scaler.transform(X_met)

        return np.hstack([rna_scaled, met_scaled])

    def _train_single_model(X, y_col, params):
        """Train a single XGBoost model."""
        model = xgb.XGBRegressor(**params)
        model.fit(X, y_col)
        return model

    def train_ensemble(X, y, cell_types, seeds, xgb_params):
        """Train ensemble of XGBoost models (parallel if joblib available)."""
        jobs = []
        for seed in seeds:
            params = xgb_params.copy()
            params['random_state'] = seed
            for i, ct in enumerate(cell_types):
                jobs.append((ct, params.copy(), y[:, i]))

        if HAS_JOBLIB:
            trained = Parallel(n_jobs=-1, prefer='threads')(
                delayed(_train_single_model)(X, y_col, params)
                for (ct, params, y_col) in jobs
            )
        else:
            trained = [_train_single_model(X, y_col, params)
                       for (ct, params, y_col) in jobs]

        models = {ct: [] for ct in cell_types}
        for (ct, _, _), model in zip(jobs, trained):
            models[ct].append(model)

        return models

    def predict_ensemble(models, X, cell_types):
        """Predict using ensemble (mean aggregation)."""
        all_models = []
        for ct in cell_types:
            for m in models[ct]:
                all_models.append((ct, m))

        if HAS_JOBLIB:
            all_preds = Parallel(n_jobs=-1, prefer='threads')(
                delayed(lambda m, x: m.predict(x))(m, X)
                for (ct, m) in all_models
            )
        else:
            all_preds = [m.predict(X) for (ct, m) in all_models]

        predictions = {ct: [] for ct in cell_types}
        for (ct, _), pred in zip(all_models, all_preds):
            predictions[ct].append(pred)

        for ct in cell_types:
            predictions[ct] = np.mean(predictions[ct], axis=0)

        return predictions

    def compute_dataset_fingerprint(mix_rna_df):
        """Compute a hash fingerprint for cache lookup."""
        shape_str = str(mix_rna_df.shape)
        flat = mix_rna_df.values.flatten()
        sample_vals = flat[:min(100, len(flat))]
        val_str = np.array2string(sample_vals, precision=4)
        combined = shape_str + val_str
        return hashlib.md5(combined.encode()).hexdigest()

    def load_cache(cache_file):
        """Load analysis cache from disk."""
        try:
            if os.path.exists(cache_file):
                with open(cache_file, 'r') as f:
                    return json.load(f)
        except (json.JSONDecodeError, IOError):
            pass
        return {}

    def save_cache(cache_file, cache):
        """Save analysis cache to disk."""
        try:
            with open(cache_file, 'w') as f:
                json.dump(cache, f, indent=2)
        except IOError:
            pass

    def detect_phase2_needed(mix_rna_df, basal_markers, classic_markers,
                             var_threshold, anticorr_threshold):
        """Detect if phase 2 correction is needed."""
        available_genes = set(mix_rna_df.index)
        n_samples = mix_rna_df.shape[1]

        details = {
            'n_samples': n_samples,
            'n_genes': len(available_genes),
        }

        if n_samples < 5:
            details['reason'] = 'too_few_samples'
            details['apply_phase2'] = False
            return False, details

        basal_score = np.zeros(n_samples)
        classic_score = np.zeros(n_samples)
        n_basal_found = 0
        n_classic_found = 0

        for marker in basal_markers:
            if marker in available_genes:
                basal_score += mix_rna_df.loc[marker].values.astype(float)
                n_basal_found += 1

        for marker in classic_markers:
            if marker in available_genes:
                classic_score += mix_rna_df.loc[marker].values.astype(float)
                n_classic_found += 1

        details['n_basal_markers'] = n_basal_found
        details['n_classic_markers'] = n_classic_found

        if n_basal_found < 3 or n_classic_found < 3:
            details['reason'] = 'insufficient_markers'
            details['apply_phase2'] = False
            return False, details

        basal_cov = np.std(basal_score) / (np.mean(basal_score) + 1e-10)
        classic_cov = np.std(classic_score) / (np.mean(classic_score) + 1e-10)
        min_cov = min(basal_cov, classic_cov)

        details['basal_cov'] = float(basal_cov)
        details['classic_cov'] = float(classic_cov)

        if min_cov < var_threshold:
            details['reason'] = 'low_marker_variance'
            details['apply_phase2'] = False
            return False, details

        if np.std(basal_score) > 0 and np.std(classic_score) > 0:
            corr, pval = stats.pearsonr(basal_score, classic_score)
        else:
            corr = 0.0
            pval = 1.0

        details['marker_anticorr'] = float(corr)
        details['marker_pval'] = float(pval)

        if corr < anticorr_threshold and pval < 0.05:
            details['reason'] = 'strong_discriminant_signal'
            details['apply_phase2'] = True
            return True, details
        else:
            details['reason'] = 'weak_discriminant_signal'
            details['apply_phase2'] = False
            return False, details

    def apply_phase2_correction(pred_array, cell_types, mix_rna_df,
                                ref_rna_aligned, basal_markers, classic_markers,
                                n_synth=200):
        """
        Redistribute basal/classic proportions using classic markers only.
        Classic markers are more robust across dataset types.
        Preserves total tumoral (basal+classic) from phase 1.
        """
        basal_idx = cell_types.index('basal')
        classic_idx = cell_types.index('classic')
        n_samples = pred_array.shape[1]

        tumoral_total = pred_array[basal_idx] + pred_array[classic_idx]

        ref_basal = ref_rna_aligned.values[basal_idx]
        ref_classic = ref_rna_aligned.values[classic_idx]

        common_genes_list = list(ref_rna_aligned.columns)
        gene_set = set(common_genes_list)
        c_marker_idx = [common_genes_list.index(g) for g in classic_markers
                        if g in gene_set]

        if len(c_marker_idx) < 3:
            return pred_array

        np.random.seed(42)
        fracs_basal = np.concatenate([
            np.linspace(0.0, 1.0, n_synth // 2),
            np.random.beta(0.5, 0.5, size=n_synth // 2)
        ])

        synth_classic_ratios = np.zeros(len(fracs_basal))
        for i, alpha in enumerate(fracs_basal):
            synth = alpha * ref_basal + (1 - alpha) * ref_classic
            c_score = synth[c_marker_idx].sum()
            c_ref_total = ref_classic[c_marker_idx].sum()
            synth_classic_ratios[i] = c_score / (c_ref_total + 1e-10)

        calibrator = Ridge(alpha=1.0)
        classic_fracs = 1.0 - fracs_basal
        calibrator.fit(synth_classic_ratios.reshape(-1, 1), classic_fracs)

        available_genes = set(mix_rna_df.index)
        classic_score = np.zeros(n_samples)

        for marker in classic_markers:
            if marker in available_genes:
                classic_score += mix_rna_df.loc[marker].values.astype(float)

        c_ref_total = ref_classic[c_marker_idx].sum()
        classic_ratio_test = classic_score / (c_ref_total + 1e-10)

        classic_frac = calibrator.predict(classic_ratio_test.reshape(-1, 1))
        classic_frac = np.clip(classic_frac, 0.01, 0.99)
        basal_frac = 1.0 - classic_frac

        corrected = pred_array.copy()
        corrected[basal_idx] = tumoral_total * basal_frac
        corrected[classic_idx] = tumoral_total * classic_frac

        return corrected

    def compositional_postprocess(props, delta=0.005, shrinkage=0.1):
        """
        Post-process for better Aitchison distance using:
        1. Multiplicative replacement (Martin-Fernandez et al. 2003)
           - statistically correct zero replacement for compositions
        2. Adaptive shrinkage toward mean composition
           - reduces extreme predictions that inflate CLR distance
        3. Entropy-based adaptive shrinkage strength
        """
        props = np.clip(props, 0, None)
        n_types, n_samples = props.shape

        for j in range(n_samples):
            col = props[:, j]
            total = col.sum()
            if total == 0:
                col[:] = 1.0 / n_types
                continue

            # Normalize
            col = col / total

            # Multiplicative replacement for zeros
            zeros = col == 0
            n_zeros = zeros.sum()
            if n_zeros > 0 and n_zeros < n_types:
                replacement = delta
                col[zeros] = replacement
                nonzero_sum = col[~zeros].sum()
                target_sum = 1.0 - n_zeros * replacement
                col[~zeros] = col[~zeros] * (target_sum / nonzero_sum)

            # Ensure minimum value
            col = np.maximum(col, delta / 2)
            col = col / col.sum()

            props[:, j] = col

        # Adaptive shrinkage toward mean composition
        if shrinkage > 0:
            mean_comp = props.mean(axis=1, keepdims=True)
            # Use entropy to make shrinkage adaptive per sample
            # Low entropy (concentrated) -> more shrinkage
            # High entropy (spread) -> less shrinkage
            for j in range(n_samples):
                col = props[:, j]
                # Compute entropy relative to max entropy
                entropy = -np.sum(col * np.log(col + 1e-15))
                max_entropy = np.log(n_types)
                rel_entropy = entropy / max_entropy
                # Low entropy -> high shrinkage; high entropy -> low shrinkage
                adaptive_shrink = shrinkage * (1.5 - rel_entropy)
                adaptive_shrink = np.clip(adaptive_shrink, 0.02, 0.25)
                props[:, j] = (1 - adaptive_shrink) * col + adaptive_shrink * mean_comp.ravel()
            # Renormalize
            col_sums = props.sum(axis=0, keepdims=True)
            props = props / col_sums

        return props

    # =========================================================================
    # PHASE 1: HYBRID NNLS + XGBOOST
    # =========================================================================

    # --- Robust Input Preprocessing ---
    # Ensure references are (n_cell_types, n_features)
    # If rows > columns and rows > 20, assume it's (n_features, n_cell_types) and transpose
    if ref_bulkRNA.shape[0] > ref_bulkRNA.shape[1] and ref_bulkRNA.shape[0] > 20:
        ref_bulkRNA = ref_bulkRNA.T
    
    if mix_rna.shape[0] < mix_rna.shape[1] and mix_rna.shape[1] > 20:
        # Assume it came as (n_samples, n_genes) -> transpose to (n_genes, n_samples)
        mix_rna = mix_rna.T

    # Handle optional Met inputs
    if ref_met is not None:
        if ref_met.shape[0] > ref_met.shape[1] and ref_met.shape[0] > 20:
            ref_met = ref_met.T

    if mix_met is not None:
        if mix_met.shape[0] < mix_met.shape[1] and mix_met.shape[1] > 20:
            mix_met = mix_met.T

    cell_types = list(ref_bulkRNA.index)
    sample_names = list(mix_rna.columns)
    n_samples = len(sample_names)
    n_cell_types = len(cell_types)

    # Align genes
    common_genes = mix_rna.index.intersection(ref_bulkRNA.columns)
    ref_rna_aligned = ref_bulkRNA[common_genes]
    mix_rna_aligned = mix_rna.loc[common_genes].T

    # Feature selection: top variable genes
    top_genes = select_top_variable_genes(ref_rna_aligned, N_TOP_GENES)
    # Keep common_genes for NNLS (uses all genes) but selected for XGBoost
    ref_rna_selected = ref_rna_aligned[top_genes]
    mix_rna_selected = mix_rna_aligned[top_genes]

    # --- NNLS Prediction (multi-strategy) ---
    nnls_props = solve_nnls_multi(ref_rna_aligned.values, mix_rna_aligned.values)

    # Handle methylation
    use_methylation = mix_met is not None and ref_met is not None

    if use_methylation:
        common_cpg = mix_met.index.intersection(ref_met.columns)
        ref_met_aligned = ref_met[common_cpg]
        mix_met_aligned = mix_met.loc[common_cpg].T

        # Select top variable CpGs
        n_select = min(len(top_genes), len(common_cpg))
        if len(common_cpg) > n_select:
            variances = ref_met_aligned.var(axis=0)
            selected_cpg = variances.nlargest(n_select).index.tolist()
            ref_met_aligned = ref_met_aligned[selected_cpg]
            mix_met_aligned = mix_met_aligned[selected_cpg]

    # --- XGBoost or Ridge Training with Dirichlet augmentation ---
    X_train_rna = ref_rna_selected.values

    if use_methylation:
        X_train_met = ref_met_aligned.values
    else:
        X_train_met = np.zeros((n_cell_types, 1))

    y_train = np.eye(n_cell_types)

    # Dirichlet augmentation
    X_aug_rna, X_aug_met, y_aug = dirichlet_augmentation(
        X_train_rna, X_train_met, n_cell_types, N_DIRICHLET,
        seeds=[42, 123, 456]
    )

    # Scale and prepare features
    rna_scaler = StandardScaler()
    met_scaler = StandardScaler()
    X_train = prepare_features(X_aug_rna, X_aug_met, rna_scaler, met_scaler,
                               fit=True)

    # Test features
    X_test_rna = mix_rna_selected.values
    if use_methylation:
        X_test_met = mix_met_aligned.values
    else:
        X_test_met = np.zeros((n_samples, 1))

    X_test = prepare_features(X_test_rna, X_test_met, rna_scaler, met_scaler,
                              fit=False)

    if HAS_XGB:
        # Train XGBoost ensemble
        models = train_ensemble(X_train, y_aug, cell_types, ENSEMBLE_SEEDS,
                                XGB_PARAMS)
        xgb_predictions = predict_ensemble(models, X_test, cell_types)
        ml_props = np.array([xgb_predictions[ct] for ct in cell_types])
    else:
        # Fallback: Ridge regression ensemble
        from sklearn.linear_model import Ridge as RidgeReg
        ml_props = np.zeros((n_cell_types, n_samples))
        for i, ct in enumerate(cell_types):
            ridge = RidgeReg(alpha=1.0)
            ridge.fit(X_train, y_aug[:, i])
            ml_props[i] = ridge.predict(X_test)

    # Clip negative values and normalize
    ml_props = np.clip(ml_props, 0, None)
    ml_sums = ml_props.sum(axis=0, keepdims=True)
    ml_sums[ml_sums == 0] = 1
    ml_props = ml_props / ml_sums

    # --- Hybrid: weighted combination ---
    pred_array = NNLS_WEIGHT * nnls_props + XGB_WEIGHT * ml_props

    # Normalize
    pred_sums = pred_array.sum(axis=0, keepdims=True)
    pred_sums[pred_sums == 0] = 1
    pred_array = pred_array / pred_sums

    # =========================================================================
    # DETECTION: SHOULD WE APPLY PHASE 2?
    # =========================================================================

    apply_phase2 = False

    if 'basal' in cell_types and 'classic' in cell_types:
        fingerprint = compute_dataset_fingerprint(mix_rna)

        cache = load_cache(CACHE_FILE)

        if fingerprint in cache:
            cached = cache[fingerprint]
            apply_phase2 = cached.get('apply_phase2', False)
        else:
            apply_phase2, details = detect_phase2_needed(
                mix_rna, BASAL_MARKERS, CLASSIC_MARKERS,
                MARKER_VARIANCE_THRESHOLD, ANTICORR_THRESHOLD
            )
            cache[fingerprint] = details
            save_cache(CACHE_FILE, cache)

    # =========================================================================
    # PHASE 2 (CONDITIONAL): REDISTRIBUTE BASAL/CLASSIC
    # =========================================================================

    if apply_phase2:
        pred_array = apply_phase2_correction(
            pred_array, cell_types, mix_rna,
            ref_rna_aligned, BASAL_MARKERS, CLASSIC_MARKERS,
            n_synth=N_SYNTH_PHASE2
        )

    # =========================================================================
    # POST-PROCESSING
    # =========================================================================

    final_props = compositional_postprocess(pred_array, delta=MIN_PROPORTION,
                                             shrinkage=0.10)

    # =========================================================================
    # OUTPUT
    # =========================================================================

    # Sanitize index and columns to ensure strings (not bytes) for R compatibility
    clean_index = [x.decode() if isinstance(x, (bytes, np.bytes_)) else str(x)
                   for x in cell_types]
    clean_columns = [x.decode() if isinstance(x, (bytes, np.bytes_)) else str(x)
                     for x in sample_names]

    # Ensure shape match (safety check)
    if len(clean_columns) != final_props.shape[1]:
        # Fallback: slice if too many names, or generate generic if too few
        if len(clean_columns) > final_props.shape[1]:
            clean_columns = clean_columns[:final_props.shape[1]]
        else:
            clean_columns = [f"sample_{i}" for i in range(final_props.shape[1])]

    result = pd.DataFrame(final_props, index=clean_index, columns=clean_columns)

    return result                                                                                                     