Ejemplo n.º 1
0
def single_state_transducer(transition_weights,
                            row_vocab,
                            col_vocab,
                            input_symbols=None,
                            output_symbols=None,
                            arc_type='standard'):

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    state = fst.add_state()
    fst.set_start(state)
    fst.set_final(state, one)

    for i_input, row in enumerate(transition_weights):
        for i_output, tx_weight in enumerate(row):
            weight = openfst.Weight(fst.weight_type(), tx_weight)
            input_id = fst.input_symbols().find(row_vocab[i_input])
            output_id = fst.output_symbols().find(col_vocab[i_output])
            if weight != zero:
                arc = openfst.Arc(input_id, output_id, weight, state)
                fst.add_arc(state, arc)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Ejemplo n.º 2
0
def fromSequence(seq, arc_type='standard', symbol_table=None):
    fst = openfst.VectorFst(arc_type=arc_type)
    one = openfst.Weight.one(fst.weight_type())

    if symbol_table is not None:
        fst.set_input_symbols(symbol_table)
        fst.set_output_symbols(symbol_table)

    init_state = fst.add_state()
    fst.set_start(init_state)

    cur_state = init_state
    for i, label in enumerate(seq):
        next_state = fst.add_state()
        arc = openfst.Arc(label + 1, label + 1, one, next_state)
        fst.add_arc(cur_state, arc)
        cur_state = next_state

    fst.set_final(cur_state, one)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Ejemplo n.º 3
0
def make_event_to_assembly_fst(weights,
                               input_vocab,
                               output_vocab,
                               input_parts_to_str,
                               output_parts_to_str,
                               init_weights=None,
                               final_weights=None,
                               input_symbols=None,
                               output_symbols=None,
                               state_tx_in_input=False,
                               eps_str='ε',
                               dur_internal_str='I',
                               dur_final_str='F',
                               bos_str='<BOS>',
                               eos_str='<EOS>',
                               arc_type='standard'):

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    init_state = fst.add_state()
    final_state = fst.add_state()
    fst.set_start(init_state)
    fst.set_final(final_state, one)

    W_a_s = -scipy.special.logsumexp(-weights, axis=-1)

    def make_states(i_input, i_output):
        seg_internal_state = fst.add_state()
        seg_final_state = fst.add_state()

        if openfst.Weight(fst.weight_type(), W_a_s[i_input, i_output]) == zero:
            return [seg_internal_state, seg_final_state]

        state_istr = input_vocab[i_input]
        state_ostr = output_vocab[i_output]

        # Initial state -> seg internal state
        if init_weights is None:
            init_weight = one
        else:
            init_weight = openfst.Weight(fst.weight_type(),
                                         init_weights[i_input, i_output])
        if init_weight != zero:
            arc_istr = bos_str
            arc_ostr = bos_str
            arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                              fst.output_symbols().find(arc_ostr), init_weight,
                              seg_internal_state)
            fst.add_arc(init_state, arc)

        # (in, I) : (out, I), weight one, self transition
        if state_tx_in_input:
            arc_istr = input_parts_to_str[state_istr, state_ostr, state_ostr,
                                          dur_internal_str]
        else:
            arc_istr = input_parts_to_str[state_istr, dur_internal_str]
        arc_ostr = output_parts_to_str[state_istr, state_ostr, state_ostr,
                                       dur_internal_str]
        arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                          fst.output_symbols().find(arc_ostr), one,
                          seg_internal_state)
        fst.add_arc(seg_internal_state, arc)

        # (in, F) : (out, F), weight one, transition into final state
        if state_tx_in_input:
            arc_istr = input_parts_to_str[state_istr, state_ostr, state_ostr,
                                          dur_final_str]
        else:
            arc_istr = input_parts_to_str[state_istr, dur_final_str]
        arc_ostr = output_parts_to_str[state_istr, state_ostr, state_ostr,
                                       dur_final_str]
        arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                          fst.output_symbols().find(arc_ostr), one,
                          seg_final_state)
        fst.add_arc(seg_internal_state, arc)

        # seg final state -> final_state
        if final_weights is None:
            final_weight = one
        else:
            final_weight = openfst.Weight(fst.weight_type(),
                                          final_weights[i_input, i_output])
        if final_weight != zero:
            arc_istr = eos_str
            arc_ostr = eos_str
            arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                              fst.output_symbols().find(arc_ostr),
                              final_weight, final_state)
            fst.add_arc(seg_final_state, arc)

        return [seg_internal_state, seg_final_state]

    # Build segmental backbone
    states = [[
        make_states(i_input, i_output)
        for i_output, _ in enumerate(output_vocab)
    ] for i_input, _ in enumerate(input_vocab)]

    # Add transitions from final (action, assembly) to initial (action, assembly)
    for i_input_cur, arr in enumerate(weights):
        for i_output_cur, row in enumerate(arr):
            for i_output_next, tx_weight in enumerate(row):
                weight = openfst.Weight(fst.weight_type(), tx_weight)
                for i_input_next, __, in enumerate(weights):
                    for i_dur, dur_str in enumerate(
                        (dur_internal_str, dur_final_str)):
                        # From Seg-Final to Seg-Final or Seg-Internal
                        from_state = states[i_input_cur][i_output_cur][1]
                        to_state = states[i_input_next][i_output_next][i_dur]

                        istr = input_vocab[i_input_next]
                        cur_ostr = output_vocab[i_output_next]
                        next_ostr = output_vocab[i_output_next]
                        if state_tx_in_input:
                            arc_istr = input_parts_to_str[istr, cur_ostr,
                                                          next_ostr, dur_str]
                        else:
                            arc_istr = input_parts_to_str[istr, dur_str]
                        arc_ostr = output_parts_to_str[istr, cur_ostr,
                                                       next_ostr, dur_str]

                        if weight != zero:
                            arc = openfst.Arc(
                                fst.input_symbols().find(arc_istr),
                                fst.output_symbols().find(arc_ostr), weight,
                                to_state)
                            fst.add_arc(from_state, arc)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Ejemplo n.º 4
