Exemple #1
0
def string_to_fsa(input_string, sym):
    '''build an FSA for a given input string using the symbol table, sym'''

    # first make sure all chars can be converted
    input_list = list(input_string)
    for i in input_list:
        if sym.find(i) == -1:
            raise ValueError('Input character not found')

    # build the FSA

    f = pywrapfst.VectorFst()
    one = pywrapfst.Weight.one(f.weight_type())
    f.set_input_symbols(sym)
    f.set_output_symbols(sym)
    s = f.add_state()
    f.set_start(s)
    for i in input_list:
        n = f.add_state()
        f.add_arc(s, pywrapfst.Arc(sym.find(i), sym.find(i), one, n))
        s = n
    f.set_final(n, 1)

    # verify
    if not f.verify():
        raise ValueError('FSA failed to verify')
    return (f)
Exemple #2
0
def _compile_cg(ifar_path: str, ofar_path: str, insertions: bool,
                deletions: bool) -> str:
  """Compiles the covering grammar from the input and output FARs.

  Args:
    ifar_path: path to the input FAR.
    ofar_path: path to the output FAR.
    insertions: should insertions be permitted?
    deletions: should deletions be permitted?

  Returns:
    The path to the CG FST.
  """
  ilabels = _get_far_labels(ifar_path)
  olabels = _get_far_labels(ofar_path)
  cg = pywrapfst.VectorFst()
  state = cg.add_state()
  cg.set_start(state)
  one = pywrapfst.Weight.one(cg.weight_type())
  for ilabel, olabel in itertools.product(ilabels, olabels):
    cg.add_arc(state, pywrapfst.Arc(ilabel, olabel, one, state))
  # Handles epsilons, carefully avoiding adding a useless 0:0 label.
  if insertions:
    for olabel in olabels:
      cg.add_arc(state, pywrapfst.Arc(0, olabel, one, state))
  if deletions:
    for ilabel in ilabels:
      cg.add_arc(state, pywrapfst.Arc(ilabel, 0, one, state))
  cg.set_final(state)
  assert cg.verify(), "Label acceptor is ill-formed"
  cg_path = _mktemp("cg.fst")
  cg.write(cg_path)
  return cg_path
Exemple #3
0
def single_state_transducer(transition_weights,
                            row_vocab,
                            col_vocab,
                            input_symbols=None,
                            output_symbols=None,
                            arc_type='standard'):

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    state = fst.add_state()
    fst.set_start(state)
    fst.set_final(state, one)

    for i_input, row in enumerate(transition_weights):
        for i_output, tx_weight in enumerate(row):
            weight = openfst.Weight(fst.weight_type(), tx_weight)
            input_id = fst.input_symbols().find(row_vocab[i_input])
            output_id = fst.output_symbols().find(col_vocab[i_output])
            if weight != zero:
                arc = openfst.Arc(input_id, output_id, weight, state)
                fst.add_arc(state, arc)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
 def _lexicon_covering(self, ) -> None:
     """Builds covering grammar and lexicon FARs."""
     # Sets of labels for the covering grammar.
     with open(os.path.join(self.working_log_directory,
                            "covering_grammar.log"),
               "w",
               encoding="utf8") as log_file:
         com = [
             thirdparty_binary("farcompilestrings"),
             "--fst_type=compact",
         ]
         if self.input_token_type != "utf8":
             com.append("--token_type=symbol")
             com.append(f"--symbols={self.input_token_type}", )
             com.append("--unknown_symbol=<unk>")
         else:
             com.append("--token_type=utf8")
         com.extend([self.input_path, self.input_far_path])
         print(" ".join(com), file=log_file)
         subprocess.check_call(com,
                               env=os.environ,
                               stderr=log_file,
                               stdout=log_file)
         com = [
             thirdparty_binary("farcompilestrings"),
             "--fst_type=compact",
             "--token_type=symbol",
             f"--symbols={self.phone_symbol_table_path}",
             self.output_path,
             self.output_far_path,
         ]
         print(" ".join(com), file=log_file)
         subprocess.check_call(com,
                               env=os.environ,
                               stderr=log_file,
                               stdout=log_file)
         ilabels = _get_far_labels(self.input_far_path)
         print(ilabels, file=log_file)
         olabels = _get_far_labels(self.output_far_path)
         print(olabels, file=log_file)
         cg = pywrapfst.VectorFst()
         state = cg.add_state()
         cg.set_start(state)
         one = pywrapfst.Weight.one(cg.weight_type())
         for ilabel, olabel in itertools.product(ilabels, olabels):
             cg.add_arc(state, pywrapfst.Arc(ilabel, olabel, one, state))
         # Handles epsilons, carefully avoiding adding a useless 0:0 label.
         if self.insertions:
             for olabel in olabels:
                 cg.add_arc(state, pywrapfst.Arc(0, olabel, one, state))
         if self.deletions:
             for ilabel in ilabels:
                 cg.add_arc(state, pywrapfst.Arc(ilabel, 0, one, state))
         cg.set_final(state)
         assert cg.verify(), "Label acceptor is ill-formed"
         cg.write(self.cg_path)
