"""
HADACA3 - XGBoost Deconvolution Model
======================================
Modelo de deconvolución celular basado en XGBoost.

Estrategia:
- Multi-output regression: un modelo por tipo celular
- Combina RNA + Metilación como features
- Normalización post-hoc para asegurar suma = 1
- Robusto al ruido por gradient boosting

Ventajas de XGBoost para este problema:
- Maneja alta dimensionalidad (16k+ features)
- Robusto a la colinealidad 
- Regularización L1/L2 incorporada
- Rápido entrenamiento con GPU opcional
"""

import numpy as np
import pandas as pd
import pickle
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from dataclasses import dataclass, field
import warnings

# XGBoost
try:
    import xgboost as xgb
    XGB_AVAILABLE = True
except ImportError:
    XGB_AVAILABLE = False
    warnings.warn("XGBoost no instalado. Instalar con: pip install xgboost")

# Sklearn utilities
from sklearn.model_selection import KFold, cross_val_score
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.preprocessing import StandardScaler
from joblib import Parallel, delayed

# Local imports
from data_loader import (
    get_or_load_dataset, 
    get_training_data,
    DeconvolutionDataset,
    CELL_TYPES
)


# =============================================================================
# CONFIGURACIÓN
# =============================================================================

MODEL_DIR = Path(__file__).parent / "models"
MODEL_DIR.mkdir(exist_ok=True)

# Hiperparámetros por defecto (balanceados para generalizar sin perder precisión)
# Ajustados después de comparar modelo1 vs modelo2
DEFAULT_XGB_PARAMS = {
    'objective': 'reg:squarederror',
    'eval_metric': 'rmse',
    'max_depth': 5,            # Balance entre 4 y 6
    'learning_rate': 0.08,     # Balance entre 0.05 y 0.1
    'n_estimators': 150,       # Balance entre 100 y 200
    'subsample': 0.75,         # Regularización moderada
    'colsample_bytree': 0.7,   # Balance para alta dim.
    'colsample_bylevel': 0.7,  # Regularización adicional
    'reg_alpha': 0.2,          # L1 moderado
    'reg_lambda': 1.5,         # L2 moderado
    'min_child_weight': 4,     # Balance
    'gamma': 0.05,             # Mínima reducción de loss para split
    'random_state': 42,
    'n_jobs': -1,
    'verbosity': 0,
}


# =============================================================================
# DATA CLASSES
# =============================================================================

@dataclass
class DeconvolutionResult:
    """Resultado de deconvolución para una muestra."""
    sample_id: str
    proportions: Dict[str, float]  # {cell_type: proportion}
    raw_proportions: Dict[str, float]  # Sin normalizar
    confidence: Optional[float] = None  # Métrica de confianza


@dataclass 
class ModelMetrics:
    """Métricas de evaluación del modelo."""
    rmse: float
    mae: float
    r2: float
    per_cell_type_rmse: Dict[str, float] = field(default_factory=dict)
    

# =============================================================================
# MODELO DE DECONVOLUCIÓN
# =============================================================================

