def single_state_transducer(transition_weights, row_vocab, col_vocab, input_symbols=None, output_symbols=None, arc_type='standard'): fst = openfst.VectorFst(arc_type=arc_type) fst.set_input_symbols(input_symbols) fst.set_output_symbols(output_symbols) zero = openfst.Weight.zero(fst.weight_type()) one = openfst.Weight.one(fst.weight_type()) state = fst.add_state() fst.set_start(state) fst.set_final(state, one) for i_input, row in enumerate(transition_weights): for i_output, tx_weight in enumerate(row): weight = openfst.Weight(fst.weight_type(), tx_weight) input_id = fst.input_symbols().find(row_vocab[i_input]) output_id = fst.output_symbols().find(col_vocab[i_output]) if weight != zero: arc = openfst.Arc(input_id, output_id, weight, state) fst.add_arc(state, arc) if not fst.verify(): raise openfst.FstError("fst.verify() returned False") return fst
def fromSequence(seq, arc_type='standard', symbol_table=None): fst = openfst.VectorFst(arc_type=arc_type) one = openfst.Weight.one(fst.weight_type()) if symbol_table is not None: fst.set_input_symbols(symbol_table) fst.set_output_symbols(symbol_table) init_state = fst.add_state() fst.set_start(init_state) cur_state = init_state for i, label in enumerate(seq): next_state = fst.add_state() arc = openfst.Arc(label + 1, label + 1, one, next_state) fst.add_arc(cur_state, arc) cur_state = next_state fst.set_final(cur_state, one) if not fst.verify(): raise openfst.FstError("fst.verify() returned False") return fst
def make_event_to_assembly_fst(weights, input_vocab, output_vocab, input_parts_to_str, output_parts_to_str, init_weights=None, final_weights=None, input_symbols=None, output_symbols=None, state_tx_in_input=False, eps_str='ε', dur_internal_str='I', dur_final_str='F', bos_str='<BOS>', eos_str='<EOS>', arc_type='standard'): fst = openfst.VectorFst(arc_type=arc_type) fst.set_input_symbols(input_symbols) fst.set_output_symbols(output_symbols) zero = openfst.Weight.zero(fst.weight_type()) one = openfst.Weight.one(fst.weight_type()) init_state = fst.add_state() final_state = fst.add_state() fst.set_start(init_state) fst.set_final(final_state, one) W_a_s = -scipy.special.logsumexp(-weights, axis=-1) def make_states(i_input, i_output): seg_internal_state = fst.add_state() seg_final_state = fst.add_state() if openfst.Weight(fst.weight_type(), W_a_s[i_input, i_output]) == zero: return [seg_internal_state, seg_final_state] state_istr = input_vocab[i_input] state_ostr = output_vocab[i_output] # Initial state -> seg internal state if init_weights is None: init_weight = one else: init_weight = openfst.Weight(fst.weight_type(), init_weights[i_input, i_output]) if init_weight != zero: arc_istr = bos_str arc_ostr = bos_str arc = openfst.Arc(fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), init_weight, seg_internal_state) fst.add_arc(init_state, arc) # (in, I) : (out, I), weight one, self transition if state_tx_in_input: arc_istr = input_parts_to_str[state_istr, state_ostr, state_ostr, dur_internal_str] else: arc_istr = input_parts_to_str[state_istr, dur_internal_str] arc_ostr = output_parts_to_str[state_istr, state_ostr, state_ostr, dur_internal_str] arc = openfst.Arc(fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), one, seg_internal_state) fst.add_arc(seg_internal_state, arc) # (in, F) : (out, F), weight one, transition into final state if state_tx_in_input: arc_istr = input_parts_to_str[state_istr, state_ostr, state_ostr, dur_final_str] else: arc_istr = input_parts_to_str[state_istr, dur_final_str] arc_ostr = output_parts_to_str[state_istr, state_ostr, state_ostr, dur_final_str] arc = openfst.Arc(fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), one, seg_final_state) fst.add_arc(seg_internal_state, arc) # seg final state -> final_state if final_weights is None: final_weight = one else: final_weight = openfst.Weight(fst.weight_type(), final_weights[i_input, i_output]) if final_weight != zero: arc_istr = eos_str arc_ostr = eos_str arc = openfst.Arc(fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), final_weight, final_state) fst.add_arc(seg_final_state, arc) return [seg_internal_state, seg_final_state] # Build segmental backbone states = [[ make_states(i_input, i_output) for i_output, _ in enumerate(output_vocab) ] for i_input, _ in enumerate(input_vocab)] # Add transitions from final (action, assembly) to initial (action, assembly) for i_input_cur, arr in enumerate(weights): for i_output_cur, row in enumerate(arr): for i_output_next, tx_weight in enumerate(row): weight = openfst.Weight(fst.weight_type(), tx_weight) for i_input_next, __, in enumerate(weights): for i_dur, dur_str in enumerate( (dur_internal_str, dur_final_str)): # From Seg-Final to Seg-Final or Seg-Internal from_state = states[i_input_cur][i_output_cur][1] to_state = states[i_input_next][i_output_next][i_dur] istr = input_vocab[i_input_next] cur_ostr = output_vocab[i_output_next] next_ostr = output_vocab[i_output_next] if state_tx_in_input: arc_istr = input_parts_to_str[istr, cur_ostr, next_ostr, dur_str] else: arc_istr = input_parts_to_str[istr, dur_str] arc_ostr = output_parts_to_str[istr, cur_ostr, next_ostr, dur_str] if weight != zero: arc = openfst.Arc( fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), weight, to_state) fst.add_arc(from_state, arc) if not fst.verify(): raise openfst.FstError("fst.verify() returned False") return fst
def make_duration_fst(final_weights, class_vocab, class_dur_to_str, dur_internal_str='I', dur_final_str='F', input_symbols=None, output_symbols=None, allow_self_transitions=True, arc_type='standard'): num_classes, num_states = final_weights.shape fst = openfst.VectorFst(arc_type=arc_type) fst.set_input_symbols(input_symbols) fst.set_output_symbols(output_symbols) one = openfst.Weight.one(fst.weight_type()) zero = openfst.Weight.zero(fst.weight_type()) init_state = fst.add_state() final_state = fst.add_state() fst.set_start(init_state) fst.set_final(final_state, one) def durationFst(label_str, dur_internal_str, dur_final_str, final_weights): """ Construct a left-to-right WFST from an input sequence. Parameters ---------- input_seq : iterable(int or string) Returns ------- fst : openfst.Fst """ input_label = fst.input_symbols().find(label_str) output_label_int = fst.output_symbols().find(dur_internal_str) output_label_ext = fst.output_symbols().find(dur_final_str) max_dur = np.nonzero(final_weights != float(zero))[0].max() if max_dur < 1: raise AssertionError(f"max_dur = {max_dur}, but should be >= 1)") states = tuple(fst.add_state() for __ in range(max_dur)) seg_final_state = fst.add_state() fst.add_arc(init_state, openfst.Arc(0, 0, one, states[0])) fst.add_arc(seg_final_state, openfst.Arc(0, 0, one, final_state)) for i, cur_state in enumerate(states): cur_state = states[i] final_weight = openfst.Weight(fst.weight_type(), final_weights[i]) if final_weight != zero: arc = openfst.Arc(input_label, output_label_ext, one, seg_final_state) fst.add_arc(cur_state, arc) if i + 1 < len(states): next_state = states[i + 1] arc = openfst.Arc(input_label, output_label_int, one, next_state) fst.add_arc(cur_state, arc) return states[0], seg_final_state endpoints = tuple( durationFst( class_vocab[i], class_dur_to_str[class_vocab[i], dur_internal_str], class_dur_to_str[class_vocab[i], dur_final_str], final_weights[i], ) for i in range(num_classes)) for i, (s_cur_first, s_cur_last) in enumerate(endpoints): for j, (s_next_first, s_next_last) in enumerate(endpoints): if not allow_self_transitions and i == j: continue arc = openfst.Arc(0, 0, one, s_next_first) fst.add_arc(s_cur_last, arc) if not fst.verify(): raise openfst.FstError("fst.verify() returned False") return fst
def fromArray(weights, row_vocab, col_vocab, final_weight=None, arc_type=None, input_symbols=None, output_symbols=None): """ Instantiate a state machine from an array of weights. Parameters ---------- weights : array_like, shape (num_inputs, num_outputs) Needs to implement `.shape`, so it should be a numpy array or a torch tensor. final_weight : arc_types.AbstractSemiringWeight, optional Should have the same type as `arc_type`. Default is `arc_type.zero` arc_type : {'standard', 'log'}, optional Default is 'standard' (ie the tropical arc_type) input_labels : output_labels : Returns ------- fst : fsm.FST The transducer's arcs have input labels corresponding to the state they left, and output labels corresponding to the state they entered. """ if weights.ndim != 2: raise AssertionError( f"weights have unrecognized shape {weights.shape}") if arc_type is None: arc_type = 'standard' fst = openfst.VectorFst(arc_type=arc_type) fst.set_input_symbols(input_symbols) fst.set_output_symbols(output_symbols) zero = openfst.Weight.zero(fst.weight_type()) one = openfst.Weight.one(fst.weight_type()) if final_weight is None: final_weight = one else: final_weight = openfst.Weight(fst.weight_type(), final_weight) init_state = fst.add_state() fst.set_start(init_state) prev_state = init_state for sample_index, row in enumerate(weights): cur_state = fst.add_state() for i, weight in enumerate(row): input_label = row_vocab[sample_index] output_label = col_vocab[i] input_label_index = fst.input_symbols().find(input_label) output_label_index = fst.output_symbols().find(output_label) weight = openfst.Weight(fst.weight_type(), weight) if weight != zero: arc = openfst.Arc(input_label_index, output_label_index, weight, cur_state) fst.add_arc(prev_state, arc) prev_state = cur_state fst.set_final(cur_state, final_weight) if not fst.verify(): raise openfst.FstError("fst.verify() returned False") return fst
def fromTransitions(transition_weights, row_vocab, col_vocab, init_weights=None, final_weights=None, input_symbols=None, output_symbols=None, bos_str='<BOS>', eos_str='<EOS>', eps_str=libfst.EPSILON_STRING, arc_type='standard', transition_ids=None): """ Instantiate a state machine from state transitions. Parameters ---------- Returns ------- """ num_states = transition_weights.shape[0] if transition_ids is None: transition_ids = {} for s_cur in range(num_states): for s_next in range(num_states): transition_ids[(s_cur, s_next)] = len(transition_ids) for s in range(num_states): transition_ids[(-1, s)] = len(transition_ids) fst = openfst.VectorFst(arc_type=arc_type) fst.set_input_symbols(input_symbols) fst.set_output_symbols(output_symbols) zero = openfst.Weight.zero(fst.weight_type()) one = openfst.Weight.one(fst.weight_type()) if init_weights is None: init_weights = tuple(float(one) for __ in range(num_states)) if final_weights is None: final_weights = tuple(float(one) for __ in range(num_states)) fst.set_start(fst.add_state()) final_state = fst.add_state() fst.set_final(final_state, one) bos_index = fst.input_symbols().find(bos_str) eos_index = fst.input_symbols().find(eos_str) eps_index = fst.output_symbols().find(eps_str) def makeState(i): state = fst.add_state() initial_weight = openfst.Weight(fst.weight_type(), init_weights[i]) if initial_weight != zero: next_state_str = col_vocab[i] next_state_index = fst.output_symbols().find(next_state_str) arc = openfst.Arc(bos_index, next_state_index, initial_weight, state) fst.add_arc(fst.start(), arc) final_weight = openfst.Weight(fst.weight_type(), final_weights[i]) if final_weight != zero: arc = openfst.Arc(eos_index, eps_index, final_weight, final_state) fst.add_arc(state, arc) return state states = tuple(makeState(i) for i in range(num_states)) for i_cur, row in enumerate(transition_weights): for i_next, tx_weight in enumerate(row): cur_state = states[i_cur] next_state = states[i_next] weight = openfst.Weight(fst.weight_type(), tx_weight) if weight != zero: next_state_str = col_vocab[i_next] next_state_index = fst.output_symbols().find(next_state_str) arc = openfst.Arc(next_state_index, next_state_index, weight, next_state) fst.add_arc(cur_state, arc) if not fst.verify(): raise openfst.FstError("fst.verify() returned False") return fst
def single_seg_transducer(weights, input_vocab, output_vocab, input_parts_to_str, output_parts_to_str, input_symbols=None, output_symbols=None, eps_str='ε', bos_str='<BOS>', eos_str='<EOS>', dur_internal_str='I', dur_final_str='F', arc_type='standard', pass_input=False): fst = openfst.VectorFst(arc_type=arc_type) fst.set_input_symbols(input_symbols) fst.set_output_symbols(output_symbols) zero = openfst.Weight.zero(fst.weight_type()) one = openfst.Weight.one(fst.weight_type()) state = fst.add_state() fst.set_start(state) fst.set_final(state, one) def make_state(i_input, i_output, weight): io_state = fst.add_state() state_istr = input_vocab[i_input] state_ostr = output_vocab[i_output] # CASE 1: (in, I) : (out, I), weight one, transition into io state arc_istr = input_parts_to_str[state_istr, dur_internal_str] if pass_input: arc_ostr = output_parts_to_str[state_istr, state_ostr, dur_internal_str] else: arc_ostr = output_parts_to_str[state_ostr, dur_internal_str] arc = openfst.Arc(fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), one, io_state) fst.add_arc(state, arc) fst.add_arc(io_state, arc.copy()) # CASE 2: (in, F) : (out, F), weight tx_weight arc_istr = input_parts_to_str[state_istr, dur_final_str] if pass_input: arc_ostr = output_parts_to_str[state_istr, state_ostr, dur_final_str] else: arc_ostr = output_parts_to_str[state_ostr, dur_final_str] arc = openfst.Arc(fst.input_symbols().find(arc_istr), fst.output_symbols().find(arc_ostr), weight, state) fst.add_arc(io_state, arc) for i_input, row in enumerate(weights): for i_output, tx_weight in enumerate(row): weight = openfst.Weight(fst.weight_type(), tx_weight) if weight != zero: make_state(i_input, i_output, weight) if not fst.verify(): raise openfst.FstError("fst.verify() returned False") return fst
def durationFst( label, num_states, transition_weights=None, self_weights=None, arc_type='standard', symbol_table=None): """ Construct a left-to-right WFST from an input sequence. The input is usually a sequence of segment-level labels, and this machine is used to align labels with sample-level scores. Parameters ---------- input_seq : iterable(int or string) Returns ------- fst : openfst.Fst A linear-chain weighted finite-state transducer. Each state has one self-transition and one transition to its right neighbor. i.e. the topology looks like this: __ __ __ \/ \/ \/ [START] --> s1 --> s2 --> s3 --> [END] """ if num_states < 1: raise AssertionError(f"num_states = {num_states}, but should be >= 1)") fst = openfst.VectorFst(arc_type=arc_type) one = openfst.Weight.one(fst.weight_type()) zero = openfst.Weight.zero(fst.weight_type()) if transition_weights is None: transition_weights = [one for __ in range(num_states)] if self_weights is None: self_weights = [one for __ in range(num_states)] if symbol_table is not None: fst.set_input_symbols(symbol_table) fst.set_output_symbols(symbol_table) init_state = fst.add_state() fst.set_start(init_state) cur_state = fst.add_state() arc = openfst.Arc(EPSILON, label + 1, one, cur_state) fst.add_arc(init_state, arc) for i in range(num_states): next_state = fst.add_state() transition_weight = openfst.Weight(fst.weight_type(), transition_weights[i]) if transition_weight != zero: arc = openfst.Arc(label + 1, EPSILON, transition_weight, next_state) fst.add_arc(cur_state, arc) self_weight = openfst.Weight(fst.weight_type(), self_weights[i]) if self_weight != zero: arc = openfst.Arc(label + 1, EPSILON, self_weight, cur_state) fst.add_arc(cur_state, arc) cur_state = next_state fst.set_final(cur_state, one) if not fst.verify(): raise openfst.FstError("fst.verify() returned False") return fst
def fromTransitions( transition_weights, init_weights=None, final_weights=None, arc_type='standard', transition_ids=None): """ Instantiate a state machine from state transitions. Parameters ---------- Returns ------- """ num_states = transition_weights.shape[0] if transition_ids is None: transition_ids = {} for s_cur in range(num_states): for s_next in range(num_states): transition_ids[(s_cur, s_next)] = len(transition_ids) for s in range(num_states): transition_ids[(-1, s)] = len(transition_ids) output_table = openfst.SymbolTable() output_table.add_symbol(EPSILON_STRING, key=EPSILON) for transition, index in transition_ids.items(): output_table.add_symbol(str(transition), key=index + 1) input_table = openfst.SymbolTable() input_table.add_symbol(EPSILON_STRING, key=EPSILON) for transition, index in transition_ids.items(): input_table.add_symbol(str(transition), key=index + 1) fst = openfst.VectorFst(arc_type=arc_type) fst.set_input_symbols(input_table) fst.set_output_symbols(output_table) zero = openfst.Weight.zero(fst.weight_type()) one = openfst.Weight.one(fst.weight_type()) if init_weights is None: init_weights = tuple(float(one) for __ in range(num_states)) if final_weights is None: final_weights = tuple(float(one) for __ in range(num_states)) fst.set_start(fst.add_state()) def makeState(i): state = fst.add_state() initial_weight = openfst.Weight(fst.weight_type(), init_weights[i]) if initial_weight != zero: transition = transition_ids[-1, i] + 1 arc = openfst.Arc(EPSILON, transition, initial_weight, state) fst.add_arc(fst.start(), arc) final_weight = openfst.Weight(fst.weight_type(), final_weights[i]) if final_weight != zero: fst.set_final(state, final_weight) return state states = tuple(makeState(i) for i in range(num_states)) for i_cur, row in enumerate(transition_weights): for i_next, tx_weight in enumerate(row): cur_state = states[i_cur] next_state = states[i_next] weight = openfst.Weight(fst.weight_type(), tx_weight) transition = transition_ids[i_cur, i_next] + 1 if weight != zero: arc = openfst.Arc(transition, transition, weight, next_state) fst.add_arc(cur_state, arc) if not fst.verify(): raise openfst.FstError("fst.verify() returned False") # print("fst.verify() returned False") return fst
def fromArray( weights, final_weight=None, arc_type=None, # input_labels=None, output_labels=None input_symbols=None, output_symbols=None): """ Instantiate a state machine from an array of weights. Parameters ---------- weights : array_like, shape (num_inputs, num_outputs) Needs to implement `.shape`, so it should be a numpy array or a torch tensor. final_weight : arc_types.AbstractSemiringWeight, optional Should have the same type as `arc_type`. Default is `arc_type.zero` arc_type : {'standard', 'log'}, optional Default is 'standard' (ie the tropical arc_type) input_labels : output_labels : Returns ------- fst : fsm.FST The transducer's arcs have input labels corresponding to the state they left, and output labels corresponding to the state they entered. """ if weights.ndim == 3: is_lattice = False elif weights.ndim == 2: is_lattice = True else: raise AssertionError(f"weights have unrecognized shape {weights.shape}") if arc_type is None: arc_type = 'standard' """ if output_labels is None: output_labels = {str(i): i for i in range(weights.shape[1])} if input_labels is None: if is_lattice: input_labels = {str(i): i for i in range(weights.shape[0])} else: input_labels = {str(i): i for i in range(weights.shape[2])} input_table = openfst.SymbolTable() input_table.add_symbol(EPSILON_STRING, key=EPSILON) for in_symbol, index in input_labels.items(): input_table.add_symbol(str(in_symbol), key=index + 1) output_table = openfst.SymbolTable() output_table.add_symbol(EPSILON_STRING, key=EPSILON) for out_symbol, index in output_labels.items(): output_table.add_symbol(str(out_symbol), key=index + 1) """ fst = openfst.VectorFst(arc_type=arc_type) fst.set_input_symbols(input_symbols) fst.set_output_symbols(output_symbols) zero = openfst.Weight.zero(fst.weight_type()) one = openfst.Weight.one(fst.weight_type()) if final_weight is None: final_weight = one else: final_weight = openfst.Weight(fst.weight_type(), final_weight) init_state = fst.add_state() fst.set_start(init_state) if is_lattice: prev_state = init_state for sample_index, row in enumerate(weights): cur_state = fst.add_state() for i, weight in enumerate(row): input_label_index = sample_index + 1 output_label_index = i + 1 weight = openfst.Weight(fst.weight_type(), weight) if weight != zero: arc = openfst.Arc( input_label_index, output_label_index, weight, cur_state ) fst.add_arc(prev_state, arc) prev_state = cur_state fst.set_final(cur_state, final_weight) else: prev_state = init_state for sample_index, input_output in enumerate(weights): cur_state = fst.add_state() for i, outputs in enumerate(input_output): for j, weight in enumerate(outputs): input_label_index = i + 1 output_label_index = j + 1 weight = openfst.Weight(fst.weight_type(), weight) if weight != zero: arc = openfst.Arc( input_label_index, output_label_index, weight, cur_state ) fst.add_arc(prev_state, arc) prev_state = cur_state fst.set_final(cur_state, final_weight) if not fst.verify(): raise openfst.FstError("fst.verify() returned False") return fst