def spellout_machine(wrdfname, ltr2wrdfst): lm = fst.Fst.read(ltr2wrdfst) s_in = lm.output_symbols() s_out = lm.input_symbols() letter = fst.Fst() letter.set_input_symbols(s_in) letter.set_output_symbols(s_out) letter.add_state() for word in open(wrdfname, "r").readlines(): word = word.strip() orig = copy.copy(word) # word = list(word) word += "#" #word = dig2word(word) nletter = fst.Fst() nletter.set_input_symbols(s_in) nletter.set_output_symbols(s_out) nletter.add_state() for i, ltr in enumerate(word): nletter.add_state() code2 = s_out.find(ltr) if i == 0: nletter.set_start(0) code1 = s_in.find(orig) nletter.add_arc(i, fst.Arc(code1, code2, None, i + 1)) else: code1 = s_in.find("<epsilon>") nletter.add_arc(i, fst.Arc(code1, code2, None, i + 1)) nletter.set_final(i + 1) letter.union(nletter) letter.rmepsilon() letter.write("spellout.fst")
def make_state(i_input, i_output, weight): io_state = fst.add_state() state_istr = input_vocab[i_input] state_ostr = output_vocab[i_output] # CASE 1: (in, I) : (out, I), weight one, transition into io state arc_istr = input_parts_to_str[state_istr, dur_internal_str] if pass_input: arc_ostr = output_parts_to_str[state_istr, state_ostr, dur_internal_str] else: arc_ostr = output_parts_to_str[state_ostr, dur_internal_str] arc = openfst.Arc(fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), one, io_state) fst.add_arc(state, arc) fst.add_arc(io_state, arc.copy()) # CASE 2: (in, F) : (out, F), weight tx_weight arc_istr = input_parts_to_str[state_istr, dur_final_str] if pass_input: arc_ostr = output_parts_to_str[state_istr, state_ostr, dur_final_str] else: arc_ostr = output_parts_to_str[state_ostr, dur_final_str] arc = openfst.Arc(fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), weight, state) fst.add_arc(io_state, arc)
def __weighted_union(self, left, right, left_prob, right_prob): ''' Union the FSTs left, right with a weight. ''' # left hand side part. left_w = -math.log(left_prob) lhs = fst.Fst() lhs.set_input_symbols(left.input_symbols()) lhs.set_output_symbols(left.output_symbols()) lhs.add_state() lhs.set_start(0) lhs.add_state() lhs.add_arc(0, fst.Arc(0, 0, left_w, 1)) lhs.set_final(1) lhs.concat(left) # prefix part. right_w = -math.log(right_prob) rhs = fst.Fst() rhs.set_input_symbols(right.input_symbols()) rhs.set_output_symbols(right.output_symbols()) rhs.add_state() rhs.set_start(0) rhs.add_state() rhs.add_arc(0, fst.Arc(0, 0, right_w, 1)) rhs.set_final(1) rhs.concat(right) lhs.union(rhs) return lhs
def add_endpoints(fst, bos_str='<BOS>', eos_str='<EOS>'): one = openfst.Weight.one(fst.weight_type()) zero = openfst.Weight.zero(fst.weight_type()) # add pre-initial state accepting BOS i_bos_in = fst.input_symbols().find(bos_str) i_bos_out = fst.output_symbols().find(bos_str) old_start = fst.start() new_start = fst.add_state() fst.set_start(new_start) init_arc = openfst.Arc(i_bos_in, i_bos_out, one, old_start) fst.add_arc(new_start, init_arc) # add superfinal state accepting EOS i_eos_in = fst.input_symbols().find(eos_str) i_eos_out = fst.output_symbols().find(eos_str) new_final = fst.add_state() for state in fst.states(): w_final = fst.final(state) if w_final != zero: fst.set_final(state, zero) final_arc = openfst.Arc(i_eos_in, i_eos_out, w_final, new_final) fst.add_arc(state, final_arc) fst.set_final(new_final, one) return fst
def enterAlternative(self, ctx): anchor_state = self.alt_states[self.group_depth] if self.group_depth not in self.alt_ends: # Patch start of alternative next_state = self.fst.add_state() for arc in self.fst.arcs(anchor_state): self.fst.add_arc(next_state, arc) self.fst.delete_arcs(anchor_state) self.fst.add_arc( anchor_state, fst.Arc(self.in_eps, self.out_eps, self.weight_one, next_state), ) # Create shared end state for alternatives self.alt_ends[self.group_depth] = self.fst.add_state() # Close previous alternative last_state = self.last_states[self.rule_name] end_state = self.alt_ends[self.group_depth] if last_state != end_state: self.fst.add_arc( last_state, fst.Arc(self.in_eps, self.out_eps, self.weight_one, end_state), ) # Add new intermediary state next_state = self.fst.add_state() self.fst.add_arc( anchor_state, fst.Arc(self.in_eps, self.out_eps, self.weight_one, next_state), ) self.last_states[self.rule_name] = next_state
def _compile_cg(ifar_path: str, ofar_path: str, insertions: bool, deletions: bool) -> str: """Compiles the covering grammar from the input and output FARs. Args: ifar_path: path to the input FAR. ofar_path: path to the output FAR. insertions: should insertions be permitted? deletions: should deletions be permitted? Returns: The path to the CG FST. """ ilabels = _get_far_labels(ifar_path) olabels = _get_far_labels(ofar_path) cg = pywrapfst.VectorFst() state = cg.add_state() cg.set_start(state) one = pywrapfst.Weight.one(cg.weight_type()) for ilabel, olabel in itertools.product(ilabels, olabels): cg.add_arc(state, pywrapfst.Arc(ilabel, olabel, one, state)) # Handles epsilons, carefully avoiding adding a useless 0:0 label. if insertions: for olabel in olabels: cg.add_arc(state, pywrapfst.Arc(0, olabel, one, state)) if deletions: for ilabel in ilabels: cg.add_arc(state, pywrapfst.Arc(ilabel, 0, one, state)) cg.set_final(state) assert cg.verify(), "Label acceptor is ill-formed" cg_path = _mktemp("cg.fst") cg.write(cg_path) return cg_path
def _lexicon_covering(self, ) -> None: """Builds covering grammar and lexicon FARs.""" # Sets of labels for the covering grammar. with open(os.path.join(self.working_log_directory, "covering_grammar.log"), "w", encoding="utf8") as log_file: com = [ thirdparty_binary("farcompilestrings"), "--fst_type=compact", ] if self.input_token_type != "utf8": com.append("--token_type=symbol") com.append(f"--symbols={self.input_token_type}", ) com.append("--unknown_symbol=<unk>") else: com.append("--token_type=utf8") com.extend([self.input_path, self.input_far_path]) print(" ".join(com), file=log_file) subprocess.check_call(com, env=os.environ, stderr=log_file, stdout=log_file) com = [ thirdparty_binary("farcompilestrings"), "--fst_type=compact", "--token_type=symbol", f"--symbols={self.phone_symbol_table_path}", self.output_path, self.output_far_path, ] print(" ".join(com), file=log_file) subprocess.check_call(com, env=os.environ, stderr=log_file, stdout=log_file) ilabels = _get_far_labels(self.input_far_path) print(ilabels, file=log_file) olabels = _get_far_labels(self.output_far_path) print(olabels, file=log_file) cg = pywrapfst.VectorFst() state = cg.add_state() cg.set_start(state) one = pywrapfst.Weight.one(cg.weight_type()) for ilabel, olabel in itertools.product(ilabels, olabels): cg.add_arc(state, pywrapfst.Arc(ilabel, olabel, one, state)) # Handles epsilons, carefully avoiding adding a useless 0:0 label. if self.insertions: for olabel in olabels: cg.add_arc(state, pywrapfst.Arc(0, olabel, one, state)) if self.deletions: for ilabel in ilabels: cg.add_arc(state, pywrapfst.Arc(ilabel, 0, one, state)) cg.set_final(state) assert cg.verify(), "Label acceptor is ill-formed" cg.write(self.cg_path)
def replace_and_patch( outer_fst: fst.Fst, outer_start_state: int, outer_final_state: int, inner_fst: fst.Fst, label_sym: int, eps: int = 0, ) -> None: """Copies an inner FST into an outer FST, creating states and mapping symbols. Creates arcs from outer start/final states to inner start/final states.""" in_symbols = outer_fst.input_symbols() out_symbols = outer_fst.output_symbols() inner_zero = fst.Weight.Zero(inner_fst.weight_type()) outer_one = fst.Weight.One(outer_fst.weight_type()) state_map = {} in_symbol_map = {} out_symbol_map = {} for i in range(inner_fst.output_symbols().num_symbols()): sym_str = inner_fst.output_symbols().find(i).decode() out_symbol_map[i] = out_symbols.find(sym_str) for i in range(inner_fst.input_symbols().num_symbols()): sym_str = inner_fst.input_symbols().find(i).decode() in_symbol_map[i] = in_symbols.find(sym_str) # Create states in outer FST for inner_state in inner_fst.states(): state_map[inner_state] = outer_fst.add_state() # Create arcs in outer FST for inner_state in inner_fst.states(): if inner_state == inner_fst.start(): outer_fst.add_arc( outer_start_state, fst.Arc(eps, label_sym, outer_one, state_map[inner_state]), ) for inner_arc in inner_fst.arcs(inner_state): outer_fst.add_arc( state_map[inner_state], fst.Arc( in_symbol_map[inner_arc.ilabel], out_symbol_map[inner_arc.olabel], outer_one, state_map[inner_arc.nextstate], ), ) if inner_fst.final(inner_arc.nextstate) != inner_zero: outer_fst.add_arc( state_map[inner_arc.nextstate], fst.Arc(eps, eps, outer_one, outer_final_state), )
def make_kleeneplus(s, graphs, stoi): """one-or-more-graphs""" fst = wfst.Fst() start = fst.add_state() end = fst.add_state() fst.set_start(start) fst.set_final(end, wfst.Weight.One(fst.weight_type())) for g in graphs: fst.add_arc(start, wfst.Arc(stoi[g], stoi[g], wfst.Weight.One(fst.weight_type()), end)) fst.add_arc(end, wfst.Arc(stoi[g], stoi[g], wfst.Weight.One(fst.weight_type()), end)) return fst
def make_lexicon_fst(input_words, words): compiler = fst.Compiler() lexicon_fst = fst.Fst() start = lexicon_fst.add_state() lexicon_fst.set_start(start) last = lexicon_fst.add_state() # this line projects space to epsilon lexicon_fst.add_arc( start, fst.Arc(32, 0, fst.Weight.One(lexicon_fst.weight_type()), last)) lexicon_fst.set_final(last) for i, w in enumerate(words): w = w.strip() index = i + 1 last = lexicon_fst.add_state() lexicon_fst.add_arc( start, fst.Arc(ord(w[0]), index, fst.Weight.One(lexicon_fst.weight_type()), last)) for c in w[1:]: this = lexicon_fst.add_state() lexicon_fst.add_arc( last, fst.Arc(ord(c), 0, fst.Weight.One(lexicon_fst.weight_type()), this)) last = this lexicon_fst.set_final(last, 0) lexicon_fst = fst.determinize(lexicon_fst).minimize().closure() with open('words.syms', 'w') as f: f.write('<eps> 0\n') for i, w in enumerate(words + input_words): # we put word symbol here f.write('{} {}'.format(w, str(i + 1))) f.write('\n') f.write('<SPACE> {}\n'.format(str(32))) epsilon_fst = fst.Fst() start = epsilon_fst.add_state() end = epsilon_fst.add_state() for i, w in enumerate(words): index = i + 1 epsilon_fst.add_arc( start, fst.Arc(0, index, fst.Weight.One(epsilon_fst.weight_type()), end)) epsilon_fst.add_arc( start, fst.Arc(0, 32, fst.Weight.One(epsilon_fst.weight_type()), end)) epsilon_fst.set_final(end, 0) epsilon_fst.set_start(start) epsilon_fst = epsilon_fst.closure() return lexicon_fst, epsilon_fst
def search(root, lat2_paths): # root is a node in lat which defines the current history # lat2_paths are weights which could be added to the current # path in lat1 if the key does not get discarded in the # future global lat, lat2, visited eps_paths = {node : [] for node in lat2_paths} open_paths = dict(eps_paths) while open_paths: next_open = {} for node,path in open_paths.iteritems(): if node in eps_paths: continue for arc in lat2.arcs(node): if arc.olabel == 0: if node in next_open: # Add paths to eps reachable nodes while open_nodes: visited = if root in visited: return visited[root] = True for arc in lat.arcs(root): dfs(arc.nextstate, hist + [str(arc.ilabel)]) key = ' '.join(hist[-hist_len:]) if key in hist2node: # connect with it arc1 = fst.Arc(0, 0, fst.Weight.One(lat.weight_type()), hist2node[key]) arc2 = fst.Arc(0, 0, fst.Weight.One(lat.weight_type()), root) lat.add_arc(root, arc1) lat.add_arc(hist2node[key], arc2) else: hist2node[key] = root idx = 0 while True: idx += 1 input_path1 = get_path(args.input1, idx) if not input_path or not os.path.isfile(input_path1): break input_path2 = get_path(args.input2, idx) lat = fst.Fst.read(input_path1) lat.rmepsilon() lat.determinize() lat.minimize() lat2 = fst.Fst.read(input_path2) visited = {} search(lat.start(), [lat2.start()]) lat.write(get_path(args.output, idx))
def make_compounder(syms, word_ids): c = fst.Fst() start_state = c.add_state() assert (start_state == 0) c.set_start(start_state) space_id = syms["<space>"] c.add_arc(0, fst.Arc(space_id, syms["<eps>"], 1, 0)) c.add_arc(0, fst.Arc(space_id, syms["+C+"], 1, 0)) c.add_arc(0, fst.Arc(space_id, syms["+D+"], 1, 0)) for word_id in word_ids: c.add_arc(0, fst.Arc(word_id, word_id, 1, 0)) c.set_final(0, 1) return c
def update(self, ch_dist): ''' Update the history with the new likelihood array in the correct scale (nagative log space) to the history. ''' new_ch = fst.Fst() new_ch.set_input_symbols(self.ch_syms) new_ch.set_output_symbols(self.ch_syms) new_ch.add_state() new_ch.set_start(0) new_ch.add_state() new_ch.set_final(1) space_code = -1 space_pr = 0. for ch, pr in ch_dist: code = self.ch_syms.find(ch) if ch == '#': # Adds space after we finish updating trailing chars. space_code = code space_pr = pr continue new_ch.add_arc(0, fst.Arc(code, code, pr, 1)) new_ch.arcsort(sort_type="olabel") # Adds the trailing characters to existing binned history. for words_bin in self.prefix_words: if words_bin[2] >= 10: # We discard the whole trail in this case (TODO) continue # Unless we are testing a straight line machine, this normally # doesn't happen in practice. if new_ch.num_arcs(0) == 0: continue words_bin[1].concat(new_ch).rmepsilon() words_bin[2] += 1 # Continues updating the history and adds back the space if necessary. if space_code >= 0: new_ch.add_arc(0, fst.Arc(space_code, space_code, space_pr, 1)) self.history_fst.concat(new_ch).rmepsilon() # Respectively update the binned history if space_code >= 0: # If there is a space # Finishes the prefix words in current position word_lattice = fst.compose(self.history_fst, self.ltr2wrd) word_lattice.project(project_output=True).rmepsilon() word_lattice = fst.determinize(word_lattice) word_lattice.minimize() if word_lattice.num_states() == 0: word_lattice = self.create_empty_fst(self.wd_syms, self.wd_syms) trailing_chars = self.create_empty_fst(self.ch_syms, self.ch_syms) self.prefix_words.append([word_lattice, trailing_chars, 0])
def make_intent_fst(grammar_fsts: Dict[str, fst.Fst], eps: str = "<eps>") -> fst.Fst: """Merges grammar FSTs created with grammar_to_fsts into a single acceptor FST.""" input_symbols = fst.SymbolTable() output_symbols = fst.SymbolTable() in_eps: int = input_symbols.add_symbol(eps) out_eps: int = output_symbols.add_symbol(eps) intent_fst = fst.Fst() weight_one = fst.Weight.One(intent_fst.weight_type()) # Create start/final states start_state = intent_fst.add_state() intent_fst.set_start(start_state) final_state = intent_fst.add_state() intent_fst.set_final(final_state) replacements: Dict[int, fst.Fst] = {} for intent, grammar_fst in grammar_fsts.items(): intent_label = f"__label__{intent}" out_label = output_symbols.add_symbol(intent_label) # --[__label__INTENT]--> intent_start = intent_fst.add_state() intent_fst.add_arc( start_state, fst.Arc(in_eps, out_label, weight_one, intent_start)) # --[__replace__INTENT]--> intent_end = intent_fst.add_state() replace_symbol = f"__replace__{intent}" out_replace = output_symbols.add_symbol(replace_symbol) intent_fst.add_arc( intent_start, fst.Arc(in_eps, out_replace, weight_one, intent_end)) # --[eps]--> intent_fst.add_arc(intent_end, fst.Arc(in_eps, out_eps, weight_one, final_state)) replacements[out_replace] = grammar_fst # Fix symbol tables intent_fst.set_input_symbols(input_symbols) intent_fst.set_output_symbols(output_symbols) # Do replacements return _replace_fsts(intent_fst, replacements, eps=eps)
def dfs(root, hist): global hist_len, visited, hist2node, lat if root in visited: return visited[root] = True for arc in lat.arcs(root): dfs(arc.nextstate, hist + [str(arc.ilabel)]) key = ' '.join(hist[-hist_len:]) if key in hist2node: # connect with it arc1 = fst.Arc(0, 0, fst.Weight.One(lat.weight_type()), hist2node[key]) arc2 = fst.Arc(0, 0, fst.Weight.One(lat.weight_type()), root) lat.add_arc(root, arc1) lat.add_arc(hist2node[key], arc2) else: hist2node[key] = root
def build_ctc_mono_decoding_fst(S, arc_type='log', add_syms=False): """ Build a monophone CTC decoding fst. Args: S - number of monophones arc_type - log or standard. Gives the interpretation of the FST. Returns: an FST that accepts all sequences over [1,..,S]^* and returns shorter ones with duplicates and blanks removed. The input labels are shifted by one, so that there are no epsilon transitions. The output labels are not (blank is zero), allowing one to read out the label sequence easily. """ CTC = fst.Fst(arc_type=arc_type) weight_one = fst.Weight.One(CTC.weight_type()) for s in range(S): s1 = CTC.add_state() assert s == s1 CTC.set_final(s1) CTC.set_start(0) for s in range(S): # transitions out of symbol s # self-loop, don't emit CTC.add_arc(s, fst.Arc(s + 1, 0, weight_one, s)) for s_next in range(S): if s_next == s: continue # transition to next symbol CTC.add_arc(s, fst.Arc(s_next + 1, s_next, weight_one, s_next)) CTC.arcsort('olabel') if add_syms: in_syms = fst.SymbolTable() in_syms.add_symbol('<eps>', 0) in_syms.add_symbol('B', 1) for s in range(1, S): in_syms.add_symbol(chr(ord('a') + s - 1), s + 1) out_syms = fst.SymbolTable() out_syms.add_symbol('<eps>', 0) for s in range(1, S): out_syms.add_symbol(chr(ord('a') + s - 1), s) CTC.set_input_symbols(in_syms) CTC.set_output_symbols(out_syms) return CTC
def addArc(self, ilabels, olabels=None, start_state=None, is_loop=False): ''' ''' isyms = self._label_to_sym(ilabels, self.fst.mutable_input_symbols()) if olabels is None: # osyms = isyms[:] # create copy osyms = self._label_to_sym(ilabels, self.fst.mutable_output_symbols()) else: osyms = self._label_to_sym(olabels, self.fst.mutable_output_symbols()) maxix = max(len(isyms), len(osyms)) if len(isyms) != len(osyms): isyms += [0] * (maxix - len(isyms)) osyms += [0] * (maxix - len(osyms)) if start_state is None: start_state = self.STATE_START q0 = start_state for i in range(maxix): if is_loop and i == maxix - 1: q1 = start_state else: q1 = self.fst.add_state() self.fst.add_arc( q0, pywrapfst.Arc(isyms[i], osyms[i], pywrapfst.Weight.One(self.fst.weight_type()), q1)) q0 = q1 return q1
def __call__(self, x): x, xs = transform_output(x) # Normalize log-posterior matrices, if necessary if self._normalize: x = log_softmax(x, dim=2) x = x.permute(1, 0, 2).cpu() self._output = [] D = x.size(2) for logpost, length in zip(x, xs): f = fst.Fst() f.set_start(f.add_state()) for t in range(length): f.add_state() for j in range(D): weight = fst.Weight(f.weight_type(), float(-logpost[t, j])) f.add_arc( t, fst.Arc( j + 1, # input label j + 1, # output label weight, # -logpost[t, j] t + 1, # nextstate ), ) f.set_final(length, fst.Weight.One(f.weight_type())) f.verify() self._output.append(f) return self._output
def normalize(self, anfst): ''' produce a normalized fst ''' # possibly there's a shorter way # that keeps all in fst land dist = [] labels = [] syms = anfst.input_symbols() state = anfst.start() for arc in anfst.arcs(state): label = syms.find(arc.ilabel) pr = float(arc.weight) dist.append(BitWeight(pr)) # ebitweight gets -log(pr) only labels.append(label) sum_value = sum(dist, BitWeight(1e6)) # will sum in log domain (log-add) norm_dist = [(prob/sum_value).loge() for prob in dist] del anfst # construct a norm fst output = fst.Fst() output.set_input_symbols(syms) output.set_output_symbols(syms) output.add_state() output.add_state() for (pr, label) in zip(norm_dist,labels): code = syms.find(label) output.add_arc(0, fst.Arc(code, code, pr, 1)) output.set_start(0) output.set_final(1) return output
def add_arc(fst_in, from_word, to_word, weight): """ Adds an arc to a given FST Note: Despite returning an updated FST, this method makes the changes **IN PLACE**, so you may want to make a copy of the original FST before updating the weights :param fst_in: <openfst.Fst> to modify :param from_word: <str> :param to_word: <str> :param weight: <float> :return: updated <openfst.Fst> """ # make a dict and node_2_word from index_fst() fst_dict, node_2_word = index_fst(fst_in) # get a lookup table lookup = fst_in.input_symbols() # set from state as idx from_state = fst_dict[from_word]["state_id"] # set to state as idx to_state = fst_dict[to_word]["state_id"] fst_in = fst_in.add_arc( from_state, openfst.Arc(lookup_word(to_word, lookup), lookup_word(to_word, lookup), openfst.Weight("tropical", weight), to_state)) return fst_in
def single_state_transducer(transition_weights, row_vocab, col_vocab, input_symbols=None, output_symbols=None, arc_type='standard'): fst = openfst.VectorFst(arc_type=arc_type) fst.set_input_symbols(input_symbols) fst.set_output_symbols(output_symbols) zero = openfst.Weight.zero(fst.weight_type()) one = openfst.Weight.one(fst.weight_type()) state = fst.add_state() fst.set_start(state) fst.set_final(state, one) for i_input, row in enumerate(transition_weights): for i_output, tx_weight in enumerate(row): weight = openfst.Weight(fst.weight_type(), tx_weight) input_id = fst.input_symbols().find(row_vocab[i_input]) output_id = fst.output_symbols().find(col_vocab[i_output]) if weight != zero: arc = openfst.Arc(input_id, output_id, weight, state) fst.add_arc(state, arc) if not fst.verify(): raise openfst.FstError("fst.verify() returned False") return fst
def build_chain_fst(labels, arc_type='log', vocab=None): """ Build an acceptor for string given by elements of labels. Args: labels - a sequence of labels in the range 1..S arc_type - fst arc type (standard or log) Returns: FST consuming symbols in the range 1..S. Notes: Elements of labels are assumed to be greater than zero (which maps to blank)! """ C = fst.Fst(arc_type=arc_type) weight_one = fst.Weight.One(C.weight_type()) s = C.add_state() C.set_start(s) for l in labels: s_next = C.add_state() C.add_arc(s, fst.Arc(l, l, weight_one, s_next)) s = s_next C.set_final(s) C.arcsort('ilabel') return C
def gen_trigram_graph(ngram_to_class_file, net_vocab_file, token_file, out_file, add_final_space=False, use_contextual_blanks=False, prevent_epsilons=False, determinize=True): net_vocab = read_net_vocab(net_vocab_file) print("net vocab", net_vocab) N = len(net_vocab) with open(ngram_to_class_file, 'r') as f: trigrams = [tuple([int(n) for n in line.split()]) for line in f] CTC = build_ctc_trigram_decoding_fst_v2( N, trigrams, arc_type='standard', use_context_blanks=use_contextual_blanks, prevent_epsilons=prevent_epsilons, determinize=determinize, add_syms=False) assert CTC.weight_type() == 'tropical' # Emitted symbols need to be remapped from net_vocab to token symbols # net_vocab[:5] : ['<pad>', '<unk>', '<spc>', 'E', 'T'] # tokens[:5] : ['<eps> 0', '<spc> 1', '<pad> 2', '<unk> 3', 'E 4'] # <pad> is unused and gets mapped to eps, <unk> and <spc> change ids, # the rest is roughly shifted by 1. tokens = {t.split()[0]: int(t.split()[1]) for t in open(token_file, 'r')} net_vocab_dict = {t: i for i, t in enumerate(net_vocab)} osym_map = [] for t, i in net_vocab_dict.items(): osym_map.append((i, 0 if t == '<pad>' else tokens[t])) CTC.relabel_pairs(ipairs=None, opairs=osym_map) print(osym_map) CTC_os = fst.SymbolTable.read_text(token_file) CTC.set_output_symbols(CTC_os) os_eps = CTC_os.find('<eps>') assert os_eps == 0 weight_one = fst.Weight.One('tropical') if add_final_space: is_final = lambda s: CTC.final(s) != fst.Weight( CTC.weight_type(), 'infinity') final_space = CTC.add_state() CTC.set_final(final_space) final_space_arc = fst.Arc(0, CTC_os.find('<spc>'), weight_one, final_space) for s in CTC.states(): if is_final(s): CTC.add_arc(s, final_space_arc) CTC.arcsort('olabel') CTC.write(out_file)
def addArcLinear(fst, start_state, labels, olabels=None, is_loop=False): if not isinstance(labels, list): raise ValueError("Label argument must be a list") isym_tbl = fst.mutable_input_symbols() osym_tbl = fst.mutable_output_symbols() if olabels is None: sym_len = len(labels) isyms = [isym_tbl.add_symbol(lab) for lab in labels] osyms = [osym_tbl.add_symbol(lab) for lab in labels] else: sym_len = max(len(labels), len(olabels)) isyms = [isym_tbl.add_symbol(lab) for lab in labels] + [0] * (sym_len - len(labels)) osyms = [osym_tbl.add_symbol(lab) for lab in olabels] + [0] * (sym_len - len(olabels)) q0 = start_state for i in range(sym_len): if is_loop and i == sym_len - 1: q1 = start_state else: q1 = fst.add_state() fst.add_arc( q0, pywrapfst.Arc(isyms[i], osyms[i], pywrapfst.Weight.One(fst.weight_type()), q1)) q0 = q1 return q1
def addArcFlower(fst, q0, q1, ilabels, olabels=None, weight=None): ''' Adding q0 arc(s) origin q1 arc(s) target ''' if not isinstance(ilabels, list): raise ValueError("input label argument must be a list") if olabels is None: olabels = ilabels if not isinstance(olabels, list): raise ValueError("Output label argument must be a list") isym_tbl = fst.mutable_input_symbols() osym_tbl = fst.mutable_output_symbols() for label in ilabels + olabels: isym = isym_tbl.add_symbol(label) osym = osym_tbl.add_symbol(label) if weight is None: weight = pywrapfst.Weight.One(fst.weight_type()) for i in range(len(ilabels)): isym = isym_tbl.add_symbol(ilabels[i]) osym = osym_tbl.add_symbol(olabels[i]) fst.add_arc(q0, pywrapfst.Arc(isym, osym, weight, q1))
def string_to_fsa(input_string, sym): '''build an FSA for a given input string using the symbol table, sym''' # first make sure all chars can be converted input_list = list(input_string) for i in input_list: if sym.find(i) == -1: raise ValueError('Input character not found') # build the FSA f = pywrapfst.VectorFst() one = pywrapfst.Weight.one(f.weight_type()) f.set_input_symbols(sym) f.set_output_symbols(sym) s = f.add_state() f.set_start(s) for i in input_list: n = f.add_state() f.add_arc(s, pywrapfst.Arc(sym.find(i), sym.find(i), one, n)) s = n f.set_final(n, 1) # verify if not f.verify(): raise ValueError('FSA failed to verify') return (f)
def toFst(self): """Convert the HMM graph to an OpenFst object. You need to have installed the OpenFst python extension to use this method. Returns ------- graph : pywrapfst.Fst The FST representation of the HMM graph. An super initial state and a super final state will be added though they are not present in the HMM. """ import pywrapfst as fst f = fst.Fst('log') start_state = f.add_state() f.set_start(start_state) end_state = f.add_state() f.set_final(end_state) state_fstid = {} for state in self.states: fstid = f.add_state() state_fstid[state.state_id] = fstid for state in self.states: for next_state_id, weight in state.next_states.items(): fstid = state_fstid[state.state_id] next_fstid = state_fstid[next_state_id] arc = fst.Arc(0, 0, fst.Weight('log', -weight), next_fstid) f.add_arc(fstid, arc) for state in self.init_states: fstid = state_fstid[state.state_id] arc = fst.Arc(0, 0, fst.Weight.One('log'), fstid) f.add_arc(start_state, arc) for state in self.final_states: fstid = state_fstid[state.state_id] arc = fst.Arc(0, 0, fst.Weight.One('log'), end_state) f.add_arc(fstid, arc) return f
def makeState(i): state = fst.add_state() initial_weight = openfst.Weight(fst.weight_type(), init_weights[i]) if initial_weight != zero: next_state_str = col_vocab[i] next_state_index = fst.output_symbols().find(next_state_str) arc = openfst.Arc(bos_index, next_state_index, initial_weight, state) fst.add_arc(fst.start(), arc) final_weight = openfst.Weight(fst.weight_type(), final_weights[i]) if final_weight != zero: arc = openfst.Arc(eos_index, eps_index, final_weight, final_state) fst.add_arc(state, arc) return state
def _make_slot_fst(state: int, intent_fst: fst.Fst, slot_to_fst: Dict[str, fst.Fst]): out_symbols = intent_fst.output_symbols() one_weight = fst.Weight.One(intent_fst.weight_type()) for arc in intent_fst.arcs(state): label = out_symbols.find(arc.olabel).decode() if label.startswith("__begin__"): slot_name = label[9:] # Big assumption here that each instance of a slot (e.g., location) # will produce the same FST, and therefore doesn't need to be # processed again. if slot_name in slot_to_fst: continue # skip duplicate slots end_label = f"__end__{slot_name}" # Create new FST slot_fst = fst.Fst() slot_fst.set_input_symbols(intent_fst.input_symbols()) slot_fst.set_output_symbols(intent_fst.output_symbols()) start_state = slot_fst.add_state() slot_fst.set_start(start_state) q = [arc.nextstate] state_map = {arc.nextstate: start_state} # Copy states/arcs from intent FST until __end__ is found while len(q) > 0: q_state = q.pop() for q_arc in intent_fst.arcs(q_state): slot_arc_label = out_symbols.find(q_arc.olabel).decode() if slot_arc_label != end_label: if not q_arc.nextstate in state_map: state_map[q_arc.nextstate] = slot_fst.add_state() # Create arc slot_fst.add_arc( state_map[q_state], fst.Arc( q_arc.ilabel, q_arc.olabel, one_weight, state_map[q_arc.nextstate], ), ) # Continue copy q.append(q_arc.nextstate) else: # Mark previous state as final slot_fst.set_final(state_map[q_state]) slot_to_fst[slot_name] = minimize_fst(slot_fst) # Recurse _make_slot_fst(arc.nextstate, intent_fst, slot_to_fst)
def enterTagBody(self, ctx): # Get the original text *with* whitespace from ANTLR input_stream = ctx.start.getInputStream() start = ctx.start.start stop = ctx.stop.stop tag_text = input_stream.getText(start, stop) # Patch start of tag anchor_state = self.exp_states[self.group_depth] next_state = self.fst.add_state() # --[__begin__TAG]--> begin_symbol = "__begin__" + tag_text input_idx = self.input_symbols.add_symbol(begin_symbol) output_idx = self.output_symbols.add_symbol(begin_symbol) self.tag_input_symbols.add(input_idx) # Move outgoing anchor arcs for arc in self.fst.arcs(anchor_state): self.fst.add_arc( next_state, fst.Arc(arc.ilabel, arc.olabel, arc.weight, arc.nextstate)) # Patch anchor self.fst.delete_arcs(anchor_state) self.fst.add_arc( anchor_state, fst.Arc(input_idx, output_idx, self.weight_one, next_state)) # Patch end of tag last_state = self.last_states[self.rule_name] next_state = self.fst.add_state() # --[__end__TAG]--> end_symbol = "__end__" + tag_text input_idx = self.input_symbols.add_symbol(end_symbol) output_idx = self.output_symbols.add_symbol(end_symbol) self.tag_input_symbols.add(input_idx) self.fst.add_arc( last_state, fst.Arc(input_idx, output_idx, self.weight_one, next_state)) self.last_states[self.rule_name] = next_state