class XGBDeconvolutionModel:
    """
    Modelo de deconvolución celular usando XGBoost.
    
    Entrena un modelo separado para cada tipo celular (multi-output).
    Las predicciones se normalizan para que sumen 1.
    """
    
    def __init__(
        self,
        cell_types: List[str] = None,
        use_rna: bool = True,
        use_methylation: bool = True,
        xgb_params: Dict = None,
        normalize_features: bool = True,
        balance_features: bool = False,
        max_met_features: Optional[int] = None,
        feature_selection_method: str = 'variance'
    ):
        """
        Inicializa el modelo.
        
        Args:
            cell_types: Lista de tipos celulares a predecir
            use_rna: Si usar datos de RNA
            use_methylation: Si usar datos de metilación
            xgb_params: Parámetros de XGBoost (override defaults)
            normalize_features: Si normalizar features con StandardScaler
            balance_features: Si True, limita MET features para igualar RNA
            max_met_features: Número máximo de features de metilación (None = sin límite)
            feature_selection_method: 'variance' o 'random' para seleccionar features MET
        """
        self.cell_types = cell_types or CELL_TYPES
        self.use_rna = use_rna
        self.use_methylation = use_methylation
        self.normalize_features = normalize_features
        self.balance_features = balance_features
        self.max_met_features = max_met_features
        self.feature_selection_method = feature_selection_method
        
        # Parámetros XGBoost
        self.xgb_params = DEFAULT_XGB_PARAMS.copy()
        if xgb_params:
            self.xgb_params.update(xgb_params)
        
        # Modelos (uno por tipo celular)
        self.models: Dict[str, xgb.XGBRegressor] = {}
        
        # Scalers
        self.rna_scaler: Optional[StandardScaler] = None
        self.met_scaler: Optional[StandardScaler] = None
        
        # Feature names
        self.feature_names: List[str] = []
        
        # Columnas seleccionadas de metilación (para inferencia)
        self.selected_met_columns: Optional[List[str]] = None
        
        # Estado
        self.is_fitted = False
        
    def _select_met_features(
        self,
        X_met: pd.DataFrame,
        n_features: int
    ) -> List[str]:
        """
        Selecciona las features de metilación más informativas.
        
        Args:
            X_met: DataFrame de metilación
            n_features: Número de features a seleccionar
        
        Returns:
            Lista de nombres de columnas seleccionadas
        """
        if self.feature_selection_method == 'variance':
            # Seleccionar por mayor varianza
            variances = X_met.var()
            selected = variances.nlargest(n_features).index.tolist()
        elif self.feature_selection_method == 'random':
            # Selección aleatoria (para comparación)
            np.random.seed(42)
            all_cols = X_met.columns.tolist()
            selected = list(np.random.choice(all_cols, size=min(n_features, len(all_cols)), replace=False))
        else:
            raise ValueError(f"Método de selección no soportado: {self.feature_selection_method}")
        
        return selected
    
    def _prepare_features(
        self,
        X_rna: Optional[pd.DataFrame],
        X_met: Optional[pd.DataFrame],
        fit_scalers: bool = False
    ) -> np.ndarray:
        """
        Prepara y combina features de RNA y metilación.
        Optimizado con preallocación de arrays.
        
        Args:
            X_rna: DataFrame de expresión RNA (samples x genes)
            X_met: DataFrame de metilación (samples x cpg_sites)
            fit_scalers: Si True, ajusta los scalers (solo durante training)
        
        Returns:
            Array combinado de features
        """
        has_rna = self.use_rna and X_rna is not None
        has_met = self.use_methylation and X_met is not None
        
        if not has_rna and not has_met:
            raise ValueError("Debe usar al menos RNA o metilación")
        
        # Determinar número de samples
        n_samples = X_rna.shape[0] if has_rna else X_met.shape[0]
        
        # ---- Selección de features de metilación ----
        if has_met and fit_scalers:
            # Durante entrenamiento: determinar qué features usar
            n_rna = X_rna.shape[1] if has_rna else 0
            
            if self.balance_features and has_rna:
                # Balancear: usar mismo número de features que RNA
                target_met_features = n_rna
                print(f"   ⚖️  Balanceando features: limitando MET a {target_met_features} (igual que RNA)")
            elif self.max_met_features is not None:
                # Límite explícito
                target_met_features = min(self.max_met_features, X_met.shape[1])
                print(f"   📉 Limitando MET a {target_met_features} features")
            else:
                # Sin límite
                target_met_features = X_met.shape[1]
            
            if target_met_features < X_met.shape[1]:
                self.selected_met_columns = self._select_met_features(X_met, target_met_features)
                print(f"   🔍 Selección por {self.feature_selection_method}: {len(self.selected_met_columns)} features MET")
            else:
                self.selected_met_columns = X_met.columns.tolist()
        
        # Aplicar selección de columnas MET
        if has_met and self.selected_met_columns is not None:
            X_met = X_met[self.selected_met_columns]
        
        # Calcular dimensiones finales
        n_rna_features = X_rna.shape[1] if has_rna else 0
        n_met_features = X_met.shape[1] if has_met else 0
        total_features = n_rna_features + n_met_features
        
        # Preallocar array de salida
        result = np.empty((n_samples, total_features), dtype=np.float64)
        
        # Construir nombres de features
        feature_names = []
        col_idx = 0
        
        # Procesar RNA
        if has_rna:
            rna_values = X_rna.values
            
            if self.normalize_features:
                if fit_scalers:
                    self.rna_scaler = StandardScaler()
                    rna_values = self.rna_scaler.fit_transform(rna_values)
                elif self.rna_scaler is not None:
                    rna_values = self.rna_scaler.transform(rna_values)
            
            # Copiar directamente al array preallocado
            result[:, col_idx:col_idx + n_rna_features] = rna_values
            col_idx += n_rna_features
            feature_names.extend([f"rna_{g}" for g in X_rna.columns])
        
        # Procesar Metilación
        if has_met:
            met_values = X_met.values
            
            if self.normalize_features:
                if fit_scalers:
                    self.met_scaler = StandardScaler()
                    met_values = self.met_scaler.fit_transform(met_values)
                elif self.met_scaler is not None:
                    met_values = self.met_scaler.transform(met_values)
            
            # Copiar directamente al array preallocado
            result[:, col_idx:col_idx + n_met_features] = met_values
            feature_names.extend([f"met_{c}" for c in X_met.columns])
        
        self.feature_names = feature_names
        return result
    
    def fit(
        self,
        X_rna: pd.DataFrame,
        X_met: pd.DataFrame,
        y: pd.DataFrame,
        verbose: bool = True,
        n_jobs: int = -1
    ) -> 'XGBDeconvolutionModel':
        """
        Entrena el modelo de deconvolución.
        
        Args:
            X_rna: Expresión RNA (samples x genes)
            X_met: Metilación (samples x cpg_sites)
            y: Proporciones reales (samples x cell_types)
            verbose: Si mostrar progreso
            n_jobs: Número de trabajos paralelos (-1 = todos los cores)
        
        Returns:
            self
        """
        if not XGB_AVAILABLE:
            raise ImportError("XGBoost no instalado. Ejecutar: pip install xgboost")
        
        if verbose:
            print("=" * 60)
            print("🚀 Entrenando modelo XGBoost de deconvolución")
            print("=" * 60)
        
        # Preparar features
        X = self._prepare_features(X_rna, X_met, fit_scalers=True)
        
        # Contar features eficientemente
        n_rna = sum(1 for f in self.feature_names if f.startswith('rna_'))
        n_met = len(self.feature_names) - n_rna
        
        if verbose:
            print(f"📊 Features: {X.shape[1]:,} ({n_rna:,} RNA + {n_met:,} MET)")
            print(f"📊 Muestras de entrenamiento: {X.shape[0]}")
            print(f"📊 Tipos celulares: {self.cell_types}")
        
        # Validar que todos los tipos celulares existen en y
        for cell_type in self.cell_types:
            if cell_type not in y.columns:
                raise ValueError(f"Tipo celular '{cell_type}' no encontrado en y. Columnas: {list(y.columns)}")
        
        # Función helper para entrenar un modelo
        def train_single_model(cell_type: str):
            model = xgb.XGBRegressor(**self.xgb_params)
            model.fit(X, y[cell_type].values)
            train_pred = model.predict(X)
            train_rmse = np.sqrt(mean_squared_error(y[cell_type], train_pred))
            return cell_type, model, train_rmse
        
        # Entrenar modelos en paralelo
        if verbose:
            print(f"\n   🔧 Entrenando {len(self.cell_types)} modelos en paralelo...")
        
        results = Parallel(n_jobs=n_jobs, prefer="threads")(
            delayed(train_single_model)(ct) for ct in self.cell_types
        )
        
        # Almacenar resultados
        for cell_type, model, train_rmse in results:
            self.models[cell_type] = model
            if verbose:
                print(f"      {cell_type}: RMSE (train) = {train_rmse:.4f}")
        
        self.is_fitted = True
        
        if verbose:
            print(f"\n✅ Modelo entrenado exitosamente")
        
        return self
    
    def predict(
        self,
        X_rna: pd.DataFrame,
        X_met: pd.DataFrame,
        normalize: bool = True,
        clip_negative: bool = True
    ) -> pd.DataFrame:
        """
        Predice proporciones celulares.
        
        Args:
            X_rna: Expresión RNA (samples x genes)
            X_met: Metilación (samples x cpg_sites)
            normalize: Si normalizar para que sumen 1
            clip_negative: Si hacer clip de valores negativos a 0
        
        Returns:
            DataFrame con proporciones predichas (samples x cell_types)
        """
        if not self.is_fitted:
            raise ValueError("Modelo no entrenado. Ejecutar fit() primero.")
        
        # Preparar features
        X = self._prepare_features(X_rna, X_met, fit_scalers=False)
        
        # Obtener sample IDs
        if hasattr(X_rna, 'index'):
            sample_ids = list(X_rna.index)
        else:
            sample_ids = [f"sample_{i}" for i in range(X.shape[0])]
        
        # Predecir para cada tipo celular
        predictions = {}
        for cell_type, model in self.models.items():
            pred = model.predict(X)
            predictions[cell_type] = pred
        
        # Crear DataFrame
        pred_df = pd.DataFrame(predictions, index=sample_ids)
        
        # Post-procesamiento
        if clip_negative:
            pred_df = pred_df.clip(lower=0)
        
        if normalize:
            # Normalizar para que cada fila sume 1
            row_sums = pred_df.sum(axis=1)
            pred_df = pred_df.div(row_sums, axis=0)
        
        return pred_df
    
    def evaluate(
        self,
        X_rna: pd.DataFrame,
        X_met: pd.DataFrame,
        y_true: pd.DataFrame,
        verbose: bool = True
    ) -> ModelMetrics:
        """
        Evalúa el modelo contra ground truth.
        
        Args:
            X_rna: Expresión RNA
            X_met: Metilación
            y_true: Proporciones reales
            verbose: Si mostrar métricas
        
        Returns:
            ModelMetrics con métricas de evaluación
        """
        y_pred = self.predict(X_rna, X_met)
        
        # Alinear columnas
        common_cols = [c for c in self.cell_types if c in y_true.columns]
        y_true_aligned = y_true[common_cols]
        y_pred_aligned = y_pred[common_cols]
        
        # Métricas globales
        rmse = np.sqrt(mean_squared_error(y_true_aligned.values.flatten(), y_pred_aligned.values.flatten()))
        mae = mean_absolute_error(y_true_aligned.values.flatten(), y_pred_aligned.values.flatten())
        r2 = r2_score(y_true_aligned.values.flatten(), y_pred_aligned.values.flatten())
        
        # Métricas por tipo celular
        per_cell_rmse = {}
        for cell_type in common_cols:
            cell_rmse = np.sqrt(mean_squared_error(y_true_aligned[cell_type], y_pred_aligned[cell_type]))
            per_cell_rmse[cell_type] = cell_rmse
        
        metrics = ModelMetrics(
            rmse=rmse,
            mae=mae,
            r2=r2,
            per_cell_type_rmse=per_cell_rmse
        )
        
        if verbose:
            print("\n" + "=" * 60)
            print("📊 MÉTRICAS DE EVALUACIÓN")
            print("=" * 60)
            print(f"   RMSE global: {rmse:.4f}")
            print(f"   MAE global:  {mae:.4f}")
            print(f"   R² global:   {r2:.4f}")
            print(f"\n   RMSE por tipo celular:")
            for cell_type, cell_rmse in per_cell_rmse.items():
                print(f"      {cell_type}: {cell_rmse:.4f}")
        
        return metrics
    
    def cross_validate(
        self,
        X_rna: pd.DataFrame,
        X_met: pd.DataFrame,
        y: pd.DataFrame,
        n_folds: int = 5,
        verbose: bool = True,
        n_jobs: int = -1
    ) -> Dict[str, float]:
        """
        Validación cruzada del modelo (paralelizada).
        
        Args:
            X_rna: Expresión RNA
            X_met: Metilación
            y: Proporciones reales
            n_folds: Número de folds
            verbose: Si mostrar resultados
            n_jobs: Número de trabajos paralelos
        
        Returns:
            Dict con métricas promedio de CV
        """
        if verbose:
            print(f"\n🔄 Validación cruzada ({n_folds} folds, paralelo)...")
        
        X = self._prepare_features(X_rna, X_met, fit_scalers=True)
        
        kf = KFold(n_splits=n_folds, shuffle=True, random_state=42)
        folds_list = list(kf.split(X))
        
        # Función helper para procesar un fold
        def process_fold(fold_data):
            fold_idx, (train_idx, val_idx) = fold_data
            X_train, X_val = X[train_idx], X[val_idx]
            y_train, y_val = y.iloc[train_idx], y.iloc[val_idx]
            
            # Entrenar modelos para este fold
            fold_preds = {}
            for cell_type in self.cell_types:
                model = xgb.XGBRegressor(**self.xgb_params)
                model.fit(X_train, y_train[cell_type].values)
                fold_preds[cell_type] = model.predict(X_val)
            
            # Calcular métricas del fold
            pred_df = pd.DataFrame(fold_preds, index=y_val.index)
            pred_df = pred_df.clip(lower=0)
            row_sums = pred_df.sum(axis=1)
            pred_df = pred_df.div(row_sums, axis=0)
            
            fold_rmse = np.sqrt(mean_squared_error(y_val.values.flatten(), pred_df.values.flatten()))
            return fold_idx, fold_rmse
        
        # Ejecutar folds en paralelo
        results = Parallel(n_jobs=n_jobs, prefer="threads")(
            delayed(process_fold)((i, fold)) for i, fold in enumerate(folds_list)
        )
        
        # Ordenar resultados por fold_idx
        results.sort(key=lambda x: x[0])
        fold_metrics = [rmse for _, rmse in results]
        
        if verbose:
            for fold_idx, rmse in results:
                print(f"   Fold {fold_idx + 1}: RMSE = {rmse:.4f}")
        
        avg_rmse = np.mean(fold_metrics)
        std_rmse = np.std(fold_metrics)
        
        if verbose:
            print(f"\n   📊 RMSE promedio: {avg_rmse:.4f} (±{std_rmse:.4f})")
        
        return {
            'mean_rmse': avg_rmse,
            'std_rmse': std_rmse,
            'fold_rmses': fold_metrics
        }
    
    def save(self, filepath: Optional[Path] = None) -> Path:
        """Guarda el modelo entrenado."""
        if filepath is None:
            filepath = MODEL_DIR / "xgb_deconvolution_model.pkl"
        
        print(f"💾 Guardando modelo en: {filepath}")
        with open(filepath, 'wb') as f:
            pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)
        
        return filepath
    
    @classmethod
    def load(cls, filepath: Optional[Path] = None) -> 'XGBDeconvolutionModel':
        """Carga un modelo guardado."""
        if filepath is None:
            filepath = MODEL_DIR / "xgb_deconvolution_model.pkl"
        
        print(f"📂 Cargando modelo desde: {filepath}")
        with open(filepath, 'rb') as f:
            model = pickle.load(f)
        
        return model
    
    def get_feature_importance(self, top_n: int = 20) -> pd.DataFrame:
        """
        Obtiene las features más importantes agregadas de todos los modelos.
        
        Args:
            top_n: Número de features a mostrar
        
        Returns:
            DataFrame con importancias
        """
        if not self.is_fitted:
            raise ValueError("Modelo no entrenado")
        
        # Agregar importancias de todos los modelos
        total_importance = np.zeros(len(self.feature_names))
        
        for model in self.models.values():
            total_importance += model.feature_importances_
        
        # Normalizar
        total_importance /= len(self.models)
        
        # Crear DataFrame
        importance_df = pd.DataFrame({
            'feature': self.feature_names,
            'importance': total_importance
        }).sort_values('importance', ascending=False)
        
        return importance_df.head(top_n)