0
def make_duration_fst(final_weights,
                      class_vocab,
                      class_dur_to_str,
                      dur_internal_str='I',
                      dur_final_str='F',
                      input_symbols=None,
                      output_symbols=None,
                      allow_self_transitions=True,
                      arc_type='standard'):
    num_classes, num_states = final_weights.shape

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    one = openfst.Weight.one(fst.weight_type())
    zero = openfst.Weight.zero(fst.weight_type())

    init_state = fst.add_state()
    final_state = fst.add_state()
    fst.set_start(init_state)
    fst.set_final(final_state, one)

    def durationFst(label_str, dur_internal_str, dur_final_str, final_weights):
        """ Construct a left-to-right WFST from an input sequence.

        Parameters
        ----------
        input_seq : iterable(int or string)

        Returns
        -------
        fst : openfst.Fst
        """

        input_label = fst.input_symbols().find(label_str)
        output_label_int = fst.output_symbols().find(dur_internal_str)
        output_label_ext = fst.output_symbols().find(dur_final_str)

        max_dur = np.nonzero(final_weights != float(zero))[0].max()
        if max_dur < 1:
            raise AssertionError(f"max_dur = {max_dur}, but should be >= 1)")

        states = tuple(fst.add_state() for __ in range(max_dur))
        seg_final_state = fst.add_state()
        fst.add_arc(init_state, openfst.Arc(0, 0, one, states[0]))
        fst.add_arc(seg_final_state, openfst.Arc(0, 0, one, final_state))

        for i, cur_state in enumerate(states):
            cur_state = states[i]

            final_weight = openfst.Weight(fst.weight_type(), final_weights[i])
            if final_weight != zero:
                arc = openfst.Arc(input_label, output_label_ext, one,
                                  seg_final_state)
                fst.add_arc(cur_state, arc)

            if i + 1 < len(states):
                next_state = states[i + 1]
                arc = openfst.Arc(input_label, output_label_int, one,
                                  next_state)
                fst.add_arc(cur_state, arc)

        return states[0], seg_final_state

    endpoints = tuple(
        durationFst(
            class_vocab[i],
            class_dur_to_str[class_vocab[i], dur_internal_str],
            class_dur_to_str[class_vocab[i], dur_final_str],
            final_weights[i],
        ) for i in range(num_classes))

    for i, (s_cur_first, s_cur_last) in enumerate(endpoints):
        for j, (s_next_first, s_next_last) in enumerate(endpoints):
            if not allow_self_transitions and i == j:
                continue
            arc = openfst.Arc(0, 0, one, s_next_first)
            fst.add_arc(s_cur_last, arc)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Ejemplo n.º 5
