예제 #1
0
    def test_la_viterbi_parsing_2(self):
        grammar = self.build_paper_grammar()
        inp = ["a"] * 3
        nontMap = Enumerator()
        gi = PyGrammarInfo(grammar, nontMap)
        sm = PyStorageManager()
        print(nontMap.object_index("S"))
        print(nontMap.object_index("B"))
        la = build_PyLatentAnnotation(
            [2, 1], [1.0], [[0.25, 1.0], [1.0, 0.0],
                            [0.0, 0.5, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0]], gi, sm)
        self.assertTrue(la.is_proper())

        parser = DiscodopKbestParser(grammar,
                                     la=la,
                                     nontMap=nontMap,
                                     grammarInfo=gi,
                                     latent_viterbi_mode=True)
        parser.set_input(inp)
        parser.parse()
        self.assertTrue(parser.recognized())
        der = parser.latent_viterbi_derivation(True)
        print(der)
        ranges = {der.spanned_ranges(idx)[0] for idx in der.ids()}
        self.assertSetEqual({(0, 3), (0, 2), (0, 1), (1, 2), (2, 3)}, ranges)
예제 #2
0
    def test_la_viterbi_parsing_3(self):
        grammar = LCFRS("S")

        # rule 0
        lhs = LCFRS_lhs("B")
        lhs.add_arg(["a"])
        grammar.add_rule(lhs, [], 0.25)

        # rule 1
        lhs = LCFRS_lhs("A")
        lhs.add_arg(["a"])
        grammar.add_rule(lhs, [], 0.5)

        # rule 2
        lhs = LCFRS_lhs("S")
        lhs.add_arg([LCFRS_var(0, 0)])
        grammar.add_rule(lhs, ["B"], 1.0)

        # rule 3
        lhs = LCFRS_lhs("A")
        lhs.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0)])
        grammar.add_rule(lhs, ["A", "B"], 0.5)

        # rule 4
        lhs = LCFRS_lhs("B")
        lhs.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0)])
        grammar.add_rule(lhs, ["A", "B"], 0.75)

        grammar.make_proper()

        inp = ["a"] * 3

        nontMap = Enumerator()
        gi = PyGrammarInfo(grammar, nontMap)
        sm = PyStorageManager()
        print(nontMap.object_index("S"))
        print(nontMap.object_index("B"))

        la = build_PyLatentAnnotation_initial(grammar, gi, sm)
        parser = DiscodopKbestParser(grammar,
                                     la=la,
                                     nontMap=nontMap,
                                     grammarInfo=gi,
                                     latent_viterbi_mode=True)
        parser.set_input(inp)
        parser.parse()
        self.assertTrue(parser.recognized())
        der = parser.latent_viterbi_derivation(True)
        print(der)

        der2 = None

        for w, der_ in parser.k_best_derivation_trees():
            if der2 is None:
                der2 = der_
            print(w, der_)

        print(der2)
예제 #3
0
def linearize(grammar,
              nonterminal_labeling,
              terminal_labeling,
              file,
              delimiter='::',
              nonterminal_encoder=None):
    """
    :type grammar: LCFRS
    :param nonterminal_labeling:
    :param terminal_labeling:
    :param file: file handle to write to
    :type delimiter: str
    :param delimiter: string used to join terminal symbol with edge label symbol
    :type nonterminal_encoder: Enumerator
    :param nonterminal_encoder: mapping that assigns unique non-negative integer to each nonterminal
    """
    print("Nonterminal Labeling: ", nonterminal_labeling, file=file)
    print("Terminal Labeling: ", terminal_labeling, file=file)
    print(file=file)

    terminals = Enumerator(first_index=1)
    if nonterminal_encoder is None:
        nonterminals = Enumerator()
    else:
        nonterminals = nonterminal_encoder
    num_inherited_args = {}
    num_synthesized_args = {}

    for rule in grammar.rules():
        rid = 'r%i' % (rule.get_idx() + 1)
        print(rid,
              'RTG   ',
              nonterminals.object_index(rule.lhs().nont()),
              '->',
              file=file,
              end=" ")
        print(list(
            map(lambda nont: nonterminals.object_index(nont), rule.rhs())),
              ';',
              file=file)

        print(rid, 'WEIGHT', rule.weight(), ';', file=file)

        sync_index = {}
        inh_args = defaultdict(lambda: 0)
        lhs_var_counter = CountLHSVars()
        synthesized_attributes = 0

        dcp_ordered = sorted(rule.dcp(),
                             key=lambda x: (x.lhs().mem(), x.lhs().arg()))

        for dcp in dcp_ordered:
            if dcp.lhs().mem() != -1:
                inh_args[dcp.lhs().mem()] += 1
            else:
                synthesized_attributes += 1
            lhs_var_counter.evaluate_list(dcp.rhs())
        num_inherited_args[nonterminals.object_index(
            rule.lhs().nont())] = inh_args[-1] = lhs_var_counter.get_number()
        num_synthesized_args[nonterminals.object_index(
            rule.lhs().nont())] = synthesized_attributes

        for dcp in dcp_ordered:
            printer = DcpPrinter(terminals.object_index,
                                 rule,
                                 sync_index,
                                 inh_args,
                                 delimiter=delimiter)
            printer.evaluate_list(dcp.rhs())
            var = dcp.lhs()
            if var.mem() == -1:
                var_string = 's<0,%i>' % (var.arg() + 1 - inh_args[-1])
            else:
                var_string = 's<%i,%i>' % (var.mem() + 1, var.arg() + 1)
            print('%s sDCP   %s == %s ;' % (rid, var_string, printer.string),
                  file=file)

        s = 0
        for j, arg in enumerate(rule.lhs().args()):
            print(rid, 'LCFRS  s<0,%i> == [' % (j + 1), end=' ', file=file)
            first = True
            for a in arg:
                if not first:
                    print(",", end=' ', file=file)
                if isinstance(a, LCFRS_var):
                    print("x<%i,%i>" % (a.mem + 1, a.arg + 1),
                          end=' ',
                          file=file)
                    pass
                else:
                    if s in sync_index:
                        print(str(terminals.object_index(a)) +
                              '^{%i}' % sync_index[s],
                              end=' ',
                              file=file)
                    else:
                        print(str(terminals.object_index(a)),
                              end=' ',
                              file=file)
                    s += 1
                first = False
            print('] ;', file=file)
        print(file=file)

    print("Terminals: ", file=file)
    terminals.print_index(to_file=file)
    print(file=file)

    print("Nonterminal ID, nonterminal name, fanout, #inh, #synth: ",
          file=file)
    max_fanout, max_inh, max_syn, max_args, fanouts, inherits, synths, args \
        = print_index_and_stats(nonterminals, grammar, num_inherited_args, num_synthesized_args, file=file)
    print(file=file)
    print("max fanout:", max_fanout, file=file)
    print("max inh:", max_inh, file=file)
    print("max synth:", max_syn, file=file)
    print("max args:", max_args, file=file)
    print(file=file)
    for s, d, m in [('fanout', fanouts, max_fanout),
                    ('inh', inherits, max_inh), ('syn', synths, max_syn),
                    ('args', args, max_args)]:
        for i in range(m + 1):
            print('# the number of nonterminals with %s = %i is %i' %
                  (s, i, d[i]),
                  file=file)
        print(file=file)
    print(file=file)

    print("Initial nonterminal: ",
          nonterminals.object_index(grammar.start()),
          file=file)
    print(file=file)
    return nonterminals, terminals