예제 #1
0
def test_prefixes():
    trie = marisa_trie.Trie(['foo', 'f', 'foobar', 'bar'])
    assert trie.prefixes('foobar') == ['f', 'foo', 'foobar']
    assert trie.prefixes('foo') == ['f', 'foo']
    assert trie.prefixes('bar') == ['bar']
    assert trie.prefixes('b') == []

    assert list(trie.iter_prefixes('foobar')) == ['f', 'foo', 'foobar']
예제 #2
0
def test_iterkeys():
    keys = get_random_words(1000)
    trie = marisa_trie.Trie(keys)
    assert trie.keys() == list(trie.iterkeys())

    for key in keys:
        prefix = key[:5]
        assert trie.keys(prefix) == list(trie.iterkeys(prefix))
예제 #3
0
    def build_from_wikipedia(
        dump_db: DumpDB,
        tokenizer,
        normalizer,
        out_file,
        max_candidate_size,
        min_mention_count,
        max_mention_length,
        pool_size,
        chunk_size,
    ):
        logger.info("Extracting all entity names...")

        title_dict = defaultdict(Counter)
        with tqdm(total=dump_db.page_size(), mininterval=0.5) as pbar:
            initargs = (dump_db, tokenizer, normalizer, max_mention_length)
            with closing(
                    Pool(pool_size,
                         initializer=EntityDB._initialize_worker,
                         initargs=initargs)) as pool:
                for ret in pool.imap_unordered(
                        EntityDB._extract_name_entity_pairs,
                        dump_db.titles(),
                        chunksize=chunk_size):
                    for (name, title) in ret:
                        title_dict[title][name] += 1
                    pbar.update()

        logger.info("Building DB...")

        mentions = frozenset([
            mention for mention_counter in title_dict.values()
            for mention in mention_counter.keys()
        ])
        title_trie = frozenset(title_dict.keys())
        mention_trie = marisa_trie.Trie(mentions)

        def item_generator():
            for (title, mention_counter) in title_dict.items():
                for (mention, mention_count
                     ) in mention_counter.most_common()[:max_candidate_size]:
                    if mention_count < min_mention_count:
                        continue
                    yield (title, (mention_trie[mention], mention_count))

        data_trie = marisa_trie.RecordTrie("<II", item_generator())

        joblib.dump(
            dict(
                title_trie=title_trie,
                mention_trie=mention_trie,
                data_trie=data_trie,
                tokenizer=tokenizer,
                normalizer=normalizer,
                max_mention_length=max_mention_length,
            ),
            out_file,
        )
예제 #4
0
def build_trie(word_list_path):
    """
    Returns a filled Marisa Trie object from a word list file.
    @param word_list_path Path to a word list file, one per line
    @return A Trie object containing all words
    """
    with open(word_list_path) as words:
        word_list = [word.strip() for word in words]  # Removes trailing \n
    return marisa_trie.Trie(word_list)
