示例#1
0
def spellout_machine(wrdfname, ltr2wrdfst):

    lm = fst.Fst.read(ltr2wrdfst)
    s_in = lm.output_symbols()
    s_out = lm.input_symbols()

    letter = fst.Fst()
    letter.set_input_symbols(s_in)
    letter.set_output_symbols(s_out)
    letter.add_state()

    for word in open(wrdfname, "r").readlines():
        word = word.strip()
        orig = copy.copy(word)
        #        word = list(word)
        word += "#"
        #word = dig2word(word)
        nletter = fst.Fst()
        nletter.set_input_symbols(s_in)
        nletter.set_output_symbols(s_out)
        nletter.add_state()
        for i, ltr in enumerate(word):
            nletter.add_state()
            code2 = s_out.find(ltr)
            if i == 0:
                nletter.set_start(0)
                code1 = s_in.find(orig)
                nletter.add_arc(i, fst.Arc(code1, code2, None, i + 1))
            else:
                code1 = s_in.find("<epsilon>")
                nletter.add_arc(i, fst.Arc(code1, code2, None, i + 1))
        nletter.set_final(i + 1)
        letter.union(nletter)
    letter.rmepsilon()
    letter.write("spellout.fst")
示例#2
0
    def make_state(i_input, i_output, weight):
        io_state = fst.add_state()

        state_istr = input_vocab[i_input]
        state_ostr = output_vocab[i_output]

        # CASE 1: (in, I) : (out, I), weight one, transition into io state
        arc_istr = input_parts_to_str[state_istr, dur_internal_str]
        if pass_input:
            arc_ostr = output_parts_to_str[state_istr, state_ostr,
                                           dur_internal_str]
        else:
            arc_ostr = output_parts_to_str[state_ostr, dur_internal_str]
        arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                          fst.output_symbols().find(arc_ostr), one, io_state)
        fst.add_arc(state, arc)
        fst.add_arc(io_state, arc.copy())

        # CASE 2: (in, F) : (out, F), weight tx_weight
        arc_istr = input_parts_to_str[state_istr, dur_final_str]
        if pass_input:
            arc_ostr = output_parts_to_str[state_istr, state_ostr,
                                           dur_final_str]
        else:
            arc_ostr = output_parts_to_str[state_ostr, dur_final_str]
        arc = openfst.Arc(fst.input_symbols().find(arc_istr),
                          fst.output_symbols().find(arc_ostr), weight, state)
        fst.add_arc(io_state, arc)
示例#3
0
文件: oclm_c.py 项目: shiranD/oclm
    def __weighted_union(self, left, right, left_prob, right_prob):
        '''
        Union the FSTs left, right with a weight.
        '''
        # left hand side part.
        left_w = -math.log(left_prob)
        lhs = fst.Fst()
        lhs.set_input_symbols(left.input_symbols())
        lhs.set_output_symbols(left.output_symbols())
        lhs.add_state()
        lhs.set_start(0)
        lhs.add_state()
        lhs.add_arc(0, fst.Arc(0, 0, left_w, 1))
        lhs.set_final(1)
        lhs.concat(left)

        # prefix part.
        right_w = -math.log(right_prob)
        rhs = fst.Fst()
        rhs.set_input_symbols(right.input_symbols())
        rhs.set_output_symbols(right.output_symbols())
        rhs.add_state()
        rhs.set_start(0)
        rhs.add_state()
        rhs.add_arc(0, fst.Arc(0, 0, right_w, 1))
        rhs.set_final(1)
        rhs.concat(right)

        lhs.union(rhs)
        return lhs
示例#4
0
def add_endpoints(fst, bos_str='<BOS>', eos_str='<EOS>'):
    one = openfst.Weight.one(fst.weight_type())
    zero = openfst.Weight.zero(fst.weight_type())

    # add pre-initial state accepting BOS
    i_bos_in = fst.input_symbols().find(bos_str)
    i_bos_out = fst.output_symbols().find(bos_str)
    old_start = fst.start()
    new_start = fst.add_state()
    fst.set_start(new_start)
    init_arc = openfst.Arc(i_bos_in, i_bos_out, one, old_start)
    fst.add_arc(new_start, init_arc)

    # add superfinal state accepting EOS
    i_eos_in = fst.input_symbols().find(eos_str)
    i_eos_out = fst.output_symbols().find(eos_str)
    new_final = fst.add_state()
    for state in fst.states():
        w_final = fst.final(state)
        if w_final != zero:
            fst.set_final(state, zero)
            final_arc = openfst.Arc(i_eos_in, i_eos_out, w_final, new_final)
            fst.add_arc(state, final_arc)
    fst.set_final(new_final, one)

    return fst
