Example #1
0
    def makeState(i):
        state = fst.add_state()

        initial_weight = openfst.Weight(fst.weight_type(), init_weights[i])
        if initial_weight != zero:
            transition = transition_ids[-1, i] + 1
            arc = openfst.Arc(EPSILON, transition, initial_weight, state)
            fst.add_arc(fst.start(), arc)

        final_weight = openfst.Weight(fst.weight_type(), final_weights[i])
        if final_weight != zero:
            fst.set_final(state, final_weight)

        return state
Example #2
0
def single_state_transducer(transition_weights,
                            row_vocab,
                            col_vocab,
                            input_symbols=None,
                            output_symbols=None,
                            arc_type='standard'):

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    state = fst.add_state()
    fst.set_start(state)
    fst.set_final(state, one)

    for i_input, row in enumerate(transition_weights):
        for i_output, tx_weight in enumerate(row):
            weight = openfst.Weight(fst.weight_type(), tx_weight)
            input_id = fst.input_symbols().find(row_vocab[i_input])
            output_id = fst.output_symbols().find(col_vocab[i_output])
            if weight != zero:
                arc = openfst.Arc(input_id, output_id, weight, state)
                fst.add_arc(state, arc)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Example #3
0
def gen_trigram_graph(ngram_to_class_file,
                      net_vocab_file,
                      token_file,
                      out_file,
                      add_final_space=False,
                      use_contextual_blanks=False,
                      prevent_epsilons=False,
                      determinize=True):

    net_vocab = read_net_vocab(net_vocab_file)
    print("net vocab", net_vocab)
    N = len(net_vocab)

    with open(ngram_to_class_file, 'r') as f:
        trigrams = [tuple([int(n) for n in line.split()]) for line in f]

    CTC = build_ctc_trigram_decoding_fst_v2(
        N,
        trigrams,
        arc_type='standard',
        use_context_blanks=use_contextual_blanks,
        prevent_epsilons=prevent_epsilons,
        determinize=determinize,
        add_syms=False)

    assert CTC.weight_type() == 'tropical'

    # Emitted symbols need to be remapped from net_vocab to token symbols
    #   net_vocab[:5] : ['<pad>', '<unk>', '<spc>', 'E', 'T']
    #   tokens[:5]    : ['<eps> 0', '<spc> 1', '<pad> 2', '<unk> 3', 'E 4']
    # <pad> is unused and gets mapped to eps, <unk> and <spc> change ids,
    # the rest is roughly shifted by 1.
    tokens = {t.split()[0]: int(t.split()[1]) for t in open(token_file, 'r')}
    net_vocab_dict = {t: i for i, t in enumerate(net_vocab)}
    osym_map = []
    for t, i in net_vocab_dict.items():
        osym_map.append((i, 0 if t == '<pad>' else tokens[t]))
    CTC.relabel_pairs(ipairs=None, opairs=osym_map)
    print(osym_map)

    CTC_os = fst.SymbolTable.read_text(token_file)
    CTC.set_output_symbols(CTC_os)
    os_eps = CTC_os.find('<eps>')
    assert os_eps == 0

    weight_one = fst.Weight.One('tropical')

    if add_final_space:
        is_final = lambda s: CTC.final(s) != fst.Weight(
            CTC.weight_type(), 'infinity')
        final_space = CTC.add_state()
        CTC.set_final(final_space)
        final_space_arc = fst.Arc(0, CTC_os.find('<spc>'), weight_one,
                                  final_space)
        for s in CTC.states():
            if is_final(s):
                CTC.add_arc(s, final_space_arc)

    CTC.arcsort('olabel')
    CTC.write(out_file)
Example #4
0
 def __call__(self, x):
     x, xs = transform_output(x)
     # Normalize log-posterior matrices, if necessary
     if self._normalize:
         x = log_softmax(x, dim=2)
     x = x.permute(1, 0, 2).cpu()
     self._output = []
     D = x.size(2)
     for logpost, length in zip(x, xs):
         f = fst.Fst()
         f.set_start(f.add_state())
         for t in range(length):
             f.add_state()
             for j in range(D):
                 weight = fst.Weight(f.weight_type(), float(-logpost[t, j]))
                 f.add_arc(
                     t,
                     fst.Arc(
                         j + 1,  # input label
                         j + 1,  # output label
                         weight,  # -logpost[t, j]
                         t + 1,  # nextstate
                     ),
                 )
         f.set_final(length, fst.Weight.One(f.weight_type()))
         f.verify()
         self._output.append(f)
     return self._output
Example #5
0
def add_arc(fst_in, from_word, to_word, weight):
    """
	Adds an arc to a given FST
	Note: Despite returning an updated FST, this  method makes the changes
	**IN PLACE**, so you may want to make a copy of the original
	FST before updating the weights
	:param fst_in: <openfst.Fst> to modify
	:param from_word: <str>
	:param to_word: <str>
	:param weight: <float>
	:return: updated <openfst.Fst>
	"""
    # make a dict and node_2_word from index_fst()
    fst_dict, node_2_word = index_fst(fst_in)

    # get a lookup table
    lookup = fst_in.input_symbols()

    # set from state as idx
    from_state = fst_dict[from_word]["state_id"]

    # set to state as idx
    to_state = fst_dict[to_word]["state_id"]

    fst_in = fst_in.add_arc(
        from_state,
        openfst.Arc(lookup_word(to_word, lookup), lookup_word(to_word, lookup),
                    openfst.Weight("tropical", weight), to_state))

    return fst_in
