예제 #1
0
파일: oclm_c.py 프로젝트: shiranD/oclm
    def __weighted_union(self, left, right, left_prob, right_prob):
        '''
        Union the FSTs left, right with a weight.
        '''
        # left hand side part.
        left_w = -math.log(left_prob)
        lhs = fst.Fst()
        lhs.set_input_symbols(left.input_symbols())
        lhs.set_output_symbols(left.output_symbols())
        lhs.add_state()
        lhs.set_start(0)
        lhs.add_state()
        lhs.add_arc(0, fst.Arc(0, 0, left_w, 1))
        lhs.set_final(1)
        lhs.concat(left)

        # prefix part.
        right_w = -math.log(right_prob)
        rhs = fst.Fst()
        rhs.set_input_symbols(right.input_symbols())
        rhs.set_output_symbols(right.output_symbols())
        rhs.add_state()
        rhs.set_start(0)
        rhs.add_state()
        rhs.add_arc(0, fst.Arc(0, 0, right_w, 1))
        rhs.set_final(1)
        rhs.concat(right)

        lhs.union(rhs)
        return lhs
예제 #2
0
def spellout_machine(wrdfname, ltr2wrdfst):

    lm = fst.Fst.read(ltr2wrdfst)
    s_in = lm.output_symbols()
    s_out = lm.input_symbols()

    letter = fst.Fst()
    letter.set_input_symbols(s_in)
    letter.set_output_symbols(s_out)
    letter.add_state()

    for word in open(wrdfname, "r").readlines():
        word = word.strip()
        orig = copy.copy(word)
        #        word = list(word)
        word += "#"
        #word = dig2word(word)
        nletter = fst.Fst()
        nletter.set_input_symbols(s_in)
        nletter.set_output_symbols(s_out)
        nletter.add_state()
        for i, ltr in enumerate(word):
            nletter.add_state()
            code2 = s_out.find(ltr)
            if i == 0:
                nletter.set_start(0)
                code1 = s_in.find(orig)
                nletter.add_arc(i, fst.Arc(code1, code2, None, i + 1))
            else:
                code1 = s_in.find("<epsilon>")
                nletter.add_arc(i, fst.Arc(code1, code2, None, i + 1))
        nletter.set_final(i + 1)
        letter.union(nletter)
    letter.rmepsilon()
    letter.write("spellout.fst")
예제 #3
0
def make_lexicon_fst(input_words, words):
    compiler = fst.Compiler()
    lexicon_fst = fst.Fst()
    start = lexicon_fst.add_state()
    lexicon_fst.set_start(start)

    last = lexicon_fst.add_state()

    # this line projects space to epsilon
    lexicon_fst.add_arc(
        start, fst.Arc(32, 0, fst.Weight.One(lexicon_fst.weight_type()), last))
    lexicon_fst.set_final(last)
    for i, w in enumerate(words):
        w = w.strip()
        index = i + 1
        last = lexicon_fst.add_state()
        lexicon_fst.add_arc(
            start,
            fst.Arc(ord(w[0]), index,
                    fst.Weight.One(lexicon_fst.weight_type()), last))
        for c in w[1:]:
            this = lexicon_fst.add_state()
            lexicon_fst.add_arc(
                last,
                fst.Arc(ord(c), 0, fst.Weight.One(lexicon_fst.weight_type()),
                        this))
            last = this
        lexicon_fst.set_final(last, 0)
    lexicon_fst = fst.determinize(lexicon_fst).minimize().closure()

    with open('words.syms', 'w') as f:
        f.write('<eps> 0\n')
        for i, w in enumerate(words + input_words):  # we put word symbol here
            f.write('{} {}'.format(w, str(i + 1)))
            f.write('\n')
        f.write('<SPACE> {}\n'.format(str(32)))

    epsilon_fst = fst.Fst()
    start = epsilon_fst.add_state()
    end = epsilon_fst.add_state()
    for i, w in enumerate(words):
        index = i + 1
        epsilon_fst.add_arc(
            start,
            fst.Arc(0, index, fst.Weight.One(epsilon_fst.weight_type()), end))

    epsilon_fst.add_arc(
        start, fst.Arc(0, 32, fst.Weight.One(epsilon_fst.weight_type()), end))
    epsilon_fst.set_final(end, 0)
    epsilon_fst.set_start(start)
    epsilon_fst = epsilon_fst.closure()

    return lexicon_fst, epsilon_fst
