def _compute_tries(self): ordered_city_tiles = self._load_city() dataset = {} with open(self.bssid_sobol_obfuscated_csv) as file_in: reader = csv.reader(file_in) last_bssid = None bssid_locations = None for row in reader: bssid, tile_x, tile_y, zlevel = row bssid = unicode(bssid) tile_key = (int(tile_x), int(tile_y)) tile_id = ordered_city_tiles[tile_key] if bssid != last_bssid: if last_bssid is not None: # push bssid-locations into dataset dataset[last_bssid] = bssid_locations bssid_locations = [] if len(dataset) % 10000 == 0: print "Constructing trie with record: %d" % len(dataset) bssid_locations.append(tile_id) last_bssid = bssid # Copy the last batch into the dataset if bssid_locations: dataset[last_bssid] = bssid_locations bssid_locations = [] print "Constructing trie" trie = RecordTrie(self.fmt, dataset.items()) trie.save(self.output_trie_fname) print "trie saved!"
def load(in_file, mmap=True): title_dict = Trie() redirect_dict = RecordTrie('<I') title_dict.mmap(in_file + '_title.trie') redirect_dict.mmap(in_file + '_redirect.trie') inlink_arr = np.load(in_file + '_prior.npy', mmap_mode='r') return EntityDB(title_dict, redirect_dict, inlink_arr)
def load(in_file, mmap_mode='r'): obj = joblib.load(in_file, mmap_mode=mmap_mode) title_dict = Trie() redirect_dict = RecordTrie('<I') title_dict.frombytes(obj['title_dict']) redirect_dict.frombytes(obj['redirect_dict']) return EntityDB(title_dict, redirect_dict, obj['inlink_arr'])
def __init__(self, path, header="@dd", order=3, unk='<unk>'): """Load language model from file""" io_utils.check_file_readable(path) self.logger = logging.getLogger(__name__) self.order = order self.model = RecordTrie(header) self.model.load(path) self.unk = unk
def __init__(self, path): """Build trie on ARPA n-grams""" io_utils.check_file_readable(path) self.logger = logging.getLogger(__name__) self.logger.info("Load ARPA model from {}".format(path)) self.order = None self.total = {} self.trie = RecordTrie("@dd", self.load_ngram_tuples(path)) self.logger.info( "Loaded a {}-gram LM with {} counts".format(self.order, self.total))
def build(dump_file, pool_size, chunk_size): dump_reader = WikiDumpReader(dump_file) global _extractor _extractor = WikiExtractor() titles = [] redirects = {} title_counter = Counter() with closing(Pool(pool_size)) as pool: for (page, links) in pool.imap_unordered(_process_page, dump_reader, chunksize=chunk_size): titles.append(normalize(page.title)) if page.is_redirect: redirects[normalize(page.title)] = page.redirect for link_obj in links: title_counter[normalize(link_obj.title)] += 1 title_dict = Trie(titles) redirect_items = [] for (title, dest_title) in redirects.items(): if dest_title in title_dict: redirect_items.append((title, (title_dict[dest_title], ))) redirect_dict = RecordTrie('<I', redirect_items) delete_keys = [] keys = list(title_counter.keys()) for key in keys: title = key count = title_counter[key] dest_obj = redirect_dict.get(title) if dest_obj is not None: title_counter[title_dict.restore_key(dest_obj[0][0])] += count del title_counter[title] inlink_arr = np.zeros(len(title_dict), dtype=np.int) for (title, count) in title_counter.items(): title_index = title_dict.get(title) if title_index is not None: inlink_arr[title_index] = count return EntityDB(title_dict, redirect_dict, inlink_arr)
def load(target, device, mmap=True): word_dict = Trie() entity_dict = Trie() redirect_dict = RecordTrie("<I") if not isinstance(target, dict): if mmap: target = joblib.load(target, mmap_mode="r") else: target = joblib.load(target) word_dict.frombytes(target["word_dict"]) entity_dict.frombytes(target["entity_dict"]) redirect_dict.frombytes(target["redirect_dict"]) word_stats = target["word_stats"] entity_stats = target["entity_stats"] if not isinstance(word_stats, np.ndarray): word_stats = np.frombuffer( word_stats, dtype=np.int32, ).reshape(-1, 2) word_stats = torch.tensor( word_stats, device=device, requires_grad=False, ) entity_stats = np.frombuffer( entity_stats, dtype=np.int32, ).reshape(-1, 2) entity_stats = torch.tensor( entity_stats, device=device, requires_grad=False, ) return Wikipedia2VecDict( word_dict, entity_dict, redirect_dict, word_stats, entity_stats, **target["meta"], )
def build(dump_file, pool_size, chunk_size): dump_reader = WikiDumpReader(dump_file) global _extractor _extractor = WikiExtractor() titles = [] redirects = {} title_counter = Counter() with closing(Pool(pool_size)) as pool: for (page, links) in pool.imap_unordered( _process_page, dump_reader, chunksize=chunk_size ): titles.append(page.title) if page.is_redirect: redirects[page.title] = page.redirect for link_obj in links: title_counter[link_obj.title] += 1 title_dict = Trie(titles) redirect_items = [] for (title, dest_title) in redirects.iteritems(): if dest_title in title_dict: redirect_items.append((title, (title_dict[dest_title],))) redirect_dict = RecordTrie('<I', redirect_items) for (title, count) in title_counter.items(): dest_obj = redirect_dict.get(title) if dest_obj is not None: title_counter[title_dict.restore_key(dest_obj[0][0])] += count del title_counter[title] inlink_arr = np.zeros(len(title_dict), dtype=np.int) for (title, count) in title_counter.items(): title_index = title_dict.get(title) if title_index is not None: inlink_arr[title_index] = count return EntityDB(title_dict, redirect_dict, inlink_arr)
def __init__(self, args): self.args = args self.all_titles = self._all_titles_collector() self.redirects = _extract_pages(self.args.path_for_raw_xml) self.nlp = nlp_returner(args=self.args) self.entity_dict = Trie(self.all_titles) self.redirect_dict = RecordTrie( '<I', [(title, (self.entity_dict[dest_title], )) for (title, dest_title) in self.redirects if dest_title in self.entity_dict])
def SCFG(filename): sys.stderr.write("Reading grammar from %s...\n" % (filename,)) keys = [] items = [] #m = 0 for i,line in enumerate(open(filename)): pos, es, en, feats = line.strip().split(" ||| ") #m = max(m, len(es)) feats = map(lambda x: float(x), feats.strip().split()) keys.append(es.decode("utf-8","replace")) items.append((pos, es, en, feats[4], feats[5], feats[0], feats[1], feats[3], feats[2],)) #print m return RecordTrie("=14s79s73sffffff", zip(keys,items)) # hack
def search_in_label(label_obj: DomainLabel, trie: marisa_trie.RecordTrie, special_filter, location_match_queue: mp.Queue) \ -> typing.DefaultDict[LocationCodeType, int]: """returns all matches for this label""" ids = set() type_count = collections.defaultdict(int) location_hint_tuples = [] for o_label in label_obj.sub_labels: label = o_label[:] blacklisted = [] while label: matching_keys = trie.prefixes(label) matching_keys.sort(key=len, reverse=True) for key in matching_keys: if [ black_word for black_word in blacklisted if key in black_word ]: continue if key in special_filter and \ [black_word for black_word in special_filter[key] if black_word in o_label]: continue matching_locations = trie[key] if [ code_type for _, code_type in matching_locations if code_type == -1 ]: blacklisted.append(key) continue for location_id, code_type in matching_locations: real_code_type = LocationCodeType(code_type) if location_id in ids: continue location_hint_tuples.append( (location_id.decode(), key, code_type, label_obj.id)) type_count[real_code_type] += 1 label = label[1:] location_match_queue.put(location_hint_tuples) return type_count
def test_load(): keys = [u'foo', u'bar'] expected_values = [(1, 2, 3), (4, 5, 6)] fmt = ">iii" rtrie = RecordTrie(fmt) rtrie.load('tests/demo.record_trie') for i, k in enumerate(keys): assert [expected_values[i]] == rtrie.get(k) print "Got: %s" % (rtrie.get(k))
all_old_entities = [ent.title for ent in dictionary.entities()] all_old_entities_set = set(all_old_entities) all_new_entities = sorted([ ent for ent in all_needed_entities if not ent in all_old_entities_set ]) joint_entity_stats = np.concatenate([ old_entity_stats, np.array([[5, 5] for _ in all_new_entities]).astype(old_entity_stats.dtype) ]) new_entity_dict = Trie(all_old_entities + all_new_entities) new_redirect_dict = RecordTrie( '<I', [(title, (new_entity_dict[dest_title], )) for (title, dest_title) in dumpdb.redirects() if dest_title in new_entity_dict]) new_dictionary = Dictionary(\ uuid=dictionary.uuid, word_dict = old_word_dict, entity_dict = new_entity_dict, redirect_dict = new_redirect_dict, word_stats = old_word_stats, entity_stats = joint_entity_stats, min_paragraph_len = dictionary.min_paragraph_len, language = dictionary.language, lowercase = dictionary.lowercase, build_params = dictionary.build_params) for entity in all_needed_entities_raw:
class LanguageModel: """ Use the trie-based n-gram language model (code inspired from pynlpl's library) Focus: - fast load the n-gram trie from binary file - compute the log-probabilities of n-grams - reorder a sequence of n-grams by their log-probabilities """ def __init__(self, path, header="@dd", order=3, unk='<unk>'): """Load language model from file""" io_utils.check_file_readable(path) self.logger = logging.getLogger(__name__) self.order = order self.model = RecordTrie(header) self.model.load(path) self.unk = unk def _prob(self, ngram): """Return probability of given ngram tuple""" return self.model[ngram][0][0] def _backoff(self, ngram): """Return backoff value of a given ngram tuple""" return self.model[ngram][0][1] def has_word(self, word): """Check if given word is known by the model""" if word in self.model: return True return False def score_word(self, word): """Get the unigram log-probability of given word""" try: return self._prob(word) except KeyError: try: return self._prob(self.unk) except KeyError: raise KeyError( "Word {} not found (model has no UNK token)".format(word)) def _score(self, word, history=None): """Get the n-gram's log probability""" if not history: return self.score_word(word) # constrain sequence length up to 'order' words lookup = history + (word, ) if len(lookup) > self.order: lookup = lookup[-self.order:] try: return self._prob(' '.join(lookup)) except KeyError: # not found, back off try: backoffweight = self._backoff(' '.join(history)) except KeyError: backoffweight = 0 return backoffweight + self._score(word, history[1:]) def score_sequence(self, data): """ Compute the log-probability of a given word sequence When using a 3-gram language model score('manger une pomme') = prob('manger') + prob('manger une') + prob('manger une pomme') If the 3-gram 'manger une pomme' is not known by the model score('manger une pomme') = prob('manger') + prob('manger une') + backoff('manger une') + prob('une pomme') """ if isinstance(data, str): data = tuple(data.split()) if len(data) == 1: return self.score_word(data[0]) result = 0 history = () for word in data: result += self._score(word, history) history += (word, ) return result def score_sentence(self, data, bos='<s>', eos='</s>'): """Compute the sum of log-probabilities for given sentence""" data = bos + ' ' + data + ' ' + eos return self.score_sequence(data) def order_sequences(self, sequences): """Order the sequences by their n-gram log probabilities""" scores = {seq: self.score_sequence(seq) for seq in sequences} return sorted(scores.keys(), key=lambda k: scores[k], reverse=True) def __getitem__(self, ngram): """Allow easy access to n-gram's log probability""" return self.score_sequence(ngram)
def __init__(self): vals = [] with open('./en_wordlist.combined') as infile: items = l.split(',') # print(items) w = items[0].split('=')[1] if items[0].startswith(' '): print(l) continue f = int(items[1].split('=')[1]) vals.append(w) # self.dict = [word.strip() for word in infile.readlines()] self.dict = vals self.vectors = {} self.mapping = {} self.score_trie = RecordTrie(fmt='<H') with open('trie.marisa') as ftrie: self.score_trie.read(ftrie) print(self.get_vec_mapping(layout)) self.get_vectors_from_dict() # print(self.vectors) allvecs = [arr for arr in self.vectors.itervalues()] allvecsarr = np.array(allvecs) self.kdtree = KDTree(data=allvecsarr) print("#####\n\n\n") print(self.vectors['test']) print("####") pprint(self.match('test')) pprint(self.score_and_sort_matches(self.match('teat')[1], 'nice')) pprint(self.match('tedt')[1]) pprint(self.match('perspecitev')[1]) pprint(self.match('angle')[1]) pprint(self.match('angel')[1]) print("#####\n\n\n") print(self.vectors['news']) print("####") pprint(self.match('newa')[1]) pprint(self.match('newr')[1]) pprint(self.match('newst')[1]) ms = self.match('newst')[1] v = self.get_vector_for_word('newst') pprint(sorted(ms, key=lambda x: abs(x[1][2] - v[2]))) pprint(self.match('obascure')[1]) pprint(self.match('obscure')[1]) pprint(self.match('relativetiy')[1]) pprint(self.match('absolutealy')[1]) pprint(self.match('porcealitn')[1]) print(allvecsarr) plt.scatter(allvecsarr[:, 0], allvecsarr[:, 1]) plt.show()
def __load_rules(cls, filename, lm, config): ''' Load rule table from filename Args: filename: the name of the file that stores the rules lm: language model config: an instance of Config Return: a RuleTable ''' feature_num = config.get_feature_num() glue_rule_index = config.order['glue-rule-count'] max_rule_num = config.rule_beamsize table = RuleTable() keys = [] ranges = [] idx = 0 # glue rules # S -> X # type 1 glue rule is not counted features = [0] * feature_num glue_rule1 = Rule('|0', ['|0'], [0], features, cls.GLUE_RULE1_GLOBAL_ID) table._rules.append(glue_rule1) keys.append(cls.GLUE_RULE1.decode('utf-8')) ranges.append((idx, idx)) idx += 1 # S -> S X features = [0] * feature_num features[glue_rule_index] = 1 glue_rule2 = Rule('|0 |1', ['|0', '|1'], [0, 1], features, cls.GLUE_RULE2_GLOBAL_ID) glue_rule2.score = config.weights[config.order['glue-rule-count']] table._rules.append(glue_rule2) idx += 1 if config.enable_type3_glue_rule: # S -> <S X; X S> features = [0] * feature_num features[glue_rule_index] = 1 glue_rule3 = Rule('|0 |1', ['|1', '|0'], [1, 0], features, cls.GLUE_RULE3_GLOBAL_ID) glue_rule3.score = config.weights[config.order['glue-rule-count']] table._rules.append(glue_rule3) idx += 1 keys.append(cls.GLUE_RULE2.decode('utf-8')) ranges.append((1, idx - 1)) table.glue_rule_ids = tuple(i for i in range(idx)) # normal rules with Reader(filename) as reader: last_src = None current_rules = [] for rule_str in reader: parts = rule_str.strip().split(' ||| ') src = parts[0] tgt = parts[1].split(' ') nonterminal_pos = [] for tword, pos in zip(tgt, range(len(tgt))): if tword[0] == '|': if len(nonterminal_pos) == 0: nonterminal_pos.append(pos) else: index = int(tword[1:]) nonterminal_pos.insert(index, pos) features = [float(f) for f in parts[2].split(' ')] features.append(len(tgt) - len(nonterminal_pos)) # word number features.append(1) # rule count features.append(0) # glue rule count if len(parts) >= 4: global_rule_id = int(parts[3]) rule = Rule(src, tgt, nonterminal_pos, features, global_rule_id) else: rule = Rule(src, tgt, nonterminal_pos, features, idx) lmscore, hlmscore = cls.__get_lm_scores(rule, lm) features.append(lmscore) # lm score rule.hlmscore = hlmscore if last_src == None or src == last_src: current_rules.append(rule) last_src = src else: cls.__update_table(table, keys, ranges, last_src, current_rules, config, max_rule_num) current_rules = [rule] last_src = src idx += 1 cls.__update_table(table, keys, ranges, last_src, current_rules, config, max_rule_num) table._idranges = RecordTrie('<II', zip(keys, ranges)) del keys del ranges gc.collect() return table
class Correct: def get_vec_mapping(self, layout): max_line_length = max([len(line) for line in layout.split()]) max_height = len(layout.split()) mapping = {} y_pos = max_height # start at top x_increment = 2 for line in layout.split(): # x_pos = -max_line_length / 2 # start left x_pos = 0 for letter in line.strip(): mapping[letter] = np.array([x_pos, y_pos]) if letter == '.': x_pos += 1 else: x_pos += x_increment y_pos -= 1 print(mapping) self.mapping = mapping def get_vector_for_word(self, word): # mlist = [1, 3, 5, 7, 11, 13, 17, 19, 23, 27, 29, 31, 37, 41, 43, 47, 49, 53] multiplier = 1 vector = np.zeros(3) prev_letter = None prevprev_letter = None basefct = math.sin def get_base_vec(idx, letter): letterindx = ord(letter) - ord('a') i1 = (idx + letterindx - 1) / 10.0 i2 = (idx + letterindx) / 10.0 cos = math.cos sin = math.sin return np.array([cos(i2) - cos(i1), sin(i2) - sin(i1)]) for idx, letter in enumerate(word): try: if idx == 0: # multiplier = idx sep = 100 else: # multiplier = math.e ** (float(idx) / 8.0) # multiplier = math.e ** (float(idx) / 8.0) # multiplier = 1 # sep = math.e ** (float(idx) / 8.0) sep = 1 # z = np.linalg.norm(self.mapping[letter]) * (idx + 1) if prevprev_letter: # z = np.abs(np.arctan2(self.mapping[letter][1] - self.mapping[prev_letter][1], # self.mapping[letter][0] - self.mapping[prev_letter][0])) # z = np.dot( self.mapping[prev_letter], self.mapping[letter]) v1 = self.mapping[prev_letter] - self.mapping[prevprev_letter] v2 = self.mapping[prev_letter] - self.mapping[letter] z = np.arccos(np.dot(v1, v2)/ (np.linalg.norm(v1) * np.linalg.norm(v2))) if np.isnan(z): z = 0 # z += math.pi else: z = 0.0 vector += np.hstack((self.mapping[letter] * sep + get_base_vec(idx, letter) * 3, z)) # multiplier = 1 prevprev_letter = prev_letter prev_letter = letter except KeyError as e: print("%s doesn't exist in mapping :(" % letter) return vector def get_vectors_from_dict(self): self.vectors = {} for word in self.dict: try: self.vectors[word] = self.get_vector_for_word(word) except KeyError as e: print("%s can't be used as key :(" % word) # multiplier = 1 # for letter in word: # try: # self.vectors[word] += self.mapping[letter] * multiplier # multiplier += 0.33 # except KeyError as e: # print("%s doesn't exist in mapping :(" % letter); def match(self, word): matchvec = self.get_vector_for_word(word) # - np.array([-1, 0]) print("\n\nTesting %s\nVector: %r" % (word, matchvec)) match_dists, match_idxs = self.kdtree.query(matchvec, k=50) results = [] for idx in match_idxs: results.append(self.vectors.items()[idx]) return match_idxs, results def score(self, word, context=None): # score word # context is previous words in order! wscore = self.score_trie.get(word) if not wscore: wscore = 0 gram2score = 1 gram3score = 1 if context: context = context.split() if len(context) > 1: gram2score = self.score_trie.get(word + " " + context[-1]) if len(context) > 2: gram3score = self.score_trie.get(word + " " + context[-1] + " " + context[-2]) if wscore: wscore = wscore[0] if type(gram2score) == tuple: wscore = gram2score[0] if type(gram3score) == tuple: wscore = gram3score[0] return wscore * gram2score * gram3score def score_and_sort_matches(self, matches, context): scores = [] for m in matches: scores.append(self.score(m[0], context)) res = zip(matches, scores) return sorted(res, key=lambda x: x[1]) def __init__(self): vals = [] with open('./en_wordlist.combined') as infile: items = l.split(',') # print(items) w = items[0].split('=')[1] if items[0].startswith(' '): print(l) continue f = int(items[1].split('=')[1]) vals.append(w) # self.dict = [word.strip() for word in infile.readlines()] self.dict = vals self.vectors = {} self.mapping = {} self.score_trie = RecordTrie(fmt='<H') with open('trie.marisa') as ftrie: self.score_trie.read(ftrie) print(self.get_vec_mapping(layout)) self.get_vectors_from_dict() # print(self.vectors) allvecs = [arr for arr in self.vectors.itervalues()] allvecsarr = np.array(allvecs) self.kdtree = KDTree(data=allvecsarr) print("#####\n\n\n") print(self.vectors['test']) print("####") pprint(self.match('test')) pprint(self.score_and_sort_matches(self.match('teat')[1], 'nice')) pprint(self.match('tedt')[1]) pprint(self.match('perspecitev')[1]) pprint(self.match('angle')[1]) pprint(self.match('angel')[1]) print("#####\n\n\n") print(self.vectors['news']) print("####") pprint(self.match('newa')[1]) pprint(self.match('newr')[1]) pprint(self.match('newst')[1]) ms = self.match('newst')[1] v = self.get_vector_for_word('newst') pprint(sorted(ms, key=lambda x: abs(x[1][2] - v[2]))) pprint(self.match('obascure')[1]) pprint(self.match('obscure')[1]) pprint(self.match('relativetiy')[1]) pprint(self.match('absolutealy')[1]) pprint(self.match('porcealitn')[1]) print(allvecsarr) plt.scatter(allvecsarr[:, 0], allvecsarr[:, 1]) plt.show()
class ArpaLanguageModel: """ Change format of ARPA language model (code inspired from pynlpl's library) Focus: - load pre-computed back-off language model from file in ARPA format - build a trie on (ngram, (logprob, backoff)) entries (use marisa_trie) - store the trie model in binary file (usefull for faster reload) Note: - might require a lot of memory depending on the size of the language model - use only once """ def __init__(self, path): """Build trie on ARPA n-grams""" io_utils.check_file_readable(path) self.logger = logging.getLogger(__name__) self.logger.info("Load ARPA model from {}".format(path)) self.order = None self.total = {} self.trie = RecordTrie("@dd", self.load_ngram_tuples(path)) self.logger.info( "Loaded a {}-gram LM with {} counts".format(self.order, self.total)) def load_ngram_tuples(self, path): """ Process ARPA language model. Yield (ngram, (logprob, backoff)) entries. """ order = None self.total = {} with open(path, 'rt', encoding='utf-8') as istream: for line in istream: line = line.strip() if line: if line == '\\data\\': order = 0 elif line == '\\end\\': break elif line.startswith('\\') and line.endswith('-grams:'): order = int(re.findall(r"\d+", line)[0]) self.logger.info("Processing {}-grams".format(order)) elif order == 0 and line.startswith('ngram'): n = int(line[6]) count = int(line[8:]) self.total[n] = count elif order > 0: fields = line.split('\t') ngram = ' '.join(fields[1].split()) logprob = float(fields[0]) # handle absent/present backoff backoffprob = 0.0 if len(fields) > 2: backoffprob = float(fields[2]) # handle the prob(<s>) = -99 case if ngram == '<s>' and logprob == -99: logprob = 0.0 yield (ngram, (logprob, backoffprob)) self.order = order def save_trie(self, output): """Store trie-based n-grams under binary format""" self.trie.save(output)
def build(dump_db, tokenizer, out_file, min_link_prob, min_prior_prob, min_link_count, max_mention_length, pool_size, chunk_size): name_dict = defaultdict(Counter) logger.info('Iteration 1/2: Extracting all entity names...') with tqdm(total=dump_db.page_size(), mininterval=0.5) as pbar: initargs = (dump_db, tokenizer, max_mention_length) with closing(Pool(pool_size, initializer=EntityLinker._initialize_worker, initargs=initargs)) as pool: for ret in pool.imap_unordered(EntityLinker._extract_name_entity_pairs, dump_db.titles(), chunksize=chunk_size): for text, title in ret: name_dict[text][title] += 1 pbar.update() name_counter = Counter() disambi_matcher = re.compile(r'\s\(.*\)$') for title in dump_db.titles(): text = normalize_text(disambi_matcher.sub('', title)) name_dict[text][title] += 1 name_counter[text] += 1 for src, dest in dump_db.redirects(): text = normalize_text(disambi_matcher.sub('', src)) name_dict[text][dest] += 1 name_counter[text] += 1 logger.info('Iteration 2/2: Counting occurrences of entity names...') with tqdm(total=dump_db.page_size(), mininterval=0.5) as pbar: initargs = (dump_db, tokenizer, max_mention_length, Trie(name_dict.keys())) with closing(Pool(pool_size, initializer=EntityLinker._initialize_worker, initargs=initargs)) as pool: for names in pool.imap_unordered(EntityLinker._extract_name_occurrences, dump_db.titles(), chunksize=chunk_size): name_counter.update(names) pbar.update() logger.info('Step 4/4: Building DB...') titles = frozenset([title for entity_counter in name_dict.values() for title in entity_counter.keys()]) title_trie = Trie(titles) def item_generator(): for name, entity_counter in name_dict.items(): doc_count = name_counter[name] total_link_count = sum(entity_counter.values()) if doc_count == 0: continue link_prob = total_link_count / doc_count if link_prob < min_link_prob: continue for title, link_count in entity_counter.items(): if link_count < min_link_count: continue prior_prob = link_count / total_link_count if prior_prob < min_prior_prob: continue yield name, (title_trie[title], link_count, total_link_count, doc_count) data_trie = RecordTrie('<IIII', item_generator()) mention_trie = Trie(data_trie.keys()) joblib.dump(dict(title_trie=title_trie, mention_trie=mention_trie, data_trie=data_trie, tokenizer=tokenizer, max_mention_length=max_mention_length), out_file)