"""
HADACA3 - Data Loader & Normalizer
===================================
Módulo para cargar y normalizar los datos HDF5 para deconvolución celular.

Tipos de datos:
- Reference: Perfiles de expresión por tipo celular (firmas)
- Mixes: Datos bulk a deconvolucionar  
- Ground Truth: Proporciones reales (para entrenamiento/validación)

Nota: Se excluyen los datos 'invivo' que solo tienen 2 tipos celulares 
y valores atípicos que podrían afectar el modelo.
"""

import h5py
import numpy as np
import pandas as pd
import pickle
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple


# =============================================================================
# CONFIGURACIÓN
# =============================================================================

DATA_DIR = Path(__file__).parent / "data"

# Tipos celulares esperados (5 tipos)
CELL_TYPES = ['endo', 'fibro', 'immune', 'classic', 'basal']

# Archivos a excluir por tener diferente número de tipos celulares
EXCLUDED_FILES = [
    # Solo 2 tipos celulares (basal, classic)
    'groundtruth1_invivo_pdac.h5', 
    'mixes1_invivo_pdac.h5',
    # Solo 4 tipos celulares (sin basal)
    'groundtruth1_insilicodirichletNoDep4CTsource_pdac.h5',
    'mixes1_insilicodirichletNoDep4CTsource_pdac.h5',
]


# =============================================================================
# DATA CLASSES
# =============================================================================

@dataclass
class ReferenceData:
    """Datos de referencia (firmas de expresión por tipo celular)."""
    bulk_rna: pd.DataFrame  # (n_cell_types x n_genes) - Expresión bulk RNA
    methylation: pd.DataFrame  # (n_cell_types x n_cpg_sites) - Metilación
    cell_types: List[str]
    genes: List[str]
    cpg_sites: List[str]
    

@dataclass
class MixData:
    """Datos de mezcla a deconvolucionar."""
    rna: pd.DataFrame  # (n_samples x n_genes)
    methylation: pd.DataFrame  # (n_samples x n_cpg_sites)
    sample_ids: List[str]
    source: str  # Nombre del archivo origen
    

@dataclass
class GroundTruth:
    """Proporciones reales de tipos celulares."""
    proportions: pd.DataFrame  # (n_samples x n_cell_types)
    sample_ids: Optional[List[str]]
    source: str


@dataclass
class DeconvolutionDataset:
    """Dataset completo para deconvolución."""
    reference: ReferenceData
    mixes: Dict[str, MixData]
    ground_truths: Dict[str, GroundTruth]
    

# =============================================================================
# FUNCIONES DE CARGA
# =============================================================================

def _decode_strings(arr: np.ndarray) -> List[str]:
    """Decodifica arrays de bytes a strings."""
    if arr.dtype.kind == 'S' or arr.dtype.kind == 'O':
        return [x.decode('utf-8') if isinstance(x, bytes) else str(x) for x in arr]
    return [str(x) for x in arr]


def load_reference_data(filepath: Optional[Path] = None) -> ReferenceData:
    """
    Carga los datos de referencia (firmas de expresión por tipo celular).
    
    Returns:
        ReferenceData con perfiles de RNA y metilación normalizados.
    """
    if filepath is None:
        filepath = DATA_DIR / "ref.h5"
    
    print(f"📂 Cargando referencia: {filepath.name}")
    
    with h5py.File(filepath, 'r') as f:
        # ---- Bulk RNA ----
        bulk_rna_grp = f['ref_bulkRNA']
        rna_data = bulk_rna_grp['data'][:]  # (5, 15908)
        genes = _decode_strings(bulk_rna_grp['genes'][:])
        cell_types = _decode_strings(bulk_rna_grp['cell_types'][:])
        
        # ---- Metilación ----
        met_grp = f['ref_met']
        met_data = met_grp['data'][:]  # (5, 23724)
        cpg_sites = _decode_strings(met_grp['CpG_sites'][:])
        
    # Crear DataFrames
    bulk_rna_df = pd.DataFrame(rna_data, index=cell_types, columns=genes)
    met_df = pd.DataFrame(met_data, index=cell_types, columns=cpg_sites)
    
    print(f"   ✓ RNA: {bulk_rna_df.shape} | Metilación: {met_df.shape}")
    print(f"   ✓ Tipos celulares: {cell_types}")
    
    return ReferenceData(
        bulk_rna=bulk_rna_df,
        methylation=met_df,
        cell_types=cell_types,
        genes=genes,
        cpg_sites=cpg_sites
    )