Exemple #5
0
def easyUnion(*fsts, disambiguate=False):
    union_fst = openfst.VectorFst(arc_type=fsts[0].arc_type())
    union_fst.set_start(union_fst.add_state())

    merged_input_symbols = fsts[0].input_symbols()
    merged_output_symbols = fsts[0].output_symbols()
    for fst in fsts:
        merged_input_symbols = openfst.merge_symbol_table(
            merged_input_symbols, fst.input_symbols()
        )
        merged_output_symbols = openfst.merge_symbol_table(
            merged_output_symbols, fst.output_symbols()
        )

    union_fst.set_input_symbols(merged_input_symbols)
    union_fst.set_output_symbols(merged_output_symbols)
    for fst in fsts:
        fst.set_input_symbols(merged_input_symbols)
        fst.set_output_symbols(merged_output_symbols)

    union_fst.union(*fsts)

    if disambiguate:
        for seq_index, __ in enumerate(fsts):
            union_fst.mutable_input_symbols().add_symbol(f"seq{seq_index}")

        for seq_index, __ in enumerate(fsts):
            union_fst.mutable_output_symbols().add_symbol(f"seq{seq_index}")

        seq_index = 0
        arc_iterator = union_fst.mutable_arcs(union_fst.start())
        while not arc_iterator.done():
            arc = arc_iterator.value()

            arc.ilabel = union_fst.input_symbols().find(f"seq{seq_index}")
            arc.olabel = union_fst.output_symbols().find(f"seq{seq_index}")
            arc_iterator.set_value(arc)

            arc_iterator.next()
            seq_index += 1

    return union_fst
Exemple #6
0
def fromSequence(seq, arc_type='standard', symbol_table=None):
    fst = openfst.VectorFst(arc_type=arc_type)
    one = openfst.Weight.one(fst.weight_type())

    if symbol_table is not None:
        fst.set_input_symbols(symbol_table)
        fst.set_output_symbols(symbol_table)

    init_state = fst.add_state()
    fst.set_start(init_state)

    cur_state = init_state
    for i, label in enumerate(seq):
        next_state = fst.add_state()
        arc = openfst.Arc(label + 1, label + 1, one, next_state)
        fst.add_arc(cur_state, arc)
        cur_state = next_state

    fst.set_final(cur_state, one)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Exemple #7
0
import logging
import numpy as np

from PIL import Image
from bidi.algorithm import get_display

from kraken import rpred
from kraken.lib.exceptions import KrakenInputException, KrakenEncodeException
from kraken.lib.segmentation import compute_polygon_section

logger = logging.getLogger('kraken')

try:
    import pywrapfst as fst

    _get_fst = lambda: fst.VectorFst()
    _get_best_weight = lambda x: fst.Weight.one(f.weight_type())
except ImportError:
    logger.info('pywrapfst not available. Falling back to openfst_python.')
    try:
        import openfst_python as fst
        _get_fst = lambda: fst.Fst()
        _get_best_weight = lambda x: 0
    except ImportError:
        logger.error(
            'Neither pywrapfst nor openfst_python bindings available.')
        raise


def _get_arc(input_label, output_label, weight, state):
    return fst.Arc(input_label, output_label, weight, state)
