import sys
import time
import random
import math
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict

# Import functions from the main recommendation script
try:
    from Recomendation import (
        load_ratings, 
        cosine_similarity, 
        compute_movie_similarities,
        predict_rating,
        recommend_for_user
    )
except ImportError:
    print("Error: Could not import from Recomendation.py")
    print("Make sure Recomendation.py is in the same directory")
    sys.exit(1)


def split_train_test(users, test_ratio=0.2, min_ratings=5):
    """
    Split user ratings into train and test sets
    dict, float, int -> dict, dict, dict
    """
    train_users = {}
    test_users = {}
    train_movies = defaultdict(dict)
    
    for user_id, ratings in users.items():
        if len(ratings) < min_ratings:
            continue
            
        user_ratings = list(ratings.items())
        random.shuffle(user_ratings)
        
        split_point = int(len(user_ratings) * (1 - test_ratio))
        train_ratings = user_ratings[:split_point]
        test_ratings = user_ratings[split_point:]
        
        train_users[user_id] = dict(train_ratings)
        test_users[user_id] = dict(test_ratings)
        
        for movie_id, rating in train_ratings:
            train_movies[movie_id][user_id] = rating
    
    return train_users, test_users, dict(train_movies)


def compute_rmse_mae(users, test_users, movies, similarities, k=10):
    """
    Compute RMSE and MAE for predictions
    dict, dict, dict, dict, int -> float, float
    """
    squared_errors = []
    absolute_errors = []
    
    for user_id, test_ratings in test_users.items():
        if user_id not in users:
            continue
            
        for movie_id, true_rating in test_ratings.items():
            pred_rating = predict_rating(user_id, movie_id, users, similarities, k)
            
            if pred_rating is not None:
                error = true_rating - pred_rating
                squared_errors.append(error ** 2)
                absolute_errors.append(abs(error))
    
    if not squared_errors:
        return None, None
    
    rmse = math.sqrt(sum(squared_errors) / len(squared_errors))
    mae = sum(absolute_errors) / len(absolute_errors)
    
    return rmse, mae


def evaluate_accuracy_vs_k(train_users, test_users, train_movies, similarities, k_values):
    """
    Evaluate RMSE/MAE for different k values
    """
    results = {'k': [], 'rmse': [], 'mae': []}
    
    print("\nEvaluating accuracy vs k...")
    for k in k_values:
        print(f"Testing k={k}...")
        rmse, mae = compute_rmse_mae(train_users, test_users, train_movies, similarities, k)
        if rmse is not None:
            results['k'].append(k)
            results['rmse'].append(rmse)
            results['mae'].append(mae)
    
    return results


def evaluate_time_vs_threshold(movies, thresholds):
    """
    Evaluate execution time for different similarity thresholds
    """
    results = {'threshold': [], 'time': []}
    
    print("\nEvaluating time vs threshold...")
    for threshold in thresholds:
        print(f"Testing threshold={threshold}...")
        start_time = time.time()
        similarities = compute_movie_similarities(movies, threshold)
        elapsed = time.time() - start_time
        
        results['threshold'].append(threshold)
        results['time'].append(elapsed)
    
    return results


def evaluate_time_vs_dataset_size(users, movies, fractions):
    """
    Evaluate execution time for different dataset sizes
    """
    results = {'fraction': [], 'n_users': [], 'n_movies': [], 'time': []}
    
    print("\nEvaluating time vs dataset size...")
    user_ids = list(users.keys())
    
    for fraction in fractions:
        print(f"Testing fraction={fraction}...")
        n_users = int(len(user_ids) * fraction)
        sample_user_ids = set(random.sample(user_ids, n_users))
        
        # Create subset
        subset_users = {uid: users[uid] for uid in sample_user_ids}
        subset_movies = defaultdict(dict)
        for uid, ratings in subset_users.items():
            for mid, rating in ratings.items():
                subset_movies[mid][uid] = rating
        
        start_time = time.time()
        similarities = compute_movie_similarities(dict(subset_movies), 0.0)
        elapsed = time.time() - start_time
        
        results['fraction'].append(fraction)
        results['n_users'].append(n_users)
        results['n_movies'].append(len(subset_movies))
        results['time'].append(elapsed)
    
    return results