def load_mix_data(filepath: Path) -> MixData:
    """
    Carga datos de mezcla (bulk) a deconvolucionar.
    
    Returns:
        MixData con expresión RNA y metilación por muestra.
    """
    print(f"📂 Cargando mix: {filepath.name}")
    
    with h5py.File(filepath, 'r') as f:
        # ---- RNA ----
        rna_grp = f['mix_rna']
        rna_data = rna_grp['data'][:]
        genes = _decode_strings(rna_grp['genes'][:])
        
        # Obtener samples si existe
        if 'samples' in rna_grp:
            samples = _decode_strings(rna_grp['samples'][:])
        else:
            samples = [f"sample_{i}" for i in range(rna_data.shape[0])]
        
        # ---- Metilación ----
        met_grp = f['mix_met']
        met_data = met_grp['data'][:]
        cpg_sites = _decode_strings(met_grp['CpG_sites'][:])
    
    # Manejar datos estructurados (numpy structured array)
    if rna_data.dtype.names is not None:
        # Es un structured array - convertir a matriz regular
        # Cada columna del structured array es una muestra, cada fila un gen
        # Resultado: (samples x genes)
        samples = list(rna_data.dtype.names)
        rna_matrix = np.array([rna_data[name] for name in samples])  # (samples, genes)
        rna_data = rna_matrix
        
    if met_data.dtype.names is not None:
        met_samples = list(met_data.dtype.names)
        met_matrix = np.array([met_data[name] for name in met_samples])  # (samples, cpg_sites)
        met_data = met_matrix
        samples = met_samples  # Usar los samples de metilación si RNA era structured
    
    # Crear DataFrames (samples x features)
    rna_df = pd.DataFrame(rna_data, index=samples, columns=genes)
    met_df = pd.DataFrame(met_data, index=samples, columns=cpg_sites)
    
    print(f"   ✓ RNA: {rna_df.shape} | Metilación: {met_df.shape}")
    print(f"   ✓ Muestras: {len(samples)}")
    
    return MixData(
        rna=rna_df,
        methylation=met_df,
        sample_ids=samples,
        source=filepath.stem
    )


def load_ground_truth(filepath: Path) -> GroundTruth:
    """
    Carga ground truth (proporciones reales de tipos celulares).
    
    Returns:
        GroundTruth con proporciones normalizadas (suman 1 por fila).
    """
    print(f"📂 Cargando ground truth: {filepath.name}")
    
    with h5py.File(filepath, 'r') as f:
        # Detectar estructura: puede ser f['data'] o f['groundtruth']['data']
        if 'groundtruth' in f:
            # Nueva estructura: datos dentro del grupo 'groundtruth'
            grp = f['groundtruth']
        else:
            # Estructura antigua: datos en la raíz
            grp = f
        
        data = grp['data'][:]
        cell_types = _decode_strings(grp['genes'][:])  # En GT, 'genes' son los tipos celulares
        
        # Samples si existen
        samples = None
        if 'samples' in grp:
            samples = _decode_strings(grp['samples'][:])
        else:
            samples = [f"sample_{i}" for i in range(data.shape[0])]
    
    # Crear DataFrame
    proportions_df = pd.DataFrame(data, index=samples, columns=cell_types)
    
    # Verificar que suman 1
    row_sums = proportions_df.sum(axis=1)
    if not np.allclose(row_sums, 1.0, atol=1e-3):
        print(f"   ⚠️ ADVERTENCIA: Las proporciones no suman 1. Rango: [{row_sums.min():.4f}, {row_sums.max():.4f}]")
    else:
        print(f"   ✓ Proporciones válidas (suman 1.0)")
    
    print(f"   ✓ Shape: {proportions_df.shape}")
    print(f"   ✓ Tipos celulares: {cell_types}")
    
    return GroundTruth(
        proportions=proportions_df,
        sample_ids=samples,
        source=filepath.stem
    )


# =============================================================================
# NORMALIZACIÓN
# =============================================================================

def normalize_rna(df: pd.DataFrame, method: str = 'cpm') -> pd.DataFrame:
    """
    Normaliza datos de expresión RNA.
    
    Args:
        df: DataFrame con expresión (samples x genes)
        method: 'cpm' (counts per million), 'log' (log2 + 1), 'zscore'
    
    Returns:
        DataFrame normalizado
    """
    if method == 'cpm':
        # Counts per million: divide por suma de fila y multiplica por 1e6
        row_sums = df.sum(axis=1)
        normalized = df.div(row_sums, axis=0) * 1e6
    elif method == 'log':
        # Log2 transform (añadir 1 para evitar log(0))
        normalized = np.log2(df + 1)
    elif method == 'zscore':
        # Z-score por gen (columna)
        normalized = (df - df.mean()) / df.std()
    else:
        raise ValueError(f"Método no soportado: {method}")
    
    return normalized