Example #6
0
    def makeState(i):
        state = fst.add_state()

        initial_weight = openfst.Weight(fst.weight_type(), init_weights[i])
        if initial_weight != zero:
            next_state_str = col_vocab[i]
            next_state_index = fst.output_symbols().find(next_state_str)
            arc = openfst.Arc(bos_index, next_state_index, initial_weight,
                              state)
            fst.add_arc(fst.start(), arc)

        final_weight = openfst.Weight(fst.weight_type(), final_weights[i])
        if final_weight != zero:
            arc = openfst.Arc(eos_index, eps_index, final_weight, final_state)
            fst.add_arc(state, arc)

        return state
Example #7
0
 def normalize_fst(f):
     f2 = f.copy()
     z = fst.shortestdistance(fst.arcmap(f, map_type="to_log"),
                              reverse=True)[0]
     for s in f2.states():
         w = f2.final(s)
         nw = fst.Weight(f2.weight_type(), float(w) - float(z))
         f2.set_final(s, nw)
     return f2
Example #8
0
def remove_arc(fst_in, from_word, to_word):
    """
	Removes an arc from a given FST
	Note: Despite returning an updated FST, this  method makes the changes
	**IN PLACE**, so you may want to make a copy of the original
	FST before updating the weights
	:param fst_in: <openfst.Fst> to modify
	:param from_word: <str>
	:param to_word: <str>
	:return: updated <openfst.Fst>

	"""
    # make a dict and node_2_word from index_fst()
    fst_dict, node_2_word = index_fst(fst_in)

    # get a lookup table
    lookup = fst_in.input_symbols()

    # set from state as idx
    from_state = fst_dict[from_word]["state_id"]

    # initialize list to hold all arcs to add
    arcs_to_keep = []

    # traverse all arcs and add to arcs_to_keep
    # except for one to remove
    for arc in fst_in.arcs(from_state):
        arc_from_word = node_2_word[from_state]
        arc_to_word = node_2_word[arc.nextstate]
        arc_weight = float(arc.weight.to_string())
        if not (arc_from_word == from_word and arc_to_word == to_word):
            dict_ = {
                "from_state": from_state,
                "to_state": arc.nextstate,
                "to_word_id": arc.ilabel,
                "weight": arc_weight
            }
            arcs_to_keep.append(dict_)
        else:
            print("removing: from_state:{} -> arc:{} -> to_state:{}".format(
                from_state, arc.ilabel, arc.nextstate))

    # delete all arcs from from_state
    fst_intermediate = fst_in.delete_arcs(from_state)

    # add back arcs from arcs_to_keep
    for arc_dict in arcs_to_keep:
        fst_in = fst_in.add_arc(
            arc_dict["from_state"],
            openfst.Arc(arc_dict["to_word_id"], arc_dict["to_word_id"],
                        openfst.Weight("tropical", arc_dict["weight"]),
                        arc_dict["to_state"]))

    return fst_in
Example #9
0
 def makelattice(self, fst, startstate, symtable, cost, firstword):
     length = self.nleafnodes()
     fst.add_arc(
         startstate,
         wfst.Arc(
             symtable[self.value], symtable[self.value],
             wfst.Weight(fst.weight_type(), cost(self.value, firstword)),
             startstate + length))
     offset = 0
     for n in self.nodes:
         n.makelattice(fst,
                       startstate + offset,
                       symtable,
                       cost,
                       firstword=False)
         offset += n.nleafnodes()
Example #10
0
    def toFst(self):
        """Convert the HMM graph to an OpenFst object.

        You need to have installed the OpenFst python extension to use
        this method.

        Returns
        -------
        graph : pywrapfst.Fst
            The FST representation of the HMM graph. An super initial
            state and a super final state will be added though they are
            not present in the HMM.

        """

        import pywrapfst as fst

        f = fst.Fst('log')

        start_state = f.add_state()
        f.set_start(start_state)
        end_state = f.add_state()
        f.set_final(end_state)

        state_fstid = {}
        for state in self.states:
            fstid = f.add_state()
            state_fstid[state.state_id] = fstid

        for state in self.states:
            for next_state_id, weight in state.next_states.items():
                fstid = state_fstid[state.state_id]
                next_fstid = state_fstid[next_state_id]
                arc = fst.Arc(0, 0, fst.Weight('log', -weight), next_fstid)
                f.add_arc(fstid, arc)

        for state in self.init_states:
            fstid = state_fstid[state.state_id]
            arc = fst.Arc(0, 0, fst.Weight.One('log'), fstid)
            f.add_arc(start_state, arc)

        for state in self.final_states:
            fstid = state_fstid[state.state_id]
            arc = fst.Arc(0, 0, fst.Weight.One('log'), end_state)
            f.add_arc(fstid, arc)

        return f
