예제 #1
0
 def prefixes(self, text):
     agent = marisa.Agent()
     agent.set_query(text)
     out = []
     while self._trie.common_prefix_search(agent):
         out.append(agent.key().str())
     return out
예제 #2
0
    def __getitem__(self, item):
        agent = marisa.Agent()
        agent.set_query(item)

        out = []
        if not self._trie.lookup(agent):
            return out

        key = agent.key()
        indexes = self._key_id_to_index[key.id()]

        for index in indexes:
            # This is the index into dict_mm.
            # Read that line.
            end_index = self._dict_mm.find(b'\n', index)
            line = self._dict_mm[index:end_index]
            line = str(line, 'utf-8')
            out.append(Record.from_line(line))

        return out
예제 #3
0
    def __read_ngrams(self):
        self.__infile.seek(0)
        for n in range(1, NGRAM + 1):
            while True:
                line = self.__infile.readline()
                if line == "":
                    break
                if line.startswith("\\%s-grams:" % n):
                    break

            while True:
                line = self.__infile.readline()
                if line == "":
                    break
                line = line.strip('\n')
                if line == "":
                    break
                match = self.__ngram_line_regex.match(line)
                if not match:
                    continue
                strv = match.groups()
                ngram = strv[1].split(" ")
                ids = []
                for word in ngram:
                    agent = marisa.Agent()
                    agent.set_query(word)
                    if not self.__vocab_trie.lookup(agent):
                        continue
                    ids.append(agent.key_id())
                cost = float(strv[0])
                if cost != -99 and cost < self.__min_cost:
                    self.__min_cost = cost
                backoff = 0.0
                if strv[2]:
                    backoff = float(strv[2])
                self.__ngram_entries[n - 1][tuple(ids)] = (cost, backoff)
예제 #4
0
def main():
    def output_ngram(f_out, n, context_dict, context_id_dicts):
        f_out.write('#%d\t%d\n' % (n, len(context_dict)))

        for context_key in sorted(context_dict.keys()):
            context_id_dicts[n][context_key] = len(context_id_dicts[n])
            first = context_key[0]
            if len(context_key) > 1:
                next_context_id = context_id_dicts[n - 1][tuple(
                    context_key[1:])]
            else:
                next_context_id = 0

            f_out.write('%d\t' % (context_id_dicts[n][context_key]) +
                        '%d\t%d\t' % (context_key[0], next_context_id) +
                        '%f\t%f\t%s\n' % context_dict[context_key])

        return

    try:
        opts, args = getopt.getopt(sys.argv[1:], 'hi:o:t:')
    except getopt.GetoptError:
        Usage()
        sys.exit(2)

    input_filename = ''
    output_filename_prefix = ''

    for k, v in opts:
        if k == '-h':
            usage()
            sys.exit()
        elif k == '-i':
            input_filename = v
        elif k == '-o':
            output_filename_prefix = v

    if input_filename == '' or output_filename_prefix == '':
        Usage()

    f_in = codecs.open(input_filename, 'r', 'utf-8')
    f_out_text = codecs.open(output_filename_prefix + EXT_TEXT, 'w', 'utf-8')

    cur_n = 0
    pair_id = 1
    pair_dict = {}
    context_dict = {}
    context_id_dicts = [{}]
    keyset_key = marisa.Keyset()
    keyset_pair = marisa.Keyset()
    trie_key = marisa.Trie()
    trie_pair = marisa.Trie()
    agent = marisa.Agent()
    min_score = 99.0
    max_backoff = -99.0

    for line in f_in:
        line = line.rstrip('\n')
        if not line:
            continue

        if line[0] == '\\':
            m = re.search(r'^\\(\d+)-grams:', line)
            if (cur_n and (m or re.search(r'^\\end\\', line))):
                if cur_n == 1:
                    trie_key.build(keyset_key)
                    trie_key.save(output_filename_prefix + EXT_KEY)
                    trie_pair.build(keyset_pair)
                    trie_pair.save(output_filename_prefix + EXT_PAIR)
                    for k, v in pair_dict.iteritems():
                        context_dict[(to_id(trie_pair, agent, k), )] = v

                output_ngram(f_out_text, cur_n, context_dict, context_id_dicts)

            if m:
                cur_n = int(m.group(1))
                context_dict = {}
                context_id_dicts.append({})
                print 'Processing %d-gram...' % cur_n

            continue

        if cur_n == 0:
            continue

        fields = line.split('\t')
        if len(fields) < 2:
            continue

        if len(fields) == 2:
            fields.append('-99')

        score = float(fields[0])
        backoff = float(fields[2])
        if score < min_score and score > -99:
            min_score = score
        if backoff > max_backoff:
            max_backoff = backoff

        if cur_n == 1:
            k = fields[1].encode('utf-8')
            keyset_pair.push_back(k)
            pair = k.split(PAIR_SEPARATOR, 1)
            keyset_key.push_back(pair[0])
            pair_dict[k] = (float(fields[0]), float(fields[2]), fields[1])
        else:
            ngram = [
                to_id(trie_pair, agent, x.encode('utf-8'))
                for x in reversed(fields[1].split(' '))
            ]
            context_dict[tuple(ngram)] = (float(fields[0]), float(fields[2]),
                                          fields[1])

    f_in.close()
    f_out_text.close()
    print 'Done.'
    print 'min_score = %f, max_backoff = %f' % (min_score, max_backoff)
예제 #5
0
print("\nSum of 1-gram:", scount)
if factor > 1:
    scount = int(scount/factor)
    print("Normalized sum of 1-gram:", scount)
    
if scount > 2**31:
    print("Trouble: sum of 1-grams doesn't fit INT32. Please normalize the data manually or automatically by increasing threshold for counts")
    sys.exit(-1)

print()

# save ngrams
print('Saving in Marisa format')
trie = marisa.Trie()
trie.build(keyset)
trie.save(os.path.join(args.output, "ngrams.trie"))

print("Keys: ", trie.num_keys(), "\n")

arr = np.zeros(trie.num_keys()+1, dtype=np.int32)
arr[0] = scount
agent = marisa.Agent()
for k in data:
    agent.set_query(k)
    trie.lookup(agent)
    arr[ agent.key_id() + 1 ] = int(data[k] / factor)

binwrite=open(os.path.join(args.output, 'ngrams.counts'),'wb')
arr.tofile(binwrite)
binwrite.close()