示例#5
0
    def enterAlternative(self, ctx):
        anchor_state = self.alt_states[self.group_depth]

        if self.group_depth not in self.alt_ends:
            # Patch start of alternative
            next_state = self.fst.add_state()
            for arc in self.fst.arcs(anchor_state):
                self.fst.add_arc(next_state, arc)

            self.fst.delete_arcs(anchor_state)
            self.fst.add_arc(
                anchor_state,
                fst.Arc(self.in_eps, self.out_eps, self.weight_one,
                        next_state),
            )

            # Create shared end state for alternatives
            self.alt_ends[self.group_depth] = self.fst.add_state()

        # Close previous alternative
        last_state = self.last_states[self.rule_name]
        end_state = self.alt_ends[self.group_depth]
        if last_state != end_state:
            self.fst.add_arc(
                last_state,
                fst.Arc(self.in_eps, self.out_eps, self.weight_one, end_state),
            )

        # Add new intermediary state
        next_state = self.fst.add_state()
        self.fst.add_arc(
            anchor_state,
            fst.Arc(self.in_eps, self.out_eps, self.weight_one, next_state),
        )
        self.last_states[self.rule_name] = next_state
示例#6
0
def _compile_cg(ifar_path: str, ofar_path: str, insertions: bool,
                deletions: bool) -> str:
  """Compiles the covering grammar from the input and output FARs.

  Args:
    ifar_path: path to the input FAR.
    ofar_path: path to the output FAR.
    insertions: should insertions be permitted?
    deletions: should deletions be permitted?

  Returns:
    The path to the CG FST.
  """
  ilabels = _get_far_labels(ifar_path)
  olabels = _get_far_labels(ofar_path)
  cg = pywrapfst.VectorFst()
  state = cg.add_state()
  cg.set_start(state)
  one = pywrapfst.Weight.one(cg.weight_type())
  for ilabel, olabel in itertools.product(ilabels, olabels):
    cg.add_arc(state, pywrapfst.Arc(ilabel, olabel, one, state))
  # Handles epsilons, carefully avoiding adding a useless 0:0 label.
  if insertions:
    for olabel in olabels:
      cg.add_arc(state, pywrapfst.Arc(0, olabel, one, state))
  if deletions:
    for ilabel in ilabels:
      cg.add_arc(state, pywrapfst.Arc(ilabel, 0, one, state))
  cg.set_final(state)
  assert cg.verify(), "Label acceptor is ill-formed"
  cg_path = _mktemp("cg.fst")
  cg.write(cg_path)
  return cg_path
 def _lexicon_covering(self, ) -> None:
     """Builds covering grammar and lexicon FARs."""
     # Sets of labels for the covering grammar.
     with open(os.path.join(self.working_log_directory,
                            "covering_grammar.log"),
               "w",
               encoding="utf8") as log_file:
         com = [
             thirdparty_binary("farcompilestrings"),
             "--fst_type=compact",
         ]
         if self.input_token_type != "utf8":
             com.append("--token_type=symbol")
             com.append(f"--symbols={self.input_token_type}", )
             com.append("--unknown_symbol=<unk>")
         else:
             com.append("--token_type=utf8")
         com.extend([self.input_path, self.input_far_path])
         print(" ".join(com), file=log_file)
         subprocess.check_call(com,
                               env=os.environ,
                               stderr=log_file,
                               stdout=log_file)
         com = [
             thirdparty_binary("farcompilestrings"),
             "--fst_type=compact",
             "--token_type=symbol",
             f"--symbols={self.phone_symbol_table_path}",
             self.output_path,
             self.output_far_path,
         ]
         print(" ".join(com), file=log_file)
         subprocess.check_call(com,
                               env=os.environ,
                               stderr=log_file,
                               stdout=log_file)
         ilabels = _get_far_labels(self.input_far_path)
         print(ilabels, file=log_file)
         olabels = _get_far_labels(self.output_far_path)
         print(olabels, file=log_file)
         cg = pywrapfst.VectorFst()
         state = cg.add_state()
         cg.set_start(state)
         one = pywrapfst.Weight.one(cg.weight_type())
         for ilabel, olabel in itertools.product(ilabels, olabels):
             cg.add_arc(state, pywrapfst.Arc(ilabel, olabel, one, state))
         # Handles epsilons, carefully avoiding adding a useless 0:0 label.
         if self.insertions:
             for olabel in olabels:
                 cg.add_arc(state, pywrapfst.Arc(0, olabel, one, state))
         if self.deletions:
             for ilabel in ilabels:
                 cg.add_arc(state, pywrapfst.Arc(ilabel, 0, one, state))
         cg.set_final(state)
         assert cg.verify(), "Label acceptor is ill-formed"
         cg.write(self.cg_path)
