示例#1
0
 def test_init2(self):
     trie = Trie()
     ids = [trie.insert(w) for w in [u'Ruby', u'ruby', u'rb', u'XX']]
     self.assertEqual([(k, v) for k, v in trie.items()], [(u"Ruby", 0),
                                                          (u"ruby", 1),
                                                          (u"rb", 2),
                                                          (u"XX", 3)])
示例#2
0
 def test_ignore_case_predict(self):
     if sys.version_info.major < 3:
         return
     trie = Trie(ignore_case=True, ordered=True)
     ids = {w: trie.insert(w) for w in [u"aaİ", u"aİİ", u"aai̇", u"aai̇bİ"]}
     predicts = list(trie.predict(u"aaİ"))
     self.assertEqual(predicts, [ids[u'aai̇'], ids[u'aai̇bİ']])
示例#3
0
 def test_init(self):
     trie = Trie()
     ids = [trie.insert(w) for w in [u'Ruby', u'ruby', u'rb']]
     self.assertEqual(ids, [0, 1, 2])
     self.assertEqual(u"ruby" in trie, True)
     self.assertEqual(u"rubyx" in trie, False)
     self.assertEqual(trie.remove(u"ruby"), 1)
     self.assertEqual(trie.remove(u"ruby"), -1)
示例#4
0
 def test_get_element(self):
     trie = Trie()
     ids = {
         w: trie.insert(w)
         for w in [u"ruby", u"rubx", u"rab", u"rub", u"rb"]
     }
     for w, id_ in ids.items():
         self.assertEqual(trie[id_], w)
         self.assertEqual(trie[w], id_)
示例#5
0
 def test_predict(self):
     trie = Trie(ordered=True)
     ids = {
         w: trie.insert(w)
         for w in [u"ruby", u"rubx", u"rab", u"rub", u"rb"]
     }
     predicts = list(trie.predict(u"r"))
     self.assertEqual(
         predicts,
         [ids[u"rb"], ids[u"rab"], ids[u"rub"], ids[u"rubx"], ids[u"ruby"]])
示例#6
0
 def test_ignore_case_prefix(self):
     if sys.version_info.major < 3:
         return
     txt = u"aaİbİc"
     trie = Trie(ignore_case=True)
     ids = {w: trie.insert(w) for w in [u"aİ", u"aİİ", u"aai̇", u"aai̇bİ"]}
     prefixes = list(trie.prefix(txt))
     self.assertEqual(prefixes, [(ids[u"aai̇"], 3), (ids[u"aai̇bİ"], 5)])
     txt = u"aai̇bİc"
     prefixes = list(trie.prefix(txt))
     self.assertEqual(prefixes, [(ids[u"aai̇"], 4), (ids[u"aai̇bİ"], 6)])
示例#7
0
 def test_ignore_case_match_longest(self):
     if sys.version_info.major < 3:
         return
     trie = Trie(ignore_case=True)
     ids = {w: trie.insert(w) for w in [u"aİİ", u"aai̇", u"aai̇bİ"]}
     matches = list(trie.match_longest(u"aaİ aai̇bİaa"))
     self.assertEqual(matches,
                      [(ids[u"aai̇"], 0, len(u"aaİ")),
                       (ids[u"aai̇bİ"], len(u"aaİ "), len(u"aaİ aai̇bİ"))])
     sep = set([ord(" ")])  # space as seperator
     matches = list(trie.match_longest(u"aaİ aai̇bİaa", sep))
     self.assertEqual(matches, [
         (ids[u"aai̇"], 0, len(u"aaİ")),
     ])
示例#8
0
 def build(self, pages: Iterable[str]):
     key2titles = {}
     for page in pages:
         if not page:
             continue
         key = _clean_title(page).lower()
         if not key:
             key = page
         titles = key2titles.setdefault(key, [])
         titles.append(page)
     mapping = {}
     self._trie = Trie(ignore_case=True)
     for key in key2titles:
         id_ = self._trie.insert(key)
         mapping.setdefault(id_, tuple(key2titles[key]))
     self._map = tuple([mapping.get(i) for i in range(max(mapping) + 1)])
示例#9
0
 def test_remove_words(self):
     dir_ = os.path.dirname(__file__)
     trie = Trie()
     for i in range(3):
         ids = []
         words = []
         with open(os.path.join(dir_, "../bench/words.txt")) as fi:
             for l in fi:
                 l = l.strip()
                 if isinstance(l, bytes):
                     l = l.decode("utf8")
                 if len(l) > 0:
                     words.append(l)
                     ids.append(trie.insert(l))
         for id_, w in zip(ids, words):
             self.assertEqual(id_, trie.remove(w))
示例#10
0
 def test_reuse_id(self):
     trie = Trie()
     ids = {w: trie.insert(w) for w in [u"abc", u"abd", u"abe"]}
     trie.remove(u"abc")
     trie.remove(u"abe")
     v = trie.insert(u"abf")
     self.assertEqual(v, ids[u"abe"])
     v = trie.insert(u"abg")
     self.assertEqual(v, ids[u"abc"])
     v = trie.insert(u"abh")
     self.assertEqual(v, 3)
     v = trie.insert(u"abi")
     self.assertEqual(v, 4)