def normalize_methylation(df: pd.DataFrame) -> pd.DataFrame:
    """
    Normaliza datos de metilación (ya deberían estar en [0, 1]).
    Aplica clipping para asegurar el rango.
    """
    return df.clip(0, 1)


# =============================================================================
# CARGADOR PRINCIPAL
# =============================================================================

def load_all_data(
    data_dir: Optional[Path] = None,
    exclude_invivo: bool = True,
    normalize: bool = True,
    rna_norm_method: str = 'log'
) -> DeconvolutionDataset:
    """
    Carga todos los datos necesarios para deconvolución.
    
    Args:
        data_dir: Directorio con archivos HDF5
        exclude_invivo: Si True, excluye datos 'invivo' (solo 2 tipos celulares)
        normalize: Si True, normaliza los datos
        rna_norm_method: Método de normalización RNA ('cpm', 'log', 'zscore')
    
    Returns:
        DeconvolutionDataset con todos los datos cargados y normalizados
    """
    if data_dir is None:
        data_dir = DATA_DIR
    
    print("=" * 60)
    print("HADACA3 - Cargando Dataset de Deconvolución")
    print("=" * 60)
    
    # ---- 1. Cargar referencia ----
    reference = load_reference_data(data_dir / "ref.h5")
    
    # ---- 2. Cargar mixes y ground truths ----
    mixes = {}
    ground_truths = {}
    
    for h5_file in sorted(data_dir.glob("*.h5")):
        filename = h5_file.name
        
        # Excluir archivos no deseados
        if exclude_invivo and filename in EXCLUDED_FILES:
            print(f"⏭️  Excluyendo: {filename} (solo 2 tipos celulares)")
            continue
        
        # Cargar ground truths
        if filename.startswith("groundtruth"):
            gt = load_ground_truth(h5_file)
            ground_truths[gt.source] = gt
        
        # Cargar mixes
        elif filename.startswith("mixes"):
            mix = load_mix_data(h5_file)
            mixes[mix.source] = mix
    
    # ---- 3. Normalización ----
    if normalize:
        print("\n" + "-" * 60)
        print("🔄 Normalizando datos...")
        
        # Normalizar referencia
        reference.bulk_rna = normalize_rna(reference.bulk_rna, method=rna_norm_method)
        reference.methylation = normalize_methylation(reference.methylation)
        
        # Normalizar mixes
        for key, mix in mixes.items():
            mixes[key].rna = normalize_rna(mix.rna, method=rna_norm_method)
            mixes[key].methylation = normalize_methylation(mix.methylation)
        
        print(f"   ✓ RNA normalizado con: {rna_norm_method}")
        print(f"   ✓ Metilación clipped a [0, 1]")
    
    # ---- 4. Resumen ----
    print("\n" + "=" * 60)
    print("📊 RESUMEN DEL DATASET")
    print("=" * 60)
    print(f"Referencia:")
    print(f"   - RNA: {reference.bulk_rna.shape}")
    print(f"   - Metilación: {reference.methylation.shape}")
    print(f"   - Tipos celulares: {reference.cell_types}")
    
    print(f"\nMixes cargados: {len(mixes)}")
    for name, mix in mixes.items():
        print(f"   - {name}: {mix.rna.shape[0]} muestras")
    
    print(f"\nGround truths cargados: {len(ground_truths)}")
    for name, gt in ground_truths.items():
        print(f"   - {name}: {gt.proportions.shape}")
    
    return DeconvolutionDataset(
        reference=reference,
        mixes=mixes,
        ground_truths=ground_truths
    )


# =============================================================================
# SERIALIZACIÓN
# =============================================================================

CACHE_FILE = DATA_DIR / "dataset_cache.pkl"


def save_dataset(dataset: DeconvolutionDataset, filepath: Optional[Path] = None) -> Path:
    """
    Serializa el dataset a un archivo pickle para carga rápida.
    
    Args:
        dataset: Dataset a serializar
        filepath: Ruta del archivo (por defecto: data/dataset_cache.pkl)
    
    Returns:
        Path del archivo guardado
    """
    if filepath is None:
        filepath = CACHE_FILE
    
    print(f"💾 Guardando dataset en: {filepath}")
    with open(filepath, 'wb') as f:
        pickle.dump(dataset, f, protocol=pickle.HIGHEST_PROTOCOL)
    
    size_mb = filepath.stat().st_size / (1024 * 1024)
    print(f"   ✓ Guardado ({size_mb:.2f} MB)")
    
    return filepath