Example #11
0
    def durationFst(label_str, dur_internal_str, dur_final_str, final_weights):
        """ Construct a left-to-right WFST from an input sequence.

        Parameters
        ----------
        input_seq : iterable(int or string)

        Returns
        -------
        fst : openfst.Fst
        """

        input_label = fst.input_symbols().find(label_str)
        output_label_int = fst.output_symbols().find(dur_internal_str)
        output_label_ext = fst.output_symbols().find(dur_final_str)

        max_dur = np.nonzero(final_weights != float(zero))[0].max()
        if max_dur < 1:
            raise AssertionError(f"max_dur = {max_dur}, but should be >= 1)")

        states = tuple(fst.add_state() for __ in range(max_dur))
        seg_final_state = fst.add_state()
        fst.add_arc(init_state, openfst.Arc(0, 0, one, states[0]))
        fst.add_arc(seg_final_state, openfst.Arc(0, 0, one, final_state))

        for i, cur_state in enumerate(states):
            cur_state = states[i]

            final_weight = openfst.Weight(fst.weight_type(), final_weights[i])
            if final_weight != zero:
                arc = openfst.Arc(input_label, output_label_ext, one,
                                  seg_final_state)
                fst.add_arc(cur_state, arc)

            if i + 1 < len(states):
                next_state = states[i + 1]
                arc = openfst.Arc(input_label, output_label_int, one,
                                  next_state)
                fst.add_arc(cur_state, arc)

        return states[0], seg_final_state
Example #12
0
def normalize_fst(in_fst):
    if not in_fst.verify():
        print("ERROR WRONG FST PASSED FOR NORMALIZATION")
        return in_fst
    else:
        out_fst = fst.Fst(in_fst.weight_type())
        for state in in_fst.states():
            n = out_fst.add_state()
            arcsum = 0.0
            for arc in in_fst.arcs(state):
                str_w = arc.weight.to_string()
                arcsum += np.exp(-1 * float(
                    str_w[:str_w.find(b' ')]))  #fst.plus(arcsum,arc.weight)
            for arc in in_fst.arcs(state):
                str_w = arc.weight.to_string()
                new_weight = np.exp(
                    -1 * float(str_w[:str_w.find(b' ')])) / arcsum
                weight_log = -1 * np.log(new_weight)
                st_t = int(str_w[str_w.find(b' ') +
                                 1:str_w.find(b' ',
                                              str_w.find(b' ') + 1)])
                en_t = int(str_w[str_w.find(b' ', str_w.find(b' ') + 1) + 1:])
                out_fst.add_arc(
                    state,
                    fst.Arc(
                        arc.ilabel, arc.olabel,
                        fst.Weight(
                            out_fst.weight_type(),
                            str(weight_log) + ' ' + str(st_t) + ' ' +
                            str(en_t)), arc.nextstate
                    ))  #fst.divide(arc.weight,arcsum), arc.nextstate))
        out_fst.set_start(0)
        out_fst.set_final(n, in_fst.final(n))
        if out_fst.verify():
            return out_fst
        else:
            print("NORM ERROR")
            fst_printout(out_fst)
            return in_fst
Example #13
0
def fromArray(weights,
              row_vocab,
              col_vocab,
              final_weight=None,
              arc_type=None,
              input_symbols=None,
              output_symbols=None):
    """ Instantiate a state machine from an array of weights.

    Parameters
    ----------
    weights : array_like, shape (num_inputs, num_outputs)
        Needs to implement `.shape`, so it should be a numpy array or a torch
        tensor.
    final_weight : arc_types.AbstractSemiringWeight, optional
        Should have the same type as `arc_type`. Default is `arc_type.zero`
    arc_type : {'standard', 'log'}, optional
        Default is 'standard' (ie the tropical arc_type)
    input_labels :
    output_labels :

    Returns
    -------
    fst : fsm.FST
        The transducer's arcs have input labels corresponding to the state
        they left, and output labels corresponding to the state they entered.
    """

    if weights.ndim != 2:
        raise AssertionError(
            f"weights have unrecognized shape {weights.shape}")

    if arc_type is None:
        arc_type = 'standard'

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    if final_weight is None:
        final_weight = one
    else:
        final_weight = openfst.Weight(fst.weight_type(), final_weight)

    init_state = fst.add_state()
    fst.set_start(init_state)

    prev_state = init_state
    for sample_index, row in enumerate(weights):
        cur_state = fst.add_state()
        for i, weight in enumerate(row):
            input_label = row_vocab[sample_index]
            output_label = col_vocab[i]
            input_label_index = fst.input_symbols().find(input_label)
            output_label_index = fst.output_symbols().find(output_label)
            weight = openfst.Weight(fst.weight_type(), weight)
            if weight != zero:
                arc = openfst.Arc(input_label_index, output_label_index,
                                  weight, cur_state)
                fst.add_arc(prev_state, arc)
        prev_state = cur_state
    fst.set_final(cur_state, final_weight)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Example #14