def read_pw_nh_graph(fname, q=-1, _N=-1):
    """Reads the typo trie file and the neighborhood map created by
    `create_pw_nh_graph` function.

    Returns: (M, A, typo_trie)
    M is the rpw -> Neighborhood information
      - M[i][0] is the rpw_id, of i-th most probable password
      - M[i][1:] is the neighborhood, truncted to MAX_NH_SIZE (500)
    A is the weight of the balls of all the typos we collected
      - A[i] = Total sum of frequencies of all the rpw in the ball
               of i-th password in trie. (see typo_trie)
    typo_trie is a maping from typo_id to typos, so, to retrieve
    the i-th typo in A[i], use typo_trie.restore_key(i).
    typo_trie is not required for computing the total success of
    an attacker.
    q: Prune the typo list based on q value, so that don't worry
       about typos that are very low in the tail, for example, a
       typo with total ball weight < 10*q-th most probable typo, is
       most likely useless. Where assume the average ball size is 10.
    """
    # N = 1000
    global N
    if _N > 0:
        N = _N
    typodir = '{}/typodir'.format(pwd)
    pwm = Passwords(fname, max_pass_len=25, min_pass_len=5)
    N = min(N, len(pwm))
    tpw_trie_fname = '{}/{}__{}_{}_typo.trie'\
                     .format(typodir, pwm.fbasename, 0, N)
    rpw_nh_graph = '{}/{}__{}_{}_rpw_nh_graph.npz'\
                   .format(typodir, pwm.fbasename, 0, N)

    typo_trie = marisa_trie.Trie()
    typo_trie.load(tpw_trie_fname)
    M = np.load(rpw_nh_graph)['M']
    ## Extra fix ##
    M[M == 0] = -1
    d = len(typo_trie)
    A = np.zeros(len(typo_trie))
    for i in xrange(M.shape[0]):
        if M[i, 0] <= 0:
            continue
        p_rpw = pwm.pw2freq(typo_trie.restore_key(M[i, 0]))
        A[M[i, M[i] >= 0]] += p_rpw

    print("Done creating the 'A' array. Size={}".format(A.shape))
    # # Prune the typos, Not all typos are useful, any typo with
    # # frequency less than i_th most probable password will never be
    # # queried.
    # b = (M>0).sum() / float(A.shape[0])   # average ball size
    # print("Average ball size: {}".format(b))
    # bq_th_pw_f = pwm.id2freq(M[int(b*q)][0])
    # useful_typos = (A>=bq_th_pw_f)
    # print("Useful typos (> {}): {}/{}".format(
    #     bq_th_pw_f, useful_typos.sum(), A.shape[0]
    # ))
    return M, A, typo_trie, pwm
예제 #6
0
def test_build():
    keys = get_random_words(1000)
    trie = marisa_trie.Trie(keys)

    for key in keys:
        assert key in trie

    non_key = '2135'
    assert non_key not in trie
예제 #7
0
def init_mysql_keyword(cfg):
    db_pool = make_db_pool(cfg['mysql'],2)
    conn = db_pool.connection()
    cursor = conn.cursor()
    cursor.execute('''select trim(keyword) from task_priority_config''')
    key_word_list = [row[0].decode('utf8').lower() for row in cursor.fetchall()]
    gv.keyword_trie = marisa_trie.Trie(key_word_list)
    cursor.close()
    conn.close()
예제 #8
0
def test_dumps_loads(keys):
    trie = marisa_trie.Trie(keys)
    data = pickle.dumps(trie)

    trie2 = pickle.loads(data)

    for key in keys:
        assert key in trie2
        assert trie2.key_id(key) == trie.key_id(key)
예제 #9
0
 def _freeze_vocabulary(self, X=None):
     if not self.fixed_vocabulary_:
         frozen = marisa_trie.Trie(six.iterkeys(self.vocabulary_))
         if X is not None:
             X = self._reorder_features(X, self.vocabulary_, frozen)
         self.vocabulary_ = frozen
         self.fixed_vocabulary_ = True
         del self.stop_words_
     return X
예제 #10
0
def load_french_dictionary():
    with open('/Users/macbook/Desktop/PYTHONPROJET/francais.txt',
              encoding='iso-8859-1') as f:
        content = f.read()
    letters = set()
    for word in content.split('\n'):
        for letter in word:
            letters.add(letter)
    return letters, marisa_trie.Trie(content.split('\n'))
예제 #11
0
 def load(self, load_dir):
     self._max_value = load_json_file(
         filename=os.path.join(load_dir, "max_value.json"))
     self._stoi = marisa_trie.Trie().mmap(
         os.path.join(load_dir, f"vocabulary_trie.marisa"))
     self._itos = lambda x: self._stoi.restore_key(x)
     self._record_trie = marisa_trie.RecordTrie(
         self._get_fmt_string(self._max_value)).mmap(
             os.path.join(load_dir, f"record_trie.marisa"))
예제 #12
0
        def load__nltk_abc_data__func():
            # `pip install marisa_trie`
            import marisa_trie
            from .regexp_utils import regexp
            words = nltk.corpus.abc.words()

            two_length_words_data = marisa_trie.Trie(
                sorted(
                    list(
                        set([
                            w1 for w1 in words if (len(w1) == 2)
                            and re.compile("[a-z]", re.IGNORECASE).match(w1[0])
                        ]))))

            regular_words_data = marisa_trie.Trie(
                [unicode(w1) for w1 in words if regexp.word.match(w1)])

            return [two_length_words_data, regular_words_data]
