import sys
import numpy as np
import urllib.request
import re
import string
import random

file_location = 'https://phparis.net/slides/algo_4_ds/week2/tweets.txt' #you can change this to a local file on your computer

#keeping document in memory
infile = urllib.request.urlopen(file_location)
docs = []
for line in infile: 
  docs.append(str(line.strip()).lower())

for i in range(len(docs)): 
   if i % 100 == 0: 
      print(docs[i] + "\n")

      
print("Number of documents: %d"%len(docs))

k = 5 #k shingle size

shingle_id = {}
id_shingle = []
m = []
ids = 0

total_shingles = 0

for d in docs:
  #removing whitespace
  d_new = ''.join(c for c in d if c.isalnum())
  char_shing = [d_new[i:i+k] for i in range(len(d_new)-k+1)] #dividing into words with every k size 
  total_shingles += len(char_shing)
  sid = set()
  for sh in char_shing: #for every shingle in the pack
    if sh not in shingle_id: #if its new, we addit to our data structures 
      shingle_id[sh]=ids
      id_shingle.append(sh)
      ids=ids+1
    sid.add(shingle_id[sh]) #we add to the set of ids, the new id (if it exits it doesnt add)
  m.append(sid) #we add the set of shingle ids to the m list, so we have a list with every set of shingle on each document

print ("Unique shingles: %d"%len(id_shingle))
print ("Total shingles: %d"%total_shingles)

#Then we call the jaccard similarity that just checks what percentaje of shingles appear on both documents
# They are really not documents, they are the sets of shingles on each document 
def jaccard_similarity(doc1, doc2):
  if len(doc1)==0 or len(doc2)==0:
    return 0.0
  else:
    inter = doc1.intersection(doc2)
    union = doc1.union(doc2)
    return float(len(inter))/float(len(union))

#example

print(jaccard_similarity(m[0],m[1]))
#perm being the 
def min_hash(doc, perm):
    # perm is a permutation of shingle indices
    for shingle_id in perm: #for every possible shingle
        if shingle_id in doc: #if that shingle is on the set (set of shingles of the doc) we return it 
            return shingle_id
    return float('inf')  # if doc is empty

perm = list(range(len(id_shingle))) #list with every number from 0 to the maximum possible number different shingles
random.shuffle(perm)

min_hash(m[0],perm)
print(min_hash(m[0], perm))

print(len(m))
print(len(id_shingle))

def create_minhash_signature(docs, n_permutations):
    """
    OBJ: Create Min-Hashing signature matrix
    
    Args:
        documents: list of sets (each set contains shingle IDs for a document)
        n_permutations: number of hash functions (permutations) to use
    
    Returns:
        signature_matrix: numpy array of shape (n_permutations, len(documents))
    """
    n_docs = len(docs)
    n_shingles = len(id_shingle)
    
    # Initialize signature matrix
    signature_matrix = np.full((n_permutations, n_docs), float('inf'))
    
    # Generate n permutations
    permutations = []
    for i in range(n_permutations):
        perm = list(range(n_shingles))
        random.shuffle(perm)
        permutations.append(perm)
    
    # Compute signatures
    for doc_idx, doc in enumerate(docs):
        for perm_idx, perm in enumerate(permutations):
            signature_matrix[perm_idx, doc_idx] = min_hash(doc, perm)
    
    return signature_matrix, permutations

def minhash_similarity(sig1, sig2):
    """
    OBJ: Estimate Jaccard similarity using Min-Hash signatures
    
    Args:
        sig1, sig2: signature vectors for two documents => double(estimated similarity)
    """
    if len(sig1) != len(sig2):
        raise ValueError("Signature vectors must have same length")
    
    matches = np.sum(sig1 == sig2)
    return float(matches) / len(sig1)

# Example usage
n_permutations = 100  # Number of hash functions

print(f"\nCreating Min-Hash signatures with {n_permutations} permutations...")
signature_matrix, permutations = create_minhash_signature(m, n_permutations)