示例#8
0
def replace_and_patch(
    outer_fst: fst.Fst,
    outer_start_state: int,
    outer_final_state: int,
    inner_fst: fst.Fst,
    label_sym: int,
    eps: int = 0,
) -> None:
    """Copies an inner FST into an outer FST, creating states and mapping symbols.
    Creates arcs from outer start/final states to inner start/final states."""

    in_symbols = outer_fst.input_symbols()
    out_symbols = outer_fst.output_symbols()
    inner_zero = fst.Weight.Zero(inner_fst.weight_type())
    outer_one = fst.Weight.One(outer_fst.weight_type())

    state_map = {}
    in_symbol_map = {}
    out_symbol_map = {}

    for i in range(inner_fst.output_symbols().num_symbols()):
        sym_str = inner_fst.output_symbols().find(i).decode()
        out_symbol_map[i] = out_symbols.find(sym_str)

    for i in range(inner_fst.input_symbols().num_symbols()):
        sym_str = inner_fst.input_symbols().find(i).decode()
        in_symbol_map[i] = in_symbols.find(sym_str)

    # Create states in outer FST
    for inner_state in inner_fst.states():
        state_map[inner_state] = outer_fst.add_state()

    # Create arcs in outer FST
    for inner_state in inner_fst.states():
        if inner_state == inner_fst.start():
            outer_fst.add_arc(
                outer_start_state,
                fst.Arc(eps, label_sym, outer_one, state_map[inner_state]),
            )

        for inner_arc in inner_fst.arcs(inner_state):
            outer_fst.add_arc(
                state_map[inner_state],
                fst.Arc(
                    in_symbol_map[inner_arc.ilabel],
                    out_symbol_map[inner_arc.olabel],
                    outer_one,
                    state_map[inner_arc.nextstate],
                ),
            )

            if inner_fst.final(inner_arc.nextstate) != inner_zero:
                outer_fst.add_arc(
                    state_map[inner_arc.nextstate],
                    fst.Arc(eps, eps, outer_one, outer_final_state),
                )
示例#9
0
def make_kleeneplus(s, graphs, stoi):
    """one-or-more-graphs"""
    fst = wfst.Fst()
    start = fst.add_state()
    end = fst.add_state()
    fst.set_start(start)
    fst.set_final(end, wfst.Weight.One(fst.weight_type()))
    for g in graphs:
        fst.add_arc(start, wfst.Arc(stoi[g], stoi[g], wfst.Weight.One(fst.weight_type()), end))
        fst.add_arc(end, wfst.Arc(stoi[g], stoi[g], wfst.Weight.One(fst.weight_type()), end))
    return fst
示例#10
0
def make_lexicon_fst(input_words, words):
    compiler = fst.Compiler()
    lexicon_fst = fst.Fst()
    start = lexicon_fst.add_state()
    lexicon_fst.set_start(start)

    last = lexicon_fst.add_state()

    # this line projects space to epsilon
    lexicon_fst.add_arc(
        start, fst.Arc(32, 0, fst.Weight.One(lexicon_fst.weight_type()), last))
    lexicon_fst.set_final(last)
    for i, w in enumerate(words):
        w = w.strip()
        index = i + 1
        last = lexicon_fst.add_state()
        lexicon_fst.add_arc(
            start,
            fst.Arc(ord(w[0]), index,
                    fst.Weight.One(lexicon_fst.weight_type()), last))
        for c in w[1:]:
            this = lexicon_fst.add_state()
            lexicon_fst.add_arc(
                last,
                fst.Arc(ord(c), 0, fst.Weight.One(lexicon_fst.weight_type()),
                        this))
            last = this
        lexicon_fst.set_final(last, 0)
    lexicon_fst = fst.determinize(lexicon_fst).minimize().closure()

    with open('words.syms', 'w') as f:
        f.write('<eps> 0\n')
        for i, w in enumerate(words + input_words):  # we put word symbol here
            f.write('{} {}'.format(w, str(i + 1)))
            f.write('\n')
        f.write('<SPACE> {}\n'.format(str(32)))

    epsilon_fst = fst.Fst()
    start = epsilon_fst.add_state()
    end = epsilon_fst.add_state()
    for i, w in enumerate(words):
        index = i + 1
        epsilon_fst.add_arc(
            start,
            fst.Arc(0, index, fst.Weight.One(epsilon_fst.weight_type()), end))

    epsilon_fst.add_arc(
        start, fst.Arc(0, 32, fst.Weight.One(epsilon_fst.weight_type()), end))
    epsilon_fst.set_final(end, 0)
    epsilon_fst.set_start(start)
    epsilon_fst = epsilon_fst.closure()

    return lexicon_fst, epsilon_fst
def search(root, lat2_paths):
    # root is a node in lat which defines the current history
    # lat2_paths are weights which could be added to the current
    # path in lat1 if the key does not get discarded in the 
    # future
    global lat, lat2, visited
    eps_paths = {node : [] for node in lat2_paths}
    open_paths = dict(eps_paths)
    while open_paths:
        next_open = {}
        for node,path in open_paths.iteritems():
            if node in eps_paths:
                continue
            for arc in lat2.arcs(node):
                if arc.olabel == 0:
                    if node in next_open:
                        
                    
      
    
    # Add paths to eps reachable nodes
    while open_nodes:
    visited = 
    if root in visited:
        return
    visited[root] = True
    for arc in lat.arcs(root):
        dfs(arc.nextstate, hist + [str(arc.ilabel)])
    key = ' '.join(hist[-hist_len:])
    if key in hist2node: # connect with it
        arc1 = fst.Arc(0, 0, fst.Weight.One(lat.weight_type()), hist2node[key])
        arc2 = fst.Arc(0, 0, fst.Weight.One(lat.weight_type()), root)
        lat.add_arc(root, arc1)
        lat.add_arc(hist2node[key], arc2)
    else:
        hist2node[key] = root