예제 #13
0
 def write(
     self,
     filename,
 ):
     import marisa_trie
     with indir(filename, create=True):
         yield from self.write_groups()
         words = self._words
         trie = marisa_trie.Trie(words)
         trie.save(self.WORDS_FILE_NAME)
예제 #14
0
def test_pickling():
    words = get_random_words(1000)
    trie = marisa_trie.Trie(words)

    data = pickle.dumps(trie)
    trie2 = pickle.loads(data)

    for word in words:
        assert word in trie2
        assert trie2.key_id(word) == trie.key_id(word)
예제 #15
0
def create_meta_filter(langs, filter_bys):
    # If any classes do not use name as filters, than all names should be considered
    for filter_by in filter_bys:
        if filter_by == 'institute':
            return {}
    # Otherwise, only the names in each lang should be considered
    names = []
    for lang in langs:
        names += LAST_NAMES[lang]
    return {'authors': marisa_trie.Trie(names)}
예제 #16
0
def get_all_aliases(alias2qidcands, logger):
    # Load alias2qids
    global alias2qids
    alias2qids = {}
    logger.info("Loading candidate mapping...")
    for al in tqdm(alias2qidcands):
        alias2qids[al] = [c[0] for c in alias2qidcands[al]]
    logger.info(f"Loaded candidate mapping with {len(alias2qids)} aliases.")
    all_aliases = marisa_trie.Trie(alias2qids.keys())
    return all_aliases
예제 #17
0
 def __init__(self):
     self.max_candidates = 3
     # Used if we need to do any string searching for aliases. This keep track of the largest n-gram needed.
     self.max_alias_len = 1
     self._qid2title = {"Q1": "a b c d e", "Q2": "f", "Q3": "dd a b", "Q4": "x y z"}
     self._qid2eid = {"Q1" : 1, "Q2": 2, "Q3": 3, "Q4": 4}
     self._alias2qids = {"alias1": [["Q1", 10], ["Q2", 3]], "alias2": [["Q3", 100]], "alias3": [["Q1", 15], ["Q4", 5]]}
     self._alias_trie = marisa_trie.Trie(self._alias2qids.keys())
     self.num_entities = len(self._qid2eid)
     self.num_entities_with_pad_and_nocand = self.num_entities + 2
예제 #18
0
 def __init__(self):
     """
     Initialize necessary resources.
     """
     self.dictionary_file = open(
         os.path.join(os.path.dirname(__file__), 'data/ml_rootwords.txt'))
     self.dictionary = self.dictionary_file.readlines()
     self.dictionary_file.close()
     try:
         self.dictionary = marisa_trie.Trie(
             [x.strip().decode('utf-8') for x in self.dictionary])
     except:
         self.dictionary = marisa_trie.Trie(
             [x.strip() for x in self.dictionary])
     self.stemmer = Stemmer()
     self.inflector = inflector.Inflector(lang='ml')
     self.soundex = Soundex()
     self.syllabalizer = Syllabifier()
     self.ngrammer = Ngram()
예제 #19
0
 def makeTrie(self):
     L = []
     with self.con:
         cur = self.con.cursor()
         cur.execute("SELECT city FROM CITIES")
         for i in range(cur.rowcount):
             row = cur.fetchone()
             L.append(row[0].lower())
     # Print list
     # print '[%s]' % ', '.join(map(str, L))
     self.trie = marisa_trie.Trie(L)
예제 #20
0
    def get_trie(self):
        current_path = os.path.abspath(__file__)
        father_path = os.path.abspath(os.path.dirname(current_path))
        file_path = father_path + "/files/keyword.txt"
        words = []

        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                words.append(line.strip('\n'))
        trie = marisa_trie.Trie(words)
        return trie
예제 #21
0
    def add_all(self, signs):
        new_signs = [s for s in signs if s not in self.sign_trie]
        if len(new_signs) > 0:
            old_signs = self.sign_trie.keys()
            n_prev_ids = len(self.sign_trie)

            # the key ids are preserved
            signs = old_signs + new_signs

            # build the new trie
            self.sign_trie = marisa_trie.Trie(signs)
