def test_init2(self): trie = Trie() ids = [trie.insert(w) for w in [u'Ruby', u'ruby', u'rb', u'XX']] self.assertEqual([(k, v) for k, v in trie.items()], [(u"Ruby", 0), (u"ruby", 1), (u"rb", 2), (u"XX", 3)])
def test_ignore_case_predict(self): if sys.version_info.major < 3: return trie = Trie(ignore_case=True, ordered=True) ids = {w: trie.insert(w) for w in [u"aaİ", u"aİİ", u"aai̇", u"aai̇bİ"]} predicts = list(trie.predict(u"aaİ")) self.assertEqual(predicts, [ids[u'aai̇'], ids[u'aai̇bİ']])
def test_init(self): trie = Trie() ids = [trie.insert(w) for w in [u'Ruby', u'ruby', u'rb']] self.assertEqual(ids, [0, 1, 2]) self.assertEqual(u"ruby" in trie, True) self.assertEqual(u"rubyx" in trie, False) self.assertEqual(trie.remove(u"ruby"), 1) self.assertEqual(trie.remove(u"ruby"), -1)
def test_get_element(self): trie = Trie() ids = { w: trie.insert(w) for w in [u"ruby", u"rubx", u"rab", u"rub", u"rb"] } for w, id_ in ids.items(): self.assertEqual(trie[id_], w) self.assertEqual(trie[w], id_)
def test_predict(self): trie = Trie(ordered=True) ids = { w: trie.insert(w) for w in [u"ruby", u"rubx", u"rab", u"rub", u"rb"] } predicts = list(trie.predict(u"r")) self.assertEqual( predicts, [ids[u"rb"], ids[u"rab"], ids[u"rub"], ids[u"rubx"], ids[u"ruby"]])
def test_ignore_case_prefix(self): if sys.version_info.major < 3: return txt = u"aaİbİc" trie = Trie(ignore_case=True) ids = {w: trie.insert(w) for w in [u"aİ", u"aİİ", u"aai̇", u"aai̇bİ"]} prefixes = list(trie.prefix(txt)) self.assertEqual(prefixes, [(ids[u"aai̇"], 3), (ids[u"aai̇bİ"], 5)]) txt = u"aai̇bİc" prefixes = list(trie.prefix(txt)) self.assertEqual(prefixes, [(ids[u"aai̇"], 4), (ids[u"aai̇bİ"], 6)])
def test_ignore_case_match_longest(self): if sys.version_info.major < 3: return trie = Trie(ignore_case=True) ids = {w: trie.insert(w) for w in [u"aİİ", u"aai̇", u"aai̇bİ"]} matches = list(trie.match_longest(u"aaİ aai̇bİaa")) self.assertEqual(matches, [(ids[u"aai̇"], 0, len(u"aaİ")), (ids[u"aai̇bİ"], len(u"aaİ "), len(u"aaİ aai̇bİ"))]) sep = set([ord(" ")]) # space as seperator matches = list(trie.match_longest(u"aaİ aai̇bİaa", sep)) self.assertEqual(matches, [ (ids[u"aai̇"], 0, len(u"aaİ")), ])
def build(self, pages: Iterable[str]): key2titles = {} for page in pages: if not page: continue key = _clean_title(page).lower() if not key: key = page titles = key2titles.setdefault(key, []) titles.append(page) mapping = {} self._trie = Trie(ignore_case=True) for key in key2titles: id_ = self._trie.insert(key) mapping.setdefault(id_, tuple(key2titles[key])) self._map = tuple([mapping.get(i) for i in range(max(mapping) + 1)])
def test_remove_words(self): dir_ = os.path.dirname(__file__) trie = Trie() for i in range(3): ids = [] words = [] with open(os.path.join(dir_, "../bench/words.txt")) as fi: for l in fi: l = l.strip() if isinstance(l, bytes): l = l.decode("utf8") if len(l) > 0: words.append(l) ids.append(trie.insert(l)) for id_, w in zip(ids, words): self.assertEqual(id_, trie.remove(w))
def test_reuse_id(self): trie = Trie() ids = {w: trie.insert(w) for w in [u"abc", u"abd", u"abe"]} trie.remove(u"abc") trie.remove(u"abe") v = trie.insert(u"abf") self.assertEqual(v, ids[u"abe"]) v = trie.insert(u"abg") self.assertEqual(v, ids[u"abc"]) v = trie.insert(u"abh") self.assertEqual(v, 3) v = trie.insert(u"abi") self.assertEqual(v, 4)
def test_ignore_case_replace_longest(self): if sys.version_info.major < 3: return trie = Trie(ignore_case=True) ids = {w: trie.insert(w) for w in [u"aİİ", u"aai̇", u"aai̇bİ"]} replaced = { ids[u"aİİ"]: u"a", ids[u"aai̇"]: u"b", ids[u"aai̇bİ"]: u"c", } res = trie.replace_longest(u"aaİ aai̇bİaa", lambda x, start, end: replaced[x]) self.assertEqual(res, u"b caa") sep = set([ord(" ")]) # space as seperator res = trie.replace_longest(u"aaİ aai̇bİaa", lambda x, start, end: replaced[x], sep) self.assertEqual(res, u"b aai̇bİaa")
def test_match_longest(self): trie = Trie() ids = { w: trie.insert(w) for w in [u"New York", u"New", u"York", u"York City", u"City", u"City is"] } matches = list(trie.match_longest(u"New York City isA")) self.assertEqual( matches, [(ids[u"New York"], 0, len(u"New York")), (ids[u"City is"], len(u"New York "), len(u"New York City is"))]) sep = set([ord(" ")]) # space as seperator matches = list(trie.match_longest(u"New York City isA", sep)) self.assertEqual( matches, [(ids[u"New York"], 0, len(u"New York")), (ids[u"City"], len(u"New York "), len(u"New York City"))])
def test_pickle_trie(self): trie = Trie(ignore_case=True) ids = {w: trie.insert(w) for w in [u"aİİ", u"aai̇", u"aai̇bİ"]} with open("trie.pkl", "wb") as fo: pickle.dump(trie, fo) with open("trie.pkl", "rb") as fi: trie = pickle.load(fi) replaced = { ids[u"aİİ"]: u"a", ids[u"aai̇"]: u"b", ids[u"aai̇bİ"]: u"c", } res = trie.replace_longest(u"aaİ aai̇bİaa", lambda x, start, end: replaced[x]) self.assertEqual(res, u"b caa") sep = set([ord(" ")]) # space as seperator res = trie.replace_longest(u"aaİ aai̇bİaa", lambda x, start, end: replaced[x], sep) self.assertEqual(res, u"b aai̇bİaa")
def test_replace_words(self): dir_ = os.path.dirname(__file__) trie = Trie() ids = [] with open(os.path.join(dir_, "../bench/words.txt")) as fi: for l in fi: l = l.strip() if isinstance(l, bytes): l = l.decode("utf8") if len(l) > 0: ids.append(trie.insert(l)) with open(os.path.join(dir_, "../bench/words.txt")) as fi: txt = fi.read() if isinstance(txt, bytes): txt = txt.decode("utf8") sep = set([ord("\n")]) ret = trie.replace_longest(txt, lambda v, start, end: str(v), sep).strip() self.assertEqual(ret, "\n".join([str(i) for i in ids]))
def test_match_words(self): dir_ = os.path.dirname(__file__) trie = Trie() ids = [] with open(os.path.join(dir_, "../bench/words.txt")) as fi: for l in fi: l = l.strip() if isinstance(l, bytes): l = l.decode("utf8") if len(l) > 0: ids.append(trie.insert(l)) with open(os.path.join(dir_, "../bench/words.txt")) as fi: txt = fi.read() if isinstance(txt, bytes): txt = txt.decode("utf8") sep = set([ord("\n")]) matched = [] for v, start, end in trie.match_longest(txt, sep): matched.append(v) self.assertEqual(txt[start:end], trie[v]) self.assertEqual(matched, ids)
def _get_category_links(cat2id, id2page, **kwargs): trie = Trie() for page in id2page.values(): trie.insert(page) category_links = [] for cl_type, source_id, target_title in dt.iter_categorylinks_dump_data( **kwargs ): if ( # only categories with a page target_title not in trie # only allowed pages or cl_type == dt.WIKI_CL_TYPE_PAGE and source_id not in id2page ): continue try: target_id = cat2id[f"Category:{target_title}"] except KeyError: continue category_links.append((source_id, target_id)) return category_links
def test_replace_longest(self): trie = Trie() ids = { w: trie.insert(w) for w in [u"New York", u"New", u"York", u"York City", u"City", u"City is"] } replaced = { ids[u"New York"]: u"Beijing", ids[u"New"]: u"Old", ids[u"York"]: u"Yark", ids[u"York City"]: u"Yerk Town", ids[u"City"]: u"Country", ids[u"City is"]: u"Province are" } res = trie.replace_longest(u"New York City isA", lambda x, start, end: replaced[x]) self.assertEqual(res, u"Beijing Province areA") sep = set([ord(" ")]) # space as seperator res = trie.replace_longest(u"New York City isA", lambda x, start, end: replaced[x], sep) self.assertEqual(res, u"Beijing Country isA")
def test_prefix(self): trie = Trie() ids = {w: trie.insert(w) for w in [u"ruby", u"rubx", u"rab", u"rub"]} prefixes = list(trie.prefix(u"ruby on rails")) self.assertEqual(prefixes, [(ids[u"rub"], 3), (ids[u"ruby"], 4)])
def test_insert_zero_len_key(self): trie = Trie() self.assertEqual(trie.insert(u""), -1)
def load(path: Path): wpd = WikiPageDetector() wpd._map = pickle_load(path / "wpd_map.gz") with (path / "wpd_trie").open("r+b") as bf: wpd._trie = Trie.from_buff(mmap(bf.fileno(), 0), copy=False) return wpd
class WikiPageDetector: def __init__(self, pages: Iterable[str] = None): self._map = None self._trie = None if pages is not None: self.build(pages) @staticmethod def load(path: Path): wpd = WikiPageDetector() wpd._map = pickle_load(path / "wpd_map.gz") with (path / "wpd_trie").open("r+b") as bf: wpd._trie = Trie.from_buff(mmap(bf.fileno(), 0), copy=False) return wpd def dump(self, path: Path): self._trie.save(str(path / "wpd_trie")) pickle_dump(self._map, path / "wpd_map.gz", compress=True) def build(self, pages: Iterable[str]): key2titles = {} for page in pages: if not page: continue key = _clean_title(page).lower() if not key: key = page titles = key2titles.setdefault(key, []) titles.append(page) mapping = {} self._trie = Trie(ignore_case=True) for key in key2titles: id_ = self._trie.insert(key) mapping.setdefault(id_, tuple(key2titles[key])) self._map = tuple([mapping.get(i) for i in range(max(mapping) + 1)]) def find_pages(self, text: str): def iter_matches(source): ac_seps = set([ord(p) for p in _XP_SEPS.findall(source)]) for id_, start_idx, end_idx in self._trie.match_longest( source, ac_seps ): yield (start_idx, end_idx, self._map[id_]) for match in iter_matches(text): yield match match_text = text[match[0] : match[1]] seps = list(_XP_SEPS.finditer(match_text)) if len(seps) < 1: continue tokens = [] last_end = 0 for sep in seps: token = match_text[last_end : sep.start()] start = last_end last_end = sep.end() if len(token) < 2 and not token.isalnum(): continue tokens.append((start, token)) tokens.append((last_end, match_text[last_end:])) num_tokens = len(tokens) for s, e in combinations(range(num_tokens + 1), 2): if s == 0 and e == num_tokens: continue e -= 1 submatches = set() start = tokens[s][0] end = tokens[e][0] + len(tokens[e][1]) subtext = match_text[start:end] start += match[0] for sidx, eidx, pages in iter_matches(subtext): coords = (sidx + start, eidx + start) if coords in submatches: continue submatches.add(coords) yield (*coords, pages)
def test_ignore_case(self): trie = Trie(ignore_case=True) ids = [trie.insert(w) for w in [u'Ruby', u'ruby', u'rb']] self.assertEqual(ids, [0, 0, 1]) self.assertEqual(trie.remove(u"ruby"), 0) self.assertEqual(trie.remove(u"Ruby"), -1)
def test_buff_ac(self): trie = Trie(ignore_case=True) ids = {w : trie.insert(w) for w in [u"aİİ", u"aai̇", u"aai̇bİ"]} trie.save("trie.bin") with open("trie.bin", "rb") as fi: bs = bytearray(fi.read()) self.assertEqual(len(bs), trie.buff_size()) bs2 = bytearray(trie.buff_size()) trie.to_buff(bs2) self.assertEqual(bs2, bs) self._check_trie_correct(Trie.from_buff(bs2, copy=True), ids) self._check_trie_correct(Trie.from_buff(bs2, copy=False), ids)
def init_trie(words, size): trie = Trie() for i in range(size): trie.insert(words[i]) return trie