def makeState(i): state = fst.add_state() initial_weight = openfst.Weight(fst.weight_type(), init_weights[i]) if initial_weight != zero: transition = transition_ids[-1, i] + 1 arc = openfst.Arc(EPSILON, transition, initial_weight, state) fst.add_arc(fst.start(), arc) final_weight = openfst.Weight(fst.weight_type(), final_weights[i]) if final_weight != zero: fst.set_final(state, final_weight) return state
def single_state_transducer(transition_weights, row_vocab, col_vocab, input_symbols=None, output_symbols=None, arc_type='standard'): fst = openfst.VectorFst(arc_type=arc_type) fst.set_input_symbols(input_symbols) fst.set_output_symbols(output_symbols) zero = openfst.Weight.zero(fst.weight_type()) one = openfst.Weight.one(fst.weight_type()) state = fst.add_state() fst.set_start(state) fst.set_final(state, one) for i_input, row in enumerate(transition_weights): for i_output, tx_weight in enumerate(row): weight = openfst.Weight(fst.weight_type(), tx_weight) input_id = fst.input_symbols().find(row_vocab[i_input]) output_id = fst.output_symbols().find(col_vocab[i_output]) if weight != zero: arc = openfst.Arc(input_id, output_id, weight, state) fst.add_arc(state, arc) if not fst.verify(): raise openfst.FstError("fst.verify() returned False") return fst
def gen_trigram_graph(ngram_to_class_file, net_vocab_file, token_file, out_file, add_final_space=False, use_contextual_blanks=False, prevent_epsilons=False, determinize=True): net_vocab = read_net_vocab(net_vocab_file) print("net vocab", net_vocab) N = len(net_vocab) with open(ngram_to_class_file, 'r') as f: trigrams = [tuple([int(n) for n in line.split()]) for line in f] CTC = build_ctc_trigram_decoding_fst_v2( N, trigrams, arc_type='standard', use_context_blanks=use_contextual_blanks, prevent_epsilons=prevent_epsilons, determinize=determinize, add_syms=False) assert CTC.weight_type() == 'tropical' # Emitted symbols need to be remapped from net_vocab to token symbols # net_vocab[:5] : ['<pad>', '<unk>', '<spc>', 'E', 'T'] # tokens[:5] : ['<eps> 0', '<spc> 1', '<pad> 2', '<unk> 3', 'E 4'] # <pad> is unused and gets mapped to eps, <unk> and <spc> change ids, # the rest is roughly shifted by 1. tokens = {t.split()[0]: int(t.split()[1]) for t in open(token_file, 'r')} net_vocab_dict = {t: i for i, t in enumerate(net_vocab)} osym_map = [] for t, i in net_vocab_dict.items(): osym_map.append((i, 0 if t == '<pad>' else tokens[t])) CTC.relabel_pairs(ipairs=None, opairs=osym_map) print(osym_map) CTC_os = fst.SymbolTable.read_text(token_file) CTC.set_output_symbols(CTC_os) os_eps = CTC_os.find('<eps>') assert os_eps == 0 weight_one = fst.Weight.One('tropical') if add_final_space: is_final = lambda s: CTC.final(s) != fst.Weight( CTC.weight_type(), 'infinity') final_space = CTC.add_state() CTC.set_final(final_space) final_space_arc = fst.Arc(0, CTC_os.find('<spc>'), weight_one, final_space) for s in CTC.states(): if is_final(s): CTC.add_arc(s, final_space_arc) CTC.arcsort('olabel') CTC.write(out_file)
def __call__(self, x): x, xs = transform_output(x) # Normalize log-posterior matrices, if necessary if self._normalize: x = log_softmax(x, dim=2) x = x.permute(1, 0, 2).cpu() self._output = [] D = x.size(2) for logpost, length in zip(x, xs): f = fst.Fst() f.set_start(f.add_state()) for t in range(length): f.add_state() for j in range(D): weight = fst.Weight(f.weight_type(), float(-logpost[t, j])) f.add_arc( t, fst.Arc( j + 1, # input label j + 1, # output label weight, # -logpost[t, j] t + 1, # nextstate ), ) f.set_final(length, fst.Weight.One(f.weight_type())) f.verify() self._output.append(f) return self._output
def add_arc(fst_in, from_word, to_word, weight): """ Adds an arc to a given FST Note: Despite returning an updated FST, this method makes the changes **IN PLACE**, so you may want to make a copy of the original FST before updating the weights :param fst_in: <openfst.Fst> to modify :param from_word: <str> :param to_word: <str> :param weight: <float> :return: updated <openfst.Fst> """ # make a dict and node_2_word from index_fst() fst_dict, node_2_word = index_fst(fst_in) # get a lookup table lookup = fst_in.input_symbols() # set from state as idx from_state = fst_dict[from_word]["state_id"] # set to state as idx to_state = fst_dict[to_word]["state_id"] fst_in = fst_in.add_arc( from_state, openfst.Arc(lookup_word(to_word, lookup), lookup_word(to_word, lookup), openfst.Weight("tropical", weight), to_state)) return fst_in
def makeState(i): state = fst.add_state() initial_weight = openfst.Weight(fst.weight_type(), init_weights[i]) if initial_weight != zero: next_state_str = col_vocab[i] next_state_index = fst.output_symbols().find(next_state_str) arc = openfst.Arc(bos_index, next_state_index, initial_weight, state) fst.add_arc(fst.start(), arc) final_weight = openfst.Weight(fst.weight_type(), final_weights[i]) if final_weight != zero: arc = openfst.Arc(eos_index, eps_index, final_weight, final_state) fst.add_arc(state, arc) return state
def normalize_fst(f): f2 = f.copy() z = fst.shortestdistance(fst.arcmap(f, map_type="to_log"), reverse=True)[0] for s in f2.states(): w = f2.final(s) nw = fst.Weight(f2.weight_type(), float(w) - float(z)) f2.set_final(s, nw) return f2
def remove_arc(fst_in, from_word, to_word): """ Removes an arc from a given FST Note: Despite returning an updated FST, this method makes the changes **IN PLACE**, so you may want to make a copy of the original FST before updating the weights :param fst_in: <openfst.Fst> to modify :param from_word: <str> :param to_word: <str> :return: updated <openfst.Fst> """ # make a dict and node_2_word from index_fst() fst_dict, node_2_word = index_fst(fst_in) # get a lookup table lookup = fst_in.input_symbols() # set from state as idx from_state = fst_dict[from_word]["state_id"] # initialize list to hold all arcs to add arcs_to_keep = [] # traverse all arcs and add to arcs_to_keep # except for one to remove for arc in fst_in.arcs(from_state): arc_from_word = node_2_word[from_state] arc_to_word = node_2_word[arc.nextstate] arc_weight = float(arc.weight.to_string()) if not (arc_from_word == from_word and arc_to_word == to_word): dict_ = { "from_state": from_state, "to_state": arc.nextstate, "to_word_id": arc.ilabel, "weight": arc_weight } arcs_to_keep.append(dict_) else: print("removing: from_state:{} -> arc:{} -> to_state:{}".format( from_state, arc.ilabel, arc.nextstate)) # delete all arcs from from_state fst_intermediate = fst_in.delete_arcs(from_state) # add back arcs from arcs_to_keep for arc_dict in arcs_to_keep: fst_in = fst_in.add_arc( arc_dict["from_state"], openfst.Arc(arc_dict["to_word_id"], arc_dict["to_word_id"], openfst.Weight("tropical", arc_dict["weight"]), arc_dict["to_state"])) return fst_in
def makelattice(self, fst, startstate, symtable, cost, firstword): length = self.nleafnodes() fst.add_arc( startstate, wfst.Arc( symtable[self.value], symtable[self.value], wfst.Weight(fst.weight_type(), cost(self.value, firstword)), startstate + length)) offset = 0 for n in self.nodes: n.makelattice(fst, startstate + offset, symtable, cost, firstword=False) offset += n.nleafnodes()
def toFst(self): """Convert the HMM graph to an OpenFst object. You need to have installed the OpenFst python extension to use this method. Returns ------- graph : pywrapfst.Fst The FST representation of the HMM graph. An super initial state and a super final state will be added though they are not present in the HMM. """ import pywrapfst as fst f = fst.Fst('log') start_state = f.add_state() f.set_start(start_state) end_state = f.add_state() f.set_final(end_state) state_fstid = {} for state in self.states: fstid = f.add_state() state_fstid[state.state_id] = fstid for state in self.states: for next_state_id, weight in state.next_states.items(): fstid = state_fstid[state.state_id] next_fstid = state_fstid[next_state_id] arc = fst.Arc(0, 0, fst.Weight('log', -weight), next_fstid) f.add_arc(fstid, arc) for state in self.init_states: fstid = state_fstid[state.state_id] arc = fst.Arc(0, 0, fst.Weight.One('log'), fstid) f.add_arc(start_state, arc) for state in self.final_states: fstid = state_fstid[state.state_id] arc = fst.Arc(0, 0, fst.Weight.One('log'), end_state) f.add_arc(fstid, arc) return f
def durationFst(label_str, dur_internal_str, dur_final_str, final_weights): """ Construct a left-to-right WFST from an input sequence. Parameters ---------- input_seq : iterable(int or string) Returns ------- fst : openfst.Fst """ input_label = fst.input_symbols().find(label_str) output_label_int = fst.output_symbols().find(dur_internal_str) output_label_ext = fst.output_symbols().find(dur_final_str) max_dur = np.nonzero(final_weights != float(zero))[0].max() if max_dur < 1: raise AssertionError(f"max_dur = {max_dur}, but should be >= 1)") states = tuple(fst.add_state() for __ in range(max_dur)) seg_final_state = fst.add_state() fst.add_arc(init_state, openfst.Arc(0, 0, one, states[0])) fst.add_arc(seg_final_state, openfst.Arc(0, 0, one, final_state)) for i, cur_state in enumerate(states): cur_state = states[i] final_weight = openfst.Weight(fst.weight_type(), final_weights[i]) if final_weight != zero: arc = openfst.Arc(input_label, output_label_ext, one, seg_final_state) fst.add_arc(cur_state, arc) if i + 1 < len(states): next_state = states[i + 1] arc = openfst.Arc(input_label, output_label_int, one, next_state) fst.add_arc(cur_state, arc) return states[0], seg_final_state
def normalize_fst(in_fst): if not in_fst.verify(): print("ERROR WRONG FST PASSED FOR NORMALIZATION") return in_fst else: out_fst = fst.Fst(in_fst.weight_type()) for state in in_fst.states(): n = out_fst.add_state() arcsum = 0.0 for arc in in_fst.arcs(state): str_w = arc.weight.to_string() arcsum += np.exp(-1 * float( str_w[:str_w.find(b' ')])) #fst.plus(arcsum,arc.weight) for arc in in_fst.arcs(state): str_w = arc.weight.to_string() new_weight = np.exp( -1 * float(str_w[:str_w.find(b' ')])) / arcsum weight_log = -1 * np.log(new_weight) st_t = int(str_w[str_w.find(b' ') + 1:str_w.find(b' ', str_w.find(b' ') + 1)]) en_t = int(str_w[str_w.find(b' ', str_w.find(b' ') + 1) + 1:]) out_fst.add_arc( state, fst.Arc( arc.ilabel, arc.olabel, fst.Weight( out_fst.weight_type(), str(weight_log) + ' ' + str(st_t) + ' ' + str(en_t)), arc.nextstate )) #fst.divide(arc.weight,arcsum), arc.nextstate)) out_fst.set_start(0) out_fst.set_final(n, in_fst.final(n)) if out_fst.verify(): return out_fst else: print("NORM ERROR") fst_printout(out_fst) return in_fst
def fromArray(weights, row_vocab, col_vocab, final_weight=None, arc_type=None, input_symbols=None, output_symbols=None): """ Instantiate a state machine from an array of weights. Parameters ---------- weights : array_like, shape (num_inputs, num_outputs) Needs to implement `.shape`, so it should be a numpy array or a torch tensor. final_weight : arc_types.AbstractSemiringWeight, optional Should have the same type as `arc_type`. Default is `arc_type.zero` arc_type : {'standard', 'log'}, optional Default is 'standard' (ie the tropical arc_type) input_labels : output_labels : Returns ------- fst : fsm.FST The transducer's arcs have input labels corresponding to the state they left, and output labels corresponding to the state they entered. """ if weights.ndim != 2: raise AssertionError( f"weights have unrecognized shape {weights.shape}") if arc_type is None: arc_type = 'standard' fst = openfst.VectorFst(arc_type=arc_type) fst.set_input_symbols(input_symbols) fst.set_output_symbols(output_symbols) zero = openfst.Weight.zero(fst.weight_type()) one = openfst.Weight.one(fst.weight_type()) if final_weight is None: final_weight = one else: final_weight = openfst.Weight(fst.weight_type(), final_weight) init_state = fst.add_state() fst.set_start(init_state) prev_state = init_state for sample_index, row in enumerate(weights): cur_state = fst.add_state() for i, weight in enumerate(row): input_label = row_vocab[sample_index] output_label = col_vocab[i] input_label_index = fst.input_symbols().find(input_label) output_label_index = fst.output_symbols().find(output_label) weight = openfst.Weight(fst.weight_type(), weight) if weight != zero: arc = openfst.Arc(input_label_index, output_label_index, weight, cur_state) fst.add_arc(prev_state, arc) prev_state = cur_state fst.set_final(cur_state, final_weight) if not fst.verify(): raise openfst.FstError("fst.verify() returned False") return fst
def fromTransitions(transition_weights, row_vocab, col_vocab, init_weights=None, final_weights=None, input_symbols=None, output_symbols=None, bos_str='<BOS>', eos_str='<EOS>', eps_str=libfst.EPSILON_STRING, arc_type='standard', transition_ids=None): """ Instantiate a state machine from state transitions. Parameters ---------- Returns ------- """ num_states = transition_weights.shape[0] if transition_ids is None: transition_ids = {} for s_cur in range(num_states): for s_next in range(num_states): transition_ids[(s_cur, s_next)] = len(transition_ids) for s in range(num_states): transition_ids[(-1, s)] = len(transition_ids) fst = openfst.VectorFst(arc_type=arc_type) fst.set_input_symbols(input_symbols) fst.set_output_symbols(output_symbols) zero = openfst.Weight.zero(fst.weight_type()) one = openfst.Weight.one(fst.weight_type()) if init_weights is None: init_weights = tuple(float(one) for __ in range(num_states)) if final_weights is None: final_weights = tuple(float(one) for __ in range(num_states)) fst.set_start(fst.add_state()) final_state = fst.add_state() fst.set_final(final_state, one) bos_index = fst.input_symbols().find(bos_str) eos_index = fst.input_symbols().find(eos_str) eps_index = fst.output_symbols().find(eps_str) def makeState(i): state = fst.add_state() initial_weight = openfst.Weight(fst.weight_type(), init_weights[i]) if initial_weight != zero: next_state_str = col_vocab[i] next_state_index = fst.output_symbols().find(next_state_str) arc = openfst.Arc(bos_index, next_state_index, initial_weight, state) fst.add_arc(fst.start(), arc) final_weight = openfst.Weight(fst.weight_type(), final_weights[i]) if final_weight != zero: arc = openfst.Arc(eos_index, eps_index, final_weight, final_state) fst.add_arc(state, arc) return state states = tuple(makeState(i) for i in range(num_states)) for i_cur, row in enumerate(transition_weights): for i_next, tx_weight in enumerate(row): cur_state = states[i_cur] next_state = states[i_next] weight = openfst.Weight(fst.weight_type(), tx_weight) if weight != zero: next_state_str = col_vocab[i_next] next_state_index = fst.output_symbols().find(next_state_str) arc = openfst.Arc(next_state_index, next_state_index, weight, next_state) fst.add_arc(cur_state, arc) if not fst.verify(): raise openfst.FstError("fst.verify() returned False") return fst
def durationFst( label, num_states, transition_weights=None, self_weights=None, arc_type='standard', symbol_table=None): """ Construct a left-to-right WFST from an input sequence. The input is usually a sequence of segment-level labels, and this machine is used to align labels with sample-level scores. Parameters ---------- input_seq : iterable(int or string) Returns ------- fst : openfst.Fst A linear-chain weighted finite-state transducer. Each state has one self-transition and one transition to its right neighbor. i.e. the topology looks like this: __ __ __ \/ \/ \/ [START] --> s1 --> s2 --> s3 --> [END] """ if num_states < 1: raise AssertionError(f"num_states = {num_states}, but should be >= 1)") fst = openfst.VectorFst(arc_type=arc_type) one = openfst.Weight.one(fst.weight_type()) zero = openfst.Weight.zero(fst.weight_type()) if transition_weights is None: transition_weights = [one for __ in range(num_states)] if self_weights is None: self_weights = [one for __ in range(num_states)] if symbol_table is not None: fst.set_input_symbols(symbol_table) fst.set_output_symbols(symbol_table) init_state = fst.add_state() fst.set_start(init_state) cur_state = fst.add_state() arc = openfst.Arc(EPSILON, label + 1, one, cur_state) fst.add_arc(init_state, arc) for i in range(num_states): next_state = fst.add_state() transition_weight = openfst.Weight(fst.weight_type(), transition_weights[i]) if transition_weight != zero: arc = openfst.Arc(label + 1, EPSILON, transition_weight, next_state) fst.add_arc(cur_state, arc) self_weight = openfst.Weight(fst.weight_type(), self_weights[i]) if self_weight != zero: arc = openfst.Arc(label + 1, EPSILON, self_weight, cur_state) fst.add_arc(cur_state, arc) cur_state = next_state fst.set_final(cur_state, one) if not fst.verify(): raise openfst.FstError("fst.verify() returned False") return fst
def subtractLogWeights(lhs, rhs): difference = np.exp(-float(lhs)) - np.exp(-float(rhs)) return openfst.Weight(lhs.type(), difference)
def make_event_to_assembly_fst(weights, input_vocab, output_vocab, input_parts_to_str, output_parts_to_str, init_weights=None, final_weights=None, input_symbols=None, output_symbols=None, state_tx_in_input=False, eps_str='ε', dur_internal_str='I', dur_final_str='F', bos_str='<BOS>', eos_str='<EOS>', arc_type='standard'): fst = openfst.VectorFst(arc_type=arc_type) fst.set_input_symbols(input_symbols) fst.set_output_symbols(output_symbols) zero = openfst.Weight.zero(fst.weight_type()) one = openfst.Weight.one(fst.weight_type()) init_state = fst.add_state() final_state = fst.add_state() fst.set_start(init_state) fst.set_final(final_state, one) W_a_s = -scipy.special.logsumexp(-weights, axis=-1) def make_states(i_input, i_output): seg_internal_state = fst.add_state() seg_final_state = fst.add_state() if openfst.Weight(fst.weight_type(), W_a_s[i_input, i_output]) == zero: return [seg_internal_state, seg_final_state] state_istr = input_vocab[i_input] state_ostr = output_vocab[i_output] # Initial state -> seg internal state if init_weights is None: init_weight = one else: init_weight = openfst.Weight(fst.weight_type(), init_weights[i_input, i_output]) if init_weight != zero: arc_istr = bos_str arc_ostr = bos_str arc = openfst.Arc(fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), init_weight, seg_internal_state) fst.add_arc(init_state, arc) # (in, I) : (out, I), weight one, self transition if state_tx_in_input: arc_istr = input_parts_to_str[state_istr, state_ostr, state_ostr, dur_internal_str] else: arc_istr = input_parts_to_str[state_istr, dur_internal_str] arc_ostr = output_parts_to_str[state_istr, state_ostr, state_ostr, dur_internal_str] arc = openfst.Arc(fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), one, seg_internal_state) fst.add_arc(seg_internal_state, arc) # (in, F) : (out, F), weight one, transition into final state if state_tx_in_input: arc_istr = input_parts_to_str[state_istr, state_ostr, state_ostr, dur_final_str] else: arc_istr = input_parts_to_str[state_istr, dur_final_str] arc_ostr = output_parts_to_str[state_istr, state_ostr, state_ostr, dur_final_str] arc = openfst.Arc(fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), one, seg_final_state) fst.add_arc(seg_internal_state, arc) # seg final state -> final_state if final_weights is None: final_weight = one else: final_weight = openfst.Weight(fst.weight_type(), final_weights[i_input, i_output]) if final_weight != zero: arc_istr = eos_str arc_ostr = eos_str arc = openfst.Arc(fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), final_weight, final_state) fst.add_arc(seg_final_state, arc) return [seg_internal_state, seg_final_state] # Build segmental backbone states = [[ make_states(i_input, i_output) for i_output, _ in enumerate(output_vocab) ] for i_input, _ in enumerate(input_vocab)] # Add transitions from final (action, assembly) to initial (action, assembly) for i_input_cur, arr in enumerate(weights): for i_output_cur, row in enumerate(arr): for i_output_next, tx_weight in enumerate(row): weight = openfst.Weight(fst.weight_type(), tx_weight) for i_input_next, __, in enumerate(weights): for i_dur, dur_str in enumerate( (dur_internal_str, dur_final_str)): # From Seg-Final to Seg-Final or Seg-Internal from_state = states[i_input_cur][i_output_cur][1] to_state = states[i_input_next][i_output_next][i_dur] istr = input_vocab[i_input_next] cur_ostr = output_vocab[i_output_next] next_ostr = output_vocab[i_output_next] if state_tx_in_input: arc_istr = input_parts_to_str[istr, cur_ostr, next_ostr, dur_str] else: arc_istr = input_parts_to_str[istr, dur_str] arc_ostr = output_parts_to_str[istr, cur_ostr, next_ostr, dur_str] if weight != zero: arc = openfst.Arc( fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), weight, to_state) fst.add_arc(from_state, arc) if not fst.verify(): raise openfst.FstError("fst.verify() returned False") return fst
def get_best_path_SIMP(t_latt, j_latt, join_already_compiled=False, add_path_of_last_resort=False): ''' t_latt and j_latt: FST objects in memory ''' ## TODO temp assertions: #assert join_already_compiled if not join_already_compiled: comm('%s/fstcompile %s %s.bin' % (tool, j_latt, j_latt)) compiled_j_latt = j_latt + '.bin' sys.exit('complete porting to WRAP case wsoivnsovbnsfb34598t3h') else: compiled_j_latt = j_latt if add_path_of_last_resort: ## only makes sense if join_already_compiled ## In order that composition with precomputed join won't remove all paths, add ## emergency path of last resort, which we arbitrarily take to be the best ## path through target lattice without regard to join cost: path_of_last_resort = get_shortest_path(t_latt) print print '----------------POLR----------------' print path_of_last_resort print '------------------------------------' print print 'convert back to FST indexing' path_of_last_resort = [val + 1 for val in path_of_last_resort] print 'edit J to add emergency arcs' POLR_transitions = {} for fro, to in zip(path_of_last_resort[:-1], path_of_last_resort[1:]): if fro not in POLR_transitions: POLR_transitions[fro] = [] POLR_transitions[fro].append(to) #print j_latt print POLR_transitions # print # # print j_latt.verify() # # print help(j_latt) # # for arc in j_latt.arcs(): # print arc print 'here b' for from_state in j_latt.states(): if from_state in POLR_transitions: for arc in j_latt.arcs(from_state): to_state = arc.nextstate if to_state in POLR_transitions[from_state]: POLR_transitions[from_state].remove(to_state) ## what is left must be added: #print POLR_transitions[from_state] print POLR_transitions for (fro, to_list) in POLR_transitions.items(): if to_list == []: del POLR_transitions[fro] print POLR_transitions print j_latt.weight_type() BIGWEIGHT = openfst.Weight(j_latt.weight_type(), 500000000) for from_state in POLR_transitions.keys(): for to_state in POLR_transitions[from_state]: j_latt.add_arc( from_state, openfst.Arc(from_state, from_state, BIGWEIGHT, to_state)) assert j_latt.verify() #print j_latt ### TODO -- remove added arcs after search to resuse J!!!!!! #comm('%s/fstarcsort --sort_type=olabel %s.bin %s.bin.srt'%(tool, t_latt, t_latt)) c_latt = openfst.compose(t_latt, j_latt) #print ' ---- CLATT ---' #print c_latt #'/tmp/comp.fst' # #comm('%s/fstcompose %s.bin.srt %s %s'%(tool, t_latt, compiled_j_latt, c_latt)) ## TODO check if comp is empty and report nicely shortest_path = get_shortest_path(c_latt) # print # print '----------------shortest path----------------' # print shortest_path # print '------------------------------------' # print return shortest_path
def make_states(i_input, i_output): seg_internal_state = fst.add_state() seg_final_state = fst.add_state() if openfst.Weight(fst.weight_type(), W_a_s[i_input, i_output]) == zero: return [seg_internal_state, seg_final_state] state_istr = input_vocab[i_input] state_ostr = output_vocab[i_output] # Initial state -> seg internal state if init_weights is None: init_weight = one else: init_weight = openfst.Weight(fst.weight_type(), init_weights[i_input, i_output]) if init_weight != zero: arc_istr = bos_str arc_ostr = bos_str arc = openfst.Arc(fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), init_weight, seg_internal_state) fst.add_arc(init_state, arc) # (in, I) : (out, I), weight one, self transition if state_tx_in_input: arc_istr = input_parts_to_str[state_istr, state_ostr, state_ostr, dur_internal_str] else: arc_istr = input_parts_to_str[state_istr, dur_internal_str] arc_ostr = output_parts_to_str[state_istr, state_ostr, state_ostr, dur_internal_str] arc = openfst.Arc(fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), one, seg_internal_state) fst.add_arc(seg_internal_state, arc) # (in, F) : (out, F), weight one, transition into final state if state_tx_in_input: arc_istr = input_parts_to_str[state_istr, state_ostr, state_ostr, dur_final_str] else: arc_istr = input_parts_to_str[state_istr, dur_final_str] arc_ostr = output_parts_to_str[state_istr, state_ostr, state_ostr, dur_final_str] arc = openfst.Arc(fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), one, seg_final_state) fst.add_arc(seg_internal_state, arc) # seg final state -> final_state if final_weights is None: final_weight = one else: final_weight = openfst.Weight(fst.weight_type(), final_weights[i_input, i_output]) if final_weight != zero: arc_istr = eos_str arc_ostr = eos_str arc = openfst.Arc(fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), final_weight, final_state) fst.add_arc(seg_final_state, arc) return [seg_internal_state, seg_final_state]
def fromArray( weights, final_weight=None, arc_type=None, # input_labels=None, output_labels=None input_symbols=None, output_symbols=None): """ Instantiate a state machine from an array of weights. Parameters ---------- weights : array_like, shape (num_inputs, num_outputs) Needs to implement `.shape`, so it should be a numpy array or a torch tensor. final_weight : arc_types.AbstractSemiringWeight, optional Should have the same type as `arc_type`. Default is `arc_type.zero` arc_type : {'standard', 'log'}, optional Default is 'standard' (ie the tropical arc_type) input_labels : output_labels : Returns ------- fst : fsm.FST The transducer's arcs have input labels corresponding to the state they left, and output labels corresponding to the state they entered. """ if weights.ndim == 3: is_lattice = False elif weights.ndim == 2: is_lattice = True else: raise AssertionError(f"weights have unrecognized shape {weights.shape}") if arc_type is None: arc_type = 'standard' """ if output_labels is None: output_labels = {str(i): i for i in range(weights.shape[1])} if input_labels is None: if is_lattice: input_labels = {str(i): i for i in range(weights.shape[0])} else: input_labels = {str(i): i for i in range(weights.shape[2])} input_table = openfst.SymbolTable() input_table.add_symbol(EPSILON_STRING, key=EPSILON) for in_symbol, index in input_labels.items(): input_table.add_symbol(str(in_symbol), key=index + 1) output_table = openfst.SymbolTable() output_table.add_symbol(EPSILON_STRING, key=EPSILON) for out_symbol, index in output_labels.items(): output_table.add_symbol(str(out_symbol), key=index + 1) """ fst = openfst.VectorFst(arc_type=arc_type) fst.set_input_symbols(input_symbols) fst.set_output_symbols(output_symbols) zero = openfst.Weight.zero(fst.weight_type()) one = openfst.Weight.one(fst.weight_type()) if final_weight is None: final_weight = one else: final_weight = openfst.Weight(fst.weight_type(), final_weight) init_state = fst.add_state() fst.set_start(init_state) if is_lattice: prev_state = init_state for sample_index, row in enumerate(weights): cur_state = fst.add_state() for i, weight in enumerate(row): input_label_index = sample_index + 1 output_label_index = i + 1 weight = openfst.Weight(fst.weight_type(), weight) if weight != zero: arc = openfst.Arc( input_label_index, output_label_index, weight, cur_state ) fst.add_arc(prev_state, arc) prev_state = cur_state fst.set_final(cur_state, final_weight) else: prev_state = init_state for sample_index, input_output in enumerate(weights): cur_state = fst.add_state() for i, outputs in enumerate(input_output): for j, weight in enumerate(outputs): input_label_index = i + 1 output_label_index = j + 1 weight = openfst.Weight(fst.weight_type(), weight) if weight != zero: arc = openfst.Arc( input_label_index, output_label_index, weight, cur_state ) fst.add_arc(prev_state, arc) prev_state = cur_state fst.set_final(cur_state, final_weight) if not fst.verify(): raise openfst.FstError("fst.verify() returned False") return fst
def readHtkLattice(lattfile, ac_weight=1., lm_weight=1., gzipped=True): """Read a HTK lattice file and represent it as a OpenFst object. Parameters ---------- lattfile : string Path to the HTK lattice file. ac_weight : float Acoustic weight (default: 1). lm_weight : float Language model weight (default : 1). gzipped : boolean If the lattice file is compressed with gzip (default: True). Returns ------- fst_lattice : pywrapfst.Fst The lattice as an OpenFst object. """ # Make the import here only as some people may not have openfst installed. import pywrapfst as fst # Output fst. fst_lattice = fst.Fst("log") # The mapping id -> label (and the reverse one) are relative to the # lattice only. It avoids to have some global mapping. label2id = {"!NULL": 0} id2label = {0: "!NULL"} identifier = 0 # If the lattice is compressed with gzip choose the correct function. if gzipped: my_open = gzip.open mode = 'rt' else: my_open = open mode = 'r' with my_open(lattfile, mode) as f: for line in f: line = line.strip() # Get the number of nodes. If the field is not specified then # the function will fail badly. if line[:2] == "N=": n_nodes = int(line.split()[0].split("=")[-1]) # Create all the nodes of the fst lattice. for i in range(n_nodes): fst_lattice.add_state() fst_lattice.set_start(0) fst_lattice.set_final(n_nodes - 1) elif line[0] == "J": # Load the HTK arc definition. fields = line.split() start_node = int(fields[1].split("=")[-1]) end_node = int(fields[2].split("=")[-1]) label = fields[3].split("=")[-1] # Update the mapping id -> label and its reverse. try: label_id = label2id[label] except KeyError: identifier += 1 label_id = identifier label2id[label] = label_id id2label[label_id] = label # Add the arc to the FST lattice. ac_score = float(fields[5].split("=")[-1]) lm_score = float(fields[6].split("=")[-1]) score = ac_weight * ac_score + lm_weight * lm_score weight = fst.Weight(fst_lattice.weight_type, -score) arc = fst.Arc(label_id, label_id, weight, end_node) fst_lattice.add_arc(start_node, arc) return fst_lattice, id2label
def single_seg_transducer(weights, input_vocab, output_vocab, input_parts_to_str, output_parts_to_str, input_symbols=None, output_symbols=None, eps_str='ε', bos_str='<BOS>', eos_str='<EOS>', dur_internal_str='I', dur_final_str='F', arc_type='standard', pass_input=False): fst = openfst.VectorFst(arc_type=arc_type) fst.set_input_symbols(input_symbols) fst.set_output_symbols(output_symbols) zero = openfst.Weight.zero(fst.weight_type()) one = openfst.Weight.one(fst.weight_type()) state = fst.add_state() fst.set_start(state) fst.set_final(state, one) def make_state(i_input, i_output, weight): io_state = fst.add_state() state_istr = input_vocab[i_input] state_ostr = output_vocab[i_output] # CASE 1: (in, I) : (out, I), weight one, transition into io state arc_istr = input_parts_to_str[state_istr, dur_internal_str] if pass_input: arc_ostr = output_parts_to_str[state_istr, state_ostr, dur_internal_str] else: arc_ostr = output_parts_to_str[state_ostr, dur_internal_str] arc = openfst.Arc(fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), one, io_state) fst.add_arc(state, arc) fst.add_arc(io_state, arc.copy()) # CASE 2: (in, F) : (out, F), weight tx_weight arc_istr = input_parts_to_str[state_istr, dur_final_str] if pass_input: arc_ostr = output_parts_to_str[state_istr, state_ostr, dur_final_str] else: arc_ostr = output_parts_to_str[state_ostr, dur_final_str] arc = openfst.Arc(fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), weight, state) fst.add_arc(io_state, arc) for i_input, row in enumerate(weights): for i_output, tx_weight in enumerate(row): weight = openfst.Weight(fst.weight_type(), tx_weight) if weight != zero: make_state(i_input, i_output, weight) if not fst.verify(): raise openfst.FstError("fst.verify() returned False") return fst
utt_index = fst.Fst(arc_type=b'TripleTropicalWeight') print("utt name " + utt_name) elif line.strip(): #line not empty if len(line.split()) > 1: if line.find(",") > -1: utt_index.add_state() delimeters = list(find_all(line.split()[4], ",")) weight_string = line.split( )[4][:delimeters[0]] + " " + line.split( )[4][delimeters[0] + 1:delimeters[1]] + " " + line.split()[4][delimeters[1] + 1:] utt_index.add_arc( int(line.split()[0]), fst.Arc(int(line.split()[2]), int(line.split()[3]), fst.Weight(utt_index.weight_type(), weight_string), int(line.split()[1]))) else: utt_index.add_state() utt_index.add_arc( int(line.split()[0]), fst.Arc(int(line.split()[2]), int(line.split()[3]), fst.Weight.One(utt_index.weight_type()), int(line.split()[1]))) else: utt_index.add_state() utt_index.set_final(int(line.split()[0]), fst.Weight.One(utt_index.weight_type())) else: #empty line, end of lattice, do search ## masking lattice with phoneme mask ## utt_index.set_start(0)
def fromTransitions( transition_weights, init_weights=None, final_weights=None, arc_type='standard', transition_ids=None): """ Instantiate a state machine from state transitions. Parameters ---------- Returns ------- """ num_states = transition_weights.shape[0] if transition_ids is None: transition_ids = {} for s_cur in range(num_states): for s_next in range(num_states): transition_ids[(s_cur, s_next)] = len(transition_ids) for s in range(num_states): transition_ids[(-1, s)] = len(transition_ids) output_table = openfst.SymbolTable() output_table.add_symbol(EPSILON_STRING, key=EPSILON) for transition, index in transition_ids.items(): output_table.add_symbol(str(transition), key=index + 1) input_table = openfst.SymbolTable() input_table.add_symbol(EPSILON_STRING, key=EPSILON) for transition, index in transition_ids.items(): input_table.add_symbol(str(transition), key=index + 1) fst = openfst.VectorFst(arc_type=arc_type) fst.set_input_symbols(input_table) fst.set_output_symbols(output_table) zero = openfst.Weight.zero(fst.weight_type()) one = openfst.Weight.one(fst.weight_type()) if init_weights is None: init_weights = tuple(float(one) for __ in range(num_states)) if final_weights is None: final_weights = tuple(float(one) for __ in range(num_states)) fst.set_start(fst.add_state()) def makeState(i): state = fst.add_state() initial_weight = openfst.Weight(fst.weight_type(), init_weights[i]) if initial_weight != zero: transition = transition_ids[-1, i] + 1 arc = openfst.Arc(EPSILON, transition, initial_weight, state) fst.add_arc(fst.start(), arc) final_weight = openfst.Weight(fst.weight_type(), final_weights[i]) if final_weight != zero: fst.set_final(state, final_weight) return state states = tuple(makeState(i) for i in range(num_states)) for i_cur, row in enumerate(transition_weights): for i_next, tx_weight in enumerate(row): cur_state = states[i_cur] next_state = states[i_next] weight = openfst.Weight(fst.weight_type(), tx_weight) transition = transition_ids[i_cur, i_next] + 1 if weight != zero: arc = openfst.Arc(transition, transition, weight, next_state) fst.add_arc(cur_state, arc) if not fst.verify(): raise openfst.FstError("fst.verify() returned False") # print("fst.verify() returned False") return fst