0
def fromArray(weights,
              row_vocab,
              col_vocab,
              final_weight=None,
              arc_type=None,
              input_symbols=None,
              output_symbols=None):
    """ Instantiate a state machine from an array of weights.

    Parameters
    ----------
    weights : array_like, shape (num_inputs, num_outputs)
        Needs to implement `.shape`, so it should be a numpy array or a torch
        tensor.
    final_weight : arc_types.AbstractSemiringWeight, optional
        Should have the same type as `arc_type`. Default is `arc_type.zero`
    arc_type : {'standard', 'log'}, optional
        Default is 'standard' (ie the tropical arc_type)
    input_labels :
    output_labels :

    Returns
    -------
    fst : fsm.FST
        The transducer's arcs have input labels corresponding to the state
        they left, and output labels corresponding to the state they entered.
    """

    if weights.ndim != 2:
        raise AssertionError(
            f"weights have unrecognized shape {weights.shape}")

    if arc_type is None:
        arc_type = 'standard'

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    if final_weight is None:
        final_weight = one
    else:
        final_weight = openfst.Weight(fst.weight_type(), final_weight)

    init_state = fst.add_state()
    fst.set_start(init_state)

    prev_state = init_state
    for sample_index, row in enumerate(weights):
        cur_state = fst.add_state()
        for i, weight in enumerate(row):
            input_label = row_vocab[sample_index]
            output_label = col_vocab[i]
            input_label_index = fst.input_symbols().find(input_label)
            output_label_index = fst.output_symbols().find(output_label)
            weight = openfst.Weight(fst.weight_type(), weight)
            if weight != zero:
                arc = openfst.Arc(input_label_index, output_label_index,
                                  weight, cur_state)
                fst.add_arc(prev_state, arc)
        prev_state = cur_state
    fst.set_final(cur_state, final_weight)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Ejemplo n.º 6
0
def fromTransitions(transition_weights,
                    row_vocab,
                    col_vocab,
                    init_weights=None,
                    final_weights=None,
                    input_symbols=None,
                    output_symbols=None,
                    bos_str='<BOS>',
                    eos_str='<EOS>',
                    eps_str=libfst.EPSILON_STRING,
                    arc_type='standard',
                    transition_ids=None):
    """ Instantiate a state machine from state transitions.

    Parameters
    ----------

    Returns
    -------
    """

    num_states = transition_weights.shape[0]

    if transition_ids is None:
        transition_ids = {}
        for s_cur in range(num_states):
            for s_next in range(num_states):
                transition_ids[(s_cur, s_next)] = len(transition_ids)
        for s in range(num_states):
            transition_ids[(-1, s)] = len(transition_ids)

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    if init_weights is None:
        init_weights = tuple(float(one) for __ in range(num_states))

    if final_weights is None:
        final_weights = tuple(float(one) for __ in range(num_states))

    fst.set_start(fst.add_state())
    final_state = fst.add_state()
    fst.set_final(final_state, one)

    bos_index = fst.input_symbols().find(bos_str)
    eos_index = fst.input_symbols().find(eos_str)
    eps_index = fst.output_symbols().find(eps_str)

    def makeState(i):
        state = fst.add_state()

        initial_weight = openfst.Weight(fst.weight_type(), init_weights[i])
        if initial_weight != zero:
            next_state_str = col_vocab[i]
            next_state_index = fst.output_symbols().find(next_state_str)
            arc = openfst.Arc(bos_index, next_state_index, initial_weight,
                              state)
            fst.add_arc(fst.start(), arc)

        final_weight = openfst.Weight(fst.weight_type(), final_weights[i])
        if final_weight != zero:
            arc = openfst.Arc(eos_index, eps_index, final_weight, final_state)
            fst.add_arc(state, arc)

        return state

    states = tuple(makeState(i) for i in range(num_states))
    for i_cur, row in enumerate(transition_weights):
        for i_next, tx_weight in enumerate(row):
            cur_state = states[i_cur]
            next_state = states[i_next]
            weight = openfst.Weight(fst.weight_type(), tx_weight)
            if weight != zero:
                next_state_str = col_vocab[i_next]
                next_state_index = fst.output_symbols().find(next_state_str)
                arc = openfst.Arc(next_state_index, next_state_index, weight,
                                  next_state)
                fst.add_arc(cur_state, arc)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Ejemplo n.º 7
