"""
HADACA3 - Submission Script
============================
Generates prediction files and submission ZIP for Codabench.

Usage:
    python submission_script.py

Output:
    - ../submissions/program_YYYY_MM_DD_HH_MM_SS.zip (code submission)
    - ../submissions/results_YYYY_MM_DD_HH_MM_SS.zip (prediction submission)
"""

import subprocess
import sys
import importlib
import os
import zipfile
import shutil
import time
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

# Enable finding modules in parent directory
sys.path.append(str(Path(__file__).resolve().parent.parent))
# Also include src if needed by local scripts
sys.path.append(str(Path(__file__).resolve().parent.parent / "src"))

# Use the Improved Program (v2) for submission
try:
    from program_v4 import program
    PROGRAM_FILE = "program_v4.py"
    print(f"Using improved program: {PROGRAM_FILE}")
except ImportError:
    from program import program
    PROGRAM_FILE = "program.py"
    print(f"Using default program: {PROGRAM_FILE}")

import numpy as np
import pandas as pd
import h5py

# =============================================================================
# DATA I/O FUNCTIONS
# =============================================================================

def read_hdf5(filepath):
    """Read HDF5 file and return dictionary of DataFrames."""
    data = {}
    
    def decode_strings(arr):
        if arr.dtype.kind == 'S' or arr.dtype.kind == 'O':
            return [x.decode('utf-8') if isinstance(x, bytes) else str(x) for x in arr]
        return [str(x) for x in arr]
    
    with h5py.File(filepath, 'r') as f:
        for group_name in f.keys():
            grp = f[group_name]
            
            if 'data' not in grp:
                continue
            
            values = grp['data'][:]
            
            # Handle structured arrays (for mix data)
            if values.dtype.names is not None:
                samples = list(values.dtype.names)
                values = np.array([values[name] for name in samples])  # (samples, features)
            else:
                samples = None
            
            # Determine row and column names based on group type
            if 'ref' in group_name:
                # Reference data: (cell_types x features)
                if 'genes' in grp:
                    col_names = decode_strings(grp['genes'][:])
                elif 'CpG_sites' in grp:
                    col_names = decode_strings(grp['CpG_sites'][:])
                else:
                    col_names = [f"feat_{i}" for i in range(values.shape[1])]
                
                if 'cell_types' in grp:
                    row_names = decode_strings(grp['cell_types'][:])
                else:
                    row_names = [f"cell_{i}" for i in range(values.shape[0])]
                
                df = pd.DataFrame(values, index=row_names, columns=col_names)
            
            else:
                # Mix data
                if 'genes' in grp:
                    row_names = decode_strings(grp['genes'][:])
                elif 'CpG_sites' in grp:
                    row_names = decode_strings(grp['CpG_sites'][:])
                else:
                    row_names = [f"feat_{i}" for i in range(values.shape[1] if samples else values.shape[0])]
                
                if samples:
                    col_names = samples
                    df = pd.DataFrame(values.T, index=row_names, columns=col_names)
                else:
                    if 'samples' in grp:
                        col_names = decode_strings(grp['samples'][:])
                    else:
                        col_names = [f"sample_{i}" for i in range(values.shape[0])]
                    df = pd.DataFrame(values.T, index=row_names, columns=col_names)
            
            data[group_name] = df
    
    return data


def write_hdf5(filepath, data_dict):
    """Write dictionary of DataFrames to HDF5 file."""
    with h5py.File(filepath, 'w') as f:
        for name, df in data_dict.items():
            grp = f.create_group(name)
            grp.create_dataset('data', data=df.values.T)  # Transpose back to (samples, features) if needed?
            # Wait, check ingestion format. Ingestion reads (samples, features) usually?
            # Actually, check read_hdf5 above. Standard convention is usually (samples, features) in HDF5 and we transposed to (features, samples) in DF.
            # So here we should transpose back to match HDF5 expected layout?
            # Let's check original submission script (lines 118): grp.create_dataset('data', data=df.values)
            # Original DF was (features, samples) if mix, or was it?
            # Let's stick to original behavior: df.values
            grp.create_dataset('data', data=df.values)
            grp.create_dataset('genes', data=[s.encode('utf-8') for s in df.index])
            grp.create_dataset('samples', data=[s.encode('utf-8') for s in df.columns])


# =============================================================================
# MAIN EXECUTION
# =============================================================================

