Beispiel #1
0
def threshold_lattice_to_dfa(lattice: pynini.Fst,
                             threshold: float = 1.0,
                             state_multiplier: int = 2) -> pynini.Fst:
    """Constructs a (possibly pruned) weighted DFA of output strings.
    Given an epsilon-free lattice of output strings (such as produced by
    rewrite_lattice), attempts to determinize it, pruning non-optimal paths if
    optimal_only is true. This is valid only in a semiring with the path property.
    To prevent unexpected blowup during determinization, a state threshold is
    also used and a warning is logged if this exact threshold is reached. The
    threshold is a multiplier of the size of input lattice (by default, 4), plus
    a small constant factor. This is intended by a sensible default and is not an
    inherently meaningful value in and of itself.

    Parameters
    ----------
    lattice: :class:`~pynini.Fst`
        Epsilon-free non-deterministic finite acceptor.
    threshold: float
        Threshold for weights, 1.0 is optimal only, 0 is for all paths, greater than 1
        prunes the lattice to include paths with costs less than the optimal path's score times the threshold
    state_multiplier: int
        Max ratio for the number of states in the DFA lattice to the NFA lattice; if exceeded, a warning is logged.

    Returns
    -------
    :class:`~pynini.Fst`
        Epsilon-free deterministic finite acceptor.
    """
    weight_type = lattice.weight_type()
    weight_threshold = pynini.Weight(weight_type, threshold)
    state_threshold = 256 + state_multiplier * lattice.num_states()
    lattice = pynini.determinize(lattice,
                                 nstate=state_threshold,
                                 weight=weight_threshold)
    return lattice
Beispiel #2
0
def verify_if_single_path(input_str_fsa: pynini.Fst, fst: pynini.Fst):
    """Does nothing if given string FST has only one path; throws otherwise.

  If there is more than one path in the FST, then string() method on the FST
  will throw FstOpError. That exception is converted to AssertError with
  relevant error message with input string from the given string FSA.

  Args:
    input_str_fsa: Input string FSA specific to the FST. Used in the
        exception message.
    fst: FST to be verified if it has only a single path.

  Raises:
    AssertionError: If the FST is found to have more than one path.
  """
    try:
        fst.string()
    except pynini.FstOpError as e:
        raise AssertionError(
            "Expected FST to be functional but input string `{input}`"
            " produced multiple output strings: {outputs}".format(
                input=input_str_fsa.string(),
                outputs=", ".join(
                    f"`{ostring}`"
                    for ostring in fst.optimize().paths().ostrings()))) from e
Beispiel #3
0
def verify_identity(input_str_fsa: pynini.Fst, fst: pynini.Fst):
    """Verifies if FST produces only the input string at its minimum weight path.

  Throws AssertError with a detailed error message on verification failure;
  otherwise do nothing.

  Args:
    input_str_fsa: Input string FSA to be compared with the minimum weight path
        in the FST.
    fst: FST to be verified.

  Raises:
    AssertionError: If the verification has failed; that is, if the given FST
    produces anything other than the input string at its minimum weight path.
  """
    input_str = input_str_fsa.string()
    out_weights = collections.defaultdict(set)
    for _, out, weight in fst.optimize().paths().items():
        out_weights[int(str(weight))].add(out)
    outstrs = out_weights[min(out_weights)]

    if outstrs - {input_str}:
        raise AssertionError(
            f'Expected FST to be idenity but input `{input_str}`'
            f' produced output string(s): `{", ".join(outstrs)}`')
Beispiel #4
0
def lattice_to_dfa(lattice: pynini.Fst,
                   optimal_only: bool,
                   state_multiplier: int = 4) -> pynini.Fst:
  """Constructs a (possibly pruned) weighted DFA of output strings.

  Given an epsilon-free lattice of output strings (such as produced by
  rewrite_lattice), attempts to determinize it, pruning non-optimal paths if
  optimal_only is true. This is valid only in a semiring with the path property.

  To prevent unexpected blowup during determinization, a state threshold is
  also used and a warning is logged if this exact threshold is reached. The
  threshold is a multiplier of the size of input lattice (by default, 4), plus
  a small constant factor. This is intended by a sensible default and is not an
  inherently meaningful value in and of itself.

  Args:
    lattice: Epsilon-free non-deterministic finite acceptor.
    optimal_only: Should we only preserve optimal paths?
    state_multiplier: Max ratio for the number of states in the DFA lattice to
      the NFA lattice; if exceeded, a warning is logged.

  Returns:
    Epsilon-free deterministic finite acceptor.
  """
  weight_type = lattice.weight_type()
  weight_threshold = (
      pynini.Weight.one(weight_type)
      if optimal_only else pynini.Weight.zero(weight_type))
  state_threshold = 256 + state_multiplier * lattice.num_states()
  lattice = pynini.determinize(
      lattice, nstate=state_threshold, weight=weight_threshold)
  if lattice.num_states() == state_threshold:
    logging.warning("Unexpected hit state threshold; consider a higher value "
                    "for state_multiplier")
  return lattice