Exemple #8
0
def make_event_to_assembly_fst(weights,
                               input_vocab,
                               output_vocab,
                               input_parts_to_str,
                               output_parts_to_str,
                               init_weights=None,
                               final_weights=None,
                               input_symbols=None,
                               output_symbols=None,
                               state_tx_in_input=False,
                               eps_str='ε',
                               dur_internal_str='I',
                               dur_final_str='F',
                               bos_str='<BOS>',
                               eos_str='<EOS>',
                               arc_type='standard'):

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    init_state = fst.add_state()
    final_state = fst.add_state()
    fst.set_start(init_state)
    fst.set_final(final_state, one)

    W_a_s = -scipy.special.logsumexp(-weights, axis=-1)

    def make_states(i_input, i_output):
        seg_internal_state = fst.add_state()
        seg_final_state = fst.add_state()

        if openfst.Weight(fst.weight_type(), W_a_s[i_input, i_output]) == zero:
            return [seg_internal_state, seg_final_state]

        state_istr = input_vocab[i_input]
        state_ostr = output_vocab[i_output]

        # Initial state -> seg internal state
        if init_weights is None:
            init_weight = one
        else:
            init_weight = openfst.Weight(fst.weight_type(),
                                         init_weights[i_input, i_output])
        if init_weight != zero:
            arc_istr = bos_str
            arc_ostr = bos_str
            arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                              fst.output_symbols().find(arc_ostr), init_weight,
                              seg_internal_state)
            fst.add_arc(init_state, arc)

        # (in, I) : (out, I), weight one, self transition
        if state_tx_in_input:
            arc_istr = input_parts_to_str[state_istr, state_ostr, state_ostr,
                                          dur_internal_str]
        else:
            arc_istr = input_parts_to_str[state_istr, dur_internal_str]
        arc_ostr = output_parts_to_str[state_istr, state_ostr, state_ostr,
                                       dur_internal_str]
        arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                          fst.output_symbols().find(arc_ostr), one,
                          seg_internal_state)
        fst.add_arc(seg_internal_state, arc)

        # (in, F) : (out, F), weight one, transition into final state
        if state_tx_in_input:
            arc_istr = input_parts_to_str[state_istr, state_ostr, state_ostr,
                                          dur_final_str]
        else:
            arc_istr = input_parts_to_str[state_istr, dur_final_str]
        arc_ostr = output_parts_to_str[state_istr, state_ostr, state_ostr,
                                       dur_final_str]
        arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                          fst.output_symbols().find(arc_ostr), one,
                          seg_final_state)
        fst.add_arc(seg_internal_state, arc)

        # seg final state -> final_state
        if final_weights is None:
            final_weight = one
        else:
            final_weight = openfst.Weight(fst.weight_type(),
                                          final_weights[i_input, i_output])
        if final_weight != zero:
            arc_istr = eos_str
            arc_ostr = eos_str
            arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                              fst.output_symbols().find(arc_ostr),
                              final_weight, final_state)
            fst.add_arc(seg_final_state, arc)

        return [seg_internal_state, seg_final_state]

    # Build segmental backbone
    states = [[
        make_states(i_input, i_output)
        for i_output, _ in enumerate(output_vocab)
    ] for i_input, _ in enumerate(input_vocab)]

    # Add transitions from final (action, assembly) to initial (action, assembly)
    for i_input_cur, arr in enumerate(weights):
        for i_output_cur, row in enumerate(arr):
            for i_output_next, tx_weight in enumerate(row):
                weight = openfst.Weight(fst.weight_type(), tx_weight)
                for i_input_next, __, in enumerate(weights):
                    for i_dur, dur_str in enumerate(
                        (dur_internal_str, dur_final_str)):
                        # From Seg-Final to Seg-Final or Seg-Internal
                        from_state = states[i_input_cur][i_output_cur][1]
                        to_state = states[i_input_next][i_output_next][i_dur]

                        istr = input_vocab[i_input_next]
                        cur_ostr = output_vocab[i_output_next]
                        next_ostr = output_vocab[i_output_next]
                        if state_tx_in_input:
                            arc_istr = input_parts_to_str[istr, cur_ostr,
                                                          next_ostr, dur_str]
                        else:
                            arc_istr = input_parts_to_str[istr, dur_str]
                        arc_ostr = output_parts_to_str[istr, cur_ostr,
                                                       next_ostr, dur_str]

                        if weight != zero:
                            arc = openfst.Arc(
                                fst.input_symbols().find(arc_istr),
                                fst.output_symbols().find(arc_ostr), weight,
                                to_state)
                            fst.add_arc(from_state, arc)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Exemple #9