def plot_accuracy_vs_k(results):
    """Generate plot for accuracy vs k"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    ax1.plot(results['k'], results['rmse'], 'o-', linewidth=2, markersize=8)
    ax1.set_xlabel('k (Number of Neighbors)', fontsize=12)
    ax1.set_ylabel('RMSE', fontsize=12)
    ax1.set_title('RMSE vs k-Neighbors', fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    
    ax2.plot(results['k'], results['mae'], 'o-', color='orange', linewidth=2, markersize=8)
    ax2.set_xlabel('k (Number of Neighbors)', fontsize=12)
    ax2.set_ylabel('MAE', fontsize=12)
    ax2.set_title('MAE vs k-Neighbors', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('accuracy_vs_k.png', dpi=300, bbox_inches='tight')
    print("Saved: accuracy_vs_k.png")
    plt.close()


def plot_time_vs_threshold(results):
    """Generate plot for time vs threshold"""
    plt.figure(figsize=(8, 5))
    plt.plot(results['threshold'], results['time'], 'o-', linewidth=2, markersize=8, color='green')
    plt.xlabel('Similarity Threshold', fontsize=12)
    plt.ylabel('Execution Time (seconds)', fontsize=12)
    plt.title('Execution Time vs Similarity Threshold', fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('time_vs_threshold.png', dpi=300, bbox_inches='tight')
    print("Saved: time_vs_threshold.png")
    plt.close()


def plot_time_vs_size(results):
    """Generate plot for time vs dataset size"""
    fig, ax1 = plt.subplots(figsize=(10, 5))
    
    color = 'tab:blue'
    ax1.set_xlabel('Dataset Fraction', fontsize=12)
    ax1.set_ylabel('Execution Time (seconds)', color=color, fontsize=12)
    ax1.plot(results['fraction'], results['time'], 'o-', color=color, linewidth=2, markersize=8)
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.grid(True, alpha=0.3)
    
    ax2 = ax1.twinx()
    color = 'tab:red'
    ax2.set_ylabel('Number of Movies', color=color, fontsize=12)
    ax2.plot(results['fraction'], results['n_movies'], 's--', color=color, linewidth=2, markersize=6)
    ax2.tick_params(axis='y', labelcolor=color)
    
    plt.title('Execution Time vs Dataset Size', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig('time_vs_size.png', dpi=300, bbox_inches='tight')
    print("Saved: time_vs_size.png")
    plt.close()


def generate_summary_table(accuracy_results, best_k):
    """Generate summary statistics table"""
    print("\n" + "="*60)
    print("EXPERIMENTAL EVALUATION SUMMARY")
    print("="*60)
    
    print(f"\nOptimal k value: {best_k}")
    idx = accuracy_results['k'].index(best_k)
    print(f"RMSE at k={best_k}: {accuracy_results['rmse'][idx]:.4f}")
    print(f"MAE at k={best_k}: {accuracy_results['mae'][idx]:.4f}")
    
    print("\n" + "-"*60)
    print("Accuracy Metrics for Different k Values:")
    print("-"*60)
    print(f"{'k':<10} {'RMSE':<15} {'MAE':<15}")
    print("-"*60)
    for i in range(len(accuracy_results['k'])):
        print(f"{accuracy_results['k'][i]:<10} {accuracy_results['rmse'][i]:<15.4f} {accuracy_results['mae'][i]:<15.4f}")
    print("="*60)


# Main execution
if len(sys.argv) < 2:
    print("Usage: python experimental_evaluation.py <ratings_file>")
    sys.exit(1)

ratings_file = sys.argv[1]
random.seed(42)  # For reproducibility

# Load data
print("Loading ratings...")
users, movies = load_ratings(ratings_file)
print(f"Loaded {len(users)} users and {len(movies)} movies")

# Split train/test
print("\nSplitting data into train/test...")
train_users, test_users, train_movies = split_train_test(users, test_ratio=0.2)
print(f"Train: {len(train_users)} users, Test: {len(test_users)} users")

# Compute similarities on training data
print("\nComputing similarities on training data...")
similarities = compute_movie_similarities(train_movies, similarity_threshold=0.0)

# Experiment 1: Accuracy vs k (reduced k values)
k_values = [5, 10, 20, 30]
accuracy_results = evaluate_accuracy_vs_k(train_users, test_users, train_movies, similarities, k_values)
plot_accuracy_vs_k(accuracy_results)

# Find best k
best_idx = accuracy_results['rmse'].index(min(accuracy_results['rmse']))
best_k = accuracy_results['k'][best_idx]

# Experiment 2: Time vs threshold (reduced thresholds)
thresholds = [0.0, 0.2, 0.4]
time_threshold_results = evaluate_time_vs_threshold(movies, thresholds)
plot_time_vs_threshold(time_threshold_results)

# Experiment 3: Time vs dataset size (reduced fractions)
fractions = [0.2, 0.5, 1.0]
time_size_results = evaluate_time_vs_dataset_size(users, movies, fractions)
plot_time_vs_size(time_size_results)

# Generate summary
generate_summary_table(accuracy_results, best_k)

print("\n All plots generated successfully!")
print("Check the current directory for PNG files")