0
def single_seg_transducer(weights,
                          input_vocab,
                          output_vocab,
                          input_parts_to_str,
                          output_parts_to_str,
                          input_symbols=None,
                          output_symbols=None,
                          eps_str='ε',
                          bos_str='<BOS>',
                          eos_str='<EOS>',
                          dur_internal_str='I',
                          dur_final_str='F',
                          arc_type='standard',
                          pass_input=False):

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    state = fst.add_state()
    fst.set_start(state)
    fst.set_final(state, one)

    def make_state(i_input, i_output, weight):
        io_state = fst.add_state()

        state_istr = input_vocab[i_input]
        state_ostr = output_vocab[i_output]

        # CASE 1: (in, I) : (out, I), weight one, transition into io state
        arc_istr = input_parts_to_str[state_istr, dur_internal_str]
        if pass_input:
            arc_ostr = output_parts_to_str[state_istr, state_ostr,
                                           dur_internal_str]
        else:
            arc_ostr = output_parts_to_str[state_ostr, dur_internal_str]
        arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                          fst.output_symbols().find(arc_ostr), one, io_state)
        fst.add_arc(state, arc)
        fst.add_arc(io_state, arc.copy())

        # CASE 2: (in, F) : (out, F), weight tx_weight
        arc_istr = input_parts_to_str[state_istr, dur_final_str]
        if pass_input:
            arc_ostr = output_parts_to_str[state_istr, state_ostr,
                                           dur_final_str]
        else:
            arc_ostr = output_parts_to_str[state_ostr, dur_final_str]
        arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                          fst.output_symbols().find(arc_ostr), weight, state)
        fst.add_arc(io_state, arc)

    for i_input, row in enumerate(weights):
        for i_output, tx_weight in enumerate(row):
            weight = openfst.Weight(fst.weight_type(), tx_weight)
            if weight != zero:
                make_state(i_input, i_output, weight)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Ejemplo n.º 8
0
def durationFst(
        label, num_states, transition_weights=None, self_weights=None,
        arc_type='standard', symbol_table=None):
    """ Construct a left-to-right WFST from an input sequence.

    The input is usually a sequence of segment-level labels, and this machine
    is used to align labels with sample-level scores.

    Parameters
    ----------
    input_seq : iterable(int or string)

    Returns
    -------
    fst : openfst.Fst
        A linear-chain weighted finite-state transducer. Each state
        has one self-transition and one transition to its right neighbor. i.e.
        the topology looks like this:
                        __     __     __
                        \/     \/     \/
            [START] --> s1 --> s2 --> s3 --> [END]
    """

    if num_states < 1:
        raise AssertionError(f"num_states = {num_states}, but should be >= 1)")

    fst = openfst.VectorFst(arc_type=arc_type)
    one = openfst.Weight.one(fst.weight_type())
    zero = openfst.Weight.zero(fst.weight_type())

    if transition_weights is None:
        transition_weights = [one for __ in range(num_states)]

    if self_weights is None:
        self_weights = [one for __ in range(num_states)]

    if symbol_table is not None:
        fst.set_input_symbols(symbol_table)
        fst.set_output_symbols(symbol_table)

    init_state = fst.add_state()
    fst.set_start(init_state)

    cur_state = fst.add_state()
    arc = openfst.Arc(EPSILON, label + 1, one, cur_state)
    fst.add_arc(init_state, arc)

    for i in range(num_states):
        next_state = fst.add_state()

        transition_weight = openfst.Weight(fst.weight_type(), transition_weights[i])
        if transition_weight != zero:
            arc = openfst.Arc(label + 1, EPSILON, transition_weight, next_state)
            fst.add_arc(cur_state, arc)

        self_weight = openfst.Weight(fst.weight_type(), self_weights[i])
        if self_weight != zero:
            arc = openfst.Arc(label + 1, EPSILON, self_weight, cur_state)
            fst.add_arc(cur_state, arc)

        cur_state = next_state

    fst.set_final(cur_state, one)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Ejemplo n.º 9
