def __init__(self, n, c, err, set_size, err1k): """ n: number of letters in string c: size of alphabet err: total error rate of sandwich """ self.n = n self.c = c self.model = WordNet(n, c) self.tau = 0.5 # default value, adjust by tuning later self.alpha = 0.618503137801576 # 2 ** -log(2) # AMQs can only be set up after training model self.err = err self.err1 = self.err * err1k self.amq1 = WordBloom(Bloom.init_ne(set_size, self.err1)) self.amq2 = None # Determine size after training
def train(self, xs, ys, epochs): """ Train on examples for a certain number of epochs """ # Train neural net # Note: torch dataloader takes care of shuffling self.model.train(xs, ys, epochs) # Tune tau self.tau = self._choose_tau(xs, ys) # Get false negatives positives = [x for x, y in zip(xs, ys) if y] false_negs = [x for x in positives if not (self.model(x) > self.tau)] # Build filter for negatives if len(false_negs) > 0: self.amq = WordBloom(Bloom.init_ne(len(false_negs), self.err / 2)) self.amq.add_set(false_negs)
def bloom_test(xs, ys, num_pos, num_neg, n, c, e): """ Perform a test on the Bloom filter """ bloom = WordBloom(Bloom.init_ne(num_pos, e)) positives = [x for x, y in zip(xs, ys) if y] bloom.add_set(positives) false_pos = false_neg = 0 for x, y in zip(xs, ys): filter_contains = bloom.contains(x) false_pos += not y and filter_contains false_neg += y and not filter_contains print(bloom) print("fpr: {}, fnr: {}, correct%: {}".format( false_pos / num_neg, false_neg / num_pos, 1 - (false_pos + false_neg) / (num_pos + num_neg)))
def train(self, xs, ys, epochs): """ Train the model and setup the two amqs. """ # Filter pos/neg examples # TODO: make more efficient (don't necessarily need to compute pos/negs here) positives = [x for x, y in zip(xs, ys) if y] negatives = [x for x, y in zip(xs, ys) if not y] # Setup first filter self.amq1.add_set(positives) # Train the neural net on reported positives of first filter amq1_pos_indices = [ i for i, x in enumerate(xs) if self.amq1.contains(x) ] amq1_pos_xs = [xs[i] for i in amq1_pos_indices] amq1_pos_ys = [ys[i] for i in amq1_pos_indices] self.model.train(amq1_pos_xs, amq1_pos_ys, epochs) # Tune tau self.tau, fpr, fnr = self._choose_tau(amq1_pos_xs, amq1_pos_ys) # Get false negatives from model model_false_negs = [ x for x in amq1_pos_xs if not (self.model(x) > self.tau) ] num_model_false_negs = len(model_false_negs) # Setup second filter if we have false negs if num_model_false_negs > 0 and fnr > 0: # Compute optimal bitarray size ratio for second filter inside = fpr / ((1 - fpr) * (1 / fnr - 1)) m2 = int(0 if inside == 0 else -log2(inside) / log(2)) if m2 == 0: self.amq2 = WordBloom( Bloom.init_ne(num_model_false_negs, self.err)) else: self.amq2 = WordBloom(Bloom.init_nm(num_model_false_negs, m2)) self.amq2.add_set(model_false_negs)