示例#11
0
 def test_ignore_case_replace_longest(self):
     if sys.version_info.major < 3:
         return
     trie = Trie(ignore_case=True)
     ids = {w: trie.insert(w) for w in [u"aİİ", u"aai̇", u"aai̇bİ"]}
     replaced = {
         ids[u"aİİ"]: u"a",
         ids[u"aai̇"]: u"b",
         ids[u"aai̇bİ"]: u"c",
     }
     res = trie.replace_longest(u"aaİ aai̇bİaa",
                                lambda x, start, end: replaced[x])
     self.assertEqual(res, u"b caa")
     sep = set([ord(" ")])  # space as seperator
     res = trie.replace_longest(u"aaİ aai̇bİaa",
                                lambda x, start, end: replaced[x], sep)
     self.assertEqual(res, u"b aai̇bİaa")
示例#12
0
 def test_match_longest(self):
     trie = Trie()
     ids = {
         w: trie.insert(w)
         for w in
         [u"New York", u"New", u"York", u"York City", u"City", u"City is"]
     }
     matches = list(trie.match_longest(u"New York City isA"))
     self.assertEqual(
         matches,
         [(ids[u"New York"], 0, len(u"New York")),
          (ids[u"City is"], len(u"New York "), len(u"New York City is"))])
     sep = set([ord(" ")])  # space as seperator
     matches = list(trie.match_longest(u"New York City isA", sep))
     self.assertEqual(
         matches,
         [(ids[u"New York"], 0, len(u"New York")),
          (ids[u"City"], len(u"New York "), len(u"New York City"))])
示例#13
0
文件: test_pickle.py 项目: yyht/cyac
 def test_pickle_trie(self):
     trie = Trie(ignore_case=True)
     ids = {w: trie.insert(w) for w in [u"aİİ", u"aai̇", u"aai̇bİ"]}
     with open("trie.pkl", "wb") as fo:
         pickle.dump(trie, fo)
     with open("trie.pkl", "rb") as fi:
         trie = pickle.load(fi)
     replaced = {
         ids[u"aİİ"]: u"a",
         ids[u"aai̇"]: u"b",
         ids[u"aai̇bİ"]: u"c",
     }
     res = trie.replace_longest(u"aaİ aai̇bİaa",
                                lambda x, start, end: replaced[x])
     self.assertEqual(res, u"b caa")
     sep = set([ord(" ")])  # space as seperator
     res = trie.replace_longest(u"aaİ aai̇bİaa",
                                lambda x, start, end: replaced[x], sep)
     self.assertEqual(res, u"b aai̇bİaa")
示例#14
0
 def test_replace_words(self):
     dir_ = os.path.dirname(__file__)
     trie = Trie()
     ids = []
     with open(os.path.join(dir_, "../bench/words.txt")) as fi:
         for l in fi:
             l = l.strip()
             if isinstance(l, bytes):
                 l = l.decode("utf8")
             if len(l) > 0:
                 ids.append(trie.insert(l))
     with open(os.path.join(dir_, "../bench/words.txt")) as fi:
         txt = fi.read()
         if isinstance(txt, bytes):
             txt = txt.decode("utf8")
     sep = set([ord("\n")])
     ret = trie.replace_longest(txt, lambda v, start, end: str(v),
                                sep).strip()
     self.assertEqual(ret, "\n".join([str(i) for i in ids]))
示例#15
0
 def test_match_words(self):
     dir_ = os.path.dirname(__file__)
     trie = Trie()
     ids = []
     with open(os.path.join(dir_, "../bench/words.txt")) as fi:
         for l in fi:
             l = l.strip()
             if isinstance(l, bytes):
                 l = l.decode("utf8")
             if len(l) > 0:
                 ids.append(trie.insert(l))
     with open(os.path.join(dir_, "../bench/words.txt")) as fi:
         txt = fi.read()
         if isinstance(txt, bytes):
             txt = txt.decode("utf8")
     sep = set([ord("\n")])
     matched = []
     for v, start, end in trie.match_longest(txt, sep):
         matched.append(v)
         self.assertEqual(txt[start:end], trie[v])
     self.assertEqual(matched, ids)
示例#16
0
def _get_category_links(cat2id, id2page, **kwargs):
    trie = Trie()
    for page in id2page.values():
        trie.insert(page)
    category_links = []
    for cl_type, source_id, target_title in dt.iter_categorylinks_dump_data(
        **kwargs
    ):
        if (
            # only categories with a page
            target_title not in trie
            # only allowed pages
            or cl_type == dt.WIKI_CL_TYPE_PAGE
            and source_id not in id2page
        ):
            continue
        try:
            target_id = cat2id[f"Category:{target_title}"]
        except KeyError:
            continue
        category_links.append((source_id, target_id))
    return category_links