예제 #4
0
    def clear(self):
        """
        Clears all internal data.
        """

        self.syms = fst.SymbolTable()
        self.E = fst.Fst()
        self.Ig = fst.Fst()
        self.Ip = fst.Fst()
        self.Ip_r = re.compile(u"")

        self.status = 0
예제 #5
0
 def decompound(self, word):
     tree = Tree(word)
     self._split(word, tree)
     #print(tree)
     nleafnodes = tree.nleafnodes()
     #print("Number of leaf nodes:", nleafnodes)
     symtablel = sorted(tree.getsyms())
     symtable = dict([(s, i) for i, s in enumerate(symtablel)])
     #print("Symbols:")
     #print(symtable)
     fst = wfst.Fst()
     [fst.add_state() for i in range(nleafnodes + 1)]
     fst.set_final(nleafnodes, wfst.Weight.One(fst.weight_type()))
     fst.set_start(0)
     tree.makelattice(fst, 0, symtable, self.wordcost, firstword=True)
     #output fst for debugging
     # fstsymtable = wfst.SymbolTable(b"default")
     # for i, sym in enumerate(symtablel):
     #     fstsymtable.add_symbol(sym.encode("utf-8"), i)
     # fst.set_input_symbols(fstsymtable)
     # fst.set_output_symbols(fstsymtable)
     # fst.write("/tmp/debug.fst")
     best = wfst.shortestpath(fst, nshortest=1)
     wordseq = label_seq(best, symtablel)
     return wordseq
예제 #6
0
파일: oclm_c.py 프로젝트: shiranD/oclm
 def normalize(self, anfst):
     '''
     produce a normalized fst
     '''
     # possibly there's a shorter way
     # that keeps all in fst land
     dist = []
     labels = []
     syms = anfst.input_symbols()
     state = anfst.start()
     for arc in anfst.arcs(state):
         label = syms.find(arc.ilabel)
         pr = float(arc.weight)
         dist.append(BitWeight(pr)) # ebitweight gets -log(pr) only
         labels.append(label)
     sum_value = sum(dist, BitWeight(1e6)) # will sum in log domain (log-add)
     norm_dist = [(prob/sum_value).loge() for prob in dist]
     del anfst
     # construct a norm fst
     output = fst.Fst()
     output.set_input_symbols(syms)
     output.set_output_symbols(syms)
     output.add_state()
     output.add_state()
     for (pr, label) in zip(norm_dist,labels):
         code = syms.find(label)
         output.add_arc(0, fst.Arc(code, code, pr, 1))
     output.set_start(0)
     output.set_final(1)
     return output
예제 #7
0
    def enterRuleBody(self, ctx):
        super().enterRuleBody(ctx)

        # Create new FST for rule
        self.fst = fst.Fst()
        self.start_state = self.fst.add_state()
        self.fst.set_start(self.start_state)
        self.last_states[self.rule_name] = self.start_state
        self.weight_one = fst.Weight.One(self.fst.weight_type())

        if self.is_public:
            # Check if this is the main rule of the grammar
            grammar_rule = self.grammar_name + "." + self.grammar_name
            if self.rule_name == grammar_rule:
                self.grammar_fst = self.fst

        # Cache FST
        self.fsts[self.rule_name] = self.fst

        # Reset state
        self.group_depth = 0
        self.opt_states = {}
        self.alt_states = {}
        self.tag_states = {}
        self.exp_states = {}
        self.alt_ends = {}

        # Save anchor state
        self.alt_states[self.group_depth] = self.last_states[self.rule_name]
예제 #8
0
def build_chain_fst(labels, arc_type='log', vocab=None):
    """
    Build an acceptor for string given by elements of labels.

    Args:
        labels - a sequence of labels in the range 1..S
        arc_type - fst arc type (standard or log)
    Returns:
        FST consuming symbols in the range 1..S.

    Notes:
        Elements of labels are assumed to be greater than zero
        (which maps to blank)!
    """
    C = fst.Fst(arc_type=arc_type)
    weight_one = fst.Weight.One(C.weight_type())
    s = C.add_state()
    C.set_start(s)
    for l in labels:
        s_next = C.add_state()
        C.add_arc(s, fst.Arc(l, l, weight_one, s_next))
        s = s_next
    C.set_final(s)
    C.arcsort('ilabel')
    return C
