Ejemplo n.º 1
0
def _replace_fsts(outer_fst: fst.Fst,
                  replacements: Dict[int, fst.Fst],
                  eps="<eps>") -> fst.Fst:
    input_symbol_map: Dict[Union[int, Tuple[int, int]], int] = {}
    output_symbol_map: Dict[Union[int, Tuple[int, int]], int] = {}
    state_map: Dict[Union[int, Tuple[int, int]], int] = {}

    # Create new FST
    new_fst = fst.Fst()
    new_input_symbols = fst.SymbolTable()
    new_output_symbols = fst.SymbolTable()

    weight_one = fst.Weight.One(new_fst.weight_type())
    weight_zero = fst.Weight.Zero(new_fst.weight_type())
    weight_final = fst.Weight.Zero(outer_fst.weight_type())

    # Copy symbols
    outer_input_symbols = outer_fst.input_symbols()
    for i in range(outer_input_symbols.num_symbols()):
        key = outer_input_symbols.get_nth_key(i)
        input_symbol_map[key] = new_input_symbols.add_symbol(
            outer_input_symbols.find(key))

    outer_output_symbols = outer_fst.output_symbols()
    for i in range(outer_output_symbols.num_symbols()):
        key = outer_output_symbols.get_nth_key(i)
        output_symbol_map[key] = new_output_symbols.add_symbol(
            outer_output_symbols.find(key))

    in_eps = new_input_symbols.add_symbol(eps)
    out_eps = new_output_symbols.add_symbol(eps)

    # Copy states
    for outer_state in outer_fst.states():
        new_state = new_fst.add_state()
        state_map[outer_state] = new_state

        if outer_fst.final(outer_state) != weight_final:
            new_fst.set_final(new_state)

    # Set start state
    new_fst.set_start(state_map[outer_fst.start()])

    # Copy arcs
    for outer_state in outer_fst.states():
        new_state = state_map[outer_state]
        for outer_arc in outer_fst.arcs(outer_state):
            next_state = state_map[outer_arc.nextstate]
            replace_fst = replacements.get(outer_arc.olabel)

            if replace_fst is not None:
                # Replace in-line
                r = outer_arc.olabel
                replace_final = fst.Weight.Zero(replace_fst.weight_type())
                replace_input_symbols = replace_fst.input_symbols()
                replace_output_symbols = replace_fst.output_symbols()

                # Copy states
                for replace_state in replace_fst.states():
                    state_map[(r, replace_state)] = new_fst.add_state()

                    # Create final arc to next state
                    if replace_fst.final(replace_state) != replace_final:
                        new_fst.add_arc(
                            state_map[(r, replace_state)],
                            fst.Arc(in_eps, out_eps, weight_one, next_state),
                        )

                # Copy arcs
                for replace_state in replace_fst.states():
                    for replace_arc in replace_fst.arcs(replace_state):
                        new_fst.add_arc(
                            state_map[(r, replace_state)],
                            fst.Arc(
                                new_input_symbols.add_symbol(
                                    replace_input_symbols.find(
                                        replace_arc.ilabel)),
                                new_output_symbols.add_symbol(
                                    replace_output_symbols.find(
                                        replace_arc.olabel)),
                                weight_one,
                                state_map[(r, replace_arc.nextstate)],
                            ),
                        )

                # Create arc into start state
                new_fst.add_arc(
                    new_state,
                    fst.Arc(in_eps, out_eps, weight_one,
                            state_map[(r, replace_fst.start())]),
                )
            else:
                # Copy arc as-is
                new_fst.add_arc(
                    new_state,
                    fst.Arc(
                        input_symbol_map[outer_arc.ilabel],
                        output_symbol_map[outer_arc.olabel],
                        weight_one,
                        next_state,
                    ),
                )

    # Fix symbol tables
    new_fst.set_input_symbols(new_input_symbols)
    new_fst.set_output_symbols(new_output_symbols)

    return new_fst
Ejemplo n.º 2
0
def fromTransitions(
        transition_weights, init_weights=None, final_weights=None,
        arc_type='standard', transition_ids=None):
    """ Instantiate a state machine from state transitions.

    Parameters
    ----------

    Returns
    -------
    """

    num_states = transition_weights.shape[0]

    if transition_ids is None:
        transition_ids = {}
        for s_cur in range(num_states):
            for s_next in range(num_states):
                transition_ids[(s_cur, s_next)] = len(transition_ids)
        for s in range(num_states):
            transition_ids[(-1, s)] = len(transition_ids)

    output_table = openfst.SymbolTable()
    output_table.add_symbol(EPSILON_STRING, key=EPSILON)
    for transition, index in transition_ids.items():
        output_table.add_symbol(str(transition), key=index + 1)

    input_table = openfst.SymbolTable()
    input_table.add_symbol(EPSILON_STRING, key=EPSILON)
    for transition, index in transition_ids.items():
        input_table.add_symbol(str(transition), key=index + 1)

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_table)
    fst.set_output_symbols(output_table)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    if init_weights is None:
        init_weights = tuple(float(one) for __ in range(num_states))

    if final_weights is None:
        final_weights = tuple(float(one) for __ in range(num_states))

    fst.set_start(fst.add_state())

    def makeState(i):
        state = fst.add_state()

        initial_weight = openfst.Weight(fst.weight_type(), init_weights[i])
        if initial_weight != zero:
            transition = transition_ids[-1, i] + 1
            arc = openfst.Arc(EPSILON, transition, initial_weight, state)
            fst.add_arc(fst.start(), arc)

        final_weight = openfst.Weight(fst.weight_type(), final_weights[i])
        if final_weight != zero:
            fst.set_final(state, final_weight)

        return state

    states = tuple(makeState(i) for i in range(num_states))
    for i_cur, row in enumerate(transition_weights):
        for i_next, tx_weight in enumerate(row):
            cur_state = states[i_cur]
            next_state = states[i_next]
            weight = openfst.Weight(fst.weight_type(), tx_weight)
            transition = transition_ids[i_cur, i_next] + 1
            if weight != zero:
                arc = openfst.Arc(transition, transition, weight, next_state)
                fst.add_arc(cur_state, arc)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")
        # print("fst.verify() returned False")

    return fst