0
def fromTransitions(transition_weights,
                    row_vocab,
                    col_vocab,
                    init_weights=None,
                    final_weights=None,
                    input_symbols=None,
                    output_symbols=None,
                    bos_str='<BOS>',
                    eos_str='<EOS>',
                    eps_str=libfst.EPSILON_STRING,
                    arc_type='standard',
                    transition_ids=None):
    """ Instantiate a state machine from state transitions.

    Parameters
    ----------

    Returns
    -------
    """

    num_states = transition_weights.shape[0]

    if transition_ids is None:
        transition_ids = {}
        for s_cur in range(num_states):
            for s_next in range(num_states):
                transition_ids[(s_cur, s_next)] = len(transition_ids)
        for s in range(num_states):
            transition_ids[(-1, s)] = len(transition_ids)

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    if init_weights is None:
        init_weights = tuple(float(one) for __ in range(num_states))

    if final_weights is None:
        final_weights = tuple(float(one) for __ in range(num_states))

    fst.set_start(fst.add_state())
    final_state = fst.add_state()
    fst.set_final(final_state, one)

    bos_index = fst.input_symbols().find(bos_str)
    eos_index = fst.input_symbols().find(eos_str)
    eps_index = fst.output_symbols().find(eps_str)

    def makeState(i):
        state = fst.add_state()

        initial_weight = openfst.Weight(fst.weight_type(), init_weights[i])
        if initial_weight != zero:
            next_state_str = col_vocab[i]
            next_state_index = fst.output_symbols().find(next_state_str)
            arc = openfst.Arc(bos_index, next_state_index, initial_weight,
                              state)
            fst.add_arc(fst.start(), arc)

        final_weight = openfst.Weight(fst.weight_type(), final_weights[i])
        if final_weight != zero:
            arc = openfst.Arc(eos_index, eps_index, final_weight, final_state)
            fst.add_arc(state, arc)

        return state

    states = tuple(makeState(i) for i in range(num_states))
    for i_cur, row in enumerate(transition_weights):
        for i_next, tx_weight in enumerate(row):
            cur_state = states[i_cur]
            next_state = states[i_next]
            weight = openfst.Weight(fst.weight_type(), tx_weight)
            if weight != zero:
                next_state_str = col_vocab[i_next]
                next_state_index = fst.output_symbols().find(next_state_str)
                arc = openfst.Arc(next_state_index, next_state_index, weight,
                                  next_state)
                fst.add_arc(cur_state, arc)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Example #15
0
def durationFst(
        label, num_states, transition_weights=None, self_weights=None,
        arc_type='standard', symbol_table=None):
    """ Construct a left-to-right WFST from an input sequence.

    The input is usually a sequence of segment-level labels, and this machine
    is used to align labels with sample-level scores.

    Parameters
    ----------
    input_seq : iterable(int or string)

    Returns
    -------
    fst : openfst.Fst
        A linear-chain weighted finite-state transducer. Each state
        has one self-transition and one transition to its right neighbor. i.e.
        the topology looks like this:
                        __     __     __
                        \/     \/     \/
            [START] --> s1 --> s2 --> s3 --> [END]
    """

    if num_states < 1:
        raise AssertionError(f"num_states = {num_states}, but should be >= 1)")

    fst = openfst.VectorFst(arc_type=arc_type)
    one = openfst.Weight.one(fst.weight_type())
    zero = openfst.Weight.zero(fst.weight_type())

    if transition_weights is None:
        transition_weights = [one for __ in range(num_states)]

    if self_weights is None:
        self_weights = [one for __ in range(num_states)]

    if symbol_table is not None:
        fst.set_input_symbols(symbol_table)
        fst.set_output_symbols(symbol_table)

    init_state = fst.add_state()
    fst.set_start(init_state)

    cur_state = fst.add_state()
    arc = openfst.Arc(EPSILON, label + 1, one, cur_state)
    fst.add_arc(init_state, arc)

    for i in range(num_states):
        next_state = fst.add_state()

        transition_weight = openfst.Weight(fst.weight_type(), transition_weights[i])
        if transition_weight != zero:
            arc = openfst.Arc(label + 1, EPSILON, transition_weight, next_state)
            fst.add_arc(cur_state, arc)

        self_weight = openfst.Weight(fst.weight_type(), self_weights[i])
        if self_weight != zero:
            arc = openfst.Arc(label + 1, EPSILON, self_weight, cur_state)
            fst.add_arc(cur_state, arc)

        cur_state = next_state

    fst.set_final(cur_state, one)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Example #16
0
 def subtractLogWeights(lhs, rhs):
     difference = np.exp(-float(lhs)) - np.exp(-float(rhs))
     return openfst.Weight(lhs.type(), difference)
