示例#1
0
    def read_fst(filename, symbol_tables=False):
        fst_1 = fst.read_std(filename)

        if not symbol_tables:
            fst_1.isyms = None
            fst_1.osyms = None

        return fst_1
示例#2
0
 def _prepare_resource(self,dir_to_tagger,dir_to_phrase):
     '''
     [description]
         根据tagger和constraint的fst生成字典
     Arguments:
         dir_to_tagger {string} -- [description]
         dir_to_phrase {string} -- [description]
     Returns:
         tagger_dict -- [description] tagger_dict['${.concept}']=string list, each string is a path
         constraint_dict -- [description] constraint_dict['${@constraint}']=list of (string path, mapped value)
     '''
     # deal with entities(tagger)
     files=os.listdir(dir_to_tagger)
     isyms=fst.read_symbols(os.path.join(dir_to_tagger,'isyms.fst'))
     osyms=fst.read_symbols(os.path.join(dir_to_tagger,'osyms.fst'))
     filepath=os.path.join(dir_to_tagger,[each for each in files if each not in ['isyms.fst','osyms.fst'] and each.endswith('.fst')][0])
     lexicon=fst.read_std(filepath)
     lexicon.isyms=isyms
     lexicon.osyms=osyms
     self.tagger_dict=defaultdict(list)
     for each_path in lexicon.paths():
         input_string=[lexicon.isyms.find(arc.ilabel) for arc in each_path if arc.ilabel != 0]
         if len(input_string)!=1:
             raise ValueError('[Error]:error in resolving tagger name!')
         output_string=[lexicon.osyms.find(arc.olabel) for arc in each_path if arc.olabel != 0]
         self.tagger_dict[input_string[0]].append(reverse_preproc(output_string))
     # deal with constraints
     files=os.listdir(dir_to_phrase)
     isyms=fst.read_symbols(os.path.join(dir_to_phrase,'isyms.fst'))
     osyms=fst.read_symbols(os.path.join(dir_to_phrase,'osyms.fst'))
     fst_dict={}
     for each in files:
         if each not in ['isyms.fst','osyms.fst'] and each.endswith('.fst'):
             fst_dict[each[0]]=fst.read_std(os.path.join(dir_to_phrase,each))
             fst_dict[each[0]].isyms=isyms
             fst_dict[each[0]].osyms=osyms
     self.constraint_dict=defaultdict(list)
     for each in sorted(fst_dict.keys()): #层级phrase的fst按0-1-2-...顺序组织
         tmp_fst=fst_dict[each]
         for path in tmp_fst.paths():
             name,item_list=self._get_path_and_mapped_value(path,tmp_fst)
             self.constraint_dict[name].extend(item_list)
     return (self.tagger_dict,self.constraint_dict)
示例#3
0
 def setUp(self):
     shortest_txt = os.path.join(os.path.dirname(__file__), 'test_shortest.txt')
     shortest_fst = os.path.join(os.path.dirname(__file__), 'test_shortest.fst')
     try:
         if not os.path.exists(shortest_fst):
             call(['fstcompile', shortest_txt, shortest_fst])
     except Exception as e:
         print('Failed to generate testing fst')
         raise e
     self.s = fst.read_std(shortest_fst)
     self.s_result = [(110.40000001341105, [1, 3, 4]),
                      (110.6000000089407, [2, 3, 4]), (1000.2000000029802, [2])]
示例#4
0
 def __init__(self,
              vecfname,
              lmfname,
              onmt_dir,
              model_dir,
              kenlm_loc,
              maxtypes=0):
     self.vecs = wordvecutil.word_vectors(vecfname, maxtypes)
     self.lmfst = fst.read_std(lmfname)
     self.maxtypes = maxtypes
     # self.onmt_dir = '/data/OpenNMT'
     # self.onmt_model = '/data/soliloquy_variation/language_model/luamodel_1/model_epoch13_1.16.t7'
     self.onmt_dir = onmt_dir
     self.onmt_model = model_dir
     self.kenlm_loc = kenlm_loc
     if self.onmt_model != '':
         self.sent_rescore = self.sent_rescore_onmt
         self.onmt_model = os.path.abspath(self.onmt_model)
     elif self.kenlm_loc != '':
         self.sent_rescore = self.sent_rescore_kenlm
     else:
         self.sent_rescore = self.sent_rescore_dummy
示例#5
0
        path_ostring = [uni(f.osyms.find(arc.olabel)) for arc in path if f.osyms.find(arc.olabel) != fst.EPSILON]

        strings = []

        arcs = zip(path_istring, path_ostring)
        for (ia, oa) in arcs:
            if oa == "<s>" or oa == "</s>":
                pass
            elif not oa == "<unk>":
                strings.append(oa)
            else:
                #there is a unk in thr output
                #hack look for one of the 7 unks possible from the input and insert them
                t = list(set(path_istring).intersection(unk))
                strings.append(t[0])
        return ' '.join(strings)


if __name__ == '__main__':
    #symi = fst.read_symbols(sys.argv[2].strip())
    #symo = fst.read_symbols(sys.argv[3].strip())
    try:
        f = fst.read_std(sys.argv[1].strip())
        unk = set(open('data/unk').read().split())
    except:
        print 'useage: python printFinal.py [in.fst]\n'  # takes 1-best
        exit()
    f.remove_epsilon()
    paths = getpaths(f, unk)
    print paths