idx = 0
while True:
    idx += 1
    input_path1 = get_path(args.input1, idx)
    if not input_path or not os.path.isfile(input_path1):
        break
    input_path2 = get_path(args.input2, idx)
    lat = fst.Fst.read(input_path1)
    lat.rmepsilon()
    lat.determinize()
    lat.minimize()
    lat2 = fst.Fst.read(input_path2)
    visited = {}
    search(lat.start(), [lat2.start()])
    lat.write(get_path(args.output, idx))
示例#12
0
def make_compounder(syms, word_ids):
    c = fst.Fst()
    start_state = c.add_state()
    assert (start_state == 0)
    c.set_start(start_state)
    space_id = syms["<space>"]
    c.add_arc(0, fst.Arc(space_id, syms["<eps>"], 1, 0))
    c.add_arc(0, fst.Arc(space_id, syms["+C+"], 1, 0))
    c.add_arc(0, fst.Arc(space_id, syms["+D+"], 1, 0))
    for word_id in word_ids:
        c.add_arc(0, fst.Arc(word_id, word_id, 1, 0))
    c.set_final(0, 1)
    return c
示例#13
0
文件: oclm_c.py 项目: shiranD/oclm
    def update(self, ch_dist):
        '''
        Update the history with the new likelihood array in the correct scale
        (nagative log space) to the history.
        '''
        new_ch = fst.Fst()
        new_ch.set_input_symbols(self.ch_syms)
        new_ch.set_output_symbols(self.ch_syms)
        new_ch.add_state()
        new_ch.set_start(0)
        new_ch.add_state()
        new_ch.set_final(1)
        space_code = -1
        space_pr = 0.
        for ch, pr in ch_dist:
            code = self.ch_syms.find(ch)
            if ch == '#':  # Adds space after we finish updating trailing chars.
                space_code = code
                space_pr = pr
                continue
            new_ch.add_arc(0, fst.Arc(code, code, pr, 1))
        new_ch.arcsort(sort_type="olabel")

        # Adds the trailing characters to existing binned history.
        for words_bin in self.prefix_words:
            if words_bin[2] >= 10:  # We discard the whole trail in this case (TODO)
                continue
            # Unless we are testing a straight line machine, this normally
            # doesn't happen in practice.
            if new_ch.num_arcs(0) == 0:
                continue
            words_bin[1].concat(new_ch).rmepsilon()
            words_bin[2] += 1

        # Continues updating the history and adds back the space if necessary.
        if space_code >= 0:
            new_ch.add_arc(0, fst.Arc(space_code, space_code, space_pr, 1))
        self.history_fst.concat(new_ch).rmepsilon()

        # Respectively update the binned history
        if space_code >= 0:  # If there is a space
            # Finishes the prefix words in current position
            word_lattice = fst.compose(self.history_fst, self.ltr2wrd)
            word_lattice.project(project_output=True).rmepsilon()
            word_lattice = fst.determinize(word_lattice)
            word_lattice.minimize()
            if word_lattice.num_states() == 0:
                word_lattice = self.create_empty_fst(self.wd_syms, self.wd_syms)
            trailing_chars = self.create_empty_fst(self.ch_syms, self.ch_syms)
            self.prefix_words.append([word_lattice, trailing_chars, 0])
示例#14
0
def make_intent_fst(grammar_fsts: Dict[str, fst.Fst],
                    eps: str = "<eps>") -> fst.Fst:
    """Merges grammar FSTs created with grammar_to_fsts into a single acceptor FST."""
    input_symbols = fst.SymbolTable()
    output_symbols = fst.SymbolTable()

    in_eps: int = input_symbols.add_symbol(eps)
    out_eps: int = output_symbols.add_symbol(eps)

    intent_fst = fst.Fst()
    weight_one = fst.Weight.One(intent_fst.weight_type())

    # Create start/final states
    start_state = intent_fst.add_state()
    intent_fst.set_start(start_state)

    final_state = intent_fst.add_state()
    intent_fst.set_final(final_state)

    replacements: Dict[int, fst.Fst] = {}

    for intent, grammar_fst in grammar_fsts.items():
        intent_label = f"__label__{intent}"
        out_label = output_symbols.add_symbol(intent_label)

        # --[__label__INTENT]-->
        intent_start = intent_fst.add_state()
        intent_fst.add_arc(
            start_state, fst.Arc(in_eps, out_label, weight_one, intent_start))

        # --[__replace__INTENT]-->
        intent_end = intent_fst.add_state()
        replace_symbol = f"__replace__{intent}"
        out_replace = output_symbols.add_symbol(replace_symbol)
        intent_fst.add_arc(
            intent_start, fst.Arc(in_eps, out_replace, weight_one, intent_end))

        # --[eps]-->
        intent_fst.add_arc(intent_end,
                           fst.Arc(in_eps, out_eps, weight_one, final_state))

        replacements[out_replace] = grammar_fst

    # Fix symbol tables
    intent_fst.set_input_symbols(input_symbols)
    intent_fst.set_output_symbols(output_symbols)

    # Do replacements

    return _replace_fsts(intent_fst, replacements, eps=eps)
