def test_la_viterbi_parsing_2(self): grammar = self.build_paper_grammar() inp = ["a"] * 3 nontMap = Enumerator() gi = PyGrammarInfo(grammar, nontMap) sm = PyStorageManager() print(nontMap.object_index("S")) print(nontMap.object_index("B")) la = build_PyLatentAnnotation( [2, 1], [1.0], [[0.25, 1.0], [1.0, 0.0], [0.0, 0.5, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0]], gi, sm) self.assertTrue(la.is_proper()) parser = DiscodopKbestParser(grammar, la=la, nontMap=nontMap, grammarInfo=gi, latent_viterbi_mode=True) parser.set_input(inp) parser.parse() self.assertTrue(parser.recognized()) der = parser.latent_viterbi_derivation(True) print(der) ranges = {der.spanned_ranges(idx)[0] for idx in der.ids()} self.assertSetEqual({(0, 3), (0, 2), (0, 1), (1, 2), (2, 3)}, ranges)
def test_la_viterbi_parsing_3(self): grammar = LCFRS("S") # rule 0 lhs = LCFRS_lhs("B") lhs.add_arg(["a"]) grammar.add_rule(lhs, [], 0.25) # rule 1 lhs = LCFRS_lhs("A") lhs.add_arg(["a"]) grammar.add_rule(lhs, [], 0.5) # rule 2 lhs = LCFRS_lhs("S") lhs.add_arg([LCFRS_var(0, 0)]) grammar.add_rule(lhs, ["B"], 1.0) # rule 3 lhs = LCFRS_lhs("A") lhs.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0)]) grammar.add_rule(lhs, ["A", "B"], 0.5) # rule 4 lhs = LCFRS_lhs("B") lhs.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0)]) grammar.add_rule(lhs, ["A", "B"], 0.75) grammar.make_proper() inp = ["a"] * 3 nontMap = Enumerator() gi = PyGrammarInfo(grammar, nontMap) sm = PyStorageManager() print(nontMap.object_index("S")) print(nontMap.object_index("B")) la = build_PyLatentAnnotation_initial(grammar, gi, sm) parser = DiscodopKbestParser(grammar, la=la, nontMap=nontMap, grammarInfo=gi, latent_viterbi_mode=True) parser.set_input(inp) parser.parse() self.assertTrue(parser.recognized()) der = parser.latent_viterbi_derivation(True) print(der) der2 = None for w, der_ in parser.k_best_derivation_trees(): if der2 is None: der2 = der_ print(w, der_) print(der2)
def linearize(grammar, nonterminal_labeling, terminal_labeling, file, delimiter='::', nonterminal_encoder=None): """ :type grammar: LCFRS :param nonterminal_labeling: :param terminal_labeling: :param file: file handle to write to :type delimiter: str :param delimiter: string used to join terminal symbol with edge label symbol :type nonterminal_encoder: Enumerator :param nonterminal_encoder: mapping that assigns unique non-negative integer to each nonterminal """ print("Nonterminal Labeling: ", nonterminal_labeling, file=file) print("Terminal Labeling: ", terminal_labeling, file=file) print(file=file) terminals = Enumerator(first_index=1) if nonterminal_encoder is None: nonterminals = Enumerator() else: nonterminals = nonterminal_encoder num_inherited_args = {} num_synthesized_args = {} for rule in grammar.rules(): rid = 'r%i' % (rule.get_idx() + 1) print(rid, 'RTG ', nonterminals.object_index(rule.lhs().nont()), '->', file=file, end=" ") print(list( map(lambda nont: nonterminals.object_index(nont), rule.rhs())), ';', file=file) print(rid, 'WEIGHT', rule.weight(), ';', file=file) sync_index = {} inh_args = defaultdict(lambda: 0) lhs_var_counter = CountLHSVars() synthesized_attributes = 0 dcp_ordered = sorted(rule.dcp(), key=lambda x: (x.lhs().mem(), x.lhs().arg())) for dcp in dcp_ordered: if dcp.lhs().mem() != -1: inh_args[dcp.lhs().mem()] += 1 else: synthesized_attributes += 1 lhs_var_counter.evaluate_list(dcp.rhs()) num_inherited_args[nonterminals.object_index( rule.lhs().nont())] = inh_args[-1] = lhs_var_counter.get_number() num_synthesized_args[nonterminals.object_index( rule.lhs().nont())] = synthesized_attributes for dcp in dcp_ordered: printer = DcpPrinter(terminals.object_index, rule, sync_index, inh_args, delimiter=delimiter) printer.evaluate_list(dcp.rhs()) var = dcp.lhs() if var.mem() == -1: var_string = 's<0,%i>' % (var.arg() + 1 - inh_args[-1]) else: var_string = 's<%i,%i>' % (var.mem() + 1, var.arg() + 1) print('%s sDCP %s == %s ;' % (rid, var_string, printer.string), file=file) s = 0 for j, arg in enumerate(rule.lhs().args()): print(rid, 'LCFRS s<0,%i> == [' % (j + 1), end=' ', file=file) first = True for a in arg: if not first: print(",", end=' ', file=file) if isinstance(a, LCFRS_var): print("x<%i,%i>" % (a.mem + 1, a.arg + 1), end=' ', file=file) pass else: if s in sync_index: print(str(terminals.object_index(a)) + '^{%i}' % sync_index[s], end=' ', file=file) else: print(str(terminals.object_index(a)), end=' ', file=file) s += 1 first = False print('] ;', file=file) print(file=file) print("Terminals: ", file=file) terminals.print_index(to_file=file) print(file=file) print("Nonterminal ID, nonterminal name, fanout, #inh, #synth: ", file=file) max_fanout, max_inh, max_syn, max_args, fanouts, inherits, synths, args \ = print_index_and_stats(nonterminals, grammar, num_inherited_args, num_synthesized_args, file=file) print(file=file) print("max fanout:", max_fanout, file=file) print("max inh:", max_inh, file=file) print("max synth:", max_syn, file=file) print("max args:", max_args, file=file) print(file=file) for s, d, m in [('fanout', fanouts, max_fanout), ('inh', inherits, max_inh), ('syn', synths, max_syn), ('args', args, max_args)]: for i in range(m + 1): print('# the number of nonterminals with %s = %i is %i' % (s, i, d[i]), file=file) print(file=file) print(file=file) print("Initial nonterminal: ", nonterminals.object_index(grammar.start()), file=file) print(file=file) return nonterminals, terminals