예제 #9
0
def _build_transliterator():
    td = fst.Fst()
    initial_state = td.add_state()
    td.set_start(initial_state)
    td.set_final(initial_state)
    long_vowel_possibilities = {
        'u': 'ウ',
        'i': 'イ',
        'e': 'エ',
        'o': ['オ', 'ウ'],
        'a': 'ア'
    }
    long_vowel_states = {
        k: _long_vowel_mark_state(td, initial_state, v)
        for k, v in long_vowel_possibilities.items()
    }
    end_states = {
        'n': initial_state,
        'y': _build_small_y_state(td, long_vowel_states),
        **long_vowel_states
    }
    _build_sjsh(td, initial_state, end_states)
    _build_vowels(td, initial_state, end_states)
    _build_tdch(td, initial_state, end_states)
    _build_big_y(td, initial_state, end_states)
    _build_hpb(td, initial_state, end_states)
    _build_kg(td, initial_state, end_states)
    _build_r(td, initial_state, end_states)
    _build_m(td, initial_state, end_states)
    _build_n(td, initial_state, end_states)
    _build_w(td, initial_state, end_states)
    return td
예제 #10
0
 def __call__(self, x):
     x, xs = transform_output(x)
     # Normalize log-posterior matrices, if necessary
     if self._normalize:
         x = log_softmax(x, dim=2)
     x = x.permute(1, 0, 2).cpu()
     self._output = []
     D = x.size(2)
     for logpost, length in zip(x, xs):
         f = fst.Fst()
         f.set_start(f.add_state())
         for t in range(length):
             f.add_state()
             for j in range(D):
                 weight = fst.Weight(f.weight_type(), float(-logpost[t, j]))
                 f.add_arc(
                     t,
                     fst.Arc(
                         j + 1,  # input label
                         j + 1,  # output label
                         weight,  # -logpost[t, j]
                         t + 1,  # nextstate
                     ),
                 )
         f.set_final(length, fst.Weight.One(f.weight_type()))
         f.verify()
         self._output.append(f)
     return self._output
예제 #11
0
    def enterRuleBody(self, ctx):
        self.in_rule = True

        if self.is_public:
            # Use main start state
            self.last_states[self.rule_name] = self.start_state
        else:
            # Create new FST
            self.fst = fst.Fst()
            self.start_state = self.fst.add_state()
            self.fst.set_start(self.start_state)
            self.last_states[self.rule_name] = self.start_state

        self.fsts[self.rule_name] = self.fst

        # Reset
        self.group_depth = 0
        self.opt_states = {}
        self.alt_states = {}
        self.tag_states = {}
        self.exp_states = {}
        self.alt_ends = {}

        # Save anchor state
        self.alt_states[self.group_depth] = self.last_states[self.rule_name]
예제 #12
0
def genStrFst(ilabels, olabels=None):
    if olabels is None:
        olabels = ilabels
    fst = pywrapfst.Fst()
    initFst(fst)
    addArcLinear(fst, 0, ilabels, olabels, is_loop=False)
    return fst
예제 #13
0
def _make_slot_fst(state: int, intent_fst: fst.Fst,
                   slot_to_fst: Dict[str, fst.Fst]):
    out_symbols = intent_fst.output_symbols()
    one_weight = fst.Weight.One(intent_fst.weight_type())

    for arc in intent_fst.arcs(state):
        label = out_symbols.find(arc.olabel).decode()
        if label.startswith("__begin__"):
            slot_name = label[9:]

            # Big assumption here that each instance of a slot (e.g., location)
            # will produce the same FST, and therefore doesn't need to be
            # processed again.
            if slot_name in slot_to_fst:
                continue  # skip duplicate slots

            end_label = f"__end__{slot_name}"

            # Create new FST
            slot_fst = fst.Fst()
            slot_fst.set_input_symbols(intent_fst.input_symbols())
            slot_fst.set_output_symbols(intent_fst.output_symbols())

            start_state = slot_fst.add_state()
            slot_fst.set_start(start_state)
            q = [arc.nextstate]
            state_map = {arc.nextstate: start_state}

            # Copy states/arcs from intent FST until __end__ is found
            while len(q) > 0:
                q_state = q.pop()
                for q_arc in intent_fst.arcs(q_state):
                    slot_arc_label = out_symbols.find(q_arc.olabel).decode()
                    if slot_arc_label != end_label:
                        if not q_arc.nextstate in state_map:
                            state_map[q_arc.nextstate] = slot_fst.add_state()

                        # Create arc
                        slot_fst.add_arc(
                            state_map[q_state],
                            fst.Arc(
                                q_arc.ilabel,
                                q_arc.olabel,
                                one_weight,
                                state_map[q_arc.nextstate],
                            ),
                        )

                        # Continue copy
                        q.append(q_arc.nextstate)
                    else:
                        # Mark previous state as final
                        slot_fst.set_final(state_map[q_state])

            slot_to_fst[slot_name] = minimize_fst(slot_fst)

        # Recurse
        _make_slot_fst(arc.nextstate, intent_fst, slot_to_fst)