0
def make_duration_fst(final_weights,
                      class_vocab,
                      class_dur_to_str,
                      dur_internal_str='I',
                      dur_final_str='F',
                      input_symbols=None,
                      output_symbols=None,
                      allow_self_transitions=True,
                      arc_type='standard'):
    num_classes, num_states = final_weights.shape

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    one = openfst.Weight.one(fst.weight_type())
    zero = openfst.Weight.zero(fst.weight_type())

    init_state = fst.add_state()
    final_state = fst.add_state()
    fst.set_start(init_state)
    fst.set_final(final_state, one)

    def durationFst(label_str, dur_internal_str, dur_final_str, final_weights):
        """ Construct a left-to-right WFST from an input sequence.

        Parameters
        ----------
        input_seq : iterable(int or string)

        Returns
        -------
        fst : openfst.Fst
        """

        input_label = fst.input_symbols().find(label_str)
        output_label_int = fst.output_symbols().find(dur_internal_str)
        output_label_ext = fst.output_symbols().find(dur_final_str)

        max_dur = np.nonzero(final_weights != float(zero))[0].max()
        if max_dur < 1:
            raise AssertionError(f"max_dur = {max_dur}, but should be >= 1)")

        states = tuple(fst.add_state() for __ in range(max_dur))
        seg_final_state = fst.add_state()
        fst.add_arc(init_state, openfst.Arc(0, 0, one, states[0]))
        fst.add_arc(seg_final_state, openfst.Arc(0, 0, one, final_state))

        for i, cur_state in enumerate(states):
            cur_state = states[i]

            final_weight = openfst.Weight(fst.weight_type(), final_weights[i])
            if final_weight != zero:
                arc = openfst.Arc(input_label, output_label_ext, one,
                                  seg_final_state)
                fst.add_arc(cur_state, arc)

            if i + 1 < len(states):
                next_state = states[i + 1]
                arc = openfst.Arc(input_label, output_label_int, one,
                                  next_state)
                fst.add_arc(cur_state, arc)

        return states[0], seg_final_state

    endpoints = tuple(
        durationFst(
            class_vocab[i],
            class_dur_to_str[class_vocab[i], dur_internal_str],
            class_dur_to_str[class_vocab[i], dur_final_str],
            final_weights[i],
        ) for i in range(num_classes))

    for i, (s_cur_first, s_cur_last) in enumerate(endpoints):
        for j, (s_next_first, s_next_last) in enumerate(endpoints):
            if not allow_self_transitions and i == j:
                continue
            arc = openfst.Arc(0, 0, one, s_next_first)
            fst.add_arc(s_cur_last, arc)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Exemple #10
0
def fromArray(weights,
              row_vocab,
              col_vocab,
              final_weight=None,
              arc_type=None,
              input_symbols=None,
              output_symbols=None):
    """ Instantiate a state machine from an array of weights.

    Parameters
    ----------
    weights : array_like, shape (num_inputs, num_outputs)
        Needs to implement `.shape`, so it should be a numpy array or a torch
        tensor.
    final_weight : arc_types.AbstractSemiringWeight, optional
        Should have the same type as `arc_type`. Default is `arc_type.zero`
    arc_type : {'standard', 'log'}, optional
        Default is 'standard' (ie the tropical arc_type)
    input_labels :
    output_labels :

    Returns
    -------
    fst : fsm.FST
        The transducer's arcs have input labels corresponding to the state
        they left, and output labels corresponding to the state they entered.
    """

    if weights.ndim != 2:
        raise AssertionError(
            f"weights have unrecognized shape {weights.shape}")

    if arc_type is None:
        arc_type = 'standard'

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    if final_weight is None:
        final_weight = one
    else:
        final_weight = openfst.Weight(fst.weight_type(), final_weight)

    init_state = fst.add_state()
    fst.set_start(init_state)

    prev_state = init_state
    for sample_index, row in enumerate(weights):
        cur_state = fst.add_state()
        for i, weight in enumerate(row):
            input_label = row_vocab[sample_index]
            output_label = col_vocab[i]
            input_label_index = fst.input_symbols().find(input_label)
            output_label_index = fst.output_symbols().find(output_label)
            weight = openfst.Weight(fst.weight_type(), weight)
            if weight != zero:
                arc = openfst.Arc(input_label_index, output_label_index,
                                  weight, cur_state)
                fst.add_arc(prev_state, arc)
        prev_state = cur_state
    fst.set_final(cur_state, final_weight)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Exemple #11
