"""
Exporta muestras (head 10) de los datos HDF5 a archivos de texto.
"""
import h5py
import numpy as np
from pathlib import Path

DATA_DIR = Path("data")
OUTPUT_DIR = Path("encoded_data_format")
OUTPUT_DIR.mkdir(exist_ok=True)

def decode_if_bytes(arr):
    """Decodifica arrays de bytes a strings."""
    if arr.dtype.kind == 'S':  # byte string
        return [x.decode('utf-8') for x in arr]
    return list(arr)

def export_reference_data():
    """Exporta head 10 de los datos de referencia."""
    ref_path = DATA_DIR / "ref.h5"
    output_path = OUTPUT_DIR / "reference-data-format.txt"
    
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write("=" * 80 + "\n")
        f.write("DATOS DE REFERENCIA (ref.h5)\n")
        f.write("Contiene perfiles de expresion por tipo celular\n")
        f.write("=" * 80 + "\n\n")
        
        with h5py.File(ref_path, 'r') as h5:
            for ref_type in h5.keys():
                f.write(f"\n{'='*60}\n")
                f.write(f"[{ref_type}]\n")
                f.write(f"{'='*60}\n")
                
                group = h5[ref_type]
                
                # Mostrar keys disponibles
                f.write(f"Keys disponibles: {list(group.keys())}\n\n")
                
                # Datos principales
                if 'data' in group:
                    data = group['data'][:]
                    f.write(f"Shape de 'data': {data.shape}\n")
                    f.write(f"dtype: {data.dtype}\n\n")
                    
                    # Head 10 filas x 10 columnas
                    f.write("Head 10 filas x 10 columnas:\n")
                    head = data[:10, :10] if data.ndim == 2 else data[:10]
                    f.write(np.array2string(head, precision=4, suppress_small=True))
                    f.write("\n\n")
                
                # Nombres de genes (filas)
                if 'genes' in group:
                    genes = decode_if_bytes(group['genes'][:])
                    f.write(f"Genes (primeros 10): {genes[:10]}\n")
                    f.write(f"Total genes: {len(genes)}\n\n")
                
                # Nombres de tipos celulares o muestras (columnas)
                for col_key in ['cell_types', 'celltypes', 'samples', 'colnames']:
                    if col_key in group:
                        cols = decode_if_bytes(group[col_key][:])
                        f.write(f"{col_key}: {cols[:10]}\n")
                        f.write(f"Total {col_key}: {len(cols)}\n\n")
        
    print(f"[OK] Guardado: {output_path}")

def export_ground_truth():
    """Exporta head 10 de los ground truth."""
    output_path = OUTPUT_DIR / "ground-truth.txt"
    
    gt_files = list(DATA_DIR.glob("groundtruth*.h5"))
    
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write("=" * 80 + "\n")
        f.write("GROUND TRUTH (Proporciones reales)\n")
        f.write("Cada columna es un tipo celular, cada fila una muestra\n")
        f.write("Los valores son proporciones (deben sumar 1 por fila)\n")
        f.write("=" * 80 + "\n\n")
        
        for gt_path in sorted(gt_files):
            f.write(f"\n{'='*60}\n")
            f.write(f"[{gt_path.name}]\n")
            f.write(f"{'='*60}\n")
            
            with h5py.File(gt_path, 'r') as h5:
                # El ground truth suele estar en /groundtruth/
                if 'groundtruth' in h5:
                    group = h5['groundtruth']
                else:
                    group = h5
                
                f.write(f"Keys: {list(group.keys())}\n\n")
                
                if 'data' in group:
                    data = group['data'][:]
                    f.write(f"Shape: {data.shape} (muestras x tipos_celulares)\n")
                    f.write(f"dtype: {data.dtype}\n\n")
                    
                    # Head 10
                    f.write("Head 10 filas:\n")
                    head = data[:10, :]
                    f.write(np.array2string(head, precision=4, suppress_small=True))
                    f.write("\n\n")
                    
                    # Verificar que suman 1
                    row_sums = data.sum(axis=1)
                    f.write(f"Suma por fila (primeras 10): {row_sums[:10]}\n\n")
                
                # Nombres de tipos celulares
                for key in ['genes', 'cell_types', 'celltypes']:
                    if key in group:
                        names = decode_if_bytes(group[key][:])
                        f.write(f"{key}: {names}\n")
                
                # Nombres de muestras
                if 'samples' in group:
                    samples = decode_if_bytes(group['samples'][:])
                    f.write(f"samples (primeras 10): {samples[:10]}\n")
    
    print(f"[OK] Guardado: {output_path}")