if __name__ == "__main__":
    print("=" * 60)
    print("HADACA3 - Generating Submission from Organized Structure")
    print("=" * 60)
    
    # Root directory (parent of submission/)
    root_dir = Path(__file__).resolve().parent.parent
    data_dir = root_dir / "data"
    submissions_dir = root_dir / "submissions"
    submissions_dir.mkdir(exist_ok=True)
    
    if not data_dir.exists():
        print(f"ERROR: Data directory '{data_dir}' not found.")
        sys.exit(1)
    
    # Load reference
    ref_file = data_dir / "ref.h5"
    if not ref_file.exists():
        print(f"ERROR: Reference file '{ref_file}' not found.")
        sys.exit(1)
    
    print(f"\nLoading reference: {ref_file}")
    reference_data = read_hdf5(ref_file)
    
    # Find all mix files
    datasets_list = [f for f in data_dir.glob("mixes*.h5")]
    
    if not datasets_list:
        print("ERROR: No mix files found in data directory.")
        sys.exit(1)
    
    print(f"Found {len(datasets_list)} mix files")
    
    # Generate predictions
    def process_single_dataset(mix_file_path, reference_data):
        """Process a single dataset."""
        mix_file = Path(mix_file_path)
        t0 = time.time()
        print(f"\n  [START] {mix_file.name}")
        mixes_data = read_hdf5(mix_file)
        dataset_name = mix_file.stem.replace("mixes1_", "").replace("mixes_", "").replace("_pdac", "")

        pred = program(
            mix_rna=mixes_data.get("mix_rna"),
            ref_bulkRNA=reference_data.get("ref_bulkRNA"),
            mix_met=mixes_data.get("mix_met"),
            ref_met=reference_data.get("ref_met")
        )
        elapsed = time.time() - t0
        print(f"  [DONE]  {mix_file.name} -> {pred.shape} ({elapsed:.1f}s)")
        return dataset_name, pred

    predictions = {}
    sorted_files = sorted(datasets_list)
    n_workers = min(len(sorted_files), 3)

    try:
        with ThreadPoolExecutor(max_workers=n_workers) as executor:
            futures = {
                executor.submit(process_single_dataset, str(f), reference_data): f
                for f in sorted_files
            }
            for future in as_completed(futures):
                dataset_name, pred = future.result()
                predictions[dataset_name] = pred
    except Exception as e:
        print(f"\nParallel processing failed ({e}), falling back to sequential...")
        predictions = {}
        for mix_file in sorted_files:
            dataset_name, pred = process_single_dataset(mix_file, reference_data)
            predictions[dataset_name] = pred
    
    # ==========================================================================
    # CREATE SUBMISSIONS
    # ==========================================================================
    
    timestamp = pd.Timestamp.now().strftime("%Y_%m_%d_%H_%M_%S")
    
    # --- Code Submission ---
    print("\n" + "=" * 60)
    print("Creating Code Submission")
    print("=" * 60)
    
    # We copy the actual program file content instead of using inspect.getsource
    # to ensure all imports and global variables are preserved.
    # We rename it to program.py inside the submission folder.
    
    src_program_path = root_dir / PROGRAM_FILE
    dst_program_path = submissions_dir / "program.py"
    
    if not src_program_path.exists():
        print(f"ERROR: Source program file {src_program_path} not found!")
        sys.exit(1)
        
    shutil.copy(src_program_path, dst_program_path)
    print(f"Copied {src_program_path.name} to {dst_program_path}")
    
    # Create ZIP
    zip_program = submissions_dir / f"program_{timestamp}.zip"
    with zipfile.ZipFile(zip_program, 'w') as zipf:
        zipf.write(dst_program_path, arcname="program.py")
    
    print(f"Code submission: {zip_program}")
    
    # --- Results Submission ---
    print("\n" + "=" * 60)
    print("Creating Results Submission")
    print("=" * 60)
    
    prediction_file = submissions_dir / "prediction.h5"
    write_hdf5(prediction_file, predictions)
    
    zip_results = submissions_dir / f"results_{timestamp}.zip"
    with zipfile.ZipFile(zip_results, 'w') as zipf:
        zipf.write(prediction_file, arcname="prediction.h5")
    
    print(f"Results submission: {zip_results}")
    
    # ==========================================================================
    # SUMMARY
    # ==========================================================================
    
    print("\n" + "=" * 60)
    print("SUBMISSION COMPLETE")
    print("=" * 60)
    print(f"\nFiles created:")
    print(f"  - {zip_program}")
    print(f"  - {zip_results}")