예제 #14
0
def acceptor_for_strings(strings: List[str], weights: List[float]) -> fst.Fst:
    """Create an acceptor for strings with weights"""
    strings, weights = zip(*sorted(zip(strings, weights)))
    td = fst.Fst()
    start_state = td.add_state()
    td.set_start(start_state)
    _build_acceptor_recursive(td, strings, weights, start_state, 0, 0,
                              len(strings))
    return td
예제 #15
0
파일: oclm_c.py 프로젝트: shiranD/oclm
 def create_empty_fst(self, input_sym, output_sym):
     '''
     Create an empty fst (only one state being final).
     '''
     f = fst.Fst()
     f.set_input_symbols(input_sym)
     f.set_output_symbols(output_sym)
     f.add_state()
     f.set_start(0)
     f.set_final(0)
     return f
예제 #16
0
def make_input(chars, stoi):
    fst = wfst.Fst()
    s0 = fst.add_state()
    fst.set_start(s0)
    cs = s0
    for c in chars:
        ns = fst.add_state()
        fst.add_arc(cs, wfst.Arc(stoi[c], stoi[c], wfst.Weight.One(fst.weight_type()), ns))
        cs = ns
    fst.set_final(cs, wfst.Weight.One(fst.weight_type()))
    return fst
예제 #17
0
def make_kleeneplus(s, graphs, stoi):
    """one-or-more-graphs"""
    fst = wfst.Fst()
    start = fst.add_state()
    end = fst.add_state()
    fst.set_start(start)
    fst.set_final(end, wfst.Weight.One(fst.weight_type()))
    for g in graphs:
        fst.add_arc(start, wfst.Arc(stoi[g], stoi[g], wfst.Weight.One(fst.weight_type()), end))
        fst.add_arc(end, wfst.Arc(stoi[g], stoi[g], wfst.Weight.One(fst.weight_type()), end))
    return fst
예제 #18
0
    def __init__(self):
        '''
        Constructor
        '''
        self.fst = pywrapfst.Fst()
        self.STATE_START = self.fst.add_state()
        self.fst.set_start(self.STATE_START)

        self.fst.set_input_symbols(pywrapfst.SymbolTable())
        self.fst.set_output_symbols(pywrapfst.SymbolTable())
        self.fst.mutable_input_symbols().add_symbol(LAB_EPS, key=SYM_EPS)
        self.fst.mutable_output_symbols().add_symbol(LAB_EPS, key=SYM_EPS)
예제 #19
0
def make_compounder(syms, word_ids):
    c = fst.Fst()
    start_state = c.add_state()
    assert (start_state == 0)
    c.set_start(start_state)
    space_id = syms["<space>"]
    c.add_arc(0, fst.Arc(space_id, syms["<eps>"], 1, 0))
    c.add_arc(0, fst.Arc(space_id, syms["+C+"], 1, 0))
    c.add_arc(0, fst.Arc(space_id, syms["+D+"], 1, 0))
    for word_id in word_ids:
        c.add_arc(0, fst.Arc(word_id, word_id, 1, 0))
    c.set_final(0, 1)
    return c
예제 #20
0
파일: oclm_c.py 프로젝트: shiranD/oclm
 def append_eeg_evidence(self, ch_dist):
     new_ch = fst.Fst()
     new_ch.set_input_symbols(self.ch_syms)
     new_ch.set_output_symbols(self.ch_syms)
     new_ch.add_state()
     new_ch.set_start(0)
     new_ch.add_state()
     new_ch.set_final(1)
     for ch, pr in ch_dist:
         code = self.ch_syms.find(ch)
         new_ch.add_arc(0, fst.Arc(code, code, pr, 1))
     new_ch.arcsort(sort_type="olabel")
     self.history_fst.concat(new_ch).rmepsilon()