Beispiel #5
0
 def assertFstCompliesWithProperties(
         self, fst: pynini.Fst,
         expected_props: pynini.FstProperties) -> None:
     if fst.properties(expected_props, True) != expected_props:
         raise AssertionError(
             "Expected {actual} to contain the property {expected}".format(
                 expected=expected_props,
                 actual=fst.properties(pynini.FST_PROPERTIES, True)))
Beispiel #6
0
 def all_substrings(self, fsa: pynini.Fst) -> pynini.Fst:
     fsa = fsa.copy()
     start_state = fsa.start()
     for s in fsa.states():
         fsa.set_final(s)
         fsa.add_arc(
             start_state,
             pynini.Arc(0, 0, pynini.Weight.one(fsa.weight_type()), s))
     return fsa.optimize()
Beispiel #7
0
def decode_lattice(lattice: pynini.Fst, lm: pynini.Fst,
                   sym: pynini.SymbolTable) -> str:
    """Decodes the lattice."""
    lattice = pynini.compose(lattice, lm)
    assert lattice.start() != pynini.NO_STATE_ID, "composition failure"
    # Pynini can join the string for us.
    return pynini.shortestpath(lattice).rmepsilon().string(sym)
Beispiel #8
0
def lattice_to_one_top_string(lattice: pynini.Fst,
                              token_type: Optional[pynini.TokenType] = None
                             ) -> str:
  """Returns the top string in the lattice, raising an error if there's a tie.

  Given a pruned DFA of output strings (such as produced by lattice_to_dfa
  with optimal_only), extracts a single top string, raising an error if there's
  a tie.

  Args:
    lattice: Epsilon-free deterministic finite acceptor.
    token_type: Optional output token type, or symbol table.

  Returns:
    The top string.

  Raises:
    Error: Multiple top rewrites found.
  """
  spaths = lattice.paths(output_token_type=token_type)
  output = spaths.ostring()
  spaths.next()
  if not spaths.done():
    raise Error("Multiple top rewrites found: "
                f"{output} and {spaths.ostring()} (weight: {spaths.weight()})")
  return output
Beispiel #9
0
    def check_wellformed_lattice(lattice: pynini.Fst) -> None:
        """Raises an error if the lattice is empty.

    Args:
      lattice: A lattice FST.

    Raises:
      Error: Lattice is empty.
    """
        if lattice.start() == pynini.NO_STATE_ID:
            raise Error("Lattice is empty")
Beispiel #10
0
def lattice_to_strings(
    lattice: pynini.Fst,
    token_type: Optional[pynini.TokenType] = None) -> List[str]:
  """Returns tuple of output strings.

  Args:
    lattice: Epsilon-free acyclic WFSA.
    token_type: Optional output token type, or symbol table.

  Returns:
    An list of output strings.
  """
  return list(lattice.paths(output_token_type=token_type).ostrings())
Beispiel #11
0
    def _flip_lemmatizer_feature_labels(self,
                                        lemmatizer: pynini.Fst) -> pynini.Fst:
        """Helper function to flip lemmatizer's feature labels from input to output.

    Destructive operation.

    Args:
      lemmatizer: FST representing a partially constructed lemmatizer.

    Returns:
      Modified lemmatizer.
    """
        feature_labels = set()
        for s in self.category.feature_labels.states():
            aiter = self.category.feature_labels.arcs(s)
            while not aiter.done():
                arc = aiter.value()
                if arc.ilabel:
                    feature_labels.add(arc.ilabel)
                aiter.next()
        lemmatizer.set_input_symbols(lemmatizer.output_symbols())
        for s in lemmatizer.states():
            maiter = lemmatizer.mutable_arcs(s)
            while not maiter.done():
                arc = maiter.value()
                if arc.olabel in feature_labels:
                    # This assertion should always be true by construction.
                    assert arc.ilabel == 0, (
                        f"ilabel = "
                        f"{lemmatizer.input_symbols().find(arc.ilabel)},"
                        f" olabel = "
                        f"{lemmatizer.output_symbols().find(arc.olabel)}")
                    arc = pynini.Arc(arc.olabel, arc.ilabel, arc.weight,
                                     arc.nextstate)
                    maiter.set_value(arc)
                maiter.next()
        return lemmatizer
