def create_priors(priors, isym, osym, code): """This function creates a linear FST and adds a <sigma> (joker) symbol at the end as a place holder""" priors = priors.split(";") # init a trasducer f = fst.Fst() f.set_input_symbols(isym) f.set_output_symbols(osym) s0 = f.add_state() f.set_start(s0) old = s0 sig = "<sigma>" # adding priors for j in range(len(priors)): new = f.add_state() f.add_arc(old, fst.Arc(code[priors[j]], code[priors[j]], fst.Weight(f.weight_type(), 1.0), new)) old = new new = f.add_state() # adding <sigma> f.add_arc(old, fst.Arc(code[sig], code[sig], fst.Weight(f.weight_type(), 1.0), new)) f.add_arc(new, fst.Arc(code[sig], code[sig], fst.Weight(f.weight_type(), 1.0), new)) return f,new
def generate_phone_sequence_recognition_wfst(n, state_table, phone_table): """ generate a HMM to recognise any single phone sequence in the lexicon Args: n (int): states per phone HMM Returns: the constructed WFST """ f = fst.Fst() # create a single start state start_state = f.add_state() f.set_start(start_state) phone_set = set() for pronunciation in lex.values(): phone_set = phone_set.union(pronunciation) for phone in phone_set: current_state = f.add_state() f.add_arc(start_state, fst.Arc(0, 0, None, current_state)) end_state = generate_phone_wfst(f, current_state, phone, n, state_table, phone_table) f.add_arc(end_state, fst.Arc(0, 0, None, start_state)) f.set_final(end_state) return f
def generate_phone_sequence_recognition_wfst(n): """ generate a HMM to recognise any single phone sequence in the lexicon Args: n (int): states per phone HMM Returns: the constructed WFST """ f = fst.Fst('log') # create a single start state start_state = f.add_state() f.set_start(start_state) for i, phone in phone_table: if phone != '<eps>': tmp_state = f.add_state() weight = fst.Weight('log', -math.log(phone_table.num_symbols())) f.add_arc(start_state, fst.Arc(0, 0, weight, tmp_state)) last_state = generate_phone_wfst(f, tmp_state, phone, n) f.set_final(last_state) weight = fst.Weight('log', -math.log(1)) f.add_arc(last_state, fst.Arc(0, 0, weight, start_state)) return f
def generate_word_sequence_recognition_wfst(n): """ generate a HMM to recognise any single word sequence for words in the lexicon Args: n (int): states per phone HMM Returns: the constructed WFST """ f = fst.Fst('log') # create a single start state start_state = f.add_state() f.set_start(start_state) for _, word in word_table: if word != '<eps>': tmp_state = f.add_state() weight = fst.Weight('log', -math.log(word_table.num_symbols())) f.add_arc(start_state, fst.Arc(0, 0, weight, tmp_state)) word_wfst = generate_word_wfst(f, tmp_state, word, n) weight = fst.Weight('log', -math.log(1.0)) f.add_arc(list(word_wfst.states())[-1], fst.Arc(0, 0, weight, start_state)) return f
def generate_phone_wfst(f, start_state, phone, n): """ Generate a WFST representing an n-state left-to-right phone HMM. Args: f (fst.Fst()): an FST object, assumed to exist already start_state (int): the index of the first state, assumed to exist already phone (str): the phone label n (int): number of states of the HMM excluding start and end Returns: the final state of the FST """ current_state = start_state out_label = phone_table.find(phone) for i in range(1, n+1): in_label = state_table.find('{}_{}'.format(phone, i)) weight = fst.Weight('log', -math.log(0.1)) f.add_arc(current_state, fst.Arc(in_label, 0, weight, current_state)) new_state = f.add_state() weight = fst.Weight('log', -math.log(0.9)) f.add_arc(current_state, fst.Arc(in_label, out_label, weight, new_state)) current_state = new_state return current_state
def generate_parallel_path_wfst(f, start_state, n): """ Generate a WFST representing an n-state parallel-path left-to-right HMM Args: f (fst.Fst()): an FST object, assumed to exist already start_state (int): the index of the first state, assumed to exist already n (int): number of states of the HMM excluding start and end Returns: the final state of the FST """ current_state = start_state for i in range(n): f.add_arc(current_state, fst.Arc(0, 0, None, current_state)) if current_state-2 >= 0: f.add_arc(current_state-2, fst.Arc(0, 0, None, current_state)) new_state = f.add_state() f.add_arc(current_state, fst.Arc(0, 0, None, new_state)) current_state = new_state f.add_arc(current_state-2, fst.Arc(0, 0, None, current_state)) # f.set_final(current_state) return current_state
def build_refiner(isyms_fname, refiner_fname): """build refiner this fst would help extract the last two states (one last arc) of the machine """ # read isyms input_syms = fst.SymbolTable.read_text(isyms_fname) code = {} for ltr, c in input_syms: code[c]=ltr # build refiner refiner = fst.Fst() refiner.set_input_symbols(input_syms) refiner.set_output_symbols(input_syms) s0 = refiner.add_state() s1 = refiner.add_state() for c, ltr in code.items(): if ltr == 0: continue if ltr < 100: refiner.add_arc(s0, fst.Arc(code[c], code["<epsilon>"], fst.Weight(refiner.weight_type(), 1.0), s0)) refiner.add_arc(s0, fst.Arc(code[c], code[c], fst.Weight(refiner.weight_type(), 1.0), s1)) refiner.set_start(s0) refiner.set_final(s1) # save refiner refiner.write(refiner_fname)
def process_unigram(self, gram): """Process unigram in arpa file""" parts = re.split(r'\s+', gram) boff = '0.0' if len(parts) == 3: prob, word, boff = parts elif len(parts) == 2: prob, word = parts else: raise NotImplementedError if word not in self.words_table: return weight = convert_weight(prob) boff = convert_weight(boff) if word == '</s>': src = self.unigram2state['<start>'] self.grammar_fst.set_final(src, weight) elif word == '<s>': src = self.unigram2state['<s>'] des = self.unigram2state['<start>'] self.grammar_fst.add_arc( src, fst.Arc(self.sid('#0'), self.sid('<eps>'), boff, des)) else: src = self.unigram2state['<start>'] if word in self.unigram2state: des = self.unigram2state[word] else: des = self.grammar_fst.add_state() self.unigram2state[word] = des self.grammar_fst.add_arc( src, fst.Arc(self.sid(word), self.sid(word), weight, des)) self.grammar_fst.add_arc( des, fst.Arc(self.sid('#0'), self.sid('<eps>'), boff, src))
def process_trigram(self, gram): """Process trigram in arpa file""" prob, hist1, hist2, word = re.split(r'\s+', gram) if (hist1 not in self.words_table or hist2 not in self.words_table or word not in self.words_table): return boff = '0.0' weight = convert_weight(prob) boff = convert_weight(boff) bigram1 = hist1 + '/' + hist2 if bigram1 not in self.bigram2state: logging.info( '[{} {} {} {}] skipped: no parent (n-1)-gram exists'.format( prob, hist1, hist2, word)) return bigram2 = hist2 + '/' + word src = self.bigram2state[bigram1] if word == '</s>': self.grammar_fst.set_final(src, weight) else: if bigram2 in self.bigram2state: des = self.bigram2state[bigram2] else: des = self.grammar_fst.add_state() self.bigram2state[bigram2] = des if word in self.unigram2state: boff_state = self.unigram2state[word] else: boff_state = self.unigram2state['<start>'] self.grammar_fst.add_arc( des, fst.Arc(self.sid('#0'), self.sid('<eps>'), boff, boff_state)) self.grammar_fst.add_arc( src, fst.Arc(self.sid(word), self.sid(word), weight, des))
def generate_WFST_final_probability(n, lex, weight_fwd, weight_self, weights_final, original=False): """ generate a HMM to recognise any single word sequence for words in the lexicon Args: n (int): states per phone HMM original (bool): True/False - origianl/optimized lexicon weight_fwd (int): weight value weight_self (int): weight value of self node weight_final (dict): word -> probability of final state Returns: the constructed WFST """ f = fst.Fst('log') none_weight = fst.Weight('log', -math.log(1)) lex = parse_lexicon(lex, original) word_table, phone_table, state_table = generate_symbols_table(lex, 3) output_table = generate_output_table(word_table, phone_table) # create a single start state start_state = f.add_state() f.set_start(start_state) for word, phone_list in lex.items(): for phones in phone_list: initial_state = f.add_state() f.add_arc( start_state, fst.Arc(0, output_table.find(word), none_weight, initial_state)) current_state = initial_state for phone in phones: current_state = generate_phone_wfst(f, current_state, phone, n, state_table, output_table, weight_fwd, weight_self) f.set_final(current_state) f.add_arc(current_state, fst.Arc(0, 0, none_weight, start_state)) # final word state should be current state prob = weights_final[word] weight = fst.Weight('log', -math.log(prob)) f.set_final(current_state, weight) # print(f"Current state: {current_state} for word {word} is prob {prob} with log prob{(weight)}") f.set_input_symbols(state_table) f.set_output_symbols(output_table) return f, word_table
def set_final(self, f, output_table, word, start_state): # add dummy state which outputs word final_state = f.add_state() # Add arc to output word which connects to end state f.add_arc(self.end_state, fst.Arc(0, output_table.find(word), None, final_state)) f.set_final(final_state) # add arc to start state f.add_arc(final_state, fst.Arc(0, 0, None, start_state))
def build_lm(dev_fname, isyms_fname, constraints, lattice_output, refiner_fname): """ Make a lattice that maps lemmas and constraints (or priors) to an inflected version """ # rewrite constraints constraints = constraints.replace("_",";") # read isyms input_syms = fst.SymbolTable.read_text(isyms_fname) s_fin = '</s>' code = {} for ltr, c in input_syms: code[c]=ltr # init the lattice f_big = fst.Fst() f_big.set_input_symbols(input_syms) f_big.set_output_symbols(input_syms) for line in open(dev_fname,'r').readlines(): cns, lemma, inflection = line.split()[-3:] if cns == constraints: print(cns, lemma, inflection) # find idx that the strings diverge idx = 0 for i, (lm, flc) in enumerate(zip(lemma, inflection)): if lm !=flc: idx = i break f, old= create_priors(cns, input_syms, input_syms, code) keep = old for j in range(idx,len(lemma)): new = f.add_state() f.add_arc(old, fst.Arc(code[lemma[j]], code[lemma[j]], fst.Weight(f.weight_type(), 1.0), new)) old = new new = f.add_state() # the residual of the lemma is mapped to the inflection residual (indirectly) sym = lemma[idx:]+"_"+inflection[idx:] print(lemma, inflection, sym) f.add_arc(old, fst.Arc(code[sym], code[s_fin], fst.Weight(f.weight_type(), 1.0), new)) #f.add_arc(old, fst.Arc(code[inflection[idx:]], code[s_fin], fst.Weight(f.weight_type(), 1.0), new)) #f.add_arc(old, fst.Arc(code[s_fin], code[inflection[idx:]], fst.Weight(f.weight_type(), 1.0), new)) f.set_final(new) f_big.union(f) f_big = fst.determinize(f_big.rmepsilon()) # add <sigma> state in the <sigma place holder> for c, ltr in code.items(): if int(ltr)>1 and int(ltr)<36: # (hard coded) symbols of Runssian + 2 more f_big.add_arc(keep, fst.Arc(code[c], code[c], fst.Weight(f_big.weight_type(), 1.0), keep)) f_big.invert() # save lattice f_big.write(lattice_output)
def test_simple(self): f = fst.Fst() s0 = f.add_state() s1 = f.add_state() s2 = f.add_state() f.add_arc(s0, fst.Arc(1, 1, fst.Weight(f.weight_type(), 3.0), s1)) f.add_arc(s0, fst.Arc(1, 1, fst.Weight.One(f.weight_type()), s2)) f.set_start(s0) f.set_final(s2, fst.Weight(f.weight_type(), 1.5)) # Test fst self.assertEqual(f.num_states(), 3) self.assertAlmostEqual(float(f.final(s2)), 1.5)
def OpenFST_Automata_Example(): f = fst.Fst() s0 = f.add_state() s1 = f.add_state() s2 = f.add_state() f.add_arc(s0, fst.Arc(1, 2, fst.Weight(f.weight_type(), 3.0), s1)) f.add_arc(s0, fst.Arc(1, 3, fst.Weight.One(f.weight_type()), s2)) f.add_arc(s1, fst.Arc(2, 1, fst.Weight(f.weight_type(), 1.0), s2)) f.set_start(s0) f.set_final(s2, fst.Weight(f.weight_type(), 1.5)) print(s0, s1, s2) print(f)
def make_token_fst(self, blank): """ make token fst and map disambiguation symbols to <eps> """ start_state = self.token_fst.add_state() blank_start = self.token_fst.add_state() blank_end = self.token_fst.add_state() self.token_fst.set_start(start_state) self.token_fst.set_final(start_state, 0.0) blank_id = self.graphemes_table[blank] eps_id = self.graphemes_table['<eps>'] assert eps_id == 0 # 0->-1 # TODO: change eps_id self.token_fst.add_arc(start_state, fst.Arc(0, 0, 0.0, blank_start)) self.token_fst.add_arc(blank_start, fst.Arc(blank_id, 0, 0.0, blank_start)) self.token_fst.add_arc(blank_end, fst.Arc(blank_id, 0, 0.0, blank_end)) self.token_fst.add_arc(blank_end, fst.Arc(0, 0, 0.0, start_state)) for token, idx in self.graphemes_table.items(): if token == blank or token == '<eps>': continue # disambig symbols starts with '#' if token[0] == '#': self.token_fst.add_arc(start_state, fst.Arc(0, idx, 0.0, start_state)) else: node = self.token_fst.add_state() self.token_fst.add_arc(blank_start, fst.Arc(idx, idx, 0.0, node)) self.token_fst.add_arc(node, fst.Arc(idx, 0, 0.0, node)) self.token_fst.add_arc(node, fst.Arc(0, 0, 0.0, blank_end))
def generate_phone_wfst(f, start_state, phone, n, state_table, phone_table, weight_fwd, weight_self): """ Generate a WFST representing an n-state left-to-right phone HMM. Args: f (fst.Fst()): an FST object, assumed to exist already start_state (int): the index of the first state, assumed to exist already phone (str): the phone label n (int): number of states of the HMM weight_fwd (int): weight value weight_self (int): weight value of self node Returns: the final state of the FST """ current_state = start_state for i in range(1, n + 1): in_label = state_table.find('{}_{}'.format(phone, i)) sl_weight = None if weight_self == None else fst.Weight( 'log', -math.log(weight_self)) # weight for self-loop next_weight = None if weight_fwd == None else fst.Weight( 'log', -math.log(weight_fwd)) # weight for forward # self-loop back to current state f.add_arc(current_state, fst.Arc(in_label, 0, sl_weight, current_state)) # transition to next state # we want to output the phone label on the final state # note: if outputting words instead this code should be modified if i == n: out_label = phone_table.find(phone) else: out_label = 0 # output empty <eps> label next_state = f.add_state() # next_weight = fst.Weight('log', -math.log(0.9)) # weight to next state f.add_arc(current_state, fst.Arc(in_label, out_label, next_weight, next_state)) current_state = next_state return current_state
def generate_word_sequence_recognition_wfst(n, lex, original=False, weight_fwd=None, weight_self=None): """ generate a HMM to recognise any single word sequence for words in the lexicon Args: n (int): states per phone HMM original (bool): True/False - origianl/optimized lexicon weight_fwd (int): weight value weight_self (int): weight value of self node Returns: the constructed WFST """ if (weight_fwd != None and weight_self != None): f = fst.Fst('log') none_weight = fst.Weight('log', -math.log(1)) else: f = fst.Fst() none_weight = None lex = parse_lexicon(lex, original) word_table, phone_table, state_table = generate_symbols_table(lex, 3) output_table = generate_output_table(word_table, phone_table) # create a single start state start_state = f.add_state() f.set_start(start_state) # make fst for word, phone_list in lex.items(): for phones in phone_list: initial_state = f.add_state() f.add_arc( start_state, fst.Arc(0, output_table.find(word), none_weight, initial_state)) current_state = initial_state for phone in phones: current_state = generate_phone_wfst(f, current_state, phone, n, state_table, output_table, weight_fwd, weight_self) f.set_final(current_state) f.add_arc(current_state, fst.Arc(0, 0, none_weight, start_state)) f.set_input_symbols(state_table) f.set_output_symbols(output_table) return f, word_table
def __init__(self, phone, state, f, state_table, output_table): self.children = [] self.phone = phone self.start_state = f.add_state() f.add_arc(state, fst.Arc(0, 0, None, self.start_state)) # -- self.generate_phone_seq(f, state_table, output_table)
def make_input_fst(query, pysym): f = fst.Fst() start = f.add_state() end = f.add_state() f.set_start(start) f.set_final(end, fst.Weight(f.weight_type(), 0.0)) prev_state = start for ch in query: n = f.add_state() label = pysym[ch] f.add_arc(prev_state, fst.Arc(label, label, fst.Weight(f.weight_type(), 0.0), n)) prev_state = n f.add_arc( prev_state, fst.Arc(pysym['<eps>'], pysym['<eps>'], fst.Weight(f.weight_type(), 0.0), end)) f.write('input.fst') return f
def creation_automata(): transitions = {"s0": {"1:1:0": ["s0", "s1"], "2:3:1": ["s0", "s2"]}} # La methode iteritems appliquee a un dictionnaire permet de decomposer les differents "niveaux de profondeur" du dictionnaire en tableaux # iteritems appliquee a transitions transforme le dict en un tableau contenant les differentes transitions (la transition s0 a l'index 0, la transition s1 a l'index 1 etc...) puis pour chaque tableau de transition celui-ci contient encore 2 tableaux l'un pour le label (s_i a l'index 0) et l'autre pour la valeur associee (la chaine de caracteres contenant tous les arcs a l'index 1) # iteritems appliquee a arcs transforme le dict d'arcs en un tableau ou chaque cellule contient un arc et pour chaque cellule contenant un arc, il y a un tableau contenant a l'index 0 le label de l'arc et a l'index 1 la valeur de l'arc c'est a dire la liste de destinations # les etats de destinations sont contenus dans une liste donc il n'y a pas besoin d'utiliser la methode iteritems. for src_state_label, arcs in transitions.iteritems( ): # parcours du 1er niveau du dict : les cles sont les labels des etats sources et les objets sont les arcs associes a ces etats sources add_automate_state(src_state_label) for arc_label, set_dsts_states in arcs.iteritems( ): # parcours du 2eme niveau du dict : les cles sont les labels des arcs et les objets parcourus sont les listes d'etats de destination for dst_state_label in set_dsts_states: # parcours du 3eme niveau du dict : le 3eme niveau n'est pas un dictionnaire mais une liste ce qui signifie que les etats ne sont pas indexes par une cle quelconque mais par un entier : les objets parcourus sont les etats de destination add_automate_state(dst_state_label) for state_label, arcs in transitions.iteritems(): for arc_label, set_dsts_states in arcs.iteritems(): chars = arc_label.split(':') for dst_state_label in set_dsts_states: automate.add_arc( automate_states[state_label], fst.Arc(int(chars[0]), int(chars[1]), fst.Weight(automate.weight_type(), int(chars[2])), automate_states[dst_state_label])) automate.set_start(automate_states['s0']) automate.set_final(automate_states['s2'], fst.Weight(automate.weight_type(), 1.5)) print(automate) # Generation du code LaTeX au format GraphViz # Affichage des noeuds avec leurs labels i = 0 print("digraph G {") for state_label, state in automate_states.iteritems(): index = state_label.split("s")[1] display_node = index + " [label = \"" + state_label + "\"]" i += 1 print(display_node) # Affichage des arcs avec leurs labels for src_state_label, arcs in transitions.iteritems(): src_index = src_state_label.split("s")[1] for arc_label, set_dsts_states in arcs.iteritems(): for dst_state_label in set_dsts_states: dst_index = dst_state_label.split("s")[1] display_edge = src_index + "->" + dst_index + " [label = \"" + arc_label + "\"]" print(display_edge) print("}") return (automate)
def make_fst(word_sym, phone_sym, pydict_file): with open(pydict_file, 'r') as rp: f = fst.Fst() start = f.add_state() end = f.add_state() f.set_start(start) f.add_arc(start, fst.Arc(phone_sym['<eps>'], word_sym['<s>'], fst.Weight(f.weight_type(), 0.0), start)) # 自转 f.add_arc(end, fst.Arc(phone_sym['<eps>'], word_sym['</s>'], fst.Weight(f.weight_type(), 0.0), end)) # 自转 f.add_arc(end, fst.Arc(phone_sym['<eps>'], word_sym['<eps>'], fst.Weight(f.weight_type(), 0.0), start)) # 1 --> 0 f.set_final(end, fst.Weight(f.weight_type(), 0.0)) for l in rp.readlines(): items = l.strip().split(' ') prev_state = start ilabel = phone_sym['<eps>'] olabel = word_sym['<eps>'] for i in range(len(items[0])): n = f.add_state() pych = items[0][i] chch = items[1] ilabel = phone_sym[pych] if (i == 0): olabel = word_sym[chch] else: olabel = word_sym['<eps>'] f.add_arc( prev_state, fst.Arc(ilabel, olabel, fst.Weight(f.weight_type(), 0.0), n)) prev_state = n # connect the last state with end node f.add_arc( prev_state, fst.Arc(phone_sym['<eps>'], olabel, fst.Weight(f.weight_type(), 0.0), end)) return f
def make_lexicon_fst(self, sil_symbol, sil_prob): """Convert lexicon to WFST format There is always a disambig symbols after sil_symbol the special disambig symbols have been added in self.create_disambig_graphemes_table function Args: sil_prob: probability from end of a word to sil symbol sil_symbol: 'SIL' for phone-based ASR;'<space>' for graphemeacter-based ASR """ sil_cost = -1.0 * math.log(sil_prob) no_sil_cost = -1.0 * math.log(1.0 - sil_prob) sil_disambig_id = self.disambig_graphemes['#' + str(self.max_disambig)] start_state = self.lexicon_fst.add_state() loop_state = self.lexicon_fst.add_state() sil_state = self.lexicon_fst.add_state() disambig_state = self.lexicon_fst.add_state() self.lexicon_fst.set_start(start_state) self.lexicon_fst.add_arc(start_state, fst.Arc(self.disambig_graphemes['<eps>'], self.words['<eps>'], no_sil_cost, loop_state)) self.lexicon_fst.add_arc(start_state, fst.Arc(self.disambig_graphemes[sil_symbol], self.words['<eps>'], sil_cost, disambig_state)) self.lexicon_fst.add_arc(sil_state, fst.Arc(self.disambig_graphemes[sil_symbol], self.words['<eps>'], 0.0, disambig_state)) self.lexicon_fst.add_arc(disambig_state, fst.Arc(sil_disambig_id, self.words['<eps>'], 0.0, loop_state)) for word, grapheme_seq in self.lexicons: word_id = self.words[word] grapheme_id_seq = [self.disambig_graphemes[grapheme] for grapheme in grapheme_seq] eps_id = self.words['<eps>'] src = loop_state for pos, grapheme_id in enumerate(grapheme_id_seq[:-1]): des = self.lexicon_fst.add_state() if pos == 0: self.lexicon_fst.add_arc(src, fst.Arc(grapheme_id, word_id, 0.0, des)) else: self.lexicon_fst.add_arc(src, fst.Arc(grapheme_id, eps_id, 0.0, des)) src = des last_grapheme_id = grapheme_id_seq[-1] self.lexicon_fst.add_arc(src, fst.Arc(last_grapheme_id, eps_id, no_sil_cost, loop_state)) self.lexicon_fst.add_arc(src, fst.Arc(last_grapheme_id, eps_id, sil_cost, sil_state)) self.lexicon_fst.set_final(loop_state, 0.0) self.lexicon_fst.add_arc(loop_state, fst.Arc(self.disambig_graphemes['#0'], self.words['#0'], 0.0, loop_state)) self.lexicon_fst.arcsort(sort_type='olabel')
def process_bigram(self, gram): """Process bigram in arpa file""" parts = re.split(r'\s+', gram) boff = '0.0' if len(parts) == 4: prob, hist, word, boff = parts elif len(parts) == 3: prob, hist, word = parts else: raise NotImplementedError if (hist not in self.words_table or word not in self.words_table): return weight = convert_weight(prob) boff = convert_weight(boff) if hist not in self.unigram2state: logging.info( '[{} {} {}] skipped: no parent (n-1)-gram exists'.format( prob, hist, word)) return if word == '</s>': src = self.unigram2state[hist] self.grammar_fst.set_final(src, weight) else: src = self.unigram2state[hist] bigram = hist + '/' + word if bigram in self.bigram2state: des = self.bigram2state[bigram] else: des = self.grammar_fst.add_state() self.bigram2state[bigram] = des if word in self.unigram2state: boff_state = self.unigram2state[word] else: boff_state = self.unigram2state['<start>'] self.grammar_fst.add_arc( des, fst.Arc(self.sid('#0'), self.sid('<eps>'), boff, boff_state)) self.grammar_fst.add_arc( src, fst.Arc(self.sid(word), self.sid(word), weight, des))
def make_query(self, cns, lemma): cns = cns.split(";") lemma = list(lemma) q = cns + ["<sigma>"] + lemma + ["</s>"] f = fst.Fst() f.set_input_symbols(self.input_syms) f.set_output_symbols(self.input_syms) s0 = f.add_state() f.set_start(s0) old = s0 for j in range(len(q)): new = f.add_state() f.add_arc(old, fst.Arc(self.code[q[j]], self.code[q[j]], fst.Weight(f.weight_type(), 1.0), new)) old = new f.set_final(old) return f
def Automata_Building(ref_string, levenshtein_distance, output_weight): dict_automata = Levenshtein_Automata_Dico(ref_string, levenshtein_distance) # print(dict_automata) label_initial_state = "0;0" label_final_state = str(len(ref_string)) + ";" + str(levenshtein_distance) # Une fois l'automate represente sous forme de dictionnaire, on cree l'automate grace aux fonctions de la librairie openfst # Creation de tous les etats de l'automate (etats source et de destination confondus) # La fonction add automate state cree un dictionnaire automate states dont les cles sont les labels des etats et les valeurs associees sont les etats crees grace a la fonction de creation d'etats d'openfst state_index = 1 for src_label, set_arcs in dict_automata.iteritems(): state_index = add_automate_state(src_label, state_index) for arc_label, dst_states in set_arcs.iteritems(): for dst_label in dst_states: state_index = add_automate_state(dst_label, state_index) # print(automate_states) # # Creation des arcs de l'automate for src_label, set_arcs in dict_automata.iteritems(): for arc_label, dst_states in set_arcs.iteritems(): label_info = arc_label.split("::") transmitted_char = int(convertSymToLabel(label_info[0])) consummed_char = int(convertSymToLabel(label_info[1])) weight = int(label_info[2]) src_state_index = automate_states[src_label][1] print(transmitted_char, consummed_char, weight) for dst_label in dst_states: # print(dst_label) dst_state_index = automate_states[dst_label][1] automate.add_arc( src_state_index, fst.Arc(transmitted_char, consummed_char, fst.Weight(automate.weight_type(), weight), dst_state_index)) automate.set_start(automate_states[label_initial_state][1]) automate.set_final(automate_states[label_final_state][1], fst.Weight(automate.weight_type(), output_weight)) automate.draw("automata.dot") print(automate) return (automate)
def SimpleAutomata(): src_state_label = "0;0" src_state_index = automate.add_state() dst_state_label = "0;1" dst_state_index = automate.add_state() arc_label = "2:4:1" label_string = arc_label.split(":") consummed_char = 2 # int(label_string[0]) transmitted_char = 4 # int(label_string[1]) weight = 1 # int(label_string[2]) automate.add_arc( src_state_index, fst.Arc(transmitted_char, consummed_char, fst.Weight(automate.weight_type(), weight), dst_state_index)) print(automate)
def Automata_Building(ref_string, levenshtein_distance, output_weight): levenshtein_automata = {} levenshtein_automata = Levenshtein_Automata_Dico(ref_string, levenshtein_distance) # print(levenshtein_automata) label_inital_state = "0;0" label_final_state = str(len(ref_string)) + ";" + str(levenshtein_distance) # Une fois l'automate represente sous forme de dictionnaire, on cree l'automate grace aux fonctions de la librairie openfst # Creation de tous les etats de l'automate (etats source et de destination confondus) # La fonction add automate state cree un dictionnaire automate states dont les cles sont les labels des etats et les valeurs associees sont les etats crees grace a la fonction de creation d'etats d'openfst for src_label, set_arcs in levenshtein_automata.iteritems(): add_automate_state(src_label) for arc_label, set_dsts in set_arcs.iteritems(): for dst_label in set_dsts: add_automate_state(dst_label) print(automate) # # Creation des arcs de l'automate for src_label, set_arcs in levenshtein_automata.iteritems(): for arc_label, set_dsts in set_arcs.iteritems(): transmitted_char = arc_label.split(":")[0] consummed_char = arc_label.split(":")[1] weight = arc_label.split(":")[2] print(transmitted_char, consummed_char, weight) for dst_label in set_dsts: automate.add_arc( automate_states[src_label], fst.Arc(int(convertSymToLabel(transmitted_char)), int(convertSymToLabel(consummed_char)), fst.Weight(automate.weight_type(), int(weight)), automate_states[dst_label])) automate.set_start(automate_states[label_inital_state]) automate.set_final(automate_states[label_final_state], fst.Weight(automate.weight_type(), output_weight)) automate.draw("automata.dot") print(automate) return (automate)
def add_word(word): i_words = tokenizer.token2idx(word) + [tokenizer.space_idx] if not fst.num_states(): initial_state = fst.add_state() assert initial_state == 0 fst.set_start(initial_state) source_state = fst.start() dest_state = None for i in i_words: # The initial state of FST is state 0, hence the index of chars in # the FST should start from 1 to avoid the conflict with the initial # state, otherwise wrong decoding results would be given. i += 1 dest_state = fst.add_state() fst.add_arc(source_state, openfst.Arc(i, i, 0, dest_state)) source_state = dest_state fst.set_final(dest_state, openfst.Weight.One('tropical'))
def add_arc_to_automate(src_state_label, dst_state_label, arc_label, automate, states_dict): src_state_index = get_index(src_state_label, automate, states_dict) dst_state_index = get_index(dst_state_label, automate, states_dict) label_string = arc_label.split(":") # print(label_string[0], label_string[1], label_string[2]) consummed_char = convertSymToLabel(label_string[0]) # print(consummed_char) transmitted_char = convertSymToLabel(label_string[1]) # print(transmitted_char) weight = int(label_string[2]) # print(weight) automate.add_arc( src_state_index, fst.Arc(transmitted_char, consummed_char, fst.Weight(automate.weight_type(), weight), dst_state_index))
def generate_phone_recognition_wfst(n, state_table, phone_table): """ generate a HMM to recognise any single phone in the lexicon Args: n (int): states per phone HMM Returns: the constructed WFST """ f = fst.Fst() # create a single start state start_state = f.add_state() f.set_start(start_state) # get a list of all the phones in the lexicon # there are lots of way to do this. Here, we use the set() object # will contain all unique phones in the lexicon phone_set = set() for pronunciation in lex.values(): phone_set = phone_set.union(pronunciation) for phone in phone_set: # we need to add an empty arc from the start state to where the actual phone HMM # will begin. If you can't see why this is needed, try without it! current_state = f.add_state() f.add_arc(start_state, fst.Arc(0, 0, None, current_state)) end_state = generate_phone_wfst(f, current_state, phone, n, state_table, phone_table) f.set_final(end_state) return f