def read_symbols(fname): syms = fst.SymbolTable('eps') with open(fname) as sf: for line in sf: s, i = line.strip().split() syms[s] = int(i) return syms
def create_dt_fst(): dt_fst = fst.Transducer(isyms=fst.SymbolTable('eps'), osyms=fst.SymbolTable('eps')) fst_file = codecs.open(DUTCH_FST_FILE, 'r', encoding='utf-8') for l in fst_file: l = l.replace(u'\ufeff', '') entry = l.split() if len(entry) == 4: if entry[3] == 'ks': entry[3] = 'k s' dt_fst.add_arc(int(entry[0]), int(entry[1]), entry[2], entry[3]) dt_fst[1].final = True dt_fst[2].final = True return dt_fst
def get_transducer_symbol_table(self): transducer_symbol_table = fst.SymbolTable() all_segments_string = "".join(self.get_segments_symbols()) all_segments_string += LEFT_APPLICATION_BRACKET + LEFT_CENTER_BRACKET + LEFT_IDENTITY_BRACKET all_segments_string += RIGHT_APPLICATION_BRACKET + RIGHT_CENTER_BRACKET + RIGHT_IDENTITY_BRACKET fst.linear_chain(all_segments_string, syms=transducer_symbol_table) return transducer_symbol_table
def make_edit(sigma): """ Make an edit distance transducer """ # Create transducer syms = fst.SymbolTable() sigma.add('<eps>') edit = fst.StdVectorFst() edit.start = edit.add_state() edit[0].final = True for x in sigma: for y in sigma: if x == y == '<eps>': continue edit.add_arc(0, 0, syms[x], syms[y], (0 if x == y else 1)) # Define edit distance def distance(a, b): # Compose a o edit transducer o b comp = make_input(a, syms) >> edit >> make_input(b, syms) # Compute distance distances = comp.shortest_distance(reverse=True) dist = int(distances[0]) # Find best alignment alignment = comp.shortest_path() # Re-order states alignment.top_sort() # Replace "<eps>" -> "-" dash = syms['-'] eps = syms['<eps>'] alignment.relabel(ipairs=[(eps, dash)], opairs=[(eps, dash)]) arcs = (next(iter(state)) for state in alignment) labels = ((arc.ilabel, arc.olabel) for arc in arcs) align = [(syms.find(x), syms.find(y)) for x, y in labels] return dist, align return distance
def create_root_fst(label, int_coverage_cells): """ Create a root FST consisting of a single (nonterminal) transition :param label: Nonterminal transition label :param int_coverage_cells: Dictionary of integer coverages and associated FSTs :return: Root FST """ root_fst = fst.Transducer(isyms=fst.SymbolTable(), osyms=fst.SymbolTable()) root_fst.osyms[label] = int(label) # Adding epsilon input label using symbol table lookup for id=0 root_fst.add_arc(0, 1, root_fst.isyms.find(0), label) root_fst[1].final = True # Create root FST symbol table for int_coverage, cell_fst in int_coverage_cells.items(): root_fst.osyms[int_coverage] = int(int_coverage) return root_fst
def test_replace(): syms = fst.SymbolTable() a1 = fst.Acceptor(syms) a1.add_arc(0, 1, 'dial') a1.add_arc(1, 2, 'google') a1.add_arc(1, 2, '$name') a1.add_arc(2, 3, 'please') a1[3].final = True a2 = fst.Acceptor(syms) a2.add_arc(0, 1, 'michael') a2.add_arc(1, 2, 'riley') a2.add_arc(0, 1, '$firstname') a2.add_arc(1, 2, '$lastname') a2[2].final = True a3 = fst.Acceptor(syms) a3.add_arc(0, 1, 'johan') a3[1].final = True a4 = fst.Acceptor(syms) a4.add_arc(0, 1, 'schalkwyk') a4[1].final = True result = a1.replace({ '$name': a2, '$firstname': a3, '$lastname': a4 }, epsilon=True) result.remove_epsilon() expected = fst.Acceptor(syms) expected.add_arc(0, 1, 'dial') expected.add_arc(1, 2, 'google') expected.add_arc(1, 3, fst.EPSILON) expected.add_arc(3, 5, 'michael') expected.add_arc(3, 6, fst.EPSILON) expected.add_arc(6, 9, 'johan') expected.add_arc(9, 5, fst.EPSILON) expected.add_arc(5, 7, 'riley') expected.add_arc(5, 8, fst.EPSILON) expected.add_arc(8, 10, 'schalkwyk') expected.add_arc(10, 7, fst.EPSILON) expected.add_arc(7, 2, fst.EPSILON) expected.add_arc(2, 4, 'please') expected[4].final = True expected.remove_epsilon() eq_(result, expected)
def test_merge(): # Good merge syms1 = fst.SymbolTable() eq_(syms1['a'], 1) eq_(syms1['b'], 2) syms2 = fst.SymbolTable() syms2['a'] = 1 syms2['c'] = 3 syms1.merge(syms2) eq_(list(syms1.items()), [(fst.EPSILON, fst.EPSILON_ID), ('a', 1), ('b', 2), ('c', 3)]) # Bad merge (value conflict: a -> 1 vs a -> 2) syms3 = fst.SymbolTable() syms3['a'] = 2 assert_raises(ValueError, syms2.merge, syms3) # Bad merge (symbol conflict: a -> 1 vs b -> 1) syms4 = fst.SymbolTable() syms4['b'] = 1 assert_raises(ValueError, syms2.merge, syms4)
def genBigGraph(label_prob, symbols, seq_len, label='x'): t = fst.Transducer() sym = fst.SymbolTable() symbols = map(str, symbols) x = 0 for j in range(seq_len): for i in range(len(symbols)): prob = label_prob[j][i] #"%.4f" % t.add_arc(0 + x, 1 + x, str(label + str(j)), symbols[i], -math.log(prob)) x += 1 t[j + 1].final = -1 return t
def test_syms(): syms = fst.SymbolTable() eq_(len(syms), 1) # __len__ ok_(fst.EPSILON in syms) # __contains__ eq_(syms[fst.EPSILON], fst.EPSILON_ID) # __getitem__ eq_(syms.find(fst.EPSILON_ID), fst.EPSILON) # find(int) eq_(syms.find(fst.EPSILON), fst.EPSILON_ID) # find(str) eq_(syms, syms.copy()) # __richcmp__ syms['a'] = 2 # __setitem__ eq_(syms.find('a'), 2) eq_(syms.find(2), 'a') eq_(syms['a'], 2) assert_raises(KeyError, syms.find, 'x') assert_raises(KeyError, syms.find, 1) ok_('x' not in syms) eq_(list(syms.items()), [(fst.EPSILON, fst.EPSILON_ID), ('a', 2)])
def make_edit(sigma): """ Make an edit distance transducer with operations: - deletion: x:<epsilon>/1 - insertion: <epsilon>:x/1 - substitution: x:x/0 and x/y:1 """ # Create common symbol table syms = fst.SymbolTable() # Create transducer edit = fst.Transducer(syms, syms) edit[0].final = True for x in sigma: edit.add_arc(0, 0, x, fst.EPSILON, 1) edit.add_arc(0, 0, fst.EPSILON, x, 1) for y in sigma: edit.add_arc(0, 0, x, y, (0 if x == y else 1)) # Define edit distance def distance(a, b): # Compose a o edit transducer o b composed = fst.linear_chain(a, syms) >> edit >> fst.linear_chain( b, syms) # Compute distance distances = composed.shortest_distance(reverse=True) dist = int(distances[0]) # Find best alignment alignment = composed.shortest_path() # Re-order states alignment.top_sort() # Replace <epsilon> -> "-" alignment.relabel({fst.EPSILON: '-'}, {fst.EPSILON: '-'}) # Read alignment on the arcs of the transducer arcs = (next(state.arcs) for state in alignment) labels = ((arc.ilabel, arc.olabel) for arc in arcs) align = [(alignment.isyms.find(x), alignment.osyms.find(y)) for x, y in labels] return dist, align return distance
def gen_utt_graph(labels, symdict): t2 = fst.Transducer() sym = fst.SymbolTable() #3x3 states for this example count = 0 x = 0 # print labels for l in labels: symbols = symdict[l] symbols = map(str, symbols) for i in range(len(symbols)): if i == 0: t2.add_arc(0 + x, 1 + x, symbols[i], str(l + "/" + "(" + symbols[i] + ")")) else: t2.add_arc(0 + x, 1 + x, symbols[i], str(sym.find(0) + "(" + symbols[i] + ")")) t2.add_arc(1 + x, 1 + x, symbols[i], str(sym.find(0) + "(" + symbols[i] + ")")) x += 1 t2[x].final = True return t2
#!/usr/bin/env python # -*- coding: utf-8 -*- import itertools import math import fst, operator import alphabet import sys from functools import reduce abc = alphabet.Alphabet() syms = fst.SymbolTable() semiring = 'tropical' def Transducer(isyms=None, osyms=None, semiring=semiring): global syms if isyms is None: isyms = syms if osyms is None: osyms = syms return fst.Transducer(isyms=isyms, osyms=osyms, semiring=semiring) def GetPaths(t, return_full_path_in_ostring=False): if len(t) == 0: raise StopIteration seen_paths = set() for path in t.paths():
def main(argv): parser = argparse.ArgumentParser(description='...') parser.add_argument('-d','--domain',default='AISpeech',action='store',help='which domain: AISpeech or SpeechLab') parser.add_argument('-w','--weight',default=-1,action='store',metavar='number',type=float,help='weight number') parser.add_argument('--test',action='store_true') args = parser.parse_args() lex_file = open(os.path.join(PATH_TO_DATA[args.domain],'rules.txt'), 'r') weight = args.weight if not args.test: out_lex_file = open(os.path.join(PATH_TO_DATA[args.domain],'rules.release.txt'), 'w') else: out_lex_file = open(os.path.join(PATH_TO_DATA[args.domain],'rules.test.release.txt'), 'w') cws_model_path = PATH_TO_SPLIT_WORDS # 分词模型路径,模型名称为`cws.model` dict_path = os.path.join(PATH_TO_DATA[args.domain], 'dict.txt') # 领域相关的词典,用于帮助分词 segmentor = Segmentor() # 初始化实例 segmentor.load_with_lexicon(cws_model_path,dict_path) # 加载模型 if concept_fst_dict!={}: concept_fst_dict.clear() if constraints_names!={}: constraints_names.clear() macro_patterns = {} all_patterns = [] for line in lex_file: line=line.strip() if line == '' or line.startswith('%'): continue if '=>' not in line: #规则宏 pat_name, pat = line.strip(';').split('=') macro_patterns['${'+pat_name+'}'] = extract_simple_rules(pat.strip(), macro_patterns) else: #正常规则 pattern, node_info = line.split('=>') chunk_list = extract_simple_rules(pattern.strip(), macro_patterns) all_patterns.append((chunk_list, node_info)) isyms = ["<eps>"] label_voc = {} osyms = ["<eps>", "<unk>"] word_voc = {} #["<unk>"] #<unk> should be defined manually for chunk_list,_ in all_patterns: for word in chunk_list: if word[0] not in ['(', ')', '|']: word = word.strip('?') word_voc[word] = 1 osyms = osyms + list(word_voc) osyms_table = fst.SymbolTable() for idx,val in enumerate(osyms): osyms_table[val] = idx isyms_table = fst.SymbolTable() for idx,val in enumerate(isyms): isyms_table[val] = idx for pattern_idx, (pattern_chunk_list, node_info) in enumerate(all_patterns): # unique_rules = set() replace_mapping_dict = {} concept_fst = fst.StdTransducer(isyms=isyms_table, osyms=osyms_table) segment_stack = [{'start_of_this_segment':0, 'end_of_this_segment':0}] segment_stack[0]['value'] = '<eps>' cursor_head, cursor_tail = 0, 1 argument_count = 0 # print('Processing rule',pattern_chunk_list,'=>',node_info) for word in pattern_chunk_list: if word == '(': argument_count += 1 segment_stack.append({'start_of_this_segment':cursor_tail, 'end_of_this_segment':0, 'value':segment_stack[-1]['value']}) segment_stack[-1]['head_arc'] = [cursor_head, cursor_tail] cursor_tail += 1 cursor_head = cursor_tail - 1 elif word[0] == ')': if segment_stack[-1]['end_of_this_segment'] == 0: segment_stack[-1]['end_of_this_segment'] = cursor_head else: concept_fst.add_arc(cursor_head, segment_stack[-1]['end_of_this_segment'], '<eps>', '<eps>') cursor_head = segment_stack[-1]['end_of_this_segment'] if word == ')?': concept_fst.add_arc(segment_stack[-1]['head_arc'][0], segment_stack[-1]['head_arc'][1], '<eps>', '<eps>') concept_fst.add_arc(segment_stack[-1]['start_of_this_segment'], segment_stack[-1]['end_of_this_segment'], '<eps>', '<eps>') else: concept_fst.add_arc(segment_stack[-1]['head_arc'][0], segment_stack[-1]['head_arc'][1], '<eps>', '<eps>') segment_stack.pop() elif word == '|': if segment_stack[-1]['end_of_this_segment'] == 0: segment_stack[-1]['end_of_this_segment'] = cursor_head else: concept_fst.add_arc(cursor_head, segment_stack[-1]['end_of_this_segment'], '<eps>', '<eps>') cursor_head = segment_stack[-1]['start_of_this_segment'] else: if word[-1] == '?': concept_fst.add_arc(cursor_head, cursor_tail, '<eps>', '<eps>') word = word[:-1] else: pass next_state = add_arc(concept_fst, cursor_head, cursor_tail, word, segment_stack[-1]['value']) cursor_head = cursor_tail cursor_tail = next_state if segment_stack[-1]['end_of_this_segment'] == 0: segment_stack[-1]['end_of_this_segment'] = cursor_head else: concept_fst.add_arc(cursor_head, segment_stack[-1]['end_of_this_segment'], '<eps>', '<eps>') final_state_idx = segment_stack[-1]['end_of_this_segment'] concept_fst[final_state_idx].final = True concept_fst = concept_fst.inverse() concept_fst = concept_fst.determinize() concept_fst.minimize() concept_fst = concept_fst.inverse() t = concept_fst paths=list(t.paths()) random.shuffle(paths) if not args.test: if extract_proper_num(len(paths))>len(paths): paths=paths*(extract_proper_num(len(paths))//len(paths))+paths[:extract_proper_num(len(paths))%len(paths)] else: paths=paths[:extract_proper_num(len(paths))] else: paths=paths[:2] if len(paths)>=2 else paths for output in paths: raw_path = [] for arc in output: raw_path.append((t.osyms.find(arc.olabel), t.isyms.find(arc.ilabel))) path = raw_path input_seq = [] output_seq = [] for word, label in path: if word not in ['<eps>', u"ε"]: input_seq.append(word) if label not in ['<eps>', u"ε"]: if label == '_' and word not in ['<eps>', u"ε"]: output_seq.append(word) elif label != '_': output_seq.append(label) pattern = input_seq sentence = [item if item[0] != '$' else ',' for item in pattern] tags = [item for item in pattern if item[0] == '$'] sentence = ''.join(sentence) words = segmentor.segment(sentence) new_words = [] tag_idx = 0 for word in words: word = word if word == ',': word = tags[tag_idx] tag_idx += 1 new_words.append(word) new_rule_simple = ' '.join(new_words)+' => '+node_info out_lex_file.write(new_rule_simple+'\n')