예제 #22
0
    def setup_entity_values(self, entities):
        self.logger.info("Setting up value entities'%s'", entities)
        for entity_name, entity_values in entities.items():
            # This can be done more concisely, expanded for clarity
            updated_words = []
            for word in entity_values:
                lower = word.lower()
                temp_word = lower.strip(self.punctuation)
                updated_words.append(temp_word)

            self.entity_tries[entity_name] = marisa_trie.Trie(updated_words)
예제 #23
0
def marisa_test():
    inn = []
    for i in range(1000000):
        # if i/1000==round(i/1000,0):
        #   print i
        # inn.append((u"%s"%j.tools.hash.md5_string(str(i)),(1,5)))
        inn.append(u"%s" % j.tools.hash.md5_string(str(i)))
    #trie = marisa_trie.RecordTrie("<HH",inn)
    trie = marisa_trie.Trie(inn)
    trie.save("data")
    del(inn)
예제 #24
0
    def load_dt(self, input_dir):
        """ Loads a pre-built distributional thesaurus structure. """

        tic = time()
        self.keys = marisa_trie.Trie()
        self.keys.load(self.keys_fpath)
        print("Loaded %d keys: %s" %
              (len(list(self.keys.items())), self.keys_fpath))
        self.sims = joblib.load(self.sims_fpath)
        print("Loaded %d scores: %s" % (self.sims.size, self.sims_fpath))
        print("Loading DT took", time() - tic, "sec.")
예제 #25
0
    def __init__(self, dictionary_path=DICT_PATH):
        self._dictionary = []
        self._dictionary_path = dictionary_path

        for file in os.listdir(self._dictionary_path):
            if file.endswith('.txt'):
              with open(self._dictionary_path + file) as file:
                  for line in file:
                      self._dictionary.append(line.rstrip())

        self._trie = marisa_trie.Trie(self._dictionary)
예제 #26
0
 def __init__(self):
     super(TrainedGrammar, self).__init__()
     self.term_files = {}
     self.g_struc = GrammarStructure()
     for k, f in self.g_struc.getTermFiles().items():
         sys.path.append(hny_config.GRAMMAR_DIR)
         X = __import__('%s' % f)
         self.term_files[k] = {
             'trie' : marisa_trie.Trie().load(hny_config.GRAMMAR_DIR+f+'.tri'),
             'arr' : eval("X.%s"%k),
             'trie_fl' : hny_config.GRAMMAR_DIR+f+'.tri'
             }
예제 #27
0
def get_trie():
    '''创建 Trie 树,添加敏感词至 Trie,并缓存'''

    words = []

    with open('words.txt', 'r', encoding='utf-8') as f:
        for line in f.readlines():
            words.append(line.strip('\n'))

    trie = marisa_trie.Trie(words)

    return trie
예제 #28
0
 def save_to_trie(self, filename, outfile):
     """
     Saves the dict into a trie.
     """
     keys = []
     with open(filename) as f:
         for line in f.readlines():
             line = line.strip("\n\t ")
             line = line.split("\t")
             keys.append(line[0].encode("UTF-8") + "@" +
                         str(self.types[line[1]]))
     t = marisa_trie.Trie(keys)
     t.save(outfile)
예제 #29
0
 def create_data_structure(self, pass_file):
     passwords = {}
     for w, c in helper.open_get_line(pass_file):
         passwords[unicode(w)] = c
     self.T = marisa_trie.Trie(passwords.keys())
     self.freq_list = [0 for _ in passwords]
     for k in self.T.iterkeys():
         self.freq_list[self.T.key_id(k)] = passwords[k]
     with open(self.e_pass_file_trie, 'wb') as f:
         self.T.write(f)
     with gzip.open(self.e_pass_file_freq, 'w') as f:
         for n in self.freq_list:
             f.write('%d\n' % n)
예제 #30
0
 def __init__(self, sentences, stop_token):
     """Construct Trie.
     
     Args:
         sentences (list[unicode]): a list of sentences, where each sentence is a string. Tokens in the sentence
             should be space delimited.
         stop_token (unicode)
     """
     sentences_formatted = [
         u'{} {}'.format(s.strip().lower(), stop_token) for s in sentences
     ]
     self._trie = marisa_trie.Trie(sentences_formatted)
     self._stop_token = stop_token