def fsa_from_list_of_symbols(input, symbol_table):
    """
    :param input:
    :type input:
    :param symbol_table:
    :type symbol_table: SymbolTable
    :return: An acceptor for the given list of tokens.
    :rtype: Fst
    The symbol table gets extended, if new tokens occur in the input.
    """
    fsa = Fst()
    fsa.set_input_symbols(symbol_table)
    fsa.set_output_symbols(symbol_table)
    state = fsa.add_state()
    fsa.set_start(state)
    for x in input:
        next_state = fsa.add_state()
        try:
            arc = Arc(symbol_table.find(x), symbol_table.find(x), 0, next_state)
        except KeyError:
            arc = Arc(symbol_table.add_symbol(x), symbol_table.add_symbol(x), 0, next_state)
        fsa.add_arc(state, arc)
        state = next_state
    fsa.set_final(state)
    return fsa
Beispiel #13
0
 def total_weight(self, fst: pynini.Fst) -> float:
     return float(pynini.shortestdistance(fst, reverse=True)[fst.start()])
Beispiel #14
0
 def is_subset(self, string: pynini.FstLike, fsa: pynini.Fst) -> bool:
     fsa = pynini.determinize(fsa.copy().rmepsilon())
     string_union_fsa = pynini.determinize(
         pynini.union(fsa, string).rmepsilon())
     return pynini.equivalent(string_union_fsa, fsa)
Beispiel #15
0
 def all_prefixes(self, fsa: pynini.Fst) -> pynini.Fst:
     fsa = fsa.copy()
     for s in fsa.states():
         fsa.set_final(s)
     return fsa.optimize()
def compile_wfst_from_left_branching_grammar(grammar):
    """
        :type grammar: LCFRS
        :rtype: Fst, Enumerator
        Create a FST from a left-branching hybrid grammar.
        The Output of the is a rule tree in `reverse polish notation <https://en.wikipedia.org/wiki/Reverse_Polish_notation>`_
        """
    myfst = Fst()

    nonterminals = SymbolTable()
    sid = myfst.add_state()
    nonterminals.add_symbol(INITIAL)
    myfst.set_start(sid)

    for nont in grammar.nonts():
        sid = myfst.add_state()
        nonterminals.add_symbol(nont, sid)
        if nont == grammar.start():
            myfst.set_final(sid)

    rules = Enumerator(first_index=1)
    for rule in grammar.rules():
        rules.object_index(rule)

    terminals = SymbolTable()
    terminals.add_symbol('<epsilon>', 0)

    for rule in grammar.rules():
        if rule.rank() == 2:
            assert len(rule.lhs().arg(0)) == 2
            for rule2 in grammar.lhs_nont_to_rules(rule.rhs_nont(1)):
                if len(rule2.rhs()) == 0:
                    arc = Arc(terminals.add_symbol(rule2.lhs().args()[0][0]),
                              terminals.add_symbol(
                                  str(rules.object_index(rule2)) + '-' +
                                  str(rules.object_index(rule))),
                              make_weight(rule.weight() * rule2.weight()),
                              nonterminals.find(rule.lhs().nont()))
                    myfst.add_arc(nonterminals.find(rule.rhs_nont(0)), arc)
        elif rule.rank() == 0:
            assert len(rule.lhs().arg(0)) == 1
            arc = Arc(terminals.add_symbol(rule.lhs().args()[0][0]),
                      terminals.add_symbol(str(rules.object_index(rule))), make_weight(rule.weight()),
                      nonterminals.find(rule.lhs().nont()))
            myfst.add_arc(nonterminals.find(INITIAL), arc)
        else:
            assert rule.rank() == 1
            assert rule.lhs().nont() == grammar.start()
            assert len(rule.lhs().arg(0)) == 1
            arc = Arc(0, terminals.add_symbol(str(rules.object_index(rule))), make_weight(rule.weight()),
                      nonterminals.find(grammar.start())
                      )
            myfst.add_arc(nonterminals.find(rule.rhs_nont(0)), arc)

    myfst.set_input_symbols(terminals)
    myfst.set_output_symbols(terminals)

    myfst.optimize(True)

    return myfst, rules
Beispiel #17
0
 def assertIsFsa(self, fsa: pynini.Fst) -> None:
     if fsa.properties(pynini.ACCEPTOR, True) != pynini.ACCEPTOR:
         raise AssertionError(f"Expected {fsa} to be an acceptor")
def _narcs(f: pynini.Fst) -> int:
    """Computes the number of arcs in an FST."""
    return sum(f.num_arcs(state) for state in f.states())
Beispiel #19
0
def _olabels_iter(f: pynini.Fst) -> Iterator[List[int]]:
    it = f.paths()
    while not it.done():
        yield it.olabels()
        it.next()