예제 #21
0
def get_trivial_fst(word_index):
    trivial_word_fst = fst.Fst()
    start = trivial_word_fst.add_state()
    end = trivial_word_fst.add_state()

    trivial_word_fst.set_start(start)
    trivial_word_fst.set_final(end, 0)

    trivial_word_fst.add_arc(
        start,
        fst.Arc(word_index, 0, fst.Weight.One(trivial_word_fst.weight_type()),
                end))
    return trivial_word_fst
예제 #22
0
파일: oclm_c.py 프로젝트: shiranD/oclm
    def update(self, ch_dist):
        '''
        Update the history with the new likelihood array in the correct scale
        (nagative log space) to the history.
        '''
        new_ch = fst.Fst()
        new_ch.set_input_symbols(self.ch_syms)
        new_ch.set_output_symbols(self.ch_syms)
        new_ch.add_state()
        new_ch.set_start(0)
        new_ch.add_state()
        new_ch.set_final(1)
        space_code = -1
        space_pr = 0.
        for ch, pr in ch_dist:
            code = self.ch_syms.find(ch)
            if ch == '#':  # Adds space after we finish updating trailing chars.
                space_code = code
                space_pr = pr
                continue
            new_ch.add_arc(0, fst.Arc(code, code, pr, 1))
        new_ch.arcsort(sort_type="olabel")

        # Adds the trailing characters to existing binned history.
        for words_bin in self.prefix_words:
            if words_bin[2] >= 10:  # We discard the whole trail in this case (TODO)
                continue
            # Unless we are testing a straight line machine, this normally
            # doesn't happen in practice.
            if new_ch.num_arcs(0) == 0:
                continue
            words_bin[1].concat(new_ch).rmepsilon()
            words_bin[2] += 1

        # Continues updating the history and adds back the space if necessary.
        if space_code >= 0:
            new_ch.add_arc(0, fst.Arc(space_code, space_code, space_pr, 1))
        self.history_fst.concat(new_ch).rmepsilon()

        # Respectively update the binned history
        if space_code >= 0:  # If there is a space
            # Finishes the prefix words in current position
            word_lattice = fst.compose(self.history_fst, self.ltr2wrd)
            word_lattice.project(project_output=True).rmepsilon()
            word_lattice = fst.determinize(word_lattice)
            word_lattice.minimize()
            if word_lattice.num_states() == 0:
                word_lattice = self.create_empty_fst(self.wd_syms, self.wd_syms)
            trailing_chars = self.create_empty_fst(self.ch_syms, self.ch_syms)
            self.prefix_words.append([word_lattice, trailing_chars, 0])
예제 #23
0
def make_termfst(s, paths, stoi):
    fst = wfst.Fst()
    start = fst.add_state()
    fst.set_start(start)
    for path in paths:
        a = start
        for g in path:
            b = fst.add_state()
            fst.add_arc(a, wfst.Arc(stoi[g], stoi[g], wfst.Weight.One(fst.weight_type()), b))
            a = b
        fst.set_final(b, wfst.Weight.One(fst.weight_type()))
    fst = wfst.determinize(fst)
    fst.minimize()
    return fst