def load_dataset(filepath: Optional[Path] = None) -> Optional[DeconvolutionDataset]:
    """
    Carga el dataset desde un archivo pickle serializado.
    
    Args:
        filepath: Ruta del archivo (por defecto: data/dataset_cache.pkl)
    
    Returns:
        DeconvolutionDataset si existe el archivo, None si no
    """
    if filepath is None:
        filepath = CACHE_FILE
    
    if not filepath.exists():
        print(f"⚠️ Cache no encontrado: {filepath}")
        return None
    
    print(f"📂 Cargando dataset desde cache: {filepath}")
    with open(filepath, 'rb') as f:
        dataset = pickle.load(f)
    
    print(f"   ✓ Cargado desde cache")
    return dataset


def get_or_load_dataset(
    force_reload: bool = False,
    normalize: bool = True,
    rna_norm_method: str = 'log'
) -> DeconvolutionDataset:
    """
    Obtiene el dataset desde cache si existe, o lo carga y serializa.
    
    Args:
        force_reload: Si True, ignora el cache y recarga desde HDF5
        normalize: Si True, normaliza los datos
        rna_norm_method: Método de normalización RNA
    
    Returns:
        DeconvolutionDataset listo para usar
    """
    if not force_reload:
        dataset = load_dataset()
        if dataset is not None:
            return dataset
    
    # Cargar desde HDF5 y serializar
    dataset = load_all_data(normalize=normalize, rna_norm_method=rna_norm_method)
    save_dataset(dataset)
    
    return dataset


# =============================================================================
# FUNCIONES DE UTILIDAD
# =============================================================================

def align_samples(mix: MixData, ground_truth: GroundTruth) -> Tuple[MixData, GroundTruth]:
    """
    Alinea muestras entre mix y ground truth por ID.
    Útil cuando hay muestras que no coinciden.
    """
    common_samples = list(set(mix.sample_ids) & set(ground_truth.sample_ids or []))
    
    if not common_samples:
        # Si no hay samples explícitos, asumir orden coincidente
        return mix, ground_truth
    
    # Filtrar a samples comunes
    mix.rna = mix.rna.loc[common_samples]
    mix.methylation = mix.methylation.loc[common_samples]
    mix.sample_ids = common_samples
    
    ground_truth.proportions = ground_truth.proportions.loc[common_samples]
    ground_truth.sample_ids = common_samples
    
    return mix, ground_truth


def get_training_data(
    dataset: DeconvolutionDataset,
    mix_name: str
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Obtiene datos de entrenamiento (X_rna, X_met, y) para un mix específico.
    
    Returns:
        X_rna: Expresión RNA (samples x genes)
        X_met: Metilación (samples x cpg_sites)  
        y: Proporciones reales (samples x cell_types)
    """
    # Encontrar el ground truth correspondiente
    gt_name = mix_name.replace("mixes1", "groundtruth1")
    
    if mix_name not in dataset.mixes:
        raise ValueError(f"Mix '{mix_name}' no encontrado. Disponibles: {list(dataset.mixes.keys())}")
    
    if gt_name not in dataset.ground_truths:
        raise ValueError(f"Ground truth '{gt_name}' no encontrado. Disponibles: {list(dataset.ground_truths.keys())}")
    
    mix = dataset.mixes[mix_name]
    gt = dataset.ground_truths[gt_name]
    
    return mix.rna, mix.methylation, gt.proportions


# =============================================================================
# MAIN (para testing)
# =============================================================================

if __name__ == "__main__":
    # Cargar todos los datos
    dataset = load_all_data(normalize=True, rna_norm_method='log')
    
    print("\n" + "=" * 60)
    print("TEST: Obteniendo datos de entrenamiento")
    print("=" * 60)
    
    # Ejemplo: obtener datos para insilico
    for mix_name in dataset.mixes.keys():
        try:
            X_rna, X_met, y = get_training_data(dataset, mix_name)
            print(f"\n{mix_name}:")
            print(f"   X_rna shape: {X_rna.shape}")
            print(f"   X_met shape: {X_met.shape}")
            print(f"   y shape: {y.shape}")
            print(f"   y columnas: {list(y.columns)}")
            print(f"   y suma por fila (primeras 5): {y.sum(axis=1).head().values}")
        except ValueError as e:
            print(f"   ⚠️ {e}")
