import urllib.request
import itertools
from itertools import combinations

file_location = "https://phparis.net/slides/algo_4_ds/week1/groceries.csv"
# you can change this to a local file on your computer

# creating in-memory data structure
infile = urllib.request.urlopen(file_location)
baskets = []
for line in infile:
    line = str(line).strip().split(",")[1:-1]
    baskets.append([x for x in line if x != ""])

print("Number of baskets: %d" % len(baskets))

def apriori_algorithm(baskets, min_support):
    """Run the Apriori algorithm to find all frequent itemsets in a list of baskets.
    
    Parameters
    ----------
    baskets : list of list
        The dataset, where each element is a basket (list of items).
    min_support : int
        The minimum support threshold (number of baskets an itemset must appear in 
        to be considered frequent).
    
    Returns
    -------
    frequent_itemsets : dict
        Dictionary mapping k (itemset size) to a list of frequent k-itemsets (as frozensets).
    """
    # Count individual items
    items_count = {}
    for basket in baskets:
        for item in basket:
            if item in items_count:
                items_count[item] += 1
            else:
                items_count[item] = 1
    
    # Filter items_count to only keep items with support >= min_support
    filtered_items = {}
    for item, count in items_count.items():
        if count >= min_support:
            filtered_items[item] = count
    items_count = filtered_items

    print("====== FIRST A PRIORI REDUCTION =====")
    print(f"====== Number of singletons above treshold = {len(items_count)}  ====")
    
    
    frequent_items = list(items_count.keys())
    pairs_count = {}
    
    for pair in itertools.combinations(frequent_items, 2):
        pairs_count[pair] = 0
    
    # Count support for each pair in the baskets
    for basket in baskets:
        for pair in itertools.combinations(basket, 2):
            if pair in pairs_count:
                pairs_count[pair] += 1
    
    # Filter pairs to only keep those with support >= min_support
    frequent_pairs = {}
    for pair, count in pairs_count.items():
        if count >= min_support:
            frequent_pairs[pair] = count

    print("====== SECOND A PRIORI REDUCTION =====")
    print(f"====== Number of pairs above treshold = {len(frequent_pairs)}  ====")

    return frequent_pairs


def generate_association_rules(frequent_pairs, baskets, min_support, min_confidence):
    """Generate association rules from frequent pairs.
    
    Parameters
    ----------
    frequent_pairs : dict
        Dictionary of frequent pairs and their support counts.
    baskets : list of list
        The dataset, where each element is a basket (list of items).
    min_support : int
        The minimum support threshold.
    min_confidence : float
        The minimum confidence threshold (between 0 and 1).
    
    Returns
    -------
    rules : list
        List of association rules as tuples (antecedent, consequent, confidence).
    """
    # Count individual items for confidence calculation
    items_count = {}
    for basket in baskets:
        for item in basket:
            if item in items_count:
                items_count[item] += 1
            else:
                items_count[item] = 1
    
    rules = []
    
    # For each frequent pair, generate both possible rules
    for pair, pair_count in frequent_pairs.items():
        item1, item2 = pair
        
        # Rule: item1 -> item2
        # confidence = supp(item1, item2) / supp(item1)
        confidence_1_to_2 = pair_count / items_count[item1]
        if confidence_1_to_2 >= min_confidence:
            rules.append((item1, item2, confidence_1_to_2))
        
        # Rule: item2 -> item1
        # confidence = supp(item1, item2) / supp(item2)
        confidence_2_to_1 = pair_count / items_count[item2]
        if confidence_2_to_1 >= min_confidence:
            rules.append((item2, item1, confidence_2_to_1))
    
    return rules

# Run the algorithm
frequent_pairs = apriori_algorithm(baskets, 5)
print(generate_association_rules(frequent_pairs,baskets,10,0.6))