예제 #24
0
def make_intent_fst(grammar_fsts: Dict[str, fst.Fst],
                    eps: str = "<eps>") -> fst.Fst:
    """Merges grammar FSTs created with grammar_to_fsts into a single acceptor FST."""
    input_symbols = fst.SymbolTable()
    output_symbols = fst.SymbolTable()

    in_eps: int = input_symbols.add_symbol(eps)
    out_eps: int = output_symbols.add_symbol(eps)

    intent_fst = fst.Fst()
    weight_one = fst.Weight.One(intent_fst.weight_type())

    # Create start/final states
    start_state = intent_fst.add_state()
    intent_fst.set_start(start_state)

    final_state = intent_fst.add_state()
    intent_fst.set_final(final_state)

    replacements: Dict[int, fst.Fst] = {}

    for intent, grammar_fst in grammar_fsts.items():
        intent_label = f"__label__{intent}"
        out_label = output_symbols.add_symbol(intent_label)

        # --[__label__INTENT]-->
        intent_start = intent_fst.add_state()
        intent_fst.add_arc(
            start_state, fst.Arc(in_eps, out_label, weight_one, intent_start))

        # --[__replace__INTENT]-->
        intent_end = intent_fst.add_state()
        replace_symbol = f"__replace__{intent}"
        out_replace = output_symbols.add_symbol(replace_symbol)
        intent_fst.add_arc(
            intent_start, fst.Arc(in_eps, out_replace, weight_one, intent_end))

        # --[eps]-->
        intent_fst.add_arc(intent_end,
                           fst.Arc(in_eps, out_eps, weight_one, final_state))

        replacements[out_replace] = grammar_fst

    # Fix symbol tables
    intent_fst.set_input_symbols(input_symbols)
    intent_fst.set_output_symbols(output_symbols)

    # Do replacements

    return _replace_fsts(intent_fst, replacements, eps=eps)
예제 #25
0
def longest_path(the_fst: fst.Fst, eps: str = "<eps>") -> fst.Fst:
    output_symbols = the_fst.output_symbols()
    out_eps = output_symbols.find(eps)
    visited_states: Set[int] = set()
    best_path: List[int] = []
    state_queue: Deque[Tuple[int, List[int]]] = deque()
    state_queue.append((the_fst.start(), []))

    # Determine longest path
    while len(state_queue) > 0:
        state, path = state_queue.popleft()
        if state in visited_states:
            continue

        visited_states.add(state)

        if len(path) > len(best_path):
            best_path = path

        for arc in the_fst.arcs(state):
            next_path = list(path)
            next_path.append(arc.olabel)
            state_queue.append((arc.nextstate, next_path))

    # Create FST with longest path
    path_fst = fst.Fst()

    input_symbols = fst.SymbolTable()
    input_symbols.add_symbol(eps)
    path_fst.set_output_symbols(output_symbols)
    weight_one = fst.Weight.One(path_fst.weight_type())

    state = path_fst.add_state()
    path_fst.set_start(state)

    for olabel in best_path:
        osym = output_symbols.find(olabel).decode()
        next_state = path_fst.add_state()
        path_fst.add_arc(
            state,
            fst.Arc(input_symbols.add_symbol(osym), olabel, weight_one,
                    next_state),
        )
        state = next_state

    path_fst.set_final(state)
    path_fst.set_input_symbols(input_symbols)

    return path_fst
예제 #26
0
def _make_input_fst(string):
    # Input is passed as acceptors that accept a single string
    td = fst.Fst()
    curr = td.add_state()
    td.set_start(curr)
    for c in string:
        nxt = td.add_state()
        try:
            _char_arc(td, curr, c, c, nxt)
        except KeyError:
            raise ValueError(
                'Character {} not in input symbol table'.format(c))
        curr = nxt
    td.set_final(curr)
    return td
예제 #27
0
def build_ctc_mono_decoding_fst(S, arc_type='log', add_syms=False):
    """
    Build a monophone CTC decoding fst.
    Args:
        S - number of monophones
        arc_type - log or standard. Gives the interpretation of the FST.
    Returns:
        an FST that accepts all sequences over [1,..,S]^* and returns
        shorter ones with duplicates and blanks removed.

        The input labels are shifted by one, so that there are no epsilon
        transitions.
        The output labels are not (blank is zero), allowing one to read out
        the label sequence easily.
    """
    CTC = fst.Fst(arc_type=arc_type)
    weight_one = fst.Weight.One(CTC.weight_type())

    for s in range(S):
        s1 = CTC.add_state()
        assert s == s1
        CTC.set_final(s1)
    CTC.set_start(0)

    for s in range(S):
        # transitions out of symbol s
        # self-loop, don't emit
        CTC.add_arc(s, fst.Arc(s + 1, 0, weight_one, s))
        for s_next in range(S):
            if s_next == s:
                continue
            # transition to next symbol
            CTC.add_arc(s, fst.Arc(s_next + 1, s_next, weight_one, s_next))
    CTC.arcsort('olabel')

    if add_syms:
        in_syms = fst.SymbolTable()
        in_syms.add_symbol('<eps>', 0)
        in_syms.add_symbol('B', 1)
        for s in range(1, S):
            in_syms.add_symbol(chr(ord('a') + s - 1), s + 1)
        out_syms = fst.SymbolTable()
        out_syms.add_symbol('<eps>', 0)
        for s in range(1, S):
            out_syms.add_symbol(chr(ord('a') + s - 1), s)
        CTC.set_input_symbols(in_syms)
        CTC.set_output_symbols(out_syms)
    return CTC
예제 #28
0
def make_intent_fst(grammar_fsts: Dict[str, fst.Fst], eps=0) -> fst.Fst:
    """Merges grammar FSTs created with jsgf2fst into a single acceptor FST."""
    intent_fst = fst.Fst()
    all_in_symbols = fst.SymbolTable()
    all_out_symbols = fst.SymbolTable()
    all_in_symbols.add_symbol("<eps>", eps)
    all_out_symbols.add_symbol("<eps>", eps)

    # Merge symbols from all FSTs
    for grammar_fst in grammar_fsts.values():
        in_symbols = grammar_fst.input_symbols()
        for i in range(in_symbols.num_symbols()):
            all_in_symbols.add_symbol(in_symbols.find(i).decode())

        out_symbols = grammar_fst.output_symbols()
        for i in range(out_symbols.num_symbols()):
            all_out_symbols.add_symbol(out_symbols.find(i).decode())

    # Add __label__ for each intent
    for intent_name in grammar_fsts.keys():
        all_out_symbols.add_symbol(f"__label__{intent_name}")

    intent_fst.set_input_symbols(all_in_symbols)
    intent_fst.set_output_symbols(all_out_symbols)

    # Create start/final states
    start_state = intent_fst.add_state()
    intent_fst.set_start(start_state)

    final_state = intent_fst.add_state()
    intent_fst.set_final(final_state)

    # Merge FSTs in
    for intent_name, grammar_fst in grammar_fsts.items():
        label_sym = all_out_symbols.find(f"__label__{intent_name}")
        replace_and_patch(intent_fst,
                          start_state,
                          final_state,
                          grammar_fst,
                          label_sym,
                          eps=eps)

    # BUG: Fst.minimize does not pass allow_nondet through, so we have to call out to the command-line
    minimize_cmd = ["fstminimize", "--allow_nondet"]
    return fst.Fst.read_from_string(
        subprocess.check_output(minimize_cmd,
                                input=intent_fst.write_to_string()))
예제 #29
0
파일: hmm_graph.py 프로젝트: tglarner/amdtk
    def toFst(self):
        """Convert the HMM graph to an OpenFst object.

        You need to have installed the OpenFst python extension to use
        this method.

        Returns
        -------
        graph : pywrapfst.Fst
            The FST representation of the HMM graph. An super initial
            state and a super final state will be added though they are
            not present in the HMM.

        """

        import pywrapfst as fst

        f = fst.Fst('log')

        start_state = f.add_state()
        f.set_start(start_state)
        end_state = f.add_state()
        f.set_final(end_state)

        state_fstid = {}
        for state in self.states:
            fstid = f.add_state()
            state_fstid[state.state_id] = fstid

        for state in self.states:
            for next_state_id, weight in state.next_states.items():
                fstid = state_fstid[state.state_id]
                next_fstid = state_fstid[next_state_id]
                arc = fst.Arc(0, 0, fst.Weight('log', -weight), next_fstid)
                f.add_arc(fstid, arc)

        for state in self.init_states:
            fstid = state_fstid[state.state_id]
            arc = fst.Arc(0, 0, fst.Weight.One('log'), fstid)
            f.add_arc(start_state, arc)

        for state in self.final_states:
            fstid = state_fstid[state.state_id]
            arc = fst.Arc(0, 0, fst.Weight.One('log'), end_state)
            f.add_arc(fstid, arc)

        return f
예제 #30
0
    def __align_fst(self, g, p):
        '''
        Creates an alignment of a grapheme and phoneme sequence pair encoded as fst.
        '''

        t3 = self.segment(g)
        t3.project(project_output=True)

        t4 = self.expand(p)
        t4.project(project_output=True)

        if t4.start() == -1 or t4.num_arcs(t4.start()) == 0:
            return fst.Fst()

        t5 = fst.compose(t3, self.E)
        t6 = fst.compose(t5, t4)

        return t6