示例#17
0
    def test_replace_longest(self):
        trie = Trie()
        ids = {
            w: trie.insert(w)
            for w in
            [u"New York", u"New", u"York", u"York City", u"City", u"City is"]
        }
        replaced = {
            ids[u"New York"]: u"Beijing",
            ids[u"New"]: u"Old",
            ids[u"York"]: u"Yark",
            ids[u"York City"]: u"Yerk Town",
            ids[u"City"]: u"Country",
            ids[u"City is"]: u"Province are"
        }
        res = trie.replace_longest(u"New York  City isA",
                                   lambda x, start, end: replaced[x])
        self.assertEqual(res, u"Beijing  Province areA")

        sep = set([ord(" ")])  # space as seperator
        res = trie.replace_longest(u"New York  City isA",
                                   lambda x, start, end: replaced[x], sep)
        self.assertEqual(res, u"Beijing  Country isA")
示例#18
0
 def test_prefix(self):
     trie = Trie()
     ids = {w: trie.insert(w) for w in [u"ruby", u"rubx", u"rab", u"rub"]}
     prefixes = list(trie.prefix(u"ruby on rails"))
     self.assertEqual(prefixes, [(ids[u"rub"], 3), (ids[u"ruby"], 4)])
示例#19
0
 def test_insert_zero_len_key(self):
     trie = Trie()
     self.assertEqual(trie.insert(u""), -1)
示例#20
0
 def load(path: Path):
     wpd = WikiPageDetector()
     wpd._map = pickle_load(path / "wpd_map.gz")
     with (path / "wpd_trie").open("r+b") as bf:
         wpd._trie = Trie.from_buff(mmap(bf.fileno(), 0), copy=False)
     return wpd
示例#21
0
class WikiPageDetector:
    def __init__(self, pages: Iterable[str] = None):
        self._map = None
        self._trie = None
        if pages is not None:
            self.build(pages)

    @staticmethod
    def load(path: Path):
        wpd = WikiPageDetector()
        wpd._map = pickle_load(path / "wpd_map.gz")
        with (path / "wpd_trie").open("r+b") as bf:
            wpd._trie = Trie.from_buff(mmap(bf.fileno(), 0), copy=False)
        return wpd

    def dump(self, path: Path):
        self._trie.save(str(path / "wpd_trie"))
        pickle_dump(self._map, path / "wpd_map.gz", compress=True)

    def build(self, pages: Iterable[str]):
        key2titles = {}
        for page in pages:
            if not page:
                continue
            key = _clean_title(page).lower()
            if not key:
                key = page
            titles = key2titles.setdefault(key, [])
            titles.append(page)
        mapping = {}
        self._trie = Trie(ignore_case=True)
        for key in key2titles:
            id_ = self._trie.insert(key)
            mapping.setdefault(id_, tuple(key2titles[key]))
        self._map = tuple([mapping.get(i) for i in range(max(mapping) + 1)])

    def find_pages(self, text: str):
        def iter_matches(source):
            ac_seps = set([ord(p) for p in _XP_SEPS.findall(source)])
            for id_, start_idx, end_idx in self._trie.match_longest(
                source, ac_seps
            ):
                yield (start_idx, end_idx, self._map[id_])

        for match in iter_matches(text):
            yield match
            match_text = text[match[0] : match[1]]
            seps = list(_XP_SEPS.finditer(match_text))
            if len(seps) < 1:
                continue
            tokens = []
            last_end = 0
            for sep in seps:
                token = match_text[last_end : sep.start()]
                start = last_end
                last_end = sep.end()
                if len(token) < 2 and not token.isalnum():
                    continue
                tokens.append((start, token))
            tokens.append((last_end, match_text[last_end:]))
            num_tokens = len(tokens)
            for s, e in combinations(range(num_tokens + 1), 2):
                if s == 0 and e == num_tokens:
                    continue
                e -= 1
                submatches = set()
                start = tokens[s][0]
                end = tokens[e][0] + len(tokens[e][1])
                subtext = match_text[start:end]
                start += match[0]
                for sidx, eidx, pages in iter_matches(subtext):
                    coords = (sidx + start, eidx + start)
                    if coords in submatches:
                        continue
                    submatches.add(coords)
                    yield (*coords, pages)
示例#22
0
 def test_ignore_case(self):
     trie = Trie(ignore_case=True)
     ids = [trie.insert(w) for w in [u'Ruby', u'ruby', u'rb']]
     self.assertEqual(ids, [0, 0, 1])
     self.assertEqual(trie.remove(u"ruby"), 0)
     self.assertEqual(trie.remove(u"Ruby"), -1)
示例#23
0
    def test_buff_ac(self):
        trie = Trie(ignore_case=True)
        ids = {w : trie.insert(w) for w in [u"aİİ", u"aai̇", u"aai̇bİ"]}
        trie.save("trie.bin")
        with open("trie.bin", "rb") as fi:
            bs = bytearray(fi.read())
        self.assertEqual(len(bs), trie.buff_size())
        bs2 = bytearray(trie.buff_size())
        trie.to_buff(bs2)
        self.assertEqual(bs2, bs)

        self._check_trie_correct(Trie.from_buff(bs2, copy=True), ids)
        self._check_trie_correct(Trie.from_buff(bs2, copy=False), ids)
示例#24
0
def init_trie(words, size):
    trie = Trie()
    for i in range(size):
        trie.insert(words[i])
    return trie