Example #17
0
def make_event_to_assembly_fst(weights,
                               input_vocab,
                               output_vocab,
                               input_parts_to_str,
                               output_parts_to_str,
                               init_weights=None,
                               final_weights=None,
                               input_symbols=None,
                               output_symbols=None,
                               state_tx_in_input=False,
                               eps_str='ε',
                               dur_internal_str='I',
                               dur_final_str='F',
                               bos_str='<BOS>',
                               eos_str='<EOS>',
                               arc_type='standard'):

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    init_state = fst.add_state()
    final_state = fst.add_state()
    fst.set_start(init_state)
    fst.set_final(final_state, one)

    W_a_s = -scipy.special.logsumexp(-weights, axis=-1)

    def make_states(i_input, i_output):
        seg_internal_state = fst.add_state()
        seg_final_state = fst.add_state()

        if openfst.Weight(fst.weight_type(), W_a_s[i_input, i_output]) == zero:
            return [seg_internal_state, seg_final_state]

        state_istr = input_vocab[i_input]
        state_ostr = output_vocab[i_output]

        # Initial state -> seg internal state
        if init_weights is None:
            init_weight = one
        else:
            init_weight = openfst.Weight(fst.weight_type(),
                                         init_weights[i_input, i_output])
        if init_weight != zero:
            arc_istr = bos_str
            arc_ostr = bos_str
            arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                              fst.output_symbols().find(arc_ostr), init_weight,
                              seg_internal_state)
            fst.add_arc(init_state, arc)

        # (in, I) : (out, I), weight one, self transition
        if state_tx_in_input:
            arc_istr = input_parts_to_str[state_istr, state_ostr, state_ostr,
                                          dur_internal_str]
        else:
            arc_istr = input_parts_to_str[state_istr, dur_internal_str]
        arc_ostr = output_parts_to_str[state_istr, state_ostr, state_ostr,
                                       dur_internal_str]
        arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                          fst.output_symbols().find(arc_ostr), one,
                          seg_internal_state)
        fst.add_arc(seg_internal_state, arc)

        # (in, F) : (out, F), weight one, transition into final state
        if state_tx_in_input:
            arc_istr = input_parts_to_str[state_istr, state_ostr, state_ostr,
                                          dur_final_str]
        else:
            arc_istr = input_parts_to_str[state_istr, dur_final_str]
        arc_ostr = output_parts_to_str[state_istr, state_ostr, state_ostr,
                                       dur_final_str]
        arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                          fst.output_symbols().find(arc_ostr), one,
                          seg_final_state)
        fst.add_arc(seg_internal_state, arc)

        # seg final state -> final_state
        if final_weights is None:
            final_weight = one
        else:
            final_weight = openfst.Weight(fst.weight_type(),
                                          final_weights[i_input, i_output])
        if final_weight != zero:
            arc_istr = eos_str
            arc_ostr = eos_str
            arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                              fst.output_symbols().find(arc_ostr),
                              final_weight, final_state)
            fst.add_arc(seg_final_state, arc)

        return [seg_internal_state, seg_final_state]

    # Build segmental backbone
    states = [[
        make_states(i_input, i_output)
        for i_output, _ in enumerate(output_vocab)
    ] for i_input, _ in enumerate(input_vocab)]

    # Add transitions from final (action, assembly) to initial (action, assembly)
    for i_input_cur, arr in enumerate(weights):
        for i_output_cur, row in enumerate(arr):
            for i_output_next, tx_weight in enumerate(row):
                weight = openfst.Weight(fst.weight_type(), tx_weight)
                for i_input_next, __, in enumerate(weights):
                    for i_dur, dur_str in enumerate(
                        (dur_internal_str, dur_final_str)):
                        # From Seg-Final to Seg-Final or Seg-Internal
                        from_state = states[i_input_cur][i_output_cur][1]
                        to_state = states[i_input_next][i_output_next][i_dur]

                        istr = input_vocab[i_input_next]
                        cur_ostr = output_vocab[i_output_next]
                        next_ostr = output_vocab[i_output_next]
                        if state_tx_in_input:
                            arc_istr = input_parts_to_str[istr, cur_ostr,
                                                          next_ostr, dur_str]
                        else:
                            arc_istr = input_parts_to_str[istr, dur_str]
                        arc_ostr = output_parts_to_str[istr, cur_ostr,
                                                       next_ostr, dur_str]

                        if weight != zero:
                            arc = openfst.Arc(
                                fst.input_symbols().find(arc_istr),
                                fst.output_symbols().find(arc_ostr), weight,
                                to_state)
                            fst.add_arc(from_state, arc)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