0
def fromTransitions(transition_weights,
                    row_vocab,
                    col_vocab,
                    init_weights=None,
                    final_weights=None,
                    input_symbols=None,
                    output_symbols=None,
                    bos_str='<BOS>',
                    eos_str='<EOS>',
                    eps_str=libfst.EPSILON_STRING,
                    arc_type='standard',
                    transition_ids=None):
    """ Instantiate a state machine from state transitions.

    Parameters
    ----------

    Returns
    -------
    """

    num_states = transition_weights.shape[0]

    if transition_ids is None:
        transition_ids = {}
        for s_cur in range(num_states):
            for s_next in range(num_states):
                transition_ids[(s_cur, s_next)] = len(transition_ids)
        for s in range(num_states):
            transition_ids[(-1, s)] = len(transition_ids)

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    if init_weights is None:
        init_weights = tuple(float(one) for __ in range(num_states))

    if final_weights is None:
        final_weights = tuple(float(one) for __ in range(num_states))

    fst.set_start(fst.add_state())
    final_state = fst.add_state()
    fst.set_final(final_state, one)

    bos_index = fst.input_symbols().find(bos_str)
    eos_index = fst.input_symbols().find(eos_str)
    eps_index = fst.output_symbols().find(eps_str)

    def makeState(i):
        state = fst.add_state()

        initial_weight = openfst.Weight(fst.weight_type(), init_weights[i])
        if initial_weight != zero:
            next_state_str = col_vocab[i]
            next_state_index = fst.output_symbols().find(next_state_str)
            arc = openfst.Arc(bos_index, next_state_index, initial_weight,
                              state)
            fst.add_arc(fst.start(), arc)

        final_weight = openfst.Weight(fst.weight_type(), final_weights[i])
        if final_weight != zero:
            arc = openfst.Arc(eos_index, eps_index, final_weight, final_state)
            fst.add_arc(state, arc)

        return state

    states = tuple(makeState(i) for i in range(num_states))
    for i_cur, row in enumerate(transition_weights):
        for i_next, tx_weight in enumerate(row):
            cur_state = states[i_cur]
            next_state = states[i_next]
            weight = openfst.Weight(fst.weight_type(), tx_weight)
            if weight != zero:
                next_state_str = col_vocab[i_next]
                next_state_index = fst.output_symbols().find(next_state_str)
                arc = openfst.Arc(next_state_index, next_state_index, weight,
                                  next_state)
                fst.add_arc(cur_state, arc)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Exemple #12
0
def single_seg_transducer(weights,
                          input_vocab,
                          output_vocab,
                          input_parts_to_str,
                          output_parts_to_str,
                          input_symbols=None,
                          output_symbols=None,
                          eps_str='ε',
                          bos_str='<BOS>',
                          eos_str='<EOS>',
                          dur_internal_str='I',
                          dur_final_str='F',
                          arc_type='standard',
                          pass_input=False):

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    state = fst.add_state()
    fst.set_start(state)
    fst.set_final(state, one)

    def make_state(i_input, i_output, weight):
        io_state = fst.add_state()

        state_istr = input_vocab[i_input]
        state_ostr = output_vocab[i_output]

        # CASE 1: (in, I) : (out, I), weight one, transition into io state
        arc_istr = input_parts_to_str[state_istr, dur_internal_str]
        if pass_input:
            arc_ostr = output_parts_to_str[state_istr, state_ostr,
                                           dur_internal_str]
        else:
            arc_ostr = output_parts_to_str[state_ostr, dur_internal_str]
        arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                          fst.output_symbols().find(arc_ostr), one, io_state)
        fst.add_arc(state, arc)
        fst.add_arc(io_state, arc.copy())

        # CASE 2: (in, F) : (out, F), weight tx_weight
        arc_istr = input_parts_to_str[state_istr, dur_final_str]
        if pass_input:
            arc_ostr = output_parts_to_str[state_istr, state_ostr,
                                           dur_final_str]
        else:
            arc_ostr = output_parts_to_str[state_ostr, dur_final_str]
        arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                          fst.output_symbols().find(arc_ostr), weight, state)
        fst.add_arc(io_state, arc)

    for i_input, row in enumerate(weights):
        for i_output, tx_weight in enumerate(row):
            weight = openfst.Weight(fst.weight_type(), tx_weight)
            if weight != zero:
                make_state(i_input, i_output, weight)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Exemple #13
