def __weighted_union(self, left, right, left_prob, right_prob): ''' Union the FSTs left, right with a weight. ''' # left hand side part. left_w = -math.log(left_prob) lhs = fst.Fst() lhs.set_input_symbols(left.input_symbols()) lhs.set_output_symbols(left.output_symbols()) lhs.add_state() lhs.set_start(0) lhs.add_state() lhs.add_arc(0, fst.Arc(0, 0, left_w, 1)) lhs.set_final(1) lhs.concat(left) # prefix part. right_w = -math.log(right_prob) rhs = fst.Fst() rhs.set_input_symbols(right.input_symbols()) rhs.set_output_symbols(right.output_symbols()) rhs.add_state() rhs.set_start(0) rhs.add_state() rhs.add_arc(0, fst.Arc(0, 0, right_w, 1)) rhs.set_final(1) rhs.concat(right) lhs.union(rhs) return lhs
def spellout_machine(wrdfname, ltr2wrdfst): lm = fst.Fst.read(ltr2wrdfst) s_in = lm.output_symbols() s_out = lm.input_symbols() letter = fst.Fst() letter.set_input_symbols(s_in) letter.set_output_symbols(s_out) letter.add_state() for word in open(wrdfname, "r").readlines(): word = word.strip() orig = copy.copy(word) # word = list(word) word += "#" #word = dig2word(word) nletter = fst.Fst() nletter.set_input_symbols(s_in) nletter.set_output_symbols(s_out) nletter.add_state() for i, ltr in enumerate(word): nletter.add_state() code2 = s_out.find(ltr) if i == 0: nletter.set_start(0) code1 = s_in.find(orig) nletter.add_arc(i, fst.Arc(code1, code2, None, i + 1)) else: code1 = s_in.find("<epsilon>") nletter.add_arc(i, fst.Arc(code1, code2, None, i + 1)) nletter.set_final(i + 1) letter.union(nletter) letter.rmepsilon() letter.write("spellout.fst")
def make_lexicon_fst(input_words, words): compiler = fst.Compiler() lexicon_fst = fst.Fst() start = lexicon_fst.add_state() lexicon_fst.set_start(start) last = lexicon_fst.add_state() # this line projects space to epsilon lexicon_fst.add_arc( start, fst.Arc(32, 0, fst.Weight.One(lexicon_fst.weight_type()), last)) lexicon_fst.set_final(last) for i, w in enumerate(words): w = w.strip() index = i + 1 last = lexicon_fst.add_state() lexicon_fst.add_arc( start, fst.Arc(ord(w[0]), index, fst.Weight.One(lexicon_fst.weight_type()), last)) for c in w[1:]: this = lexicon_fst.add_state() lexicon_fst.add_arc( last, fst.Arc(ord(c), 0, fst.Weight.One(lexicon_fst.weight_type()), this)) last = this lexicon_fst.set_final(last, 0) lexicon_fst = fst.determinize(lexicon_fst).minimize().closure() with open('words.syms', 'w') as f: f.write('<eps> 0\n') for i, w in enumerate(words + input_words): # we put word symbol here f.write('{} {}'.format(w, str(i + 1))) f.write('\n') f.write('<SPACE> {}\n'.format(str(32))) epsilon_fst = fst.Fst() start = epsilon_fst.add_state() end = epsilon_fst.add_state() for i, w in enumerate(words): index = i + 1 epsilon_fst.add_arc( start, fst.Arc(0, index, fst.Weight.One(epsilon_fst.weight_type()), end)) epsilon_fst.add_arc( start, fst.Arc(0, 32, fst.Weight.One(epsilon_fst.weight_type()), end)) epsilon_fst.set_final(end, 0) epsilon_fst.set_start(start) epsilon_fst = epsilon_fst.closure() return lexicon_fst, epsilon_fst
def clear(self): """ Clears all internal data. """ self.syms = fst.SymbolTable() self.E = fst.Fst() self.Ig = fst.Fst() self.Ip = fst.Fst() self.Ip_r = re.compile(u"") self.status = 0
def decompound(self, word): tree = Tree(word) self._split(word, tree) #print(tree) nleafnodes = tree.nleafnodes() #print("Number of leaf nodes:", nleafnodes) symtablel = sorted(tree.getsyms()) symtable = dict([(s, i) for i, s in enumerate(symtablel)]) #print("Symbols:") #print(symtable) fst = wfst.Fst() [fst.add_state() for i in range(nleafnodes + 1)] fst.set_final(nleafnodes, wfst.Weight.One(fst.weight_type())) fst.set_start(0) tree.makelattice(fst, 0, symtable, self.wordcost, firstword=True) #output fst for debugging # fstsymtable = wfst.SymbolTable(b"default") # for i, sym in enumerate(symtablel): # fstsymtable.add_symbol(sym.encode("utf-8"), i) # fst.set_input_symbols(fstsymtable) # fst.set_output_symbols(fstsymtable) # fst.write("/tmp/debug.fst") best = wfst.shortestpath(fst, nshortest=1) wordseq = label_seq(best, symtablel) return wordseq
def normalize(self, anfst): ''' produce a normalized fst ''' # possibly there's a shorter way # that keeps all in fst land dist = [] labels = [] syms = anfst.input_symbols() state = anfst.start() for arc in anfst.arcs(state): label = syms.find(arc.ilabel) pr = float(arc.weight) dist.append(BitWeight(pr)) # ebitweight gets -log(pr) only labels.append(label) sum_value = sum(dist, BitWeight(1e6)) # will sum in log domain (log-add) norm_dist = [(prob/sum_value).loge() for prob in dist] del anfst # construct a norm fst output = fst.Fst() output.set_input_symbols(syms) output.set_output_symbols(syms) output.add_state() output.add_state() for (pr, label) in zip(norm_dist,labels): code = syms.find(label) output.add_arc(0, fst.Arc(code, code, pr, 1)) output.set_start(0) output.set_final(1) return output
def enterRuleBody(self, ctx): super().enterRuleBody(ctx) # Create new FST for rule self.fst = fst.Fst() self.start_state = self.fst.add_state() self.fst.set_start(self.start_state) self.last_states[self.rule_name] = self.start_state self.weight_one = fst.Weight.One(self.fst.weight_type()) if self.is_public: # Check if this is the main rule of the grammar grammar_rule = self.grammar_name + "." + self.grammar_name if self.rule_name == grammar_rule: self.grammar_fst = self.fst # Cache FST self.fsts[self.rule_name] = self.fst # Reset state self.group_depth = 0 self.opt_states = {} self.alt_states = {} self.tag_states = {} self.exp_states = {} self.alt_ends = {} # Save anchor state self.alt_states[self.group_depth] = self.last_states[self.rule_name]
def build_chain_fst(labels, arc_type='log', vocab=None): """ Build an acceptor for string given by elements of labels. Args: labels - a sequence of labels in the range 1..S arc_type - fst arc type (standard or log) Returns: FST consuming symbols in the range 1..S. Notes: Elements of labels are assumed to be greater than zero (which maps to blank)! """ C = fst.Fst(arc_type=arc_type) weight_one = fst.Weight.One(C.weight_type()) s = C.add_state() C.set_start(s) for l in labels: s_next = C.add_state() C.add_arc(s, fst.Arc(l, l, weight_one, s_next)) s = s_next C.set_final(s) C.arcsort('ilabel') return C
def _build_transliterator(): td = fst.Fst() initial_state = td.add_state() td.set_start(initial_state) td.set_final(initial_state) long_vowel_possibilities = { 'u': 'ウ', 'i': 'イ', 'e': 'エ', 'o': ['オ', 'ウ'], 'a': 'ア' } long_vowel_states = { k: _long_vowel_mark_state(td, initial_state, v) for k, v in long_vowel_possibilities.items() } end_states = { 'n': initial_state, 'y': _build_small_y_state(td, long_vowel_states), **long_vowel_states } _build_sjsh(td, initial_state, end_states) _build_vowels(td, initial_state, end_states) _build_tdch(td, initial_state, end_states) _build_big_y(td, initial_state, end_states) _build_hpb(td, initial_state, end_states) _build_kg(td, initial_state, end_states) _build_r(td, initial_state, end_states) _build_m(td, initial_state, end_states) _build_n(td, initial_state, end_states) _build_w(td, initial_state, end_states) return td
def __call__(self, x): x, xs = transform_output(x) # Normalize log-posterior matrices, if necessary if self._normalize: x = log_softmax(x, dim=2) x = x.permute(1, 0, 2).cpu() self._output = [] D = x.size(2) for logpost, length in zip(x, xs): f = fst.Fst() f.set_start(f.add_state()) for t in range(length): f.add_state() for j in range(D): weight = fst.Weight(f.weight_type(), float(-logpost[t, j])) f.add_arc( t, fst.Arc( j + 1, # input label j + 1, # output label weight, # -logpost[t, j] t + 1, # nextstate ), ) f.set_final(length, fst.Weight.One(f.weight_type())) f.verify() self._output.append(f) return self._output
def enterRuleBody(self, ctx): self.in_rule = True if self.is_public: # Use main start state self.last_states[self.rule_name] = self.start_state else: # Create new FST self.fst = fst.Fst() self.start_state = self.fst.add_state() self.fst.set_start(self.start_state) self.last_states[self.rule_name] = self.start_state self.fsts[self.rule_name] = self.fst # Reset self.group_depth = 0 self.opt_states = {} self.alt_states = {} self.tag_states = {} self.exp_states = {} self.alt_ends = {} # Save anchor state self.alt_states[self.group_depth] = self.last_states[self.rule_name]
def genStrFst(ilabels, olabels=None): if olabels is None: olabels = ilabels fst = pywrapfst.Fst() initFst(fst) addArcLinear(fst, 0, ilabels, olabels, is_loop=False) return fst
def _make_slot_fst(state: int, intent_fst: fst.Fst, slot_to_fst: Dict[str, fst.Fst]): out_symbols = intent_fst.output_symbols() one_weight = fst.Weight.One(intent_fst.weight_type()) for arc in intent_fst.arcs(state): label = out_symbols.find(arc.olabel).decode() if label.startswith("__begin__"): slot_name = label[9:] # Big assumption here that each instance of a slot (e.g., location) # will produce the same FST, and therefore doesn't need to be # processed again. if slot_name in slot_to_fst: continue # skip duplicate slots end_label = f"__end__{slot_name}" # Create new FST slot_fst = fst.Fst() slot_fst.set_input_symbols(intent_fst.input_symbols()) slot_fst.set_output_symbols(intent_fst.output_symbols()) start_state = slot_fst.add_state() slot_fst.set_start(start_state) q = [arc.nextstate] state_map = {arc.nextstate: start_state} # Copy states/arcs from intent FST until __end__ is found while len(q) > 0: q_state = q.pop() for q_arc in intent_fst.arcs(q_state): slot_arc_label = out_symbols.find(q_arc.olabel).decode() if slot_arc_label != end_label: if not q_arc.nextstate in state_map: state_map[q_arc.nextstate] = slot_fst.add_state() # Create arc slot_fst.add_arc( state_map[q_state], fst.Arc( q_arc.ilabel, q_arc.olabel, one_weight, state_map[q_arc.nextstate], ), ) # Continue copy q.append(q_arc.nextstate) else: # Mark previous state as final slot_fst.set_final(state_map[q_state]) slot_to_fst[slot_name] = minimize_fst(slot_fst) # Recurse _make_slot_fst(arc.nextstate, intent_fst, slot_to_fst)
def acceptor_for_strings(strings: List[str], weights: List[float]) -> fst.Fst: """Create an acceptor for strings with weights""" strings, weights = zip(*sorted(zip(strings, weights))) td = fst.Fst() start_state = td.add_state() td.set_start(start_state) _build_acceptor_recursive(td, strings, weights, start_state, 0, 0, len(strings)) return td
def create_empty_fst(self, input_sym, output_sym): ''' Create an empty fst (only one state being final). ''' f = fst.Fst() f.set_input_symbols(input_sym) f.set_output_symbols(output_sym) f.add_state() f.set_start(0) f.set_final(0) return f
def make_input(chars, stoi): fst = wfst.Fst() s0 = fst.add_state() fst.set_start(s0) cs = s0 for c in chars: ns = fst.add_state() fst.add_arc(cs, wfst.Arc(stoi[c], stoi[c], wfst.Weight.One(fst.weight_type()), ns)) cs = ns fst.set_final(cs, wfst.Weight.One(fst.weight_type())) return fst
def make_kleeneplus(s, graphs, stoi): """one-or-more-graphs""" fst = wfst.Fst() start = fst.add_state() end = fst.add_state() fst.set_start(start) fst.set_final(end, wfst.Weight.One(fst.weight_type())) for g in graphs: fst.add_arc(start, wfst.Arc(stoi[g], stoi[g], wfst.Weight.One(fst.weight_type()), end)) fst.add_arc(end, wfst.Arc(stoi[g], stoi[g], wfst.Weight.One(fst.weight_type()), end)) return fst
def __init__(self): ''' Constructor ''' self.fst = pywrapfst.Fst() self.STATE_START = self.fst.add_state() self.fst.set_start(self.STATE_START) self.fst.set_input_symbols(pywrapfst.SymbolTable()) self.fst.set_output_symbols(pywrapfst.SymbolTable()) self.fst.mutable_input_symbols().add_symbol(LAB_EPS, key=SYM_EPS) self.fst.mutable_output_symbols().add_symbol(LAB_EPS, key=SYM_EPS)
def make_compounder(syms, word_ids): c = fst.Fst() start_state = c.add_state() assert (start_state == 0) c.set_start(start_state) space_id = syms["<space>"] c.add_arc(0, fst.Arc(space_id, syms["<eps>"], 1, 0)) c.add_arc(0, fst.Arc(space_id, syms["+C+"], 1, 0)) c.add_arc(0, fst.Arc(space_id, syms["+D+"], 1, 0)) for word_id in word_ids: c.add_arc(0, fst.Arc(word_id, word_id, 1, 0)) c.set_final(0, 1) return c
def append_eeg_evidence(self, ch_dist): new_ch = fst.Fst() new_ch.set_input_symbols(self.ch_syms) new_ch.set_output_symbols(self.ch_syms) new_ch.add_state() new_ch.set_start(0) new_ch.add_state() new_ch.set_final(1) for ch, pr in ch_dist: code = self.ch_syms.find(ch) new_ch.add_arc(0, fst.Arc(code, code, pr, 1)) new_ch.arcsort(sort_type="olabel") self.history_fst.concat(new_ch).rmepsilon()
def get_trivial_fst(word_index): trivial_word_fst = fst.Fst() start = trivial_word_fst.add_state() end = trivial_word_fst.add_state() trivial_word_fst.set_start(start) trivial_word_fst.set_final(end, 0) trivial_word_fst.add_arc( start, fst.Arc(word_index, 0, fst.Weight.One(trivial_word_fst.weight_type()), end)) return trivial_word_fst
def update(self, ch_dist): ''' Update the history with the new likelihood array in the correct scale (nagative log space) to the history. ''' new_ch = fst.Fst() new_ch.set_input_symbols(self.ch_syms) new_ch.set_output_symbols(self.ch_syms) new_ch.add_state() new_ch.set_start(0) new_ch.add_state() new_ch.set_final(1) space_code = -1 space_pr = 0. for ch, pr in ch_dist: code = self.ch_syms.find(ch) if ch == '#': # Adds space after we finish updating trailing chars. space_code = code space_pr = pr continue new_ch.add_arc(0, fst.Arc(code, code, pr, 1)) new_ch.arcsort(sort_type="olabel") # Adds the trailing characters to existing binned history. for words_bin in self.prefix_words: if words_bin[2] >= 10: # We discard the whole trail in this case (TODO) continue # Unless we are testing a straight line machine, this normally # doesn't happen in practice. if new_ch.num_arcs(0) == 0: continue words_bin[1].concat(new_ch).rmepsilon() words_bin[2] += 1 # Continues updating the history and adds back the space if necessary. if space_code >= 0: new_ch.add_arc(0, fst.Arc(space_code, space_code, space_pr, 1)) self.history_fst.concat(new_ch).rmepsilon() # Respectively update the binned history if space_code >= 0: # If there is a space # Finishes the prefix words in current position word_lattice = fst.compose(self.history_fst, self.ltr2wrd) word_lattice.project(project_output=True).rmepsilon() word_lattice = fst.determinize(word_lattice) word_lattice.minimize() if word_lattice.num_states() == 0: word_lattice = self.create_empty_fst(self.wd_syms, self.wd_syms) trailing_chars = self.create_empty_fst(self.ch_syms, self.ch_syms) self.prefix_words.append([word_lattice, trailing_chars, 0])
def make_termfst(s, paths, stoi): fst = wfst.Fst() start = fst.add_state() fst.set_start(start) for path in paths: a = start for g in path: b = fst.add_state() fst.add_arc(a, wfst.Arc(stoi[g], stoi[g], wfst.Weight.One(fst.weight_type()), b)) a = b fst.set_final(b, wfst.Weight.One(fst.weight_type())) fst = wfst.determinize(fst) fst.minimize() return fst
def make_intent_fst(grammar_fsts: Dict[str, fst.Fst], eps: str = "<eps>") -> fst.Fst: """Merges grammar FSTs created with grammar_to_fsts into a single acceptor FST.""" input_symbols = fst.SymbolTable() output_symbols = fst.SymbolTable() in_eps: int = input_symbols.add_symbol(eps) out_eps: int = output_symbols.add_symbol(eps) intent_fst = fst.Fst() weight_one = fst.Weight.One(intent_fst.weight_type()) # Create start/final states start_state = intent_fst.add_state() intent_fst.set_start(start_state) final_state = intent_fst.add_state() intent_fst.set_final(final_state) replacements: Dict[int, fst.Fst] = {} for intent, grammar_fst in grammar_fsts.items(): intent_label = f"__label__{intent}" out_label = output_symbols.add_symbol(intent_label) # --[__label__INTENT]--> intent_start = intent_fst.add_state() intent_fst.add_arc( start_state, fst.Arc(in_eps, out_label, weight_one, intent_start)) # --[__replace__INTENT]--> intent_end = intent_fst.add_state() replace_symbol = f"__replace__{intent}" out_replace = output_symbols.add_symbol(replace_symbol) intent_fst.add_arc( intent_start, fst.Arc(in_eps, out_replace, weight_one, intent_end)) # --[eps]--> intent_fst.add_arc(intent_end, fst.Arc(in_eps, out_eps, weight_one, final_state)) replacements[out_replace] = grammar_fst # Fix symbol tables intent_fst.set_input_symbols(input_symbols) intent_fst.set_output_symbols(output_symbols) # Do replacements return _replace_fsts(intent_fst, replacements, eps=eps)
def longest_path(the_fst: fst.Fst, eps: str = "<eps>") -> fst.Fst: output_symbols = the_fst.output_symbols() out_eps = output_symbols.find(eps) visited_states: Set[int] = set() best_path: List[int] = [] state_queue: Deque[Tuple[int, List[int]]] = deque() state_queue.append((the_fst.start(), [])) # Determine longest path while len(state_queue) > 0: state, path = state_queue.popleft() if state in visited_states: continue visited_states.add(state) if len(path) > len(best_path): best_path = path for arc in the_fst.arcs(state): next_path = list(path) next_path.append(arc.olabel) state_queue.append((arc.nextstate, next_path)) # Create FST with longest path path_fst = fst.Fst() input_symbols = fst.SymbolTable() input_symbols.add_symbol(eps) path_fst.set_output_symbols(output_symbols) weight_one = fst.Weight.One(path_fst.weight_type()) state = path_fst.add_state() path_fst.set_start(state) for olabel in best_path: osym = output_symbols.find(olabel).decode() next_state = path_fst.add_state() path_fst.add_arc( state, fst.Arc(input_symbols.add_symbol(osym), olabel, weight_one, next_state), ) state = next_state path_fst.set_final(state) path_fst.set_input_symbols(input_symbols) return path_fst
def _make_input_fst(string): # Input is passed as acceptors that accept a single string td = fst.Fst() curr = td.add_state() td.set_start(curr) for c in string: nxt = td.add_state() try: _char_arc(td, curr, c, c, nxt) except KeyError: raise ValueError( 'Character {} not in input symbol table'.format(c)) curr = nxt td.set_final(curr) return td
def build_ctc_mono_decoding_fst(S, arc_type='log', add_syms=False): """ Build a monophone CTC decoding fst. Args: S - number of monophones arc_type - log or standard. Gives the interpretation of the FST. Returns: an FST that accepts all sequences over [1,..,S]^* and returns shorter ones with duplicates and blanks removed. The input labels are shifted by one, so that there are no epsilon transitions. The output labels are not (blank is zero), allowing one to read out the label sequence easily. """ CTC = fst.Fst(arc_type=arc_type) weight_one = fst.Weight.One(CTC.weight_type()) for s in range(S): s1 = CTC.add_state() assert s == s1 CTC.set_final(s1) CTC.set_start(0) for s in range(S): # transitions out of symbol s # self-loop, don't emit CTC.add_arc(s, fst.Arc(s + 1, 0, weight_one, s)) for s_next in range(S): if s_next == s: continue # transition to next symbol CTC.add_arc(s, fst.Arc(s_next + 1, s_next, weight_one, s_next)) CTC.arcsort('olabel') if add_syms: in_syms = fst.SymbolTable() in_syms.add_symbol('<eps>', 0) in_syms.add_symbol('B', 1) for s in range(1, S): in_syms.add_symbol(chr(ord('a') + s - 1), s + 1) out_syms = fst.SymbolTable() out_syms.add_symbol('<eps>', 0) for s in range(1, S): out_syms.add_symbol(chr(ord('a') + s - 1), s) CTC.set_input_symbols(in_syms) CTC.set_output_symbols(out_syms) return CTC
def make_intent_fst(grammar_fsts: Dict[str, fst.Fst], eps=0) -> fst.Fst: """Merges grammar FSTs created with jsgf2fst into a single acceptor FST.""" intent_fst = fst.Fst() all_in_symbols = fst.SymbolTable() all_out_symbols = fst.SymbolTable() all_in_symbols.add_symbol("<eps>", eps) all_out_symbols.add_symbol("<eps>", eps) # Merge symbols from all FSTs for grammar_fst in grammar_fsts.values(): in_symbols = grammar_fst.input_symbols() for i in range(in_symbols.num_symbols()): all_in_symbols.add_symbol(in_symbols.find(i).decode()) out_symbols = grammar_fst.output_symbols() for i in range(out_symbols.num_symbols()): all_out_symbols.add_symbol(out_symbols.find(i).decode()) # Add __label__ for each intent for intent_name in grammar_fsts.keys(): all_out_symbols.add_symbol(f"__label__{intent_name}") intent_fst.set_input_symbols(all_in_symbols) intent_fst.set_output_symbols(all_out_symbols) # Create start/final states start_state = intent_fst.add_state() intent_fst.set_start(start_state) final_state = intent_fst.add_state() intent_fst.set_final(final_state) # Merge FSTs in for intent_name, grammar_fst in grammar_fsts.items(): label_sym = all_out_symbols.find(f"__label__{intent_name}") replace_and_patch(intent_fst, start_state, final_state, grammar_fst, label_sym, eps=eps) # BUG: Fst.minimize does not pass allow_nondet through, so we have to call out to the command-line minimize_cmd = ["fstminimize", "--allow_nondet"] return fst.Fst.read_from_string( subprocess.check_output(minimize_cmd, input=intent_fst.write_to_string()))
def toFst(self): """Convert the HMM graph to an OpenFst object. You need to have installed the OpenFst python extension to use this method. Returns ------- graph : pywrapfst.Fst The FST representation of the HMM graph. An super initial state and a super final state will be added though they are not present in the HMM. """ import pywrapfst as fst f = fst.Fst('log') start_state = f.add_state() f.set_start(start_state) end_state = f.add_state() f.set_final(end_state) state_fstid = {} for state in self.states: fstid = f.add_state() state_fstid[state.state_id] = fstid for state in self.states: for next_state_id, weight in state.next_states.items(): fstid = state_fstid[state.state_id] next_fstid = state_fstid[next_state_id] arc = fst.Arc(0, 0, fst.Weight('log', -weight), next_fstid) f.add_arc(fstid, arc) for state in self.init_states: fstid = state_fstid[state.state_id] arc = fst.Arc(0, 0, fst.Weight.One('log'), fstid) f.add_arc(start_state, arc) for state in self.final_states: fstid = state_fstid[state.state_id] arc = fst.Arc(0, 0, fst.Weight.One('log'), end_state) f.add_arc(fstid, arc) return f
def __align_fst(self, g, p): ''' Creates an alignment of a grapheme and phoneme sequence pair encoded as fst. ''' t3 = self.segment(g) t3.project(project_output=True) t4 = self.expand(p) t4.project(project_output=True) if t4.start() == -1 or t4.num_arcs(t4.start()) == 0: return fst.Fst() t5 = fst.compose(t3, self.E) t6 = fst.compose(t5, t4) return t6