def get_best_path_SIMP(t_latt,
                       j_latt,
                       join_already_compiled=False,
                       add_path_of_last_resort=False):
    '''
    t_latt and j_latt: FST objects in memory
    '''

    ## TODO temp assertions:
    #assert join_already_compiled

    if not join_already_compiled:
        comm('%s/fstcompile %s %s.bin' % (tool, j_latt, j_latt))
        compiled_j_latt = j_latt + '.bin'
        sys.exit('complete porting to WRAP case wsoivnsovbnsfb34598t3h')
    else:
        compiled_j_latt = j_latt
        if add_path_of_last_resort:  ## only makes sense if join_already_compiled

            ## In order that composition with precomputed join won't remove all paths, add
            ## emergency path of last resort, which we arbitrarily take to be the best
            ## path through target lattice without regard to join cost:
            path_of_last_resort = get_shortest_path(t_latt)

            print
            print '----------------POLR----------------'
            print path_of_last_resort
            print '------------------------------------'
            print

            print 'convert back to FST indexing'
            path_of_last_resort = [val + 1 for val in path_of_last_resort]

            print 'edit J to add emergency arcs'
            POLR_transitions = {}
            for fro, to in zip(path_of_last_resort[:-1],
                               path_of_last_resort[1:]):
                if fro not in POLR_transitions:
                    POLR_transitions[fro] = []
                POLR_transitions[fro].append(to)

            #print j_latt

            print POLR_transitions
            # print
            #
            #             print j_latt.verify()
            #
            #             print help(j_latt)
            #
            #             for arc in j_latt.arcs():
            #                 print arc

            print 'here b'

            for from_state in j_latt.states():
                if from_state in POLR_transitions:
                    for arc in j_latt.arcs(from_state):
                        to_state = arc.nextstate
                        if to_state in POLR_transitions[from_state]:
                            POLR_transitions[from_state].remove(to_state)
                    ## what is left must be added:
                    #print POLR_transitions[from_state]

            print POLR_transitions
            for (fro, to_list) in POLR_transitions.items():
                if to_list == []:
                    del POLR_transitions[fro]

            print POLR_transitions

            print j_latt.weight_type()
            BIGWEIGHT = openfst.Weight(j_latt.weight_type(), 500000000)
            for from_state in POLR_transitions.keys():
                for to_state in POLR_transitions[from_state]:
                    j_latt.add_arc(
                        from_state,
                        openfst.Arc(from_state, from_state, BIGWEIGHT,
                                    to_state))

            assert j_latt.verify()
            #print j_latt

            ### TODO -- remove added arcs after search to resuse J!!!!!!

    #comm('%s/fstarcsort --sort_type=olabel %s.bin %s.bin.srt'%(tool, t_latt, t_latt))

    c_latt = openfst.compose(t_latt, j_latt)

    #print ' ---- CLATT ---'
    #print c_latt

    #'/tmp/comp.fst'  #
    #comm('%s/fstcompose %s.bin.srt %s %s'%(tool, t_latt, compiled_j_latt, c_latt)) ## TODO check if comp is empty and report nicely

    shortest_path = get_shortest_path(c_latt)

    # print
    # print '----------------shortest path----------------'
    # print shortest_path
    # print '------------------------------------'
    # print

    return shortest_path
Example #19
0
    def make_states(i_input, i_output):
        seg_internal_state = fst.add_state()
        seg_final_state = fst.add_state()

        if openfst.Weight(fst.weight_type(), W_a_s[i_input, i_output]) == zero:
            return [seg_internal_state, seg_final_state]

        state_istr = input_vocab[i_input]
        state_ostr = output_vocab[i_output]

        # Initial state -> seg internal state
        if init_weights is None:
            init_weight = one
        else:
            init_weight = openfst.Weight(fst.weight_type(),
                                         init_weights[i_input, i_output])
        if init_weight != zero:
            arc_istr = bos_str
            arc_ostr = bos_str
            arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                              fst.output_symbols().find(arc_ostr), init_weight,
                              seg_internal_state)
            fst.add_arc(init_state, arc)

        # (in, I) : (out, I), weight one, self transition
        if state_tx_in_input:
            arc_istr = input_parts_to_str[state_istr, state_ostr, state_ostr,
                                          dur_internal_str]
        else:
            arc_istr = input_parts_to_str[state_istr, dur_internal_str]
        arc_ostr = output_parts_to_str[state_istr, state_ostr, state_ostr,
                                       dur_internal_str]
        arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                          fst.output_symbols().find(arc_ostr), one,
                          seg_internal_state)
        fst.add_arc(seg_internal_state, arc)

        # (in, F) : (out, F), weight one, transition into final state
        if state_tx_in_input:
            arc_istr = input_parts_to_str[state_istr, state_ostr, state_ostr,
                                          dur_final_str]
        else:
            arc_istr = input_parts_to_str[state_istr, dur_final_str]
        arc_ostr = output_parts_to_str[state_istr, state_ostr, state_ostr,
                                       dur_final_str]
        arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                          fst.output_symbols().find(arc_ostr), one,
                          seg_final_state)
        fst.add_arc(seg_internal_state, arc)

        # seg final state -> final_state
        if final_weights is None:
            final_weight = one
        else:
            final_weight = openfst.Weight(fst.weight_type(),
                                          final_weights[i_input, i_output])
        if final_weight != zero:
            arc_istr = eos_str
            arc_ostr = eos_str
            arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                              fst.output_symbols().find(arc_ostr),
                              final_weight, final_state)
            fst.add_arc(seg_final_state, arc)

        return [seg_internal_state, seg_final_state]
