def __init__(self, records): self.records = records self.index = {} self.edits_lev = BKTree(levenshtein) self.edits_3 = BKTree(hamming) self.word_trie = Trie() self.word_set = set() self._index(records)
def test_trie_lookup(words: List[str]): trie = Trie() for word in words: trie.insert(word) for word in words: assert trie.search(word)
class Index: def __init__(self, records): self.records = records self.index = {} self.edits_lev = BKTree(levenshtein) self.edits_3 = BKTree(hamming) self.word_trie = Trie() self.word_set = set() self._index(records) def tokenize(self, record): #return re.split("\W+", record) return [w for w in re.split("[-_/.?+&:\W]+|(\d+)", record) if w] def filter(self, token): return token.lower() def add_token_offsets(self, record, tokens): terms = [] char_i = 0 for tok_i, token in enumerate(tokens): assert char_i < len(record) while not record[char_i:].startswith(token): char_i += 1 token = self.filter(token) terms.append( (token, RecordPosition(char_i, tok_i)) ) char_i += len(token) return terms def search(self, query, topn=10, fuzzy=False): if fuzzy: candidates = self._find_phrase_fuzzy(query) else: candidates = self._find_phrase(query) candidates.sort(key=lambda c: (c.edit_distance, c.min_dist)) candidates = candidates[:topn] for c in candidates: record = self.records[c.doc_id] highlights = self._merge_highlights(c.highlights) highlighted_record = self._highlight_record(record, highlights) yield (c.edit_distance, c.min_dist, highlighted_record) def _group_occurrences(self, occurrences): d = collections.defaultdict(list) for word, record_position in occurrences: d[word].append(record_position) return d def _index(self, records): for doc_id, record in enumerate(records): tokens = self.tokenize(record) occurrences = self.add_token_offsets(record, tokens) for token, record_positions in self._group_occurrences(occurrences).items(): for i in range(1, len(token)): prefix = token[:i + 1] if not self.word_trie.is_prefix(prefix): self.edits_lev.insert(prefix) if len(prefix) == 3: self.edits_3.insert(prefix) if token not in self.index: self.word_trie.insert(token) self.index[token] = [] self.index[token].append((doc_id, record_positions)) def _find_derived_words(self, word, is_prefix): use_levenshtein = True if is_prefix: if len(word) <= 2: return ((0, w) for w in self.word_trie.descendants_or_self(word)) if len(word) == 3: derived_words = self.edits_3.find(word, 1) use_levenshtein = False if use_levenshtein: if len(word) <= 4: d = 1 elif len(word) <= 7: d = 2 else: d = 3 derived_words = self.edits_lev.find(word, d) # TODO check case: d == 1 if is_prefix: min_dists = {} for d, w in derived_words: descs = self.word_trie.descendants_or_self(w) for desc in descs: if desc not in min_dists: min_dists[desc] = d else: min_dists[desc] = min(min_dists[desc], d) derived_words = ((d, desc) for desc, d in min_dists.items()) return derived_words def _find_one(self, word, prefix, edit_distance): docs_found = self.index.get(word, []) result = [] for doc_id, record_positions in docs_found: highlights = [ (rp.char_position, rp.char_position + len(prefix)) for rp in record_positions ] cnd = Candidate(doc_id, edit_distance, record_positions, highlights) result.append(cnd) return result def _find_one_fuzzy(self, word, is_prefix=False): derived_words = self._find_derived_words(word, is_prefix) result = [] for d, w in derived_words: candidates = self._find_one(w, word, d) result.extend(candidates) d = collections.defaultdict(list) for cnd in result: d[cnd.doc_id].append(cnd) # group candidates by doc_id result = [] for doc_id, cnds in d.items(): edit_distance = min(c.edit_distance for c in cnds) last_occurrences = sum([c.last_occurrences for c in cnds], []) # TODO remove duplicate occurrences? are there any? highlights = sum([c.highlights for c in cnds], []) c = Candidate(doc_id, edit_distance, last_occurrences, highlights) result.append(c) result.sort(key=lambda cnd: cnd.doc_id) return result def _find_phrase(self, query): tokens = self.tokenize(query) if not tokens: return [] candidates = self._find_one(tokens[0], tokens[0], 0) for token in tokens[1:]: new_candidates = self._find_one(token, token, 0) candidates = self._merge(candidates, new_candidates) return candidates def _find_phrase_fuzzy(self, query): tokens = self.tokenize(query) if not tokens: return [] is_last_prefix = query[-1] != " " candidates = self._find_one_fuzzy(tokens[0], is_prefix=is_last_prefix and len(tokens) == 1) for i, token in enumerate(tokens[1:]): new_candidates = self._find_one_fuzzy(token, is_prefix=is_last_prefix and i == len(tokens) - 2) candidates = self._merge(candidates, new_candidates) return candidates def _merge(self, xs, ys): cs = [] ix = iy = 0 while ix < len(xs) and iy < len(ys): cndx, cndy = xs[ix], ys[iy] if cndx.doc_id < cndy.doc_id: ix += 1 elif cndx.doc_id > cndy.doc_id: iy += 1 else: xpositions = cndx.last_occurrences ypositions = cndy.last_occurrences edit_distance = cndx.edit_distance + cndy.edit_distance c = Candidate(cndx.doc_id, edit_distance, ypositions, cndx.highlights + cndy.highlights) c.min_dist = cndx.min_dist + min_dist(xpositions, ypositions) cs.append(c) ix, iy = ix + 1, iy + 1 return cs def _merge_highlights(self, highlights): highlights = sorted(highlights) result = [] hlstart, hlend = 0, 0 for start, end in highlights: if start <= hlend: hlend = end else: if hlend > hlstart: result.append( (hlstart, hlend) ) hlstart, hlend = start, end # also include the last one result.append( (hlstart, hlend) ) return result def _highlight_record(self, record, highlights): normal = "\033[m" bold = "\033[1m" yellow_back = "\033[103m" normal_back = "\033[49m" result = [] last = 0 for start, end in highlights: result.append(record[last:start]) result.append(yellow_back + record[start:end] + normal_back) last = end result.append(record[end:]) return "".join(result) def _highlight_candidates(self, candidates): r = [] for candidate in candidates: #record = self.records[candidate.doc_id] highlights = self._merge_highlights(candidate.highlights) highlighted_record = self._highlight_record(record, highlights) r.append( (candidate.edit_distance, candidate.min_dist, highlighted_record) ) return r