0
def fromTransitions(
        transition_weights, init_weights=None, final_weights=None,
        arc_type='standard', transition_ids=None):
    """ Instantiate a state machine from state transitions.

    Parameters
    ----------

    Returns
    -------
    """

    num_states = transition_weights.shape[0]

    if transition_ids is None:
        transition_ids = {}
        for s_cur in range(num_states):
            for s_next in range(num_states):
                transition_ids[(s_cur, s_next)] = len(transition_ids)
        for s in range(num_states):
            transition_ids[(-1, s)] = len(transition_ids)

    output_table = openfst.SymbolTable()
    output_table.add_symbol(EPSILON_STRING, key=EPSILON)
    for transition, index in transition_ids.items():
        output_table.add_symbol(str(transition), key=index + 1)

    input_table = openfst.SymbolTable()
    input_table.add_symbol(EPSILON_STRING, key=EPSILON)
    for transition, index in transition_ids.items():
        input_table.add_symbol(str(transition), key=index + 1)

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_table)
    fst.set_output_symbols(output_table)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    if init_weights is None:
        init_weights = tuple(float(one) for __ in range(num_states))

    if final_weights is None:
        final_weights = tuple(float(one) for __ in range(num_states))

    fst.set_start(fst.add_state())

    def makeState(i):
        state = fst.add_state()

        initial_weight = openfst.Weight(fst.weight_type(), init_weights[i])
        if initial_weight != zero:
            transition = transition_ids[-1, i] + 1
            arc = openfst.Arc(EPSILON, transition, initial_weight, state)
            fst.add_arc(fst.start(), arc)

        final_weight = openfst.Weight(fst.weight_type(), final_weights[i])
        if final_weight != zero:
            fst.set_final(state, final_weight)

        return state

    states = tuple(makeState(i) for i in range(num_states))
    for i_cur, row in enumerate(transition_weights):
        for i_next, tx_weight in enumerate(row):
            cur_state = states[i_cur]
            next_state = states[i_next]
            weight = openfst.Weight(fst.weight_type(), tx_weight)
            transition = transition_ids[i_cur, i_next] + 1
            if weight != zero:
                arc = openfst.Arc(transition, transition, weight, next_state)
                fst.add_arc(cur_state, arc)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")
        # print("fst.verify() returned False")

    return fst
Ejemplo n.º 10
0
def fromArray(
        weights, final_weight=None, arc_type=None,
        # input_labels=None, output_labels=None
        input_symbols=None, output_symbols=None):
    """ Instantiate a state machine from an array of weights.

    Parameters
    ----------
    weights : array_like, shape (num_inputs, num_outputs)
        Needs to implement `.shape`, so it should be a numpy array or a torch
        tensor.
    final_weight : arc_types.AbstractSemiringWeight, optional
        Should have the same type as `arc_type`. Default is `arc_type.zero`
    arc_type : {'standard', 'log'}, optional
        Default is 'standard' (ie the tropical arc_type)
    input_labels :
    output_labels :

    Returns
    -------
    fst : fsm.FST
        The transducer's arcs have input labels corresponding to the state
        they left, and output labels corresponding to the state they entered.
    """
    if weights.ndim == 3:
        is_lattice = False
    elif weights.ndim == 2:
        is_lattice = True
    else:
        raise AssertionError(f"weights have unrecognized shape {weights.shape}")

    if arc_type is None:
        arc_type = 'standard'

    """
    if output_labels is None:
        output_labels = {str(i): i for i in range(weights.shape[1])}

    if input_labels is None:
        if is_lattice:
            input_labels = {str(i): i for i in range(weights.shape[0])}
        else:
            input_labels = {str(i): i for i in range(weights.shape[2])}

    input_table = openfst.SymbolTable()
    input_table.add_symbol(EPSILON_STRING, key=EPSILON)
    for in_symbol, index in input_labels.items():
        input_table.add_symbol(str(in_symbol), key=index + 1)

    output_table = openfst.SymbolTable()
    output_table.add_symbol(EPSILON_STRING, key=EPSILON)
    for out_symbol, index in output_labels.items():
        output_table.add_symbol(str(out_symbol), key=index + 1)
    """

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    if final_weight is None:
        final_weight = one
    else:
        final_weight = openfst.Weight(fst.weight_type(), final_weight)

    init_state = fst.add_state()
    fst.set_start(init_state)

    if is_lattice:
        prev_state = init_state
        for sample_index, row in enumerate(weights):
            cur_state = fst.add_state()
            for i, weight in enumerate(row):
                input_label_index = sample_index + 1
                output_label_index = i + 1
                weight = openfst.Weight(fst.weight_type(), weight)
                if weight != zero:
                    arc = openfst.Arc(
                        input_label_index, output_label_index,
                        weight, cur_state
                    )
                    fst.add_arc(prev_state, arc)
            prev_state = cur_state
        fst.set_final(cur_state, final_weight)
    else:
        prev_state = init_state
        for sample_index, input_output in enumerate(weights):
            cur_state = fst.add_state()
            for i, outputs in enumerate(input_output):
                for j, weight in enumerate(outputs):
                    input_label_index = i + 1
                    output_label_index = j + 1
                    weight = openfst.Weight(fst.weight_type(), weight)
                    if weight != zero:
                        arc = openfst.Arc(
                            input_label_index, output_label_index,
                            weight, cur_state
                        )
                        fst.add_arc(prev_state, arc)
            prev_state = cur_state
        fst.set_final(cur_state, final_weight)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst