"""
X-Ray Detector Calibration using ML - Solution
==============================================
Main model file for the competition.
"""

import numpy as np
import pandas as pd
from skada import make_da_pipeline, CORALAdapter
from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import StandardScaler

class Model:
    def __init__(self):
        """
        Initialize the model.
        Using Scikit-Learn's MLPRegressor coupled with Skada's Domain Adaptation (CORAL).
        A StandardScaler is added to guarantee stable training.
        """
        # Creacion del pipeline Hibrido: Adaptador de Dominio (CORAL) + Red Neuronal Densa (MLP)
        self.model = make_da_pipeline(
            StandardScaler(),
            CORALAdapter(),
            MLPRegressor(
                hidden_layer_sizes=(256, 128, 64),
                activation='relu',
                solver='adam',
                batch_size=512,
                learning_rate_init=0.001,
                max_iter=200,
                early_stopping=True,
                validation_fraction=0.1,
                random_state=42
            )
        )
        
    def fit(self, X_train, y_train, X_adapt=None):
        """
        Train the model with Unsupervised Domain Adaptation (UDA).
        
        Args:
            X_train (np.ndarray): Training data features (source domain)
            y_train (np.ndarray): Training data labels
            X_adapt (np.ndarray, optional): Domain adaptation data (target domain)
        """
        if X_adapt is not None and len(X_adapt) > 0:
            # Concatenar para Skada
            X = np.vstack((X_train, X_adapt))
            # Rellenar con NaNs las etiquetas del dominio target
            y = np.concatenate((y_train, np.full(len(X_adapt), np.nan)))
            # sample_domain: +1 u otros positivos para Source, negativos (ej: -1) para Target
            sample_domain = np.concatenate((np.ones(len(X_train)), -np.ones(len(X_adapt))))
            
            self.model.fit(X, y, sample_domain=sample_domain)
        else:
            # En caso de que se pruebe sin dominio adaptativo (solo validacion simple)
            sample_domain = np.ones(len(X_train))
            self.model.fit(X_train, y_train, sample_domain=sample_domain)
        
    def predict(self, X_test):
        """
        Predict on test data.
        
        Args:
            X_test (np.ndarray): Test data features (target domain)
            
        Returns:
            np.ndarray: Predicted labels
        """
        return self.model.predict(X_test)