0
 def setUpClass(cls):
   cls.f = pywrapfst.VectorFst()
   # Epsilon machine.
   s = cls.f.add_state()
   cls.f.set_start(s)
   cls.f.set_final(s)
Exemple #14
0
def durationFst(
        label, num_states, transition_weights=None, self_weights=None,
        arc_type='standard', symbol_table=None):
    """ Construct a left-to-right WFST from an input sequence.

    The input is usually a sequence of segment-level labels, and this machine
    is used to align labels with sample-level scores.

    Parameters
    ----------
    input_seq : iterable(int or string)

    Returns
    -------
    fst : openfst.Fst
        A linear-chain weighted finite-state transducer. Each state
        has one self-transition and one transition to its right neighbor. i.e.
        the topology looks like this:
                        __     __     __
                        \/     \/     \/
            [START] --> s1 --> s2 --> s3 --> [END]
    """

    if num_states < 1:
        raise AssertionError(f"num_states = {num_states}, but should be >= 1)")

    fst = openfst.VectorFst(arc_type=arc_type)
    one = openfst.Weight.one(fst.weight_type())
    zero = openfst.Weight.zero(fst.weight_type())

    if transition_weights is None:
        transition_weights = [one for __ in range(num_states)]

    if self_weights is None:
        self_weights = [one for __ in range(num_states)]

    if symbol_table is not None:
        fst.set_input_symbols(symbol_table)
        fst.set_output_symbols(symbol_table)

    init_state = fst.add_state()
    fst.set_start(init_state)

    cur_state = fst.add_state()
    arc = openfst.Arc(EPSILON, label + 1, one, cur_state)
    fst.add_arc(init_state, arc)

    for i in range(num_states):
        next_state = fst.add_state()

        transition_weight = openfst.Weight(fst.weight_type(), transition_weights[i])
        if transition_weight != zero:
            arc = openfst.Arc(label + 1, EPSILON, transition_weight, next_state)
            fst.add_arc(cur_state, arc)

        self_weight = openfst.Weight(fst.weight_type(), self_weights[i])
        if self_weight != zero:
            arc = openfst.Arc(label + 1, EPSILON, self_weight, cur_state)
            fst.add_arc(cur_state, arc)

        cur_state = next_state

    fst.set_final(cur_state, one)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
Exemple #15
0
def fromTransitions(
        transition_weights, init_weights=None, final_weights=None,
        arc_type='standard', transition_ids=None):
    """ Instantiate a state machine from state transitions.

    Parameters
    ----------

    Returns
    -------
    """

    num_states = transition_weights.shape[0]

    if transition_ids is None:
        transition_ids = {}
        for s_cur in range(num_states):
            for s_next in range(num_states):
                transition_ids[(s_cur, s_next)] = len(transition_ids)
        for s in range(num_states):
            transition_ids[(-1, s)] = len(transition_ids)

    output_table = openfst.SymbolTable()
    output_table.add_symbol(EPSILON_STRING, key=EPSILON)
    for transition, index in transition_ids.items():
        output_table.add_symbol(str(transition), key=index + 1)

    input_table = openfst.SymbolTable()
    input_table.add_symbol(EPSILON_STRING, key=EPSILON)
    for transition, index in transition_ids.items():
        input_table.add_symbol(str(transition), key=index + 1)

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_table)
    fst.set_output_symbols(output_table)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    if init_weights is None:
        init_weights = tuple(float(one) for __ in range(num_states))

    if final_weights is None:
        final_weights = tuple(float(one) for __ in range(num_states))

    fst.set_start(fst.add_state())

    def makeState(i):
        state = fst.add_state()

        initial_weight = openfst.Weight(fst.weight_type(), init_weights[i])
        if initial_weight != zero:
            transition = transition_ids[-1, i] + 1
            arc = openfst.Arc(EPSILON, transition, initial_weight, state)
            fst.add_arc(fst.start(), arc)

        final_weight = openfst.Weight(fst.weight_type(), final_weights[i])
        if final_weight != zero:
            fst.set_final(state, final_weight)

        return state

    states = tuple(makeState(i) for i in range(num_states))
    for i_cur, row in enumerate(transition_weights):
        for i_next, tx_weight in enumerate(row):
            cur_state = states[i_cur]
            next_state = states[i_next]
            weight = openfst.Weight(fst.weight_type(), tx_weight)
            transition = transition_ids[i_cur, i_next] + 1
            if weight != zero:
                arc = openfst.Arc(transition, transition, weight, next_state)
                fst.add_arc(cur_state, arc)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")
        # print("fst.verify() returned False")

    return fst