# =============================================================================
# ANÁLISIS DE DATOS
# =============================================================================

def analyze_class_distribution(
    y: pd.DataFrame,
    verbose: bool = True
) -> Dict[str, Dict[str, float]]:
    """
    Analiza la distribución de proporciones por tipo celular.
    Útil para detectar desbalance o variabilidad alta.
    
    Args:
        y: DataFrame con proporciones (samples x cell_types)
        verbose: Si mostrar análisis
    
    Returns:
        Dict con estadísticas por tipo celular
    """
    stats = {}
    
    if verbose:
        print("\n" + "=" * 60)
        print("📊 ANÁLISIS DE DISTRIBUCIÓN DE PROPORCIONES")
        print("=" * 60)
        print(f"\n   {'Tipo':<10} {'Media':>8} {'Std':>8} {'Min':>8} {'Max':>8} {'CV%':>8}")
        print("   " + "-" * 52)
    
    for cell_type in y.columns:
        values = y[cell_type]
        mean_val = values.mean()
        std_val = values.std()
        min_val = values.min()
        max_val = values.max()
        cv = (std_val / mean_val * 100) if mean_val > 0 else 0
        
        stats[cell_type] = {
            'mean': mean_val,
            'std': std_val,
            'min': min_val,
            'max': max_val,
            'cv_percent': cv
        }
        
        if verbose:
            print(f"   {cell_type:<10} {mean_val:>8.4f} {std_val:>8.4f} {min_val:>8.4f} {max_val:>8.4f} {cv:>7.1f}%")
    
    if verbose:
        # Detectar posibles problemas
        print("\n   🔍 Detección de problemas:")
        for cell_type, s in stats.items():
            issues = []
            if s['cv_percent'] > 100:
                issues.append("alta variabilidad")
            if s['mean'] < 0.05:
                issues.append("baja prevalencia")
            if s['std'] < 0.01:
                issues.append("poca variación")
            
            if issues:
                print(f"      {cell_type}: {', '.join(issues)}")
        
        # Análisis de correlación entre tipos
        print("\n   Correlación entre tipos celulares:")
        corr_matrix = y.corr()
        for i, ct1 in enumerate(y.columns):
            for ct2 in y.columns[i+1:]:
                corr = corr_matrix.loc[ct1, ct2]
                if abs(corr) > 0.5:
                    print(f"      {ct1} <-> {ct2}: {corr:.3f}")
    
    return stats