示例#15
0
def dfs(root, hist):
    global hist_len, visited, hist2node, lat
    if root in visited:
        return
    visited[root] = True
    for arc in lat.arcs(root):
        dfs(arc.nextstate, hist + [str(arc.ilabel)])
    key = ' '.join(hist[-hist_len:])
    if key in hist2node:  # connect with it
        arc1 = fst.Arc(0, 0, fst.Weight.One(lat.weight_type()), hist2node[key])
        arc2 = fst.Arc(0, 0, fst.Weight.One(lat.weight_type()), root)
        lat.add_arc(root, arc1)
        lat.add_arc(hist2node[key], arc2)
    else:
        hist2node[key] = root
示例#16
0
def build_ctc_mono_decoding_fst(S, arc_type='log', add_syms=False):
    """
    Build a monophone CTC decoding fst.
    Args:
        S - number of monophones
        arc_type - log or standard. Gives the interpretation of the FST.
    Returns:
        an FST that accepts all sequences over [1,..,S]^* and returns
        shorter ones with duplicates and blanks removed.

        The input labels are shifted by one, so that there are no epsilon
        transitions.
        The output labels are not (blank is zero), allowing one to read out
        the label sequence easily.
    """
    CTC = fst.Fst(arc_type=arc_type)
    weight_one = fst.Weight.One(CTC.weight_type())

    for s in range(S):
        s1 = CTC.add_state()
        assert s == s1
        CTC.set_final(s1)
    CTC.set_start(0)

    for s in range(S):
        # transitions out of symbol s
        # self-loop, don't emit
        CTC.add_arc(s, fst.Arc(s + 1, 0, weight_one, s))
        for s_next in range(S):
            if s_next == s:
                continue
            # transition to next symbol
            CTC.add_arc(s, fst.Arc(s_next + 1, s_next, weight_one, s_next))
    CTC.arcsort('olabel')

    if add_syms:
        in_syms = fst.SymbolTable()
        in_syms.add_symbol('<eps>', 0)
        in_syms.add_symbol('B', 1)
        for s in range(1, S):
            in_syms.add_symbol(chr(ord('a') + s - 1), s + 1)
        out_syms = fst.SymbolTable()
        out_syms.add_symbol('<eps>', 0)
        for s in range(1, S):
            out_syms.add_symbol(chr(ord('a') + s - 1), s)
        CTC.set_input_symbols(in_syms)
        CTC.set_output_symbols(out_syms)
    return CTC
示例#17
0
    def addArc(self, ilabels, olabels=None, start_state=None, is_loop=False):
        '''
        
        '''
        isyms = self._label_to_sym(ilabels, self.fst.mutable_input_symbols())
        if olabels is None:
            #             osyms = isyms[:] # create copy
            osyms = self._label_to_sym(ilabels,
                                       self.fst.mutable_output_symbols())
        else:
            osyms = self._label_to_sym(olabels,
                                       self.fst.mutable_output_symbols())

        maxix = max(len(isyms), len(osyms))
        if len(isyms) != len(osyms):
            isyms += [0] * (maxix - len(isyms))
            osyms += [0] * (maxix - len(osyms))

        if start_state is None:
            start_state = self.STATE_START

        q0 = start_state
        for i in range(maxix):
            if is_loop and i == maxix - 1:
                q1 = start_state
            else:
                q1 = self.fst.add_state()
            self.fst.add_arc(
                q0,
                pywrapfst.Arc(isyms[i], osyms[i],
                              pywrapfst.Weight.One(self.fst.weight_type()),
                              q1))
            q0 = q1
        return q1
示例#18
0
 def __call__(self, x):
     x, xs = transform_output(x)
     # Normalize log-posterior matrices, if necessary
     if self._normalize:
         x = log_softmax(x, dim=2)
     x = x.permute(1, 0, 2).cpu()
     self._output = []
     D = x.size(2)
     for logpost, length in zip(x, xs):
         f = fst.Fst()
         f.set_start(f.add_state())
         for t in range(length):
             f.add_state()
             for j in range(D):
                 weight = fst.Weight(f.weight_type(), float(-logpost[t, j]))
                 f.add_arc(
                     t,
                     fst.Arc(
                         j + 1,  # input label
                         j + 1,  # output label
                         weight,  # -logpost[t, j]
                         t + 1,  # nextstate
                     ),
                 )
         f.set_final(length, fst.Weight.One(f.weight_type()))
         f.verify()
         self._output.append(f)
     return self._output