print(f"Signature matrix shape: {signature_matrix.shape}")

# Test similarity estimation
doc1_idx = 0 
doc2_idx = 1 
actual_jaccard = jaccard_similarity(m[doc1_idx], m[doc2_idx])
estimated_jaccard = minhash_similarity(signature_matrix[:, doc1_idx], signature_matrix[:, doc2_idx]) #: means all rows

print(f"\nSimilarity between documents {doc1_idx} and {doc2_idx}:")
print(f"Actual Jaccard similarity: {actual_jaccard:.4f}")
print(f"Min-Hash estimated similarity: {estimated_jaccard:.4f}")
print(f"Estimation error: {abs(actual_jaccard - estimated_jaccard):.4f}")

def lsh_hash_band(band_signature):
    """Hash a band signature to a string"""
    return str(tuple(band_signature))

def lsh(signature_matrix, b, r):
    """
    Args:
        signature_matrix: Min-Hash signatures
        b: number of bands
        r: number of rows per band (b * r = n_permutations)
    
    Returns:
        candidate_pairs: set of document pairs that hash to same bucket in at least one band
    """
    n_docs = signature_matrix.shape[1]
    candidate_pairs = set()
    
    # Process each band
    for band in range(b):
        start_row = band * r
        end_row = start_row + r
        
        # Hash table for this band
        buckets = {}
        
        # Hash each document's band signature
        for doc in range(n_docs):
            band_sig = []
            for row in range(start_row, end_row):
                band_sig.append(signature_matrix[row, doc])
            
            hash_val = lsh_hash_band(band_sig)
            
            if hash_val not in buckets:
                buckets[hash_val] = []
            buckets[hash_val].append(doc)
        
        # Find pairs in same bucket
        for bucket_docs in buckets.values():
            if len(bucket_docs) > 1:
                for i in range(len(bucket_docs)):
                    for j in range(i + 1, len(bucket_docs)):
                        candidate_pairs.add((bucket_docs[i], bucket_docs[j]))
    
    return candidate_pairs

# LSH parameters
b = 20  # bands
r = 5   # rows per band (so b*r = 100 = n_permutations)

# Calculate similarity threshold
t = (1.0 / b) ** (1.0 / r)
print(f"LSH with b={b} bands, r={r} rows per band")
print(f"Similarity threshold t = (1/b)^(1/r) = {t:.4f}")

# Run LSH
candidate_pairs = lsh(signature_matrix, b, r)
print(f"LSH found {len(candidate_pairs)} candidate pairs")

# Compare LSH candidates with true Jaccard similarity
print("\SH Candidate pairs vs True Jaccard similarity:")
print("Doc1\tDoc2\tMin-Hash\tTrue Jaccard\tAbove threshold?")
print("-" * 60)

candidate_list = list(candidate_pairs)
for i in range(min(10, len(candidate_list))):
    doc1, doc2 = candidate_list[i]
    minhash_sim = minhash_similarity(signature_matrix[:, doc1], signature_matrix[:, doc2])
    #print(f"Min hash similarity: {minhash_sim}")
    true_jaccard = jaccard_similarity(m[doc1], m[doc2])
    #print(f"True Jaccard Similarity: {true_jaccard} ")
    above_threshold = "Yes" if true_jaccard >= t else "No"
    
    print(f"{doc1}\t{doc2}\t{minhash_sim:.4f}\t\t{true_jaccard:.4f}\t\t{above_threshold}")

# Statistics (comparing with real data to see its precision)
if len(candidate_pairs) > 0:
    true_positives = 0  # candidates with true similarity >= threshold
    for doc1, doc2 in candidate_pairs:
        true_jaccard = jaccard_similarity(m[doc1], m[doc2])
        if true_jaccard >= t:
            true_positives += 1
    
    precision = true_positives / len(candidate_pairs)
    print(f"\nLSH Statistics:")
    print(f"Candidates above threshold: {true_positives}/{len(candidate_pairs)}")
    print(f"Precision: {precision:.4f}")