Example #20
0
def fromArray(
        weights, final_weight=None, arc_type=None,
        # input_labels=None, output_labels=None
        input_symbols=None, output_symbols=None):
    """ Instantiate a state machine from an array of weights.

    Parameters
    ----------
    weights : array_like, shape (num_inputs, num_outputs)
        Needs to implement `.shape`, so it should be a numpy array or a torch
        tensor.
    final_weight : arc_types.AbstractSemiringWeight, optional
        Should have the same type as `arc_type`. Default is `arc_type.zero`
    arc_type : {'standard', 'log'}, optional
        Default is 'standard' (ie the tropical arc_type)
    input_labels :
    output_labels :

    Returns
    -------
    fst : fsm.FST
        The transducer's arcs have input labels corresponding to the state
        they left, and output labels corresponding to the state they entered.
    """
    if weights.ndim == 3:
        is_lattice = False
    elif weights.ndim == 2:
        is_lattice = True
    else:
        raise AssertionError(f"weights have unrecognized shape {weights.shape}")

    if arc_type is None:
        arc_type = 'standard'

    """
    if output_labels is None:
        output_labels = {str(i): i for i in range(weights.shape[1])}

    if input_labels is None:
        if is_lattice:
            input_labels = {str(i): i for i in range(weights.shape[0])}
        else:
            input_labels = {str(i): i for i in range(weights.shape[2])}

    input_table = openfst.SymbolTable()
    input_table.add_symbol(EPSILON_STRING, key=EPSILON)
    for in_symbol, index in input_labels.items():
        input_table.add_symbol(str(in_symbol), key=index + 1)

    output_table = openfst.SymbolTable()
    output_table.add_symbol(EPSILON_STRING, key=EPSILON)
    for out_symbol, index in output_labels.items():
        output_table.add_symbol(str(out_symbol), key=index + 1)
    """

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    if final_weight is None:
        final_weight = one
    else:
        final_weight = openfst.Weight(fst.weight_type(), final_weight)

    init_state = fst.add_state()
    fst.set_start(init_state)

    if is_lattice:
        prev_state = init_state
        for sample_index, row in enumerate(weights):
            cur_state = fst.add_state()
            for i, weight in enumerate(row):
                input_label_index = sample_index + 1
                output_label_index = i + 1
                weight = openfst.Weight(fst.weight_type(), weight)
                if weight != zero:
                    arc = openfst.Arc(
                        input_label_index, output_label_index,
                        weight, cur_state
                    )
                    fst.add_arc(prev_state, arc)
            prev_state = cur_state
        fst.set_final(cur_state, final_weight)
    else:
        prev_state = init_state
        for sample_index, input_output in enumerate(weights):
            cur_state = fst.add_state()
            for i, outputs in enumerate(input_output):
                for j, weight in enumerate(outputs):
                    input_label_index = i + 1
                    output_label_index = j + 1
                    weight = openfst.Weight(fst.weight_type(), weight)
                    if weight != zero:
                        arc = openfst.Arc(
                            input_label_index, output_label_index,
                            weight, cur_state
                        )
                        fst.add_arc(prev_state, arc)
            prev_state = cur_state
        fst.set_final(cur_state, final_weight)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Example #21
0
def readHtkLattice(lattfile, ac_weight=1., lm_weight=1., gzipped=True):
    """Read a HTK lattice file and represent it as a OpenFst object.

    Parameters
    ----------
    lattfile : string
        Path to the HTK lattice file.
    ac_weight : float
        Acoustic weight (default: 1).
    lm_weight : float
        Language model weight (default : 1).
    gzipped : boolean
        If the lattice file is compressed with gzip (default: True).

    Returns
    -------
    fst_lattice : pywrapfst.Fst
        The lattice as an OpenFst object.
    
    """
    # Make the import here only as some people may not have openfst installed.
    import pywrapfst as fst

    # Output fst.
    fst_lattice = fst.Fst("log")

    # The mapping id -> label (and the reverse one) are relative to the
    # lattice only. It avoids to have some global mapping.
    label2id = {"!NULL": 0}
    id2label = {0: "!NULL"}
    identifier = 0

    # If the lattice is compressed with gzip choose the correct function.
    if gzipped:
        my_open = gzip.open
        mode = 'rt'
    else:
        my_open = open
        mode = 'r'

    with my_open(lattfile, mode) as f:
        for line in f:
            line = line.strip()

            # Get the number of nodes. If the field is not specified then
            # the function will fail badly.
            if line[:2] == "N=":
                n_nodes = int(line.split()[0].split("=")[-1])

                # Create all the nodes of the fst lattice.
                for i in range(n_nodes):
                    fst_lattice.add_state()
                fst_lattice.set_start(0)
                fst_lattice.set_final(n_nodes - 1)

            elif line[0] == "J":
                # Load the HTK arc definition.
                fields = line.split()
                start_node = int(fields[1].split("=")[-1])
                end_node = int(fields[2].split("=")[-1])
                label = fields[3].split("=")[-1]

                # Update the mapping id -> label and its reverse.
                try:
                    label_id = label2id[label]
                except KeyError:
                    identifier += 1
                    label_id = identifier
                    label2id[label] = label_id
                    id2label[label_id] = label

                # Add the arc to the FST lattice.
                ac_score = float(fields[5].split("=")[-1])
                lm_score = float(fields[6].split("=")[-1])
                score = ac_weight * ac_score + lm_weight * lm_score
                weight = fst.Weight(fst_lattice.weight_type, -score)
                arc = fst.Arc(label_id, label_id, weight, end_node)
                fst_lattice.add_arc(start_node, arc)
    return fst_lattice, id2label