示例#19
0
文件: oclm_c.py 项目: shiranD/oclm
 def normalize(self, anfst):
     '''
     produce a normalized fst
     '''
     # possibly there's a shorter way
     # that keeps all in fst land
     dist = []
     labels = []
     syms = anfst.input_symbols()
     state = anfst.start()
     for arc in anfst.arcs(state):
         label = syms.find(arc.ilabel)
         pr = float(arc.weight)
         dist.append(BitWeight(pr)) # ebitweight gets -log(pr) only
         labels.append(label)
     sum_value = sum(dist, BitWeight(1e6)) # will sum in log domain (log-add)
     norm_dist = [(prob/sum_value).loge() for prob in dist]
     del anfst
     # construct a norm fst
     output = fst.Fst()
     output.set_input_symbols(syms)
     output.set_output_symbols(syms)
     output.add_state()
     output.add_state()
     for (pr, label) in zip(norm_dist,labels):
         code = syms.find(label)
         output.add_arc(0, fst.Arc(code, code, pr, 1))
     output.set_start(0)
     output.set_final(1)
     return output
示例#20
0
def add_arc(fst_in, from_word, to_word, weight):
    """
	Adds an arc to a given FST
	Note: Despite returning an updated FST, this  method makes the changes
	**IN PLACE**, so you may want to make a copy of the original
	FST before updating the weights
	:param fst_in: <openfst.Fst> to modify
	:param from_word: <str>
	:param to_word: <str>
	:param weight: <float>
	:return: updated <openfst.Fst>
	"""
    # make a dict and node_2_word from index_fst()
    fst_dict, node_2_word = index_fst(fst_in)

    # get a lookup table
    lookup = fst_in.input_symbols()

    # set from state as idx
    from_state = fst_dict[from_word]["state_id"]

    # set to state as idx
    to_state = fst_dict[to_word]["state_id"]

    fst_in = fst_in.add_arc(
        from_state,
        openfst.Arc(lookup_word(to_word, lookup), lookup_word(to_word, lookup),
                    openfst.Weight("tropical", weight), to_state))

    return fst_in
示例#21
0
def single_state_transducer(transition_weights,
                            row_vocab,
                            col_vocab,
                            input_symbols=None,
                            output_symbols=None,
                            arc_type='standard'):

    fst = openfst.VectorFst(arc_type=arc_type)
    fst.set_input_symbols(input_symbols)
    fst.set_output_symbols(output_symbols)

    zero = openfst.Weight.zero(fst.weight_type())
    one = openfst.Weight.one(fst.weight_type())

    state = fst.add_state()
    fst.set_start(state)
    fst.set_final(state, one)

    for i_input, row in enumerate(transition_weights):
        for i_output, tx_weight in enumerate(row):
            weight = openfst.Weight(fst.weight_type(), tx_weight)
            input_id = fst.input_symbols().find(row_vocab[i_input])
            output_id = fst.output_symbols().find(col_vocab[i_output])
            if weight != zero:
                arc = openfst.Arc(input_id, output_id, weight, state)
                fst.add_arc(state, arc)

    if not fst.verify():
        raise openfst.FstError("fst.verify() returned False")

    return fst
示例#22
0
def build_chain_fst(labels, arc_type='log', vocab=None):
    """
    Build an acceptor for string given by elements of labels.

    Args:
        labels - a sequence of labels in the range 1..S
        arc_type - fst arc type (standard or log)
    Returns:
        FST consuming symbols in the range 1..S.

    Notes:
        Elements of labels are assumed to be greater than zero
        (which maps to blank)!
    """
    C = fst.Fst(arc_type=arc_type)
    weight_one = fst.Weight.One(C.weight_type())
    s = C.add_state()
    C.set_start(s)
    for l in labels:
        s_next = C.add_state()
        C.add_arc(s, fst.Arc(l, l, weight_one, s_next))
        s = s_next
    C.set_final(s)
    C.arcsort('ilabel')
    return C
