import argparse
import json
from pathlib import Path

import h5py
import numpy as np
import pandas as pd


def _read_csv(path: Path) -> np.ndarray:
    return pd.read_csv(path, header=None).to_numpy()


def _infer_reference_dir(input_dir: Path) -> Path:
    if input_dir.name == "input_data":
        return input_dir.parent / "reference_data"
    return input_dir / "reference_data"


def _resolve_paths(input_dir: Path, reference_dir: Path):
    paths = {
        "train_csv": input_dir / "train.csv",
        "train_labels_csv": input_dir / "train_labels.csv",
        "adapt_csv": input_dir / "train_DA.csv",
        "test_csv": input_dir / "test.csv",
        "test_labels_csv": reference_dir / "test_labels.csv",
    }
    missing = [k for k in ("train_csv", "train_labels_csv") if not paths[k].exists()]
    if missing:
        raise FileNotFoundError(
            "No se encontraron archivos requeridos: "
            + ", ".join(str(paths[k]) for k in missing)
        )
    return paths


def _normalize_with_train_stats(
    X_train: np.ndarray,
    X_adapt: np.ndarray,
    X_test: np.ndarray,
):
    mean = X_train.mean(axis=0)
    std = X_train.std(axis=0)
    std = np.where(std < 1e-12, 1.0, std)

    X_train_n = (X_train - mean) / std
    X_adapt_n = (X_adapt - mean) / std if X_adapt.size else X_adapt
    X_test_n = (X_test - mean) / std if X_test.size else X_test
    return X_train_n, X_adapt_n, X_test_n, mean, std


def main():
    parser = argparse.ArgumentParser(
        description=(
            "Carga CSV del reto, normaliza con estadisticas de train "
            "y guarda datasets en un archivo H5."
        )
    )
    parser.add_argument(
        "--input-dir",
        type=str,
        default="data/dev_phase/input_data",
        help="Directorio con train.csv, train_labels.csv, train_DA.csv, test.csv.",
    )
    parser.add_argument(
        "--reference-dir",
        type=str,
        default="",
        help=(
            "Directorio con test_labels.csv. "
            "Si se omite, se infiere automaticamente."
        ),
    )
    parser.add_argument(
        "--output-h5",
        type=str,
        default="data/dev_phase/processed/dev_normalized.h5",
        help="Ruta de salida H5.",
    )
    args = parser.parse_args()

    input_dir = Path(args.input_dir)
    reference_dir = Path(args.reference_dir) if args.reference_dir else _infer_reference_dir(input_dir)
    paths = _resolve_paths(input_dir, reference_dir)

    X_train = _read_csv(paths["train_csv"]).astype(np.float32)
    y_train = _read_csv(paths["train_labels_csv"]).ravel().astype(np.float32)

    X_adapt = (
        _read_csv(paths["adapt_csv"]).astype(np.float32)
        if paths["adapt_csv"].exists()
        else np.empty((0, X_train.shape[1]), dtype=np.float32)
    )
    X_test = (
        _read_csv(paths["test_csv"]).astype(np.float32)
        if paths["test_csv"].exists()
        else np.empty((0, X_train.shape[1]), dtype=np.float32)
    )
    y_test = (
        _read_csv(paths["test_labels_csv"]).ravel().astype(np.float32)
        if paths["test_labels_csv"].exists()
        else np.empty((0,), dtype=np.float32)
    )

    X_train, X_adapt, X_test, mean, std = _normalize_with_train_stats(
        X_train, X_adapt, X_test
    )

    output_h5 = Path(args.output_h5)
    output_h5.parent.mkdir(parents=True, exist_ok=True)

    with h5py.File(output_h5, "w") as h5f:
        h5f.create_dataset("X_train", data=X_train, compression="gzip")
        h5f.create_dataset("y_train", data=y_train, compression="gzip")
        h5f.create_dataset("X_adapt", data=X_adapt, compression="gzip")
        h5f.create_dataset("X_test", data=X_test, compression="gzip")
        h5f.create_dataset("y_test", data=y_test, compression="gzip")
        h5f.create_dataset("feature_mean", data=mean.astype(np.float32))
        h5f.create_dataset("feature_std", data=std.astype(np.float32))
        h5f.attrs["normalized"] = True
        h5f.attrs["normalization_source"] = "train"
        h5f.attrs["input_dir"] = str(input_dir)
        h5f.attrs["reference_dir"] = str(reference_dir)

    summary = {
        "output_h5": str(output_h5),
        "input_dir": str(input_dir),
        "reference_dir": str(reference_dir),
        "n_train": int(X_train.shape[0]),
        "n_adapt": int(X_adapt.shape[0]),
        "n_test": int(X_test.shape[0]),
        "n_test_labels": int(y_test.shape[0]),
        "n_features": int(X_train.shape[1]),
    }
    print(json.dumps(summary, indent=2))


if __name__ == "__main__":
    main()