Example #22
0
def single_seg_transducer(weights,
                          input_vocab,
                          output_vocab,
                          input_parts_to_str,
                          output_parts_to_str,
                          input_symbols=None,
                          output_symbols=None,
                          eps_str='ε',
                          bos_str='<BOS>',
                          eos_str='<EOS>',
                          dur_internal_str='I',
                          dur_final_str='F',
                          arc_type='standard',
                          pass_input=False):

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    state = fst.add_state()
    fst.set_start(state)
    fst.set_final(state, one)

    def make_state(i_input, i_output, weight):
        io_state = fst.add_state()

        state_istr = input_vocab[i_input]
        state_ostr = output_vocab[i_output]

        # CASE 1: (in, I) : (out, I), weight one, transition into io state
        arc_istr = input_parts_to_str[state_istr, dur_internal_str]
        if pass_input:
            arc_ostr = output_parts_to_str[state_istr, state_ostr,
                                           dur_internal_str]
        else:
            arc_ostr = output_parts_to_str[state_ostr, dur_internal_str]
        arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                          fst.output_symbols().find(arc_ostr), one, io_state)
        fst.add_arc(state, arc)
        fst.add_arc(io_state, arc.copy())

        # CASE 2: (in, F) : (out, F), weight tx_weight
        arc_istr = input_parts_to_str[state_istr, dur_final_str]
        if pass_input:
            arc_ostr = output_parts_to_str[state_istr, state_ostr,
                                           dur_final_str]
        else:
            arc_ostr = output_parts_to_str[state_ostr, dur_final_str]
        arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                          fst.output_symbols().find(arc_ostr), weight, state)
        fst.add_arc(io_state, arc)

    for i_input, row in enumerate(weights):
        for i_output, tx_weight in enumerate(row):
            weight = openfst.Weight(fst.weight_type(), tx_weight)
            if weight != zero:
                make_state(i_input, i_output, weight)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Example #23
0
     utt_index = fst.Fst(arc_type=b'TripleTropicalWeight')
     print("utt name " + utt_name)
 elif line.strip():  #line not empty
     if len(line.split()) > 1:
         if line.find(",") > -1:
             utt_index.add_state()
             delimeters = list(find_all(line.split()[4], ","))
             weight_string = line.split(
             )[4][:delimeters[0]] + " " + line.split(
             )[4][delimeters[0] +
                  1:delimeters[1]] + " " + line.split()[4][delimeters[1] +
                                                           1:]
             utt_index.add_arc(
                 int(line.split()[0]),
                 fst.Arc(int(line.split()[2]), int(line.split()[3]),
                         fst.Weight(utt_index.weight_type(), weight_string),
                         int(line.split()[1])))
         else:
             utt_index.add_state()
             utt_index.add_arc(
                 int(line.split()[0]),
                 fst.Arc(int(line.split()[2]), int(line.split()[3]),
                         fst.Weight.One(utt_index.weight_type()),
                         int(line.split()[1])))
     else:
         utt_index.add_state()
         utt_index.set_final(int(line.split()[0]),
                             fst.Weight.One(utt_index.weight_type()))
 else:  #empty line, end of lattice, do search
     ## masking lattice with phoneme mask ##
     utt_index.set_start(0)
Example #24
0
def fromTransitions(
        transition_weights, init_weights=None, final_weights=None,
        arc_type='standard', transition_ids=None):
    """ Instantiate a state machine from state transitions.

    Parameters
    ----------

    Returns
    -------
    """

    num_states = transition_weights.shape[0]

    if transition_ids is None:
        transition_ids = {}
        for s_cur in range(num_states):
            for s_next in range(num_states):
                transition_ids[(s_cur, s_next)] = len(transition_ids)
        for s in range(num_states):
            transition_ids[(-1, s)] = len(transition_ids)

    output_table = openfst.SymbolTable()
    output_table.add_symbol(EPSILON_STRING, key=EPSILON)
    for transition, index in transition_ids.items():
        output_table.add_symbol(str(transition), key=index + 1)

    input_table = openfst.SymbolTable()
    input_table.add_symbol(EPSILON_STRING, key=EPSILON)
    for transition, index in transition_ids.items():
        input_table.add_symbol(str(transition), key=index + 1)

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_table)
    fst.set_output_symbols(output_table)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    if init_weights is None:
        init_weights = tuple(float(one) for __ in range(num_states))

    if final_weights is None:
        final_weights = tuple(float(one) for __ in range(num_states))

    fst.set_start(fst.add_state())

    def makeState(i):
        state = fst.add_state()

        initial_weight = openfst.Weight(fst.weight_type(), init_weights[i])
        if initial_weight != zero:
            transition = transition_ids[-1, i] + 1
            arc = openfst.Arc(EPSILON, transition, initial_weight, state)
            fst.add_arc(fst.start(), arc)

        final_weight = openfst.Weight(fst.weight_type(), final_weights[i])
        if final_weight != zero:
            fst.set_final(state, final_weight)

        return state

    states = tuple(makeState(i) for i in range(num_states))
    for i_cur, row in enumerate(transition_weights):
        for i_next, tx_weight in enumerate(row):
            cur_state = states[i_cur]
            next_state = states[i_next]
            weight = openfst.Weight(fst.weight_type(), tx_weight)
            transition = transition_ids[i_cur, i_next] + 1
            if weight != zero:
                arc = openfst.Arc(transition, transition, weight, next_state)
                fst.add_arc(cur_state, arc)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")
        # print("fst.verify() returned False")

    return fst