import random
from string import ascii_lowercase
import math

#parameters
m = 100 #number of different mails
g = 10  #number of elements in the sample
stream_size = 10000
n = 512

#generate some random strings of size 5 + 1 + 5
D = []
for _ in range(m):
  D.append(''.join(random.choice(ascii_lowercase) for i in range(5))+\
           '@'+''.join(random.choice(ascii_lowercase) for i in range(5)))

print(D)
"""
D is a dictonary with different mails, they will be sent in the streams 
"""
n = 128

#hash function
def h(x,n):
  return hash(x)%n #x module 128 is the hash function 

good_set = set(D[:g]) #just for checking TP and FP rates (we get the 10 first values, those will be the right ones)
print(good_set) #results may differ, because when removing elements the set reorders

#allocate the array of 0s
B = [0] * n #Data stream with bits 

#fill the byte array (we apply the hash function to all valid mails)
for i in range(g): B[h(D[i],n)] = 1
    
print(B) 

tp = 0 # good items passing
fp = 0 # bad items passing
tn = 0 # bad items discarded
fn = 0 # good items discarded

#simulate a stream
for _ in range(stream_size):
  #take a random email
  s = random.choice(D)
  #check its hash value
  if B[h(s,n)]==1: #good (we apply hash over the receiving mail, if the result index is also 1 on the B array it is valid)
    if s not in good_set:
      fp += 1
    else:
      tp += 1
  else: #bad 
    if s in good_set:
      fn += 1
    else:
      tn += 1

print('False positive rate: %f'%(float(fp)/float(tn+fp)))

p = 1223543677

a = random.randrange(p)
b = random.randrange(p)

#this is our new hashing function
def h(x,a,b,p,n):
  return ((a*hash(x)+b)%p)%n
#remark: here we use hash(x) instead of the values to allow for all hashable python types
#   e.g., strings, tuples

#reinitialize the array, for testing
B = [0] * n

for i in range(g): 
  B[h(D[i],a,b,p,n)] = 1 #this process will be offline, we initialize our valid values matrix

print(B)

"""
### 3. **TASK** - Bloom Filters

Your task is to implement the Bloom filters as described in the class lecture. For this, you have to:
1. generate $k$ random pairwise independent hash functions (_hint_: use the example shown above)
2. initialize $B$, by setting $1$ in each $h_i(x)$, $i\in\{1,\dots,k\}$, for all items $x$ in the good set
3. an item $s$ in the stream is considered good if, for all $i\in\{1,\dots,k\}$, we have $B[h_i(s)]=1$

Measure the true positive and false positive rate for various values of $k$ and compare to the values obtained when setting $k=n/m\ln 2$ (to the nearest integer value). What do you notice?

Rates:

$
  \text{false positive rate} \frac{FP}{FP+TN}
$

$
  \text{true positive rate} \frac{TP}{TP+FN}
$
"""

# Lets generate all of our k random hash functions 
p = 1223543677

# As mentioned in the presentation, the best number of hash functions is followed by:
# k = (n/m) * ln(2)
# where n = size of array B, m = number of valid mails
n_hashes = int((len(B)/g) * math.log(2))

# Generate k hash functions with their parameters stored (so then we can apply the same ones)
hash_functions = []
for _ in range(n_hashes): 
    a = random.randrange(p)
    b = random.randrange(p)
    hash_functions.append((a, b))

print(hash_functions) # each a, b combination that we will be using 

# Reinitialize the array
B = [0] * n

# Initialize B using all k hash functions for each valid email
for i in range(g): 
    for a, b in hash_functions:
        index = h(D[i], a, b, p, n)
        B[index] = 1

# Reset counters
tp = 0 # good items passing
fp = 0 # bad items passing
tn = 0 # bad items discarded
fn = 0 # good items discarded

# Simulate stream with Bloom filter
for _ in range(stream_size):
    s = random.choice(D)
    
    # Check if ALL hash functions map to 1
    passes_filter = True
    for a, b in hash_functions:
        if B[h(s, a, b, p, n)] == 0: #if just 1 of the hashes returns 1, it is not valid (NO FALSE NEGATIVES)
            passes_filter = False
            break
    
    if passes_filter:
        if s not in good_set:
            fp += 1
        else:
            tp += 1
    else:
        if s in good_set:
            fn += 1
        else:
            tn += 1

print(f'Number of hash functions (k): {n_hashes}')
print(f'Optimal k = (n/g) * ln(2) = ({n}/{g}) * ln(2) = {n_hashes}')
print(f'False positive rate: {float(fp)/float(tn+fp):.6f}')
print(f'True positive rate: {float(tp)/float(tp+fn):.6f}')
