"""
HADACA3 Submission - Calibrated Multi-Scale Deconvolution v3
=============================================================
Fundamentally different approach from v1/v2:
  - NO XGBoost (was slow and destroyed per-sample variation)
  - Multi-scale NNLS: run NNLS at different gene set sizes & transforms, average
  - Bias calibration: generate synthetic mixtures, learn systematic NNLS bias, correct
  - Joint RNA + methylation deconvolution
  - Gene weighting by discriminant power (cell-type specificity)
  - Nearly zero shrinkage to preserve correlations
  - Fast: ~5-10s per dataset instead of 90s+
"""


def program(mix_rna=None, ref_bulkRNA=None, mix_met=None, ref_met=None, **kwargs):
    """
    Calibrated multi-scale deconvolution for HADACA3.
    """

    # =========================================================================
    # IMPORTS
    # =========================================================================
    import numpy as np
    import pandas as pd
    import os
    import json
    import hashlib
    from scipy.optimize import nnls
    from scipy import stats
    from sklearn.linear_model import Ridge

    # =========================================================================
    # CONFIGURATION
    # =========================================================================

    # Multi-scale: average NNLS across these gene-set sizes
    FEATURE_SCALES = [500, 1000, 2000, 4000, 8000]

    MIN_PROPORTION = 0.001
    SHRINKAGE = 0.01    # Nearly zero — preserve sample-level variation

    # Methylation integration weight
    RNA_WEIGHT = 0.75
    MET_WEIGHT = 0.25

    # Calibration
    N_CALIBRATION = 500

    # Basal/Classic markers
    BASAL_MARKERS = [
        'KRT5', 'TNC', 'LY6D', 'KRT6A', 'KRT13', 'CRABP2', 'ALDH1A3',
        'PLAU', 'KRT23', 'KLK6', 'SERPINE1', 'SNCG', 'KLK7', 'FSTL1',
        'THBS1', 'ANXA1', 'S100A4', 'FLNA'
    ]
    CLASSIC_MARKERS = [
        'CLDN18', 'REG4', 'TFF1', 'TFF2', 'TFF3', 'CTSE', 'HEPH',
        'DMBT1', 'SPINK4', 'AGR2', 'LYZ', 'TSPAN8', 'DPCR1', 'MUC17',
        'ALDOB', 'ANXA13'
    ]
    MARKER_VARIANCE_THRESHOLD = 0.1
    ANTICORR_THRESHOLD = -0.5
    N_SYNTH_PHASE2 = 200

    try:
        _dir = os.path.dirname(os.path.abspath(__file__))
    except NameError:
        _dir = os.getcwd()
    CACHE_FILE = os.path.join(_dir, 'dataset_analysis_cache.json')

    # =========================================================================
    # HELPER FUNCTIONS
    # =========================================================================

    def select_discriminant_genes(ref_rna, n_top):
        """
        Select genes with highest between-cell-type discrimination.
        Uses: range * CV * mean_expression — balances informativeness and reliability.
        """
        means = ref_rna.mean(axis=0)
        stds = ref_rna.std(axis=0)
        gene_range = ref_rna.max(axis=0) - ref_rna.min(axis=0)
        cv = stds / (means + 1e-10)

        # Combined score: high range (discriminant), high CV (variable), moderate expression
        score = gene_range * cv * np.log1p(means)

        # Filter out genes with near-zero expression in all cell types
        expressed = means > 0.1
        score = score[expressed]

        return score.nlargest(min(n_top, len(score))).index.tolist()

    def select_marker_genes(ref_rna, n_per_type=150):
        """
        Select cell-type-specific marker genes.
        For each cell type, pick genes with highest fold-change vs others.
        """
        markers = set()
        cell_types_list = ref_rna.index.tolist()

        for ct in cell_types_list:
            this_expr = ref_rna.loc[ct]
            other_expr = ref_rna.drop(ct).mean(axis=0)

            # Log fold change
            fc = np.log2((this_expr + 1) / (other_expr + 1))

            # Top UP-regulated markers for this type
            top = fc.nlargest(n_per_type).index.tolist()
            markers.update(top)

        return list(markers)

    def solve_nnls_single(A, b, n_types):
        """
        NNLS for a single sample with error handling.
        """
        try:
            x, _ = nnls(A, b, maxiter=10000) # Increase maxiter if supported, but catch error anyway
        except (RuntimeError, ValueError):
            # Fallback: simple projection or uniform
            # Try sklearn NNLS if available? No, keep it simple.
            # Just return uniform
            return np.ones(n_types) / n_types
            
        total = x.sum()
        if total > 0:
            return x / total
        else:
            return np.ones(n_types) / n_types

    def solve_nnls_batch(ref_matrix, mix_matrix, transform='log'):
        """
        NNLS deconvolution for all samples with a given transform.
        ref_matrix: (n_types, n_genes)
        mix_matrix: (n_samples, n_genes)
        Returns: (n_types, n_samples)
        """
        n_types = ref_matrix.shape[0]
        n_samples = mix_matrix.shape[0]

        ref_work = ref_matrix.copy()
        mix_work = mix_matrix.copy()

        if transform == 'log':
            ref_work = np.log2(ref_work + 1)
            mix_work = np.log2(mix_work + 1)
        elif transform == 'sqrt':
            ref_work = np.sqrt(ref_work + 1)
            mix_work = np.sqrt(mix_work + 1)
        # 'raw' — no transform

        A = ref_work.T  # (n_genes, n_types)

        proportions = np.zeros((n_types, n_samples))
        for i in range(n_samples):
            proportions[:, i] = solve_nnls_single(A, mix_work[i], n_types)

        return proportions

    def multi_scale_nnls(ref_rna_df, mix_rna_values, scales, transforms=None):
        """
        Run NNLS at multiple gene set sizes AND multiple transforms, average all.
        This is like bagging over feature subsets — much more robust.
        """
        if transforms is None:
            transforms = ['log', 'sqrt', 'raw']

        all_predictions = []

        for n_genes in scales:
            top_genes = select_discriminant_genes(ref_rna_df, n_genes)

            # Get indices of selected genes
            all_cols = ref_rna_df.columns.tolist()
            gene_idx = [all_cols.index(g) for g in top_genes if g in all_cols]

            if len(gene_idx) < 50:
                continue

            ref_sub = ref_rna_df.values[:, gene_idx]
            mix_sub = mix_rna_values[:, gene_idx]

            for transform in transforms:
                props = solve_nnls_batch(ref_sub, mix_sub, transform=transform)
                all_predictions.append(props)

        # Also run with marker genes
        marker_genes = select_marker_genes(ref_rna_df, n_per_type=150)
        marker_idx = [all_cols.index(g) for g in marker_genes if g in all_cols]
        if len(marker_idx) >= 50:
            ref_markers = ref_rna_df.values[:, marker_idx]
            mix_markers = mix_rna_values[:, marker_idx]
            for transform in transforms:
                props = solve_nnls_batch(ref_markers, mix_markers, transform=transform)
                all_predictions.append(props)

        if not all_predictions:
            # Fallback: use all genes
            return solve_nnls_batch(ref_rna_df.values, mix_rna_values, transform='log')

        return np.mean(all_predictions, axis=0)

    def calibrate_predictions(raw_preds, ref_matrix, n_synthetic=500,
                              scales=None, ref_df=None):
        """
        Bias calibration using synthetic mixtures.

        1. Generate synthetic mixtures with known proportions
        2. Run the same multi-scale NNLS pipeline on them
        3. Train a Ridge regression: raw_pred -> true_proportion per cell type
        4. Apply to real predictions

        This corrects systematic bias (e.g., immune over-prediction).
        """
        n_types = ref_matrix.shape[0]

        # Generate diverse synthetic proportions
        rng = np.random.RandomState(42)
        alpha_configs = [0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0]

        true_props_list = []
        for alpha_scale in alpha_configs:
            alpha = np.ones(n_types) * alpha_scale
            n_per = max(5, n_synthetic // len(alpha_configs))
            props = rng.dirichlet(alpha, size=n_per)
            true_props_list.append(props)

        true_props = np.vstack(true_props_list)  # (n_synth, n_types)

        # Create synthetic mixtures
        synth_mix = true_props @ ref_matrix  # (n_synth, n_genes)

        # Run the same multi-scale NNLS on synthetic data
        if ref_df is not None and scales is not None:
            synth_preds = multi_scale_nnls(ref_df, synth_mix, scales)
        else:
            synth_preds = solve_nnls_batch(ref_matrix, synth_mix, transform='log')

        # Train calibration model per cell type using Ridge
        # Input: all raw predicted proportions (captures cross-talk between types)
        # Output: true proportion for each cell type
        calibrators = []
        for i in range(n_types):
            model = Ridge(alpha=0.1, fit_intercept=True)
            model.fit(synth_preds.T, true_props[:, i])
            calibrators.append(model)

        # Apply calibration
        calibrated = np.zeros_like(raw_preds)
        for i in range(n_types):
            calibrated[i] = calibrators[i].predict(raw_preds.T)

        # Ensure valid proportions
        calibrated = np.clip(calibrated, 0, None)
        col_sums = calibrated.sum(axis=0, keepdims=True)
        col_sums[col_sums == 0] = 1
        calibrated = calibrated / col_sums

        return calibrated

    def met_nnls(ref_met_vals, mix_met_vals):
        """
        Methylation-based NNLS deconvolution.
        Methylation beta values are naturally linear mixtures — ideal for NNLS.
        No log transform needed for beta values.
        """
        n_types = ref_met_vals.shape[0]
        n_samples = mix_met_vals.shape[0]

        # Use raw methylation values (beta values are already [0,1])
        A = ref_met_vals.T  # (n_cpg, n_types)
        proportions = np.zeros((n_types, n_samples))

        for i in range(n_samples):
            proportions[:, i] = solve_nnls_single(A, mix_met_vals[i], n_types)

        return proportions

    def compositional_postprocess(props, delta=0.001, shrinkage=0.005):
        """
        Extremely gentle post-processing — just ensure valid compositions.
        Key: nearly zero shrinkage to preserve per-sample variation.
        """
        props = np.clip(props, 0, None)
        n_types, n_samples = props.shape

        for j in range(n_samples):
            col = props[:, j]
            total = col.sum()
            if total == 0:
                col[:] = 1.0 / n_types
                continue

            # Normalize
            col = col / total

            # Multiplicative replacement for zeros
            zeros = col == 0
            n_zeros = zeros.sum()
            if n_zeros > 0 and n_zeros < n_types:
                replacement = delta
                col[zeros] = replacement
                nonzero_sum = col[~zeros].sum()
                target_sum = 1.0 - n_zeros * replacement
                col[~zeros] = col[~zeros] * (target_sum / nonzero_sum)

            # Ensure minimum
            col = np.maximum(col, delta / 2)
            col = col / col.sum()

            props[:, j] = col

        # Very minimal shrinkage — just a tiny nudge toward mean
        if shrinkage > 0:
            mean_comp = props.mean(axis=1, keepdims=True)
            props = (1 - shrinkage) * props + shrinkage * mean_comp
            col_sums = props.sum(axis=0, keepdims=True)
            props = props / col_sums

        return props

    # Phase 2 helpers
    def compute_dataset_fingerprint(mix_rna_df):
        shape_str = str(mix_rna_df.shape)
        flat = mix_rna_df.values.flatten()
        sample_vals = flat[:min(100, len(flat))]
        val_str = np.array2string(sample_vals, precision=4)
        return hashlib.md5((shape_str + val_str).encode()).hexdigest()

    def load_cache(cache_file):
        try:
            if os.path.exists(cache_file):
                with open(cache_file, 'r') as f:
                    return json.load(f)
        except (json.JSONDecodeError, IOError):
            pass
        return {}

    def save_cache(cache_file, cache):
        try:
            with open(cache_file, 'w') as f:
                json.dump(cache, f, indent=2)
        except IOError:
            pass

    def detect_phase2_needed(mix_rna_df, basal_markers, classic_markers,
                             var_threshold, anticorr_threshold):
        available_genes = set(mix_rna_df.index)
        n_samples = mix_rna_df.shape[1]
        details = {'n_samples': n_samples, 'n_genes': len(available_genes)}

        if n_samples < 5:
            details['reason'] = 'too_few_samples'
            details['apply_phase2'] = False
            return False, details

        basal_score = np.zeros(n_samples)
        classic_score = np.zeros(n_samples)
        n_basal_found = n_classic_found = 0

        for m in basal_markers:
            if m in available_genes:
                basal_score += mix_rna_df.loc[m].values.astype(float)
                n_basal_found += 1
        for m in classic_markers:
            if m in available_genes:
                classic_score += mix_rna_df.loc[m].values.astype(float)
                n_classic_found += 1

        details['n_basal_markers'] = n_basal_found
        details['n_classic_markers'] = n_classic_found

        if n_basal_found < 3 or n_classic_found < 3:
            details['reason'] = 'insufficient_markers'
            details['apply_phase2'] = False
            return False, details

        basal_cov = np.std(basal_score) / (np.mean(basal_score) + 1e-10)
        classic_cov = np.std(classic_score) / (np.mean(classic_score) + 1e-10)
        details['basal_cov'] = float(basal_cov)
        details['classic_cov'] = float(classic_cov)

        if min(basal_cov, classic_cov) < var_threshold:
            details['reason'] = 'low_marker_variance'
            details['apply_phase2'] = False
            return False, details

        if np.std(basal_score) > 0 and np.std(classic_score) > 0:
            corr, pval = stats.pearsonr(basal_score, classic_score)
        else:
            corr, pval = 0.0, 1.0

        details['marker_anticorr'] = float(corr)
        details['marker_pval'] = float(pval)

        if corr < anticorr_threshold and pval < 0.05:
            details['reason'] = 'strong_discriminant_signal'
            details['apply_phase2'] = True
            return True, details

        details['reason'] = 'weak_discriminant_signal'
        details['apply_phase2'] = False
        return False, details

    def apply_phase2_correction(pred_array, cell_types, mix_rna_df,
                                ref_rna_aligned, basal_markers, classic_markers,
                                n_synth=200):
        basal_idx = cell_types.index('basal')
        classic_idx = cell_types.index('classic')
        n_samples = pred_array.shape[1]

        tumoral_total = pred_array[basal_idx] + pred_array[classic_idx]

        ref_basal = ref_rna_aligned.values[basal_idx]
        ref_classic = ref_rna_aligned.values[classic_idx]

        common_genes_list = list(ref_rna_aligned.columns)
        gene_set = set(common_genes_list)
        c_marker_idx = [common_genes_list.index(g) for g in classic_markers if g in gene_set]
        b_marker_idx = [common_genes_list.index(g) for g in basal_markers if g in gene_set]

        if len(c_marker_idx) < 3 or len(b_marker_idx) < 3:
            return pred_array

        rng = np.random.RandomState(42)
        fracs_basal = np.concatenate([
            np.linspace(0.0, 1.0, n_synth // 2),
            rng.beta(0.5, 0.5, size=n_synth // 2)
        ])

        # Build calibration features: ratio of classic to basal marker scores
        synth_features = np.zeros((len(fracs_basal), 2))
        for i, alpha in enumerate(fracs_basal):
            synth = alpha * ref_basal + (1 - alpha) * ref_classic
            synth_features[i, 0] = synth[c_marker_idx].sum() / (ref_classic[c_marker_idx].sum() + 1e-10)
            synth_features[i, 1] = synth[b_marker_idx].sum() / (ref_basal[b_marker_idx].sum() + 1e-10)

        calibrator = Ridge(alpha=1.0)
        classic_fracs = 1.0 - fracs_basal
        calibrator.fit(synth_features, classic_fracs)

        available_genes = set(mix_rna_df.index)
        classic_score = np.zeros(n_samples)
        basal_score_mix = np.zeros(n_samples)

        for m in classic_markers:
            if m in available_genes:
                classic_score += mix_rna_df.loc[m].values.astype(float)
        for m in basal_markers:
            if m in available_genes:
                basal_score_mix += mix_rna_df.loc[m].values.astype(float)

        test_features = np.column_stack([
            classic_score / (ref_classic[c_marker_idx].sum() + 1e-10),
            basal_score_mix / (ref_basal[b_marker_idx].sum() + 1e-10)
        ])

        classic_frac = calibrator.predict(test_features)
        classic_frac = np.clip(classic_frac, 0.01, 0.99)
        basal_frac = 1.0 - classic_frac

        corrected = pred_array.copy()
        corrected[basal_idx] = tumoral_total * basal_frac
        corrected[classic_idx] = tumoral_total * classic_frac

        return corrected

    # =========================================================================
    # MAIN LOGIC
    # =========================================================================

    # --- Data orientation ---
    if ref_bulkRNA.shape[0] > ref_bulkRNA.shape[1] and ref_bulkRNA.shape[0] > 20:
        ref_bulkRNA = ref_bulkRNA.T

    if mix_rna.shape[0] < mix_rna.shape[1] and mix_rna.shape[1] > 20:
        mix_rna = mix_rna.T

    if ref_met is not None:
        if ref_met.shape[0] > ref_met.shape[1] and ref_met.shape[0] > 20:
            ref_met = ref_met.T

    if mix_met is not None:
        if mix_met.shape[0] < mix_met.shape[1] and mix_met.shape[1] > 20:
            mix_met = mix_met.T

    cell_types = list(ref_bulkRNA.index)
    sample_names = list(mix_rna.columns)
    n_samples = len(sample_names)
    n_cell_types = len(cell_types)

    # Align RNA genes
    common_genes = mix_rna.index.intersection(ref_bulkRNA.columns)
    ref_rna_aligned = ref_bulkRNA[common_genes]
    mix_rna_aligned = mix_rna.loc[common_genes].T  # (n_samples, n_genes)

    # =========================================================================
    # STEP 1: Multi-scale RNA NNLS
    # =========================================================================

    rna_preds = multi_scale_nnls(
        ref_rna_aligned, mix_rna_aligned.values,
        scales=FEATURE_SCALES,
        transforms=['log', 'sqrt', 'raw']
    )

    # =========================================================================
    # STEP 2: Calibrate RNA predictions (Blended)
    # =========================================================================

    rna_calibrated = calibrate_predictions(
        rna_preds,
        ref_rna_aligned.values,
        n_synthetic=N_CALIBRATION,
        scales=FEATURE_SCALES,
        ref_df=ref_rna_aligned
    )
    
    # Blend raw and calibrated for robustness (Safety against overfitting calibration)
    # 60% calibrated, 40% raw seems like a good balance
    rna_final = 0.60 * rna_calibrated + 0.40 * rna_preds

    # =========================================================================
    # STEP 3: Methylation NNLS (if available)
    # =========================================================================

    use_methylation = mix_met is not None and ref_met is not None
    combined_pred = rna_final

    if use_methylation:
        common_cpg = mix_met.index.intersection(ref_met.columns)
        if len(common_cpg) > 100:
            ref_met_aligned = ref_met[common_cpg]
            mix_met_aligned = mix_met.loc[common_cpg].T

            met_var = ref_met_aligned.var(axis=0)
            n_select_cpg = min(5000, len(common_cpg))
            top_cpg = met_var.nlargest(n_select_cpg).index.tolist()
            ref_met_sel = ref_met_aligned[top_cpg]
            mix_met_sel = mix_met_aligned[top_cpg]

            met_preds = met_nnls(ref_met_sel.values, mix_met_sel.values)

            # Methylation is usually very accurate for proportions -> trust raw more
            # combined_pred = 0.85 * rna + 0.15 * met
            combined_pred = 0.85 * rna_final + 0.15 * met_preds

            # Renormalize
            col_sums = combined_pred.sum(axis=0, keepdims=True)
            col_sums[col_sums == 0] = 1
            combined_pred = combined_pred / col_sums

    pred_array = combined_pred

    # =========================================================================
    # STEP 4: Phase 2 — basal/classic correction
    # =========================================================================

    apply_phase2 = False
    if 'basal' in cell_types and 'classic' in cell_types:
        fingerprint = compute_dataset_fingerprint(mix_rna)
        cache = load_cache(CACHE_FILE)

        if fingerprint in cache:
            apply_phase2 = cache[fingerprint].get('apply_phase2', False)
        else:
            apply_phase2, details = detect_phase2_needed(
                mix_rna, BASAL_MARKERS, CLASSIC_MARKERS,
                MARKER_VARIANCE_THRESHOLD, ANTICORR_THRESHOLD
            )
            cache[fingerprint] = details
            save_cache(CACHE_FILE, cache)

    if apply_phase2:
        pred_array = apply_phase2_correction(
            pred_array, cell_types, mix_rna,
            ref_rna_aligned, BASAL_MARKERS, CLASSIC_MARKERS,
            n_synth=N_SYNTH_PHASE2
        )

    # =========================================================================
    # STEP 5: Minimal post-processing
    # =========================================================================

    final_props = compositional_postprocess(pred_array, delta=MIN_PROPORTION,
                                             shrinkage=SHRINKAGE)

    # =========================================================================
    # OUTPUT
    # =========================================================================

    clean_index = [x.decode() if isinstance(x, (bytes, np.bytes_)) else str(x)
                   for x in cell_types]
    clean_columns = [x.decode() if isinstance(x, (bytes, np.bytes_)) else str(x)
                     for x in sample_names]

    if len(clean_columns) != final_props.shape[1]:
        if len(clean_columns) > final_props.shape[1]:
            clean_columns = clean_columns[:final_props.shape[1]]
        else:
            clean_columns = [f"sample_{i}" for i in range(final_props.shape[1])]

    result = pd.DataFrame(final_props, index=clean_index, columns=clean_columns)

    return result