# =============================================================================
# DATA AUGMENTATION: MIXUP COMPOSICIONAL
# =============================================================================

def mixup_compositional(
    X_rna: np.ndarray,
    X_met: np.ndarray,
    y: np.ndarray,
    n_synthetic: int = 50,
    alpha_range: Tuple[float, float] = (0.2, 0.8),
    seed: int = 42
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Genera muestras sintéticas mediante interpolación (Mixup) vectorizado.
    
    Preserva la propiedad composicional: suma(y) = 1
    
    Args:
        X_rna: Array de RNA (n_samples, n_rna_features)
        X_met: Array de metilación (n_samples, n_met_features)
        y: Array de proporciones (n_samples, n_cell_types)
        n_synthetic: Número de muestras sintéticas a generar
        alpha_range: Rango de interpolación (min, max)
        seed: Semilla para reproducibilidad
    
    Returns:
        Tuple (X_rna_aug, X_met_aug, y_aug) con datos originales + sintéticos
    """
    np.random.seed(seed)
    n_samples = X_rna.shape[0]
    
    # Generar índices aleatorios para pares de muestras (vectorizado)
    idx_a = np.random.randint(0, n_samples, size=n_synthetic)
    idx_b = np.random.randint(0, n_samples, size=n_synthetic)
    
    # Evitar interpolación consigo mismo
    mask = idx_a == idx_b
    idx_b[mask] = (idx_b[mask] + 1) % n_samples
    
    # Generar coeficientes alpha (vectorizado)
    alphas = np.random.uniform(alpha_range[0], alpha_range[1], size=(n_synthetic, 1))
    
    # Interpolación vectorizada para cada modalidad
    X_rna_synth = alphas * X_rna[idx_a] + (1 - alphas) * X_rna[idx_b]
    X_met_synth = alphas * X_met[idx_a] + (1 - alphas) * X_met[idx_b]
    y_synth = alphas * y[idx_a] + (1 - alphas) * y[idx_b]
    
    # Concatenar originales + sintéticos
    X_rna_aug = np.vstack([X_rna, X_rna_synth])
    X_met_aug = np.vstack([X_met, X_met_synth])
    y_aug = np.vstack([y, y_synth])
    
    return X_rna_aug, X_met_aug, y_aug


def mixup_from_dataframes(
    X_rna: pd.DataFrame,
    X_met: pd.DataFrame,
    y: pd.DataFrame,
    n_synthetic: int = 50,
    alpha_range: Tuple[float, float] = (0.2, 0.8),
    seed: int = 42,
    verbose: bool = True
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Wrapper de mixup_compositional que trabaja con DataFrames.
    
    Returns:
        Tuple (X_rna_aug, X_met_aug, y_aug) como DataFrames
    """
    # Convertir a numpy para operaciones vectorizadas
    X_rna_np = X_rna.values
    X_met_np = X_met.values
    y_np = y.values
    
    # Aplicar mixup vectorizado
    X_rna_aug, X_met_aug, y_aug = mixup_compositional(
        X_rna_np, X_met_np, y_np,
        n_synthetic=n_synthetic,
        alpha_range=alpha_range,
        seed=seed
    )
    
    # Generar índices para muestras sintéticas
    original_idx = list(X_rna.index)
    synthetic_idx = [f"synth_{i}" for i in range(n_synthetic)]
    all_idx = original_idx + synthetic_idx
    
    # Reconstruir DataFrames
    X_rna_df = pd.DataFrame(X_rna_aug, index=all_idx, columns=X_rna.columns)
    X_met_df = pd.DataFrame(X_met_aug, index=all_idx, columns=X_met.columns)
    y_df = pd.DataFrame(y_aug, index=all_idx, columns=y.columns)
    
    if verbose:
        print(f"   🔀 Mixup: {len(original_idx)} originales + {n_synthetic} sintéticas = {len(all_idx)} muestras")
    
    return X_rna_df, X_met_df, y_df


# =============================================================================
# ENSEMBLE DE MODELOS
# =============================================================================

def train_ensemble_model(
    X_rna: pd.DataFrame,
    X_met: pd.DataFrame,
    y: pd.DataFrame,
    n_models: int = 5,
    seeds: List[int] = None,
    balance_features: bool = True,
    feature_selection_method: str = 'variance',
    n_jobs: int = -1,
    verbose: bool = True
) -> List['XGBDeconvolutionModel']:
    """
    Entrena un ensemble de modelos con diferentes seeds (paralelizado).
    
    Args:
        X_rna: Datos de RNA
        X_met: Datos de metilación
        y: Proporciones objetivo
        n_models: Número de modelos en el ensemble
        seeds: Lista de seeds (si None, genera automáticamente)
        balance_features: Si balancear RNA/MET
        feature_selection_method: Método de selección de features
        n_jobs: Número de trabajos paralelos
        verbose: Si mostrar progreso
    
    Returns:
        Lista de modelos entrenados
    """
    if seeds is None:
        seeds = [42, 123, 456, 789, 1024][:n_models]
    
    if verbose:
        print(f"   🎲 Entrenando ensemble de {len(seeds)} modelos en paralelo...")
    
    def train_single_model_with_seed(seed: int) -> 'XGBDeconvolutionModel':
        """Helper para entrenar un modelo con seed específico."""
        model = XGBDeconvolutionModel(
            balance_features=balance_features,
            feature_selection_method=feature_selection_method,
            xgb_params={'random_state': seed}
        )
        model.fit(X_rna, X_met, y, verbose=False)
        return model
    
    # Entrenar en paralelo
    models = Parallel(n_jobs=n_jobs, prefer="threads")(
        delayed(train_single_model_with_seed)(seed) for seed in seeds
    )
    
    if verbose:
        print(f"   ✅ Ensemble entrenado: {len(models)} modelos")
    
    return models


def predict_ensemble(
    models: List['XGBDeconvolutionModel'],
    X_rna: pd.DataFrame,
    X_met: pd.DataFrame,
    method: str = 'mean'
) -> pd.DataFrame:
    """
    Predice usando un ensemble de modelos (paralelizado).
    
    Args:
        models: Lista de modelos entrenados
        X_rna: Datos de RNA
        X_met: Datos de metilación
        method: 'mean' o 'median' para agregación
    
    Returns:
        DataFrame con predicciones promediadas
    """
    # Predecir con cada modelo en paralelo
    def predict_single(model):
        return model.predict(X_rna, X_met, normalize=False, clip_negative=True).values
    
    predictions = Parallel(n_jobs=-1, prefer="threads")(
        delayed(predict_single)(m) for m in models
    )
    
    # Stack y agregar (vectorizado con numpy)
    predictions_array = np.stack(predictions, axis=0)  # (n_models, n_samples, n_cell_types)
    
    if method == 'mean':
        aggregated = np.mean(predictions_array, axis=0)
    elif method == 'median':
        aggregated = np.median(predictions_array, axis=0)
    else:
        raise ValueError(f"Método no soportado: {method}")
    
    # Normalizar para que sume 1
    row_sums = aggregated.sum(axis=1, keepdims=True)
    aggregated = aggregated / row_sums
    
    # Reconstruir DataFrame
    sample_ids = list(X_rna.index) if hasattr(X_rna, 'index') else [f"sample_{i}" for i in range(aggregated.shape[0])]
    cell_types = models[0].cell_types
    
    return pd.DataFrame(aggregated, index=sample_ids, columns=cell_types)


# =============================================================================
# FUNCIONES DE CONVENIENCIA
# =============================================================================

def train_and_evaluate(
    dataset: DeconvolutionDataset = None,
    mix_names: List[str] = None,
    use_rna: bool = True,
    use_methylation: bool = True,
    test_size: float = 0.2,
    verbose: bool = True,
    balance_features: bool = False,
    max_met_features: Optional[int] = None,
    feature_selection_method: str = 'variance'
) -> Tuple[XGBDeconvolutionModel, ModelMetrics]:
    """
    Entrena y evalúa un modelo de deconvolución.
    
    Args:
        dataset: Dataset cargado (si None, lo carga)
        mix_names: Nombres de mixes a usar (si None, usa todos)
        use_rna: Si usar RNA
        use_methylation: Si usar metilación
        test_size: Proporción para test
        verbose: Si mostrar progreso
        balance_features: Si True, limita MET features para igualar RNA
        max_met_features: Número máximo de features MET
        feature_selection_method: 'variance' o 'random'
    
    Returns:
        Tuple (modelo entrenado, métricas)
    """
    # Cargar dataset si no se proporciona
    if dataset is None:
        dataset = get_or_load_dataset()
    
    # Usar todos los mixes si no se especifican
    if mix_names is None:
        mix_names = list(dataset.mixes.keys())
    
    # Combinar datos de todos los mixes
    all_rna = []
    all_met = []
    all_y = []
    
    for mix_name in mix_names:
        X_rna, X_met, y = get_training_data(dataset, mix_name)
        all_rna.append(X_rna)
        all_met.append(X_met)
        all_y.append(y)
    
    X_rna_combined = pd.concat(all_rna, axis=0)
    X_met_combined = pd.concat(all_met, axis=0)
    y_combined = pd.concat(all_y, axis=0)
    
    if verbose:
        print(f"📊 Dataset combinado: {X_rna_combined.shape[0]} muestras")
    
    # Split train/test
    n_samples = len(y_combined)
    n_test = int(n_samples * test_size)
    
    # Shuffle indices
    np.random.seed(42)
    indices = np.random.permutation(n_samples)
    train_idx, test_idx = indices[n_test:], indices[:n_test]
    
    X_rna_train = X_rna_combined.iloc[train_idx]
    X_rna_test = X_rna_combined.iloc[test_idx]
    X_met_train = X_met_combined.iloc[train_idx]
    X_met_test = X_met_combined.iloc[test_idx]
    y_train = y_combined.iloc[train_idx]
    y_test = y_combined.iloc[test_idx]
    
    # Crear y entrenar modelo
    model = XGBDeconvolutionModel(
        use_rna=use_rna,
        use_methylation=use_methylation,
        balance_features=balance_features,
        max_met_features=max_met_features,
        feature_selection_method=feature_selection_method
    )
    
    model.fit(X_rna_train, X_met_train, y_train, verbose=verbose)
    
    # Evaluar
    metrics = model.evaluate(X_rna_test, X_met_test, y_test, verbose=verbose)
    
    return model, metrics


# =============================================================================
# MAIN
# =============================================================================

if __name__ == "__main__":
    # Iniciar timer
    start_time = time.time()
    
    print("=" * 60)
    print("HADACA3 - XGBoost Deconvolution Model (Modelo 5)")
    print("Mixup Composicional + Ensemble Paralelo")
    print("=" * 60)
    
    # Cargar dataset (desde cache si existe)
    dataset = get_or_load_dataset()
    
    # Combinar todos los datos para análisis
    all_rna = []
    all_met = []
    all_y = []
    
    for mix_name in dataset.mixes.keys():
        X_rna, X_met, y = get_training_data(dataset, mix_name)
        all_rna.append(X_rna)
        all_met.append(X_met)
        all_y.append(y)
    
    X_rna_all = pd.concat(all_rna, axis=0)
    X_met_all = pd.concat(all_met, axis=0)
    y_all = pd.concat(all_y, axis=0)
    
    print(f"\n📊 Dataset original: {len(y_all)} muestras")
    
    # Análisis de distribución de clases
    class_stats = analyze_class_distribution(y_all, verbose=True)
    
    # ================================================================
    # MODELO 5: Mixup + Ensemble
    # ================================================================
    print("\n" + "=" * 60)
    print("🚀 MODELO 5: Mixup Composicional + Ensemble")
    print("=" * 60)
    
    # ---- Parámetros del experimento ----
    N_SYNTHETIC = 75  # Duplicar el dataset
    N_ENSEMBLE = 5    # 5 modelos
    ENSEMBLE_SEEDS = [42, 123, 456, 789, 1024]
    
    # ---- 1. Split train/test primero (antes del mixup) ----
    print("\n📊 Paso 1: Split train/test")
    n_original = len(y_all)
    n_test = int(n_original * 0.2)
    
    np.random.seed(42)
    original_indices = np.random.permutation(n_original)
    test_idx = original_indices[:n_test]
    train_idx_original = original_indices[n_test:]
    
    # Usar iloc para evitar problemas con índices duplicados
    # Train: originales (sin test) 
    X_rna_train_orig = X_rna_all.iloc[train_idx_original].reset_index(drop=True)
    X_met_train_orig = X_met_all.iloc[train_idx_original].reset_index(drop=True)
    y_train_orig = y_all.iloc[train_idx_original].reset_index(drop=True)
    
    # Test: solo muestras originales
    X_rna_test = X_rna_all.iloc[test_idx].reset_index(drop=True)
    X_met_test = X_met_all.iloc[test_idx].reset_index(drop=True)
    y_test = y_all.iloc[test_idx].reset_index(drop=True)
    
    # ---- Paso 2. Aplicar Mixup a datos de entrenamiento ----
    print("\n📐 Paso 2: Data Augmentation (Mixup)")
    X_rna_train, X_met_train, y_train = mixup_from_dataframes(
        X_rna_train_orig, X_met_train_orig, y_train_orig,
        n_synthetic=N_SYNTHETIC,
        alpha_range=(0.25, 0.75),
        seed=42,
        verbose=True
    )
    
    print(f"   Train final: {len(y_train)} muestras ({len(y_train_orig)} orig + {N_SYNTHETIC} synth)")
    print(f"   Test: {len(y_test)} muestras (solo originales)")
    
    # ---- Paso 3. Entrenar Ensemble ----
    print("\n🎲 Paso 3: Entrenando Ensemble")
    ensemble_models = train_ensemble_model(
        X_rna_train, X_met_train, y_train,
        n_models=N_ENSEMBLE,
        seeds=ENSEMBLE_SEEDS,
        balance_features=True,
        feature_selection_method='variance',
        verbose=True
    )
    
    # ---- Paso 4. Predecir con Ensemble ----
    print("\n🔮 Paso 4: Predicción con Ensemble")
    y_pred = predict_ensemble(ensemble_models, X_rna_test, X_met_test, method='mean')
    
    # ---- Paso 5. Evaluar ----
    print("\n" + "=" * 60)
    print("📊 MÉTRICAS DE EVALUACIÓN (Modelo 5)")
    print("=" * 60)
    
    # Métricas globales
    rmse = np.sqrt(mean_squared_error(y_test.values.flatten(), y_pred.values.flatten()))
    mae = mean_absolute_error(y_test.values.flatten(), y_pred.values.flatten())
    r2 = r2_score(y_test.values.flatten(), y_pred.values.flatten())
    
    print(f"   RMSE global: {rmse:.4f}")
    print(f"   MAE global:  {mae:.4f}")
    print(f"   R² global:   {r2:.4f}")
    
    # Métricas por tipo celular
    print(f"\n   RMSE por tipo celular:")
    for cell_type in y_test.columns:
        cell_rmse = np.sqrt(mean_squared_error(y_test[cell_type], y_pred[cell_type]))
        print(f"      {cell_type}: {cell_rmse:.4f}")
    
    # ---- Paso 6. Correlaciones (importante para scoring) ----
    print("\n" + "=" * 60)
    print("📈 CORRELACIONES (1/3 del score)")
    print("=" * 60)
    
    from scipy.stats import pearsonr, spearmanr
    
    # Correlación total (matriz aplanada)
    pearson_total, _ = pearsonr(y_test.values.flatten(), y_pred.values.flatten())
    spearman_total, _ = spearmanr(y_test.values.flatten(), y_pred.values.flatten())
    print(f"   Total matrix - Pearson: {pearson_total:.4f}, Spearman: {spearman_total:.4f}")
    
    # Correlación por muestras (columnas)
    sample_correlations = []
    for i in range(len(y_test)):
        corr, _ = pearsonr(y_test.iloc[i], y_pred.iloc[i])
        sample_correlations.append(corr)
    print(f"   Por muestras - Pearson medio: {np.mean(sample_correlations):.4f} (±{np.std(sample_correlations):.4f})")
    
    # Correlación por tipo celular (filas)
    cell_correlations = []
    for cell_type in y_test.columns:
        corr, _ = pearsonr(y_test[cell_type], y_pred[cell_type])
        cell_correlations.append(corr)
        print(f"      {cell_type}: {corr:.4f}")
    print(f"   Por cell types - Pearson medio: {np.mean(cell_correlations):.4f}")
    
    # ---- 7. Feature importance del primer modelo del ensemble ----
    print("\n" + "=" * 60)
    print("🔍 Top 20 Features más importantes (modelo base)")
    print("=" * 60)
    importance = ensemble_models[0].get_feature_importance(top_n=20)
    print(importance.to_string())
    
    # Contar features RNA vs MET en top 20
    top20 = importance['feature'].tolist()
    n_rna_top = sum(1 for f in top20 if f.startswith('rna_'))
    n_met_top = sum(1 for f in top20 if f.startswith('met_'))
    print(f"\n   📊 Distribución Top 20: {n_rna_top} RNA, {n_met_top} MET")
    
    # ---- 8. Guardar ensemble ----
    print("\n💾 Guardando ensemble...")
    ensemble_path = MODEL_DIR / "xgb_ensemble_v5.pkl"
    with open(ensemble_path, 'wb') as f:
        pickle.dump(ensemble_models, f, protocol=pickle.HIGHEST_PROTOCOL)
    print(f"   Guardado en: {ensemble_path}")
    
    # Tiempo de ejecución
    elapsed_time = time.time() - start_time
    minutes = int(elapsed_time // 60)
    seconds = elapsed_time % 60
    
    print("\n" + "=" * 60)
    print(f"✅ Proceso completado")
    print(f"⏱️  Tiempo de ejecución: {minutes}m {seconds:.2f}s")
    print("=" * 60)