示例#23
0
def gen_trigram_graph(ngram_to_class_file,
                      net_vocab_file,
                      token_file,
                      out_file,
                      add_final_space=False,
                      use_contextual_blanks=False,
                      prevent_epsilons=False,
                      determinize=True):

    net_vocab = read_net_vocab(net_vocab_file)
    print("net vocab", net_vocab)
    N = len(net_vocab)

    with open(ngram_to_class_file, 'r') as f:
        trigrams = [tuple([int(n) for n in line.split()]) for line in f]

    CTC = build_ctc_trigram_decoding_fst_v2(
        N,
        trigrams,
        arc_type='standard',
        use_context_blanks=use_contextual_blanks,
        prevent_epsilons=prevent_epsilons,
        determinize=determinize,
        add_syms=False)

    assert CTC.weight_type() == 'tropical'

    # Emitted symbols need to be remapped from net_vocab to token symbols
    #   net_vocab[:5] : ['<pad>', '<unk>', '<spc>', 'E', 'T']
    #   tokens[:5]    : ['<eps> 0', '<spc> 1', '<pad> 2', '<unk> 3', 'E 4']
    # <pad> is unused and gets mapped to eps, <unk> and <spc> change ids,
    # the rest is roughly shifted by 1.
    tokens = {t.split()[0]: int(t.split()[1]) for t in open(token_file, 'r')}
    net_vocab_dict = {t: i for i, t in enumerate(net_vocab)}
    osym_map = []
    for t, i in net_vocab_dict.items():
        osym_map.append((i, 0 if t == '<pad>' else tokens[t]))
    CTC.relabel_pairs(ipairs=None, opairs=osym_map)
    print(osym_map)

    CTC_os = fst.SymbolTable.read_text(token_file)
    CTC.set_output_symbols(CTC_os)
    os_eps = CTC_os.find('<eps>')
    assert os_eps == 0

    weight_one = fst.Weight.One('tropical')

    if add_final_space:
        is_final = lambda s: CTC.final(s) != fst.Weight(
            CTC.weight_type(), 'infinity')
        final_space = CTC.add_state()
        CTC.set_final(final_space)
        final_space_arc = fst.Arc(0, CTC_os.find('<spc>'), weight_one,
                                  final_space)
        for s in CTC.states():
            if is_final(s):
                CTC.add_arc(s, final_space_arc)

    CTC.arcsort('olabel')
    CTC.write(out_file)
示例#24
0
def addArcLinear(fst, start_state, labels, olabels=None, is_loop=False):
    if not isinstance(labels, list):
        raise ValueError("Label argument must be a list")

    isym_tbl = fst.mutable_input_symbols()
    osym_tbl = fst.mutable_output_symbols()

    if olabels is None:
        sym_len = len(labels)
        isyms = [isym_tbl.add_symbol(lab) for lab in labels]
        osyms = [osym_tbl.add_symbol(lab) for lab in labels]
    else:
        sym_len = max(len(labels), len(olabels))
        isyms = [isym_tbl.add_symbol(lab)
                 for lab in labels] + [0] * (sym_len - len(labels))
        osyms = [osym_tbl.add_symbol(lab)
                 for lab in olabels] + [0] * (sym_len - len(olabels))

    q0 = start_state
    for i in range(sym_len):
        if is_loop and i == sym_len - 1:
            q1 = start_state
        else:
            q1 = fst.add_state()
        fst.add_arc(
            q0,
            pywrapfst.Arc(isyms[i], osyms[i],
                          pywrapfst.Weight.One(fst.weight_type()), q1))
        q0 = q1
    return q1
示例#25
0
def addArcFlower(fst, q0, q1, ilabels, olabels=None, weight=None):
    '''
    Adding
    q0 arc(s) origin
    q1 arc(s) target
    '''
    if not isinstance(ilabels, list):
        raise ValueError("input label argument must be a list")
    if olabels is None:
        olabels = ilabels
    if not isinstance(olabels, list):
        raise ValueError("Output label argument must be a list")

    isym_tbl = fst.mutable_input_symbols()
    osym_tbl = fst.mutable_output_symbols()

    for label in ilabels + olabels:
        isym = isym_tbl.add_symbol(label)
        osym = osym_tbl.add_symbol(label)

    if weight is None:
        weight = pywrapfst.Weight.One(fst.weight_type())

    for i in range(len(ilabels)):
        isym = isym_tbl.add_symbol(ilabels[i])
        osym = osym_tbl.add_symbol(olabels[i])
        fst.add_arc(q0, pywrapfst.Arc(isym, osym, weight, q1))
示例#26
0
def string_to_fsa(input_string, sym):
    '''build an FSA for a given input string using the symbol table, sym'''

    # first make sure all chars can be converted
    input_list = list(input_string)
    for i in input_list:
        if sym.find(i) == -1:
            raise ValueError('Input character not found')

    # build the FSA

    f = pywrapfst.VectorFst()
    one = pywrapfst.Weight.one(f.weight_type())
    f.set_input_symbols(sym)
    f.set_output_symbols(sym)
    s = f.add_state()
    f.set_start(s)
    for i in input_list:
        n = f.add_state()
        f.add_arc(s, pywrapfst.Arc(sym.find(i), sym.find(i), one, n))
        s = n
    f.set_final(n, 1)

    # verify
    if not f.verify():
        raise ValueError('FSA failed to verify')
    return (f)
示例#27
0
    def toFst(self):
        """Convert the HMM graph to an OpenFst object.

        You need to have installed the OpenFst python extension to use
        this method.

        Returns
        -------
        graph : pywrapfst.Fst
            The FST representation of the HMM graph. An super initial
            state and a super final state will be added though they are
            not present in the HMM.

        """

        import pywrapfst as fst

        f = fst.Fst('log')

        start_state = f.add_state()
        f.set_start(start_state)
        end_state = f.add_state()
        f.set_final(end_state)

        state_fstid = {}
        for state in self.states:
            fstid = f.add_state()
            state_fstid[state.state_id] = fstid

        for state in self.states:
            for next_state_id, weight in state.next_states.items():
                fstid = state_fstid[state.state_id]
                next_fstid = state_fstid[next_state_id]
                arc = fst.Arc(0, 0, fst.Weight('log', -weight), next_fstid)
                f.add_arc(fstid, arc)

        for state in self.init_states:
            fstid = state_fstid[state.state_id]
            arc = fst.Arc(0, 0, fst.Weight.One('log'), fstid)
            f.add_arc(start_state, arc)

        for state in self.final_states:
            fstid = state_fstid[state.state_id]
            arc = fst.Arc(0, 0, fst.Weight.One('log'), end_state)
            f.add_arc(fstid, arc)

        return f