Ejemplo n.º 3
0
def slots_to_fsts(slots_dir: Path,
                  slot_names: Optional[Set[str]] = None,
                  eps: str = "<eps>") -> Dict[str, fst.Fst]:
    """Transform slot values into FSTs."""
    slot_fsts: Dict[str, fst.Fst] = {}

    for slot_path in slots_dir.glob("*"):
        # Skip directories
        if not slot_path.is_file():
            continue

        slot_name = slot_path.name

        # Skip slots not in include list
        if (slot_names is not None) and (slot_name not in slot_names):
            continue

        slot_fst = fst.Fst()
        weight_one = fst.Weight.One(slot_fst.weight_type())
        slot_start = slot_fst.add_state()
        slot_fst.set_start(slot_start)

        slot_end = slot_fst.add_state()
        slot_fst.set_final(slot_end)

        input_symbols = fst.SymbolTable()
        in_eps = input_symbols.add_symbol(eps)

        output_symbols = fst.SymbolTable()
        out_eps = output_symbols.add_symbol(eps)

        replacements: Dict[str, fst.Fst] = {}

        with open(slot_path, "r") as slot_file:
            # Process each line independently to avoid recursion limit
            for line in slot_file:
                line = line.strip()
                if len(line) == 0:
                    continue

                replace_symbol = f"__replace__{len(replacements)}"
                out_replace = output_symbols.add_symbol(replace_symbol)

                # Convert to JSGF grammar
                with io.StringIO() as grammar_file:
                    print("#JSGF v1.0;", file=grammar_file)
                    print(f"grammar {slot_name};", file=grammar_file)
                    print(f"public <{slot_name}> = ({line});",
                          file=grammar_file)

                    line_grammar = grammar_file.getvalue()
                    line_fst = grammar_to_fsts(line_grammar).grammar_fst

                    slot_fst.add_arc(
                        slot_start,
                        fst.Arc(in_eps, out_replace, weight_one, slot_end))

                    replacements[out_replace] = line_fst

        # ---------------------------------------------------------------------

        # Fix symbol tables
        slot_fst.set_input_symbols(input_symbols)
        slot_fst.set_output_symbols(output_symbols)

        # Replace slot values
        slot_fsts["$" + slot_name] = _replace_fsts(slot_fst, replacements)

    return slot_fsts
Ejemplo n.º 4
0
def gen_unigram_graph(net_vocab_file,
                      token_file,
                      out_file,
                      add_final_space=False,
                      allow_nonblank_selfloops=True,
                      use_contextual_blanks=None,
                      loop_using_symbol_repetitions=None):
    del use_contextual_blanks  # unused, does not apply to this model
    del loop_using_symbol_repetitions  # unused, does not apply to this model
    net_vocab = read_net_vocab(net_vocab_file)
    print("net vocab", net_vocab)

    CTC = fst.Fst(arc_type='standard')
    CTC_os = fst.SymbolTable.read_text(token_file)
    CTC_is = fst.SymbolTable()
    CTC_is.add_symbol('<eps>', 0)
    for i, s in enumerate(net_vocab):
        CTC_is.add_symbol(s, i + 1)

    CTC.set_input_symbols(CTC_is)
    CTC.set_output_symbols(CTC_os)

    after_blank = CTC.add_state()
    CTC.set_start(after_blank)
    CTC.set_final(after_blank)

    l2s = {'<pad>': after_blank}

    for i in range(CTC_is.num_symbols()):
        i = CTC_is.get_nth_key(i)
        let = CTC_is.find(i)
        if l in ('<pad>', '<eps>'):
            continue
        l2s[let] = CTC.add_state()
        CTC.set_final(l2s[let])

    weight_one = fst.Weight.One('tropical')

    final_space_arc = None
    if add_final_space:
        final_space = CTC.add_state()
        CTC.set_final(final_space)
        final_space_arc = fst.Arc(CTC_is.find('<eps>'), CTC_os.find('<spc>'),
                                  weight_one, final_space)

    os_eps = CTC_os.find('<eps>')

    for let, s in l2s.items():
        in_label = CTC_is.find(let)
        out_label = os_eps if let == '<pad>' else CTC_os.find(let)

        # Self-loop, don't emit
        if let == '<pad>' or allow_nonblank_selfloops:
            CTC.add_arc(s, fst.Arc(in_label, os_eps, weight_one, s))

        # Transition from another state - this emits
        for l2, s2 in l2s.items():
            if let == l2:
                continue
            CTC.add_arc(s2, fst.Arc(in_label, out_label, weight_one, s))

        # Optional transition to emit the final space
        if final_space_arc is not None:
            CTC.add_arc(s, final_space_arc)

    CTC.arcsort('olabel')
    CTC.write(out_file)