def export_mixes():
    """Exporta head 10 de los datos de mezclas a deconvolucionar."""
    output_path = OUTPUT_DIR / "convulsioned.txt"
    
    mix_files = list(DATA_DIR.glob("mixes*.h5"))
    
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write("=" * 80 + "\n")
        f.write("DATOS A DECONVOLUCIONAR (mixes)\n")
        f.write("Cada columna es una muestra bulk, cada fila un gen\n")
        f.write("Tu modelo debe estimar las proporciones de tipos celulares\n")
        f.write("=" * 80 + "\n\n")
        
        for mix_path in sorted(mix_files):
            f.write(f"\n{'='*60}\n")
            f.write(f"[{mix_path.name}]\n")
            f.write(f"{'='*60}\n")
            
            with h5py.File(mix_path, 'r') as h5:
                f.write(f"Keys principales: {list(h5.keys())}\n\n")
                
                # Explorar cada modalidad (mix_rna, mix_met, etc.)
                for modal_key in h5.keys():
                    group = h5[modal_key]
                    
                    f.write(f"--- {modal_key} ---\n")
                    
                    if isinstance(group, h5py.Group):
                        f.write(f"  Sub-keys: {list(group.keys())}\n")
                        
                        if 'data' in group:
                            data = group['data'][:]
                            f.write(f"  Shape: {data.shape} (genes x muestras)\n")
                            f.write(f"  dtype: {data.dtype}\n")
                            
                            # Head 10x5
                            f.write("  Head 10 filas x 5 columnas:\n")
                            if data.ndim == 2 and np.issubdtype(data.dtype, np.number):
                                head = data[:10, :5]
                                for row in head:
                                    f.write("    " + " ".join([f"{x:10.4f}" for x in row]) + "\n")
                            elif data.ndim == 1 and np.issubdtype(data.dtype, np.number):
                                f.write("    " + " ".join([f"{x:10.4f}" for x in data[:10]]) + "\n")
                            elif data.ndim >= 1:
                                # Es un array de strings u otro tipo
                                items = decode_if_bytes(data[:10]) if data.dtype.kind == 'S' else list(data[:10])
                                f.write(f"    {items}\n")
                            else:
                                f.write(f"    Valor: {data}\n")
                            f.write("\n")
                        
                        # Genes
                        if 'genes' in group:
                            genes = decode_if_bytes(group['genes'][:])
                            f.write(f"  genes (primeros 10): {genes[:10]}\n")
                            f.write(f"  Total genes: {len(genes)}\n")
                        
                        # Samples
                        if 'samples' in group:
                            samples = decode_if_bytes(group['samples'][:])
                            f.write(f"  samples: {samples[:10]}\n")
                        
                        f.write("\n")
                    
                    elif isinstance(group, h5py.Dataset):
                        f.write(f"  Shape: {group.shape}\n")
                        f.write(f"  dtype: {group.dtype}\n\n")
    
    print(f"[OK] Guardado: {output_path}")

def main():
    print("Exportando muestras de datos...")
    print("-" * 40)
    
    export_reference_data()
    export_ground_truth()
    export_mixes()
    
    print("-" * 40)
    print("Completado!")

if __name__ == "__main__":
    main()
