Пример #1
0
    def test_cfg_approximation_conversion(self):
        grammar = self.build_nm_grammar()
        disco_grammar_rules = list(transform_grammar_cfg_approx(grammar))
        print(disco_grammar_rules)
        disco_grammar = Grammar(disco_grammar_rules, start=grammar.start())
        print(disco_grammar)
        n = 2
        m = 3
        inp = ["a"] * n + ["b"] * m + ["c"] * n + ["d"] * m

        chart, msg = parse(inp, disco_grammar, beam_beta=exp(-4))
        chart.filter()
        print(chart)
        print(msg)

        fine_grammar_rules = list(transform_grammar(grammar))

        fine = Grammar(fine_grammar_rules, start=grammar.start())
        fine.getmapping(disco_grammar, re.compile('\*[0-9]+$'), None, True, True)

        whitelist, msg = prunechart(chart, fine, k=10000, splitprune=True, markorigin=True, finecfg=False)
        print(msg)
        print(whitelist)

        chart2, msg = parse(inp, fine, whitelist=whitelist, splitprune=True, markorigin=True)
        print(msg)
        print(chart2)
Пример #2
0
 def __init__(self,
              grammar,
              input=None,
              save_preprocessing=None,
              load_preprocessing=None,
              k=50,
              heuristics=None,
              la=None,
              variational=False,
              sum_op=False,
              nontMap=None,
              cfg_ctf=False,
              beam_beta=0.0,
              beam_delta=50,
              pruning_k=10000,
              grammarInfo=None,
              projection_mode=False,
              latent_viterbi_mode=False,
              secondaries=None
              ):
     rule_list = list(transform_grammar(grammar))
     self.disco_grammar = Grammar(rule_list, start=grammar.start())
     self.chart = None
     self.input = input
     self.grammar = grammar
     self.k = k
     self.beam_beta = beam_beta # beam pruning factor, between 0.0 and 1.0; 0.0 to disable.
     self.beam_delta = beam_delta  # maximum span length to which beam_beta is applied
     self.counter = 0
     self.la = la
     self.nontMap = nontMap
     self.variational = variational
     self.op = add if sum_op else prod
     self.debug = False
     self.log_mode = True
     self.estimates = None
     self.cfg_approx = cfg_ctf
     self.pruning_k = pruning_k
     self.grammarInfo = grammarInfo
     self.projection_mode = projection_mode
     self.latent_viterbi_mode = latent_viterbi_mode
     self.secondaries = [] if secondaries is None else secondaries
     self.secondary_mode = "DEFAULT"
     self.k_best_reranker = None
     if grammarInfo is not None:
         if isinstance(self.la, PyLatentAnnotation):
             assert self.la.check_rule_split_alignment()
         else:
             for l in self.la:
                 assert l.check_rule_split_alignment()
     if cfg_ctf:
         cfg_rule_list = list(transform_grammar_cfg_approx(grammar))
         self.disco_cfg_grammar = Grammar(cfg_rule_list, start=grammar.start())
         self.disco_grammar.getmapping(self.disco_cfg_grammar, re.compile('\*[0-9]+$'), None, True, True)
Пример #3
0
    def test_individual_parsing_stages(self):
        grammar = self.build_grammar()

        for r in transform_grammar(grammar):
            pprint(r)

        rule_list = list(transform_grammar(grammar))
        pprint(rule_list)
        disco_grammar = Grammar(rule_list, start=grammar.start())
        print(disco_grammar)

        inp = ["a"] * 3
        estimates = 'SXlrgaps', getestimates(disco_grammar, 40, grammar.start())
        print(type(estimates))
        chart, msg = parse(inp, disco_grammar, estimates=estimates)
        print(chart)
        print(msg)
        chart.filter()
        print("filtered chart")
        print(disco_grammar.nonterminals)
        print(type(disco_grammar.nonterminals))

        print(chart)
        # print(help(chart))

        root = chart.root()
        print("root", root, type(root))
        print(chart.indices(root))
        print(chart.itemstr(root))
        print(chart.stats())
        print("root label", chart.label(root))
        print(root, chart.itemid1(chart.label(root), chart.indices(root)))
        for i in range(1, chart.numitems() + 1):
            print(i, chart.label(i), chart.indices(i), chart.numedges(i))
            if True or len(chart.indices(i)) > 1:
                for edge_num in range(chart.numedges(i)):
                    edge = chart.getEdgeForItem(i, edge_num)
                    if isinstance(edge, tuple):
                        print("\t", disco_grammar.nonterminalstr(chart.label(i)) + "[" + str(i) + "]", "->", ' '.join([disco_grammar.nonterminalstr(chart.label(j)) + "[" + str(j) + "]" for j in [edge[1], edge[2]] if j != 0]))
                    else:
                        print("\t", disco_grammar.nonterminalstr(chart.label(i)) + "[" + str(i) + "]", "->", inp[edge])
        print(chart.getEdgeForItem(root, 0))
        # print(lazykbest(chart, 5))

        manager = PyDerivationManager(grammar)
        manager.convert_chart_to_hypergraph(chart, disco_grammar, debug=True)

        file = tempfile.mktemp()
        print(file)
        manager.serialize(bytes(file, encoding="utf-8"))

        gi = PyGrammarInfo(grammar, manager.get_nonterminal_map())
        sm = PyStorageManager()
        la = build_PyLatentAnnotation_initial(grammar, gi, sm)

        vec = py_edge_weight_projection(la, manager, variational=True, debug=True, log_mode=False)
        print(vec)
        self.assertEqual([1.0, 1.0, 1.0, 0.5, 0.5, 0.5, 0.5, 0.25, 0.25, 0.25, 0.25, 1.0], vec)

        vec = py_edge_weight_projection(la, manager, variational=False, debug=True, log_mode=False)
        print(vec)
        self.assertEqual([1.0, 1.0, 1.0, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 1.0], vec)

        der = manager.viterbi_derivation(0, vec, grammar)
        print(der)

        # print(disco_grammar.rulenos)
        # print(disco_grammar.numrules)
        # print(disco_grammar.lexicalbylhs)
        # print(disco_grammar.lexicalbyword)
        # print(disco_grammar.lexicalbynum)
        # print(disco_grammar.origrules, type(disco_grammar.origrules))
        # print(disco_grammar.numbinary)
        # print(disco_grammar.numunary)
        # print(disco_grammar.toid)
        # print(disco_grammar.tolabel)
        # print(disco_grammar.bitpar)
        # striplabelre = re.compile(r'-\d+$')
        # msg = disco_grammar.getmapping(None, None)
        # disco_grammar.getrulemapping(disco_grammar, striplabelre)
        # mapping = disco_grammar.rulemapping
        # print(mapping)
        # for idx, group in enumerate(mapping):
        #     print("Index", idx)
        #     for elem in group:
        #         print(grammar.rule_index(elem))

        # for _, item in zip(range(20), chart.parseforest):
        #     edge = chart.parseforest[item]
        #     print(item, item.binrepr(), item.__repr__(), item.lexidx())
        #     print(type(edge))
        for _ in range(5):
            vec2 = py_edge_weight_projection(la, manager, debug=True, log_mode=True)
            print(vec2)