Exemple #16
0
def fromArray(
        weights, final_weight=None, arc_type=None,
        # input_labels=None, output_labels=None
        input_symbols=None, output_symbols=None):
    """ Instantiate a state machine from an array of weights.

    Parameters
    ----------
    weights : array_like, shape (num_inputs, num_outputs)
        Needs to implement `.shape`, so it should be a numpy array or a torch
        tensor.
    final_weight : arc_types.AbstractSemiringWeight, optional
        Should have the same type as `arc_type`. Default is `arc_type.zero`
    arc_type : {'standard', 'log'}, optional
        Default is 'standard' (ie the tropical arc_type)
    input_labels :
    output_labels :

    Returns
    -------
    fst : fsm.FST
        The transducer's arcs have input labels corresponding to the state
        they left, and output labels corresponding to the state they entered.
    """
    if weights.ndim == 3:
        is_lattice = False
    elif weights.ndim == 2:
        is_lattice = True
    else:
        raise AssertionError(f"weights have unrecognized shape {weights.shape}")

    if arc_type is None:
        arc_type = 'standard'

    """
    if output_labels is None:
        output_labels = {str(i): i for i in range(weights.shape[1])}

    if input_labels is None:
        if is_lattice:
            input_labels = {str(i): i for i in range(weights.shape[0])}
        else:
            input_labels = {str(i): i for i in range(weights.shape[2])}

    input_table = openfst.SymbolTable()
    input_table.add_symbol(EPSILON_STRING, key=EPSILON)
    for in_symbol, index in input_labels.items():
        input_table.add_symbol(str(in_symbol), key=index + 1)

    output_table = openfst.SymbolTable()
    output_table.add_symbol(EPSILON_STRING, key=EPSILON)
    for out_symbol, index in output_labels.items():
        output_table.add_symbol(str(out_symbol), key=index + 1)
    """

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    if final_weight is None:
        final_weight = one
    else:
        final_weight = openfst.Weight(fst.weight_type(), final_weight)

    init_state = fst.add_state()
    fst.set_start(init_state)

    if is_lattice:
        prev_state = init_state
        for sample_index, row in enumerate(weights):
            cur_state = fst.add_state()
            for i, weight in enumerate(row):
                input_label_index = sample_index + 1
                output_label_index = i + 1
                weight = openfst.Weight(fst.weight_type(), weight)
                if weight != zero:
                    arc = openfst.Arc(
                        input_label_index, output_label_index,
                        weight, cur_state
                    )
                    fst.add_arc(prev_state, arc)
            prev_state = cur_state
        fst.set_final(cur_state, final_weight)
    else:
        prev_state = init_state
        for sample_index, input_output in enumerate(weights):
            cur_state = fst.add_state()
            for i, outputs in enumerate(input_output):
                for j, weight in enumerate(outputs):
                    input_label_index = i + 1
                    output_label_index = j + 1
                    weight = openfst.Weight(fst.weight_type(), weight)
                    if weight != zero:
                        arc = openfst.Arc(
                            input_label_index, output_label_index,
                            weight, cur_state
                        )
                        fst.add_arc(prev_state, arc)
            prev_state = cur_state
        fst.set_final(cur_state, final_weight)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst