示例#1
0
 def expand_rtn(self, func):
     """This method expands the RTN as far as necessary. This means
     that the RTN is expanded s.t. we can build the posterior for 
     ``cur_history``. In practice, this means that we follow all 
     epsilon edges and replaces all NT edges until all paths with 
     the prefix ``cur_history`` in the RTN have at least one more 
     terminal token. Then, we apply ``func`` to all reachable nodes.
     """
     updated = True
     while updated:
         updated = False
         label_fst_map = {}
         self.visited_nodes = {}
         self.cur_fst.arcsort(sort_type="olabel")
         self.add_to_label_fst_map_recursive(label_fst_map, {},
                                             self.cur_fst.start, 0.0,
                                             self.cur_history, func)
         if label_fst_map:
             logging.debug("Replace %d NT arcs for history %s" %
                           (len(label_fst_map), self.cur_history))
             # First in the list is the root FST and label
             replaced_fst = fst.replace(
                 [(len(label_fst_map) + 2000000000, self.cur_fst)] +
                 [(nt_label, f)
                  for (nt_label, f) in label_fst_map.iteritems()],
                 epsilon_on_replace=True)
             self.cur_fst = replaced_fst
             updated = True
     if self.rmeps or self.minimize_rtns:
         self.cur_fst.rmepsilon()
     if self.minimize_rtns:
         tmp = fst.determinize(self.cur_fst.determinize)
         self.cur_fst = tmp
         self.cur_fst.minimize()
示例#2
0
    def replace_fsts(rule_name):
        nonlocal replaced_fsts, slots_to_replace
        rule_fst = replaced_fsts.get(rule_name)
        if rule_fst is not None:
            return rule_fst

        listener = listeners[rule_name]

        rule_fst = rule_fsts[rule_name]
        for ref_name in listener.rule_references[rule_name]:
            ref_fst = replace_fsts(ref_name)

            # Replace rule in grammar FST
            replace_symbol = "__replace__" + ref_name
            replace_idx = input_symbols.find(replace_symbol)
            if replace_idx >= 0:
                logger.debug(f"Replacing rule {ref_name} in {rule_name}")
                rule_fst = fst.replace([(-1, rule_fst),
                                        (replace_idx, ref_fst)],
                                       epsilon_on_replace=True)

        replaced_fsts[rule_name] = rule_fst
        return rule_fst
示例#3
0
 def expand_rtn(self, func):
     """This method expands the RTN as far as necessary. This means
     that the RTN is expanded s.t. we can build the posterior for 
     ``cur_history``. In practice, this means that we follow all 
     epsilon edges and replaces all NT edges until all paths with 
     the prefix ``cur_history`` in the RTN have at least one more 
     terminal token. Then, we apply ``func`` to all reachable nodes.
     """
     updated = True
     while updated:
         updated = False
         label_fst_map = {}
         self.visited_nodes = {}
         self.cur_fst.arcsort(sort_type="olabel")
         self.add_to_label_fst_map_recursive(label_fst_map,
                                             {},
                                             self.cur_fst.start(), 
                                             0.0,
                                             self.cur_history, func)
         if label_fst_map:
             logging.debug("Replace %d NT arcs for history %s" % (
                                                         len(label_fst_map),
                                                         self.cur_history))
             # First in the list is the root FST and label
             replaced_fst = fst.replace(
                     [(len(label_fst_map) + 2000000000, self.cur_fst)] 
                     + [(nt_label, f) 
                         for (nt_label, f) in label_fst_map.iteritems()],
                     epsilon_on_replace=True)
             self.cur_fst = replaced_fst
             updated = True
     if self.rmeps or self.minimize_rtns:
         self.cur_fst.rmepsilon()
     if self.minimize_rtns:
         tmp = fst.determinize(self.cur_fst.determinize)
         self.cur_fst = tmp
         self.cur_fst.minimize()
示例#4
0
    def __init__(self, dcg, descr):
        #print("Morphparse_DCG.__init__()", file=sys.stderr)
        #print("dcg[nonterminals]: {}".format(pprint.pformat(dcg["nonterminals"])), file=sys.stderr)
        ###Make symbol tables
        othersyms = set()
        for pos in descr["renamesyms"]:
            othersyms.update([e[1] for e in descr["renamesyms"][pos]])
        self.bounds = descr["bounds"]
        self.itos, self.stoi = make_symmaps(dcg, descr["graphs"], othersyms)

        # #DEBUG DUMP SYMTABLES
        # with codecs.open("tmp/stoi.pickle", "w", encoding="utf-8") as outfh:
        #     pickle.dump(self.stoi, outfh)
        # with codecs.open("tmp/itos.pickle", "w", encoding="utf-8") as outfh:
        #     pickle.dump(self.itos, outfh)

        termfsts = make_termfsts(dcg, descr["graphs"], self.stoi)
        # #DEBUG DUMP FST
        # for k in termfsts:
        #     print("DEBUG dumping:", k, file=sys.stderr)
        #     save_dot(termfsts[k], self.stoi, "tmp/termfst_"+k+".dot")
        #     termfsts[k].write("tmp/termfst_"+k+".fst")
            
        self.fsts = {}
        ###Expand/make non-terminal FSTs for each POS category
        for pos in descr["pos"]:
            print("Making/expanding non-terminal fst for POS:", pos, file=sys.stderr)
            fstcoll = make_rtn(pos, dcg["nonterminals"], self.stoi, {})
            # print("__init__(): fstcoll: {}".format(fstcoll.keys()), file=sys.stderr)
            # for sym in fstcoll:
            #     #DEBUG DUMP FST
            #     save_dot(fstcoll[sym], self.stoi, "tmp/"+pos+"_orig_"+sym+".dot")
            #     fstcoll[sym].write("tmp/"+pos+"_orig_"+sym+".fst")

            #replace non-terminals
            replace_pairs = [(self.stoi[pos], fstcoll.pop(pos))]
            for k, v in fstcoll.iteritems():
                replace_pairs.append((self.stoi[k], v))
            fst = wfst.replace(replace_pairs, call_arc_labeling="both")
            fst.rmepsilon()
            fst = wfst.determinize(fst)
            fst.minimize()
            # #DEBUG DUMP FST
            # save_dot(fst, self.stoi, "tmp/"+pos+"_expanded.dot")
            # fst.write("tmp/"+pos+"_expanded.fst")
            # if True: #DEBUGGING
            #     fst2 = fst.copy()
            #     #rename symbols (simplify) 
            #     if pos in descr["renamesyms"] and descr["renamesyms"][pos]:
            #         labpairs = map(lambda x: (self.stoi[x[0]], self.stoi[x[1]]), descr["renamesyms"][pos])
            #         fst2.relabel_pairs(opairs=labpairs, ipairs=labpairs)
            #     fst2.rmepsilon()
            #     fst2 = wfst.determinize(fst2)
            #     fst2.minimize()            
            #     #DEBUG DUMP FST
            #     save_dot(fst2, self.stoi, "tmp/"+pos+"_expandedsimple.dot")
            #     fst2.write("tmp/"+pos+"_expandedsimple.fst")            

            #replace terminals
            replace_pairs = [(self.stoi[pos], fst)]
            for k, v in termfsts.iteritems():
                replace_pairs.append((self.stoi[k], v))
            fst = wfst.replace(replace_pairs, call_arc_labeling="both")
            fst.rmepsilon()
            fst = wfst.determinize(fst)
            fst.minimize()
            # #DEBUG DUMP FST
            # save_dot(fst, self.stoi, "tmp/"+pos+"_expanded2.dot")
            # fst.write("tmp/"+pos+"_expanded2.fst")

            #rename symbols (simplify) JUST FOR DEBUGGING
            if pos in descr["renamesyms"] and descr["renamesyms"][pos]:
                labpairs = map(lambda x: (self.stoi[x[0]], self.stoi[x[1]]), descr["renamesyms"][pos])
                fst.relabel_pairs(opairs=labpairs, ipairs=labpairs)
            fst.rmepsilon()
            fst = wfst.determinize(fst)
            fst.minimize()            
            # #DEBUG DUMP FST
            # save_dot(fst, self.stoi, "tmp/"+pos+"_prefinal.dot")
            # fst.write("tmp/"+pos+"_prefinal.fst")

            #Convert into transducer:
            #split I/O symbols by convention here: input symbols are single characters:
            #Input syms (relabel outputs to EPS):
            syms = [k for k in self.stoi if len(k) == 1]
            labpairs = map(lambda x: (self.stoi[x], self.stoi[EPS]), syms)
            fst.relabel_pairs(opairs=labpairs)
            #Output syms (relabel inputs to EPS):
            syms = [k for k in self.stoi if len(k) != 1]
            labpairs = map(lambda x: (self.stoi[x], self.stoi[EPS]), syms)
            fst.relabel_pairs(ipairs=labpairs)
            # #DEBUG DUMP FST
            # save_dot(fst, self.stoi, "tmp/"+pos+"_final.dot")
            # fst.write("tmp/"+pos+"_final.fst")
            self.fsts[pos] = fst
示例#5
0
def jsgf2fst(
    grammar_paths: Union[Path, List[Path]],
    slots: Dict[str, List[str]] = {},
    eps: str = "<eps>",
) -> Dict[str, fst.Fst]:
    """Converts JSGF grammars to FSTs.
    Returns dictionary mapping grammar names to FSTs."""

    is_list = isinstance(grammar_paths, collections.Iterable)
    if not is_list:
        grammar_paths = [grammar_paths]

    # grammar name -> fst
    grammar_fsts: Dict[str, fst.Fst] = {}

    # rule name -> fst
    rule_fsts: Dict[str, fst.Fst] = {}

    # rule name -> fst
    replaced_fsts: Dict[str, fst.Fst] = {}

    # grammar name -> listener
    listeners: Dict[str, FSTListener] = {}

    # Share symbol tables between all FSTs
    input_symbols = fst.SymbolTable()
    output_symbols = fst.SymbolTable()
    input_symbols.add_symbol(eps)
    output_symbols.add_symbol(eps)

    # Set of all input symbols that are __begin__ or __end__
    tag_input_symbols: Set[int] = set()

    # Set of all slot names that were used
    slots_to_replace: Set[str] = set()

    # Process each grammar
    for grammar_path in grammar_paths:
        logger.debug(f"Processing {grammar_path}")

        with open(grammar_path, "r") as grammar_file:
            # Tokenize
            input_stream = antlr4.InputStream(grammar_file.read())
            lexer = JsgfLexer(input_stream)
            tokens = antlr4.CommonTokenStream(lexer)

            # Parse
            parser = JsgfParser(tokens)

            # Transform to FST
            context = parser.r()
            walker = antlr4.ParseTreeWalker()

            # Create FST and symbol tables
            grammar_fst = fst.Fst()

            start = grammar_fst.add_state()
            grammar_fst.set_start(start)

            listener = FSTListener(grammar_fst, input_symbols, output_symbols,
                                   start)
            walker.walk(listener, context)

            # Merge with set of all tag input symbols
            tag_input_symbols.update(listener.tag_input_symbols)

            # Merge with set of all used slots
            slots_to_replace.update(listener.slot_references)

            # Save FSTs for all rules
            for rule_name, rule_fst in listener.fsts.items():
                rule_fsts[rule_name] = rule_fst
                listeners[rule_name] = listener

                # Record FSTs that have no rule references
                if len(listener.rule_references[rule_name]) == 0:
                    replaced_fsts[rule_name] = rule_fst

            # Save for later
            grammar_fsts[listener.grammar_name] = grammar_fst

    # -------------------------------------------------------------------------

    # grammar name -> (slot names)
    def replace_fsts(rule_name):
        nonlocal replaced_fsts, slots_to_replace
        rule_fst = replaced_fsts.get(rule_name)
        if rule_fst is not None:
            return rule_fst

        listener = listeners[rule_name]

        rule_fst = rule_fsts[rule_name]
        for ref_name in listener.rule_references[rule_name]:
            ref_fst = replace_fsts(ref_name)

            # Replace rule in grammar FST
            replace_symbol = "__replace__" + ref_name
            replace_idx = input_symbols.find(replace_symbol)
            if replace_idx >= 0:
                logger.debug(f"Replacing rule {ref_name} in {rule_name}")
                rule_fst = fst.replace([(-1, rule_fst),
                                        (replace_idx, ref_fst)],
                                       epsilon_on_replace=True)

        replaced_fsts[rule_name] = rule_fst
        return rule_fst

    # Do rule replacements
    for grammar_name in list(grammar_fsts.keys()):
        main_rule_name = grammar_name + "." + grammar_name
        grammar_fsts[grammar_name] = replace_fsts(main_rule_name)

    # -------------------------------------------------------------------------

    # Do slot replacements
    slot_fsts: Dict[str, fst.Fst] = {}
    for grammar_name, grammar_fst in grammar_fsts.items():
        main_rule_name = grammar_name + "." + grammar_name
        listener = listeners[main_rule_name]

        for slot_name in slots_to_replace:
            if slot_name not in slot_fsts:
                # Create FST for slot values
                logger.debug(f"Creating FST for slot {slot_name}")

                slot_fst = fst.Fst()
                start = slot_fst.add_state()
                slot_fst.set_start(start)

                # Create a single slot grammar
                with io.StringIO() as text_file:
                    print("#JSGF v1.0;", file=text_file)
                    print(f"grammar {slot_name};", file=text_file)
                    print("", file=text_file)

                    choices = " | ".join([
                        "(" + v + ")"
                        for v in itertools.chain(slots.get_text(slot_name),
                                                 slots.get_jsgf(slot_name))
                    ])

                    # All slot values
                    print(f"public <{slot_name}> = ({choices});",
                          file=text_file)
                    text_file.seek(0)

                    # Tokenize
                    input_stream = antlr4.InputStream(text_file.getvalue())
                    lexer = JsgfLexer(input_stream)
                    tokens = antlr4.CommonTokenStream(lexer)

                    # Parse
                    parser = JsgfParser(tokens)

                    # Transform to FST
                    context = parser.r()
                    walker = antlr4.ParseTreeWalker()

                    # Fill in slot_fst
                    slot_listener = FSTListener(slot_fst, input_symbols,
                                                output_symbols, start)
                    walker.walk(slot_listener, context)

                # Cache for other grammars
                slot_fsts[slot_name] = slot_fst

            # -----------------------------------------------------------------

            # Replace slot in grammar FST
            replace_symbol = "__replace__$" + slot_name
            replace_idx = input_symbols.find(replace_symbol)
            if replace_idx >= 0:
                logger.debug(f"Replacing slot {slot_name} in {main_rule_name}")
                grammar_fst = fst.replace(
                    [(-1, grammar_fst), (replace_idx, slot_fst)],
                    epsilon_on_replace=True,
                )

                grammar_fsts[grammar_name] = grammar_fst

    # -------------------------------------------------------------------------

    # Remove tag start symbols.
    # TODO: Only do this for FSTs that actually have tags.
    for grammar_name, grammar_fst in grammar_fsts.items():
        main_rule_name = grammar_name + "." + grammar_name
        listener = listeners[main_rule_name]

        # Create a copy of the grammar FST with __begin__ and __end__ input
        # labels replaced by <eps>. For some reason, fstreplace fails when this
        # is done beforehand, whining about cyclic dependencies.
        in_eps = input_symbols.find(eps)
        old_fst = grammar_fst
        grammar_fst = fst.Fst()
        state_map: Dict[int, int] = {}
        weight_zero = fst.Weight.Zero(old_fst.weight_type())

        # Copy states with final status
        for old_state in old_fst.states():
            new_state = grammar_fst.add_state()
            state_map[old_state] = new_state
            if old_fst.final(old_state) != weight_zero:
                grammar_fst.set_final(new_state)

        # Start state
        grammar_fst.set_start(state_map[old_fst.start()])

        # Copy arcs
        for old_state, new_state in state_map.items():
            for old_arc in old_fst.arcs(old_state):
                # Replace tag input labels with <eps>
                input_idx = (in_eps if old_arc.ilabel in tag_input_symbols else
                             old_arc.ilabel)

                grammar_fst.add_arc(
                    new_state,
                    fst.Arc(
                        input_idx,
                        old_arc.olabel,
                        fst.Weight.One(grammar_fst.weight_type()),
                        state_map[old_arc.nextstate],
                    ),
                )

        grammar_fst.set_input_symbols(input_symbols)
        grammar_fst.set_output_symbols(output_symbols)

        # Replace FST
        grammar_fsts[grammar_name] = grammar_fst

    # -------------------------------------------------------------------------

    if not is_list:
        # Single input, single output
        return next(iter(grammar_fsts.values()))

    return grammar_fsts