示例#28
0
    def makeState(i):
        state = fst.add_state()

        initial_weight = openfst.Weight(fst.weight_type(), init_weights[i])
        if initial_weight != zero:
            next_state_str = col_vocab[i]
            next_state_index = fst.output_symbols().find(next_state_str)
            arc = openfst.Arc(bos_index, next_state_index, initial_weight,
                              state)
            fst.add_arc(fst.start(), arc)

        final_weight = openfst.Weight(fst.weight_type(), final_weights[i])
        if final_weight != zero:
            arc = openfst.Arc(eos_index, eps_index, final_weight, final_state)
            fst.add_arc(state, arc)

        return state
示例#29
0
def _make_slot_fst(state: int, intent_fst: fst.Fst,
                   slot_to_fst: Dict[str, fst.Fst]):
    out_symbols = intent_fst.output_symbols()
    one_weight = fst.Weight.One(intent_fst.weight_type())

    for arc in intent_fst.arcs(state):
        label = out_symbols.find(arc.olabel).decode()
        if label.startswith("__begin__"):
            slot_name = label[9:]

            # Big assumption here that each instance of a slot (e.g., location)
            # will produce the same FST, and therefore doesn't need to be
            # processed again.
            if slot_name in slot_to_fst:
                continue  # skip duplicate slots

            end_label = f"__end__{slot_name}"

            # Create new FST
            slot_fst = fst.Fst()
            slot_fst.set_input_symbols(intent_fst.input_symbols())
            slot_fst.set_output_symbols(intent_fst.output_symbols())

            start_state = slot_fst.add_state()
            slot_fst.set_start(start_state)
            q = [arc.nextstate]
            state_map = {arc.nextstate: start_state}

            # Copy states/arcs from intent FST until __end__ is found
            while len(q) > 0:
                q_state = q.pop()
                for q_arc in intent_fst.arcs(q_state):
                    slot_arc_label = out_symbols.find(q_arc.olabel).decode()
                    if slot_arc_label != end_label:
                        if not q_arc.nextstate in state_map:
                            state_map[q_arc.nextstate] = slot_fst.add_state()

                        # Create arc
                        slot_fst.add_arc(
                            state_map[q_state],
                            fst.Arc(
                                q_arc.ilabel,
                                q_arc.olabel,
                                one_weight,
                                state_map[q_arc.nextstate],
                            ),
                        )

                        # Continue copy
                        q.append(q_arc.nextstate)
                    else:
                        # Mark previous state as final
                        slot_fst.set_final(state_map[q_state])

            slot_to_fst[slot_name] = minimize_fst(slot_fst)

        # Recurse
        _make_slot_fst(arc.nextstate, intent_fst, slot_to_fst)
示例#30
0
    def enterTagBody(self, ctx):
        # Get the original text *with* whitespace from ANTLR
        input_stream = ctx.start.getInputStream()
        start = ctx.start.start
        stop = ctx.stop.stop
        tag_text = input_stream.getText(start, stop)

        # Patch start of tag
        anchor_state = self.exp_states[self.group_depth]
        next_state = self.fst.add_state()

        # --[__begin__TAG]-->
        begin_symbol = "__begin__" + tag_text
        input_idx = self.input_symbols.add_symbol(begin_symbol)
        output_idx = self.output_symbols.add_symbol(begin_symbol)

        self.tag_input_symbols.add(input_idx)

        # Move outgoing anchor arcs
        for arc in self.fst.arcs(anchor_state):
            self.fst.add_arc(
                next_state,
                fst.Arc(arc.ilabel, arc.olabel, arc.weight, arc.nextstate))

        # Patch anchor
        self.fst.delete_arcs(anchor_state)
        self.fst.add_arc(
            anchor_state,
            fst.Arc(input_idx, output_idx, self.weight_one, next_state))

        # Patch end of tag
        last_state = self.last_states[self.rule_name]
        next_state = self.fst.add_state()

        # --[__end__TAG]-->
        end_symbol = "__end__" + tag_text
        input_idx = self.input_symbols.add_symbol(end_symbol)
        output_idx = self.output_symbols.add_symbol(end_symbol)

        self.tag_input_symbols.add(input_idx)

        self.fst.add_arc(
            last_state,
            fst.Arc(input_idx, output_idx, self.weight_one, next_state))
        self.last_states[self.rule_name] = next_state