Exemplo n.º 1
0
def induce_grammar_from(dsg, rec_par, decomp, labeling=(lambda x, y: str(x)), terminal_labeling=id, terminal_labeling_lcfrs=None, start="START",
                        normalize=True, enforce_outputs=True):
    if terminal_labeling_lcfrs is None:
        terminal_labeling_lcfrs = terminal_labeling
    lcfrs = LCFRS(start=start)
    ordered_nodes = dsg.dog.ordered_nodes()
    rhs_nont = induce_grammar_rec(lcfrs, dsg, rec_par, decomp, labeling, terminal_labeling, terminal_labeling_lcfrs
                                  , normalize, enforce_outputs, ordered_nodes=ordered_nodes)
    rhs_top = dsg.dog.top(decomp[0])

    # construct a chain rule from START to initial nonterminal of decomposition
    # LCFRS part
    lcfrs_lhs = LCFRS_lhs(start)
    lcfrs_lhs.add_arg([LCFRS_var(0, 0)])

    # DOG part
    dog = DirectedOrderedGraph()
    assert len(dsg.dog.inputs) == 0
    assert not enforce_outputs or len(dsg.dog.outputs) > 0
    for i in range(len(rhs_top)):
        dog.add_node(i)
    for output in dsg.dog.outputs:
        dog.add_to_outputs(rhs_top.index(output))
    dog.add_nonterminal_edge([], [i for i in range(len(rhs_top))], enforce_outputs)

    # no sync
    sync = []
    lcfrs.add_rule(lcfrs_lhs, [rhs_nont], weight=1.0, dcp=[dog, sync])

    return lcfrs
Exemplo n.º 2
0
def induce_grammar_rec(lcfrs, dsg, rec_par, decomp, labeling, terminal_labeling, terminal_labeling_lcfrs, normalize,
                       enforce_outputs=True, ordered_nodes=None):
    lhs_nont = labeling(decomp[0], dsg)

    # build lcfrs part
    lcfrs_lhs = LCFRS_lhs(lhs_nont)
    rhs_sent_pos = map(lambda x: x[0], rec_par[1])
    generated_sent_positions = fill_lcfrs_lhs(lcfrs_lhs, rec_par[0], rhs_sent_pos, dsg.sentence, terminal_labeling_lcfrs)

    # build dog part
    rhs_nodes = list(map(lambda x: x[0], decomp[1]))
    dog = dsg.dog.extract_dog(decomp[0], rhs_nodes, enforce_outputs, ordered_nodes=ordered_nodes)
    for edge in dog.terminal_edges:
        edge.label = terminal_labeling(edge.label)
    if normalize:
        node_renaming = dog.compress_node_names()
    else:
        node_renaming = {}

    # build terminal synchronization
    sync = [[node_renaming.get(node, int(node)) for node in dsg.get_graph_position(sent_position)]
            for sent_position in generated_sent_positions]

    # recursively compute rules for rhs
    rhs_nonts = []
    for child_rec_par, child_decomp in zip(rec_par[1], decomp[1]):
        rhs_nonts.append(
            induce_grammar_rec(lcfrs, dsg, child_rec_par, child_decomp, labeling, terminal_labeling,
                               terminal_labeling_lcfrs, normalize, enforce_outputs, ordered_nodes=ordered_nodes))

    # create rule
    lcfrs.add_rule(lcfrs_lhs, rhs_nonts, weight=1.0, dcp=[dog, sync])
    return lhs_nont
Exemplo n.º 3
0
def create_leaf_lcfrs_lhs(tree, node_ids, t_max, b_max, nont_labelling, term_labelling):
    """
    Create LCFRS_lhs for a leaf of the recursive partitioning,
    i.e. this LCFRS creates (consumes) exactly one terminal symbol.
    :param tree: HybridTree
    :param node_ids: list of string
    :param t_max: top_max of node_ids
    :param b_max: bottom_max of node_ids
    :type nont_labelling: AbstractLabeling
    :param term_labelling: HybridTree, node_id -> string
    :return: LCFRS_lhs
    """

    # Build LHS
    lhs = LCFRS_lhs(nont_labelling.label_nonterminal(tree, node_ids, t_max, b_max, 1))
    id = node_ids[0]
    arg = [term_labelling(tree.node_token(id))]
    lhs.add_arg(arg)
    return lhs
Exemplo n.º 4
0
def create_lcfrs_lhs(tree, node_ids, t_max, b_max, children, nont_labelling):
    """
    Create the LCFRS_lhs of some LCFRS-DCP hybrid rule.
    :rtype: LCFRS_lhs
    :param tree:     HybridTree
    :param node_ids: list of string (node in an recursive partitioning)
    :param t_max:    top_max of node_ids
    :param b_max:    bottom_max of node ids
    :param children: list of pairs of list of list of string
#                    (pairs of top_max / bottom_max of child nodes in recursive partitioning)
    :type nont_labelling: AbstractLabeling
    :return: LCFRS_lhs :raise Exception:
    """
    positions = map(tree.node_index, node_ids)
    spans = join_spans(positions)

    children_spans = list(map(join_spans, [map(tree.node_index, ids) for (ids, _) in children]))

    lhs = LCFRS_lhs(nont_labelling.label_nonterminal(tree, node_ids, t_max, b_max, len(spans)))
    for (low, high) in spans:
        arg = []
        i = low
        while i <= high:
            mem = 0
            match = False
            while mem < len(children_spans) and not match:
                child_spans = children_spans[mem]
                mem_arg = 0
                while mem_arg < len(child_spans) and not match:
                    child_span = child_spans[mem_arg]
                    if child_span[0] == i:
                        arg.append(LCFRS_var(mem, mem_arg))
                        i = child_span[1] + 1
                        match = True
                    mem_arg += 1
                mem += 1
            # Sanity check
            if not match:
                raise Exception('Expected ingredient for LCFRS argument was not found.')
        lhs.add_arg(arg)

    return lhs
Exemplo n.º 5
0
def induce_grammar(trees, nont_labelling, term_labelling, recursive_partitioning, start_nont='START'):
    """
    :rtype: LCFRS
    :param trees: corpus of HybridTree (i.e. list (or Generator for lazy IO))
    :type trees: __generator[HybridTree]
    :type nont_labelling: AbstractLabeling
    :param term_labelling: HybridTree, NodeId -> str
    :param recursive_partitioning: HybridTree -> RecursivePartitioning
    :type start_nont: str
    :rtype: int, LCFRS

    Top level method to induce an LCFRS/DCP-hybrid grammar for dependency parsing.
    """
    grammar = LCFRS(start_nont)
    n_trees = 0
    for tree in trees:
        n_trees += 1
        for rec_par in recursive_partitioning:
            match = re.search(r'no_new_nont', rec_par.__name__)
            if match:
                rec_par_int = rec_par(tree, grammar.nonts(), nont_labelling)
            else:
                rec_par_int = rec_par(tree)

            rec_par_nodes = tree.node_id_rec_par(rec_par_int)

            (_, _, nont_name) = add_rules_to_grammar_rec(tree, rec_par_nodes, grammar, nont_labelling, term_labelling)

            # Add rule from top start symbol to top most nonterminal for the hybrid tree
            lhs = LCFRS_lhs(start_nont)
            lhs.add_arg([LCFRS_var(0, 0)])
            rhs = [nont_name]
            dcp_rule = DCP_rule(DCP_var(-1, 0), [DCP_var(0, 0)])

            grammar.add_rule(lhs, rhs, 1.0, [dcp_rule])

    grammar.make_proper()
    return n_trees, grammar
Exemplo n.º 6
0
def direct_extract_lcfrs_from_prebinarized_corpus(tree,
                                                  term_labeling=PosTerminals(),
                                                  nont_labeling=BasicNonterminalLabeling(),
                                                  isolate_pos=True):
    gram = LCFRS(start=START)
    root = tree.root[0]
    if root in tree.full_yield():
        lhs = LCFRS_lhs(START)
        label = term_labeling.token_label(tree.node_token(root))
        lhs.add_arg([label])
        dcp_rule = DCP_rule(DCP_var(-1, 0), [DCP_term(DCP_index(0, edge_label=tree.node_token(root).edge()), [])])
        gram.add_rule(lhs, [], dcp=[dcp_rule])
    else:
        first, _, _ = direct_extract_lcfrs_prebinarized_recur(tree, root, gram, term_labeling, nont_labeling, isolate_pos)
        lhs = LCFRS_lhs(START)
        lhs.add_arg([LCFRS_var(0, 0)])
        dcp_rule = DCP_rule(DCP_var(-1, 0), [DCP_var(0, 0)])
        gram.add_rule(lhs, [first], dcp=[dcp_rule])
    return gram
Exemplo n.º 7
0
def direct_extract_lcfrs_prebinarized_recur(tree, idx, gram, term_labeling,
                                            nont_labeling, isolate_pos):
    assert isinstance(tree, HybridDag)
    fringe = tree.fringe(idx)
    spans = join_spans(fringe)
    nont_fanout = len(spans)

    _bot = list(bottom(tree, [idx] + tree.descendants(idx)))
    _top = list(top(tree, [idx] + tree.descendants(idx)))

    nont = nont_labeling.label_nont(tree, idx) + '/' + '/'.join(
        map(str, [nont_fanout, len(_bot), len(_top)]))

    lhs = LCFRS_lhs(nont)

    if idx in tree.full_yield():
        label = term_labeling.token_label(tree.node_token(idx))
        lhs.add_arg([label])
        dcp_rule = DCP_rule(DCP_var(-1, 0), [
            DCP_term(DCP_index(0, edge_label=tree.node_token(idx).edge()), [])
        ])
        gram.add_rule(lhs, [], dcp=[dcp_rule])
        return lhs.nont(), _bot, _top

    if not len(tree.children(idx)) <= 2:
        raise ValueError("Tree is not prebinarized!", tree, idx)

    children = [(child, join_spans(tree.fringe(child)))
                for child in tree.children(idx)]
    edge_labels = []
    for (low, high) in spans:
        arg = []
        pos = low
        while pos <= high:
            child_num = 0
            for i, (child, child_spans) in enumerate(children):
                for j, (child_low, child_high) in enumerate(child_spans):
                    if pos == child_low:
                        if child in tree.full_yield() and not isolate_pos:
                            arg += [
                                term_labeling.token_label(
                                    tree.node_token(child))
                            ]
                            edge_labels += [tree.node_token(child).edge()]
                        else:
                            arg += [LCFRS_var(child_num, j)]
                        pos = child_high + 1
                if child not in tree.full_yield() or isolate_pos:
                    child_num += 1
        lhs.add_arg(arg)

    dcp_term_args = []
    rhs = []
    nont_counter = 0
    term_counter = 0

    cbots = []
    ctops = []

    for (child, child_spans) in children:

        if child not in tree.full_yield() or isolate_pos:
            c_nont, _cbot, _ctop = direct_extract_lcfrs_prebinarized_recur(
                tree, child, gram, term_labeling, nont_labeling, isolate_pos)
            rhs.append(c_nont)
            cbots.append(_cbot)
            ctops.append(_ctop)
            dcp_term_args.append(
                DCP_var(nont_counter,
                        len(_cbot) + _ctop.index(child)))
            nont_counter += 1
        else:
            dcp_term_args.append(
                DCP_term(
                    DCP_index(term_counter,
                              edge_label=edge_labels[term_counter]), []))
            term_counter += 1

    for sec, sec_child in enumerate(tree.sec_children(idx)):
        if sec_child not in tree.descendants(idx):
            print(idx, "has external", sec_child)
            assert sec_child in _bot
            dcp_term_args.append(
                DCP_term(DCP_string("SECEDGE"),
                         [DCP_var(-1, _bot.index(sec_child))]))

        else:
            print(idx, "has internal", sec_child)

            assert False

    dcp_lhs = DCP_var(-1, len(_bot) + _top.index(idx))

    label = tree.node_token(idx).category()
    if re.match(r'.*\|<.*>', label):
        dcp_term = dcp_term_args
    else:
        dcp_term = [
            DCP_term(DCP_string(label, edge_label=tree.node_token(idx).edge()),
                     dcp_term_args)
        ]
    dcp_rule = DCP_rule(dcp_lhs, dcp_term)

    dcp_rules = [dcp_rule]

    for top_idx in _top:
        if top_idx != idx:
            # must be in some child
            rule = None

            for nont_counter, _ctop in enumerate(ctops):
                if top_idx in _ctop:
                    rule = DCP_rule(
                        DCP_var(-1,
                                len(_bot) + _top.index(top_idx)), [
                                    DCP_var(
                                        nont_counter,
                                        len(cbots[nont_counter]) +
                                        _ctop.index(top_idx))
                                ])

                    break
            assert rule is not None
            dcp_rules.append(rule)

    for nont_counter, _cbot in enumerate(cbots):
        for bot_idx in _cbot:
            rule = None
            rule_lhs = DCP_var(nont_counter, _cbot.index(bot_idx))

            if bot_idx in _bot:
                rule = DCP_rule(rule_lhs, [DCP_var(-1, _bot.index(bot_idx))])
            else:
                for nont_counter2, _ctop in enumerate(ctops):
                    if bot_idx in _ctop:
                        rule = DCP_rule(rule_lhs, [
                            DCP_var(
                                nont_counter2,
                                len(cbots[nont_counter2]) +
                                _ctop.index(bot_idx))
                        ])
                        break
            assert rule is not None
            dcp_rules.append(rule)

    gram.add_rule(lhs, rhs, dcp=dcp_rules)

    return nont, _bot, _top
Exemplo n.º 8
0
    def build_nm_grammar():
        grammar = LCFRS("START")
        # rule 0
        lhs = LCFRS_lhs("START")
        lhs.add_arg([LCFRS_var(0, 0)])
        grammar.add_rule(lhs, ["S"])

        # rule 1
        lhs = LCFRS_lhs("S")
        lhs.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0), LCFRS_var(0, 1), LCFRS_var(1, 1)])
        grammar.add_rule(lhs, ["N", "M"])

        for nont, term in [("A", "a"), ("B", "b"), ("C", "c"), ("D", "d")]:
            # rule 2
            lhs = LCFRS_lhs(nont)
            lhs.add_arg([term])
            grammar.add_rule(lhs, [])

        for nont, nont_, c1, c2 in [("N", "N'", "A", "C"), ("M", "M'", "B", "D")]:
            # rule 3
            lhs = LCFRS_lhs(nont)
            lhs.add_arg([LCFRS_var(0, 0)])
            lhs.add_arg([LCFRS_var(1, 0)])
            grammar.add_rule(lhs, [c1, c2])

            # rule 4
            lhs = LCFRS_lhs(nont)
            lhs.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0)])
            lhs.add_arg([LCFRS_var(0,1)])
            grammar.add_rule(lhs, [nont_, c1])

            # rule 5
            lhs = LCFRS_lhs(nont_)
            lhs.add_arg([LCFRS_var(0, 0)])
            lhs.add_arg([LCFRS_var(0, 1), LCFRS_var(1, 0)])
            grammar.add_rule(lhs, [nont, c2])

        grammar.make_proper()
        return grammar
Exemplo n.º 9
0
    def test_la_viterbi_parsing_3(self):
        grammar = LCFRS("S")

        # rule 0
        lhs = LCFRS_lhs("B")
        lhs.add_arg(["a"])
        grammar.add_rule(lhs, [], 0.25)

        # rule 1
        lhs = LCFRS_lhs("A")
        lhs.add_arg(["a"])
        grammar.add_rule(lhs, [], 0.5)

        # rule 2
        lhs = LCFRS_lhs("S")
        lhs.add_arg([LCFRS_var(0, 0)])
        grammar.add_rule(lhs, ["B"], 1.0)

        # rule 3
        lhs = LCFRS_lhs("A")
        lhs.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0)])
        grammar.add_rule(lhs, ["A", "B"], 0.5)

        # rule 4
        lhs = LCFRS_lhs("B")
        lhs.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0)])
        grammar.add_rule(lhs, ["A", "B"], 0.75)

        grammar.make_proper()

        inp = ["a"] * 3

        nontMap = Enumerator()
        gi = PyGrammarInfo(grammar, nontMap)
        sm = PyStorageManager()
        print(nontMap.object_index("S"))
        print(nontMap.object_index("B"))

        la = build_PyLatentAnnotation_initial(grammar, gi, sm)
        parser = DiscodopKbestParser(grammar, la=la, nontMap=nontMap, grammarInfo=gi, latent_viterbi_mode=True)
        parser.set_input(inp)
        parser.parse()
        self.assertTrue(parser.recognized())
        der = parser.latent_viterbi_derivation(True)
        print(der)

        der2 = None

        for w, der_ in parser.k_best_derivation_trees():
            if der2 is None:
                der2 = der_
            print(w, der_)

        print(der2)
Exemplo n.º 10
0
    def build_grammar():
        grammar = LCFRS("START")
        # rule 0
        lhs = LCFRS_lhs("START")
        lhs.add_arg([LCFRS_var(0, 0)])
        grammar.add_rule(lhs, ["S"])

        # rule 1
        lhs = LCFRS_lhs("S")
        lhs.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0)])
        grammar.add_rule(lhs, ["S", "S"])

        # rule 1.5
        lhs = LCFRS_lhs("S")
        lhs.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0)])
        grammar.add_rule(lhs, ["S", "S"], dcp=["1.5"])

        # rule 2
        lhs = LCFRS_lhs("S")
        lhs.add_arg(["a"])
        grammar.add_rule(lhs, [])

        # rule 3
        lhs = LCFRS_lhs("S")
        lhs.add_arg(["b"])
        grammar.add_rule(lhs, [], weight=2.0)

        # rule 4
        lhs = LCFRS_lhs("S")
        lhs.add_arg(["b"])
        grammar.add_rule(lhs, [], dcp=["4"])

        # rule 5
        lhs = LCFRS_lhs("A")
        lhs.add_arg(["a"])
        grammar.add_rule(lhs, [])

        grammar.make_proper()
        return grammar
Exemplo n.º 11
0
    def build_paper_grammar():
        grammar = LCFRS("S")
        # rule 0
        lhs = LCFRS_lhs("B")
        lhs.add_arg(["a"])
        grammar.add_rule(lhs, [])

        # rule 1
        lhs = LCFRS_lhs("S")
        lhs.add_arg([LCFRS_var(0, 0)])
        grammar.add_rule(lhs, ["B"])

        # rule 2
        lhs = LCFRS_lhs("B")
        lhs.add_arg([LCFRS_var(0,0), LCFRS_var(1, 0)])
        grammar.add_rule(lhs, ["B", "B"])

        grammar.make_proper()
        return grammar
Exemplo n.º 12
0
def __rec_induce(tree, grammar, string_partition, terminal_labeling,
                 nonterminal_counts):
    positions, children_partitions = string_partition
    tree_positions = srp_to_trp(tree=tree, positions=positions)
    top = sum(top_max(tree=tree, id_set=tree_positions), [])

    children_positions = []
    children_nonterminals = []
    for child_partition in children_partitions:
        child_positions = child_partition[0]
        children_positions.append(child_positions)
        children_nonterminals.append(
            __rec_induce(tree=tree,
                         grammar=grammar,
                         string_partition=child_partition,
                         terminal_labeling=terminal_labeling,
                         nonterminal_counts=nonterminal_counts))

    children_spans = []
    children_tops = []
    children_tree_positions = []
    for child_positions in children_positions:
        children_spans.append(join_spans(child_positions))
        children_tops.append(
            sum(
                top_max(tree=tree,
                        id_set=srp_to_trp(tree=tree,
                                          positions=child_positions)), []))
        children_tree_positions.append(
            srp_to_trp(tree=tree, positions=child_positions))

    spans = join_spans(positions)
    arguments = []
    term_to_pos = {}
    for span in spans:
        arguments.append(
            span_to_arg(span=span,
                        children=children_spans,
                        tree=tree,
                        term_to_pos=term_to_pos,
                        term_labeling=terminal_labeling))

    # TODO idea: since we now have the top-set,
    #  we can name nonterminals based on them
    #  - also there might be a mistake: Ausgerechnet creates S(…)
    nonterminal = nonterminal_labeling_strict(tree=tree,
                                              spans=spans,
                                              children_spans=children_spans)
    lhs = LCFRS_lhs(nonterminal)
    for argument in arguments:
        lhs.add_arg(argument)

    present_tree_positions = tree_positions[:]
    for child_tree_positions in children_tree_positions:
        for position in child_tree_positions:
            if position in present_tree_positions:
                present_tree_positions.remove(position)
            else:
                print(f'{position} not in {present_tree_positions}')

    # print(f'pos: {positions}\ntos: {tree_positions}\ntop: {top}')
    # print(f'pst: {present_tree_positions}')
    # print()

    id_to_pos = {
        tree.index_node(index=index + 1): term_to_pos[index]
        for index in term_to_pos
    }
    dcp_rhs = [
        create_dcp_rhs(tree=tree,
                       root=t_id,
                       present_tree_positions=present_tree_positions,
                       string_positions=tree.id_yield(),
                       id_to_pos=id_to_pos,
                       terminal_labelling=terminal_labeling,
                       children_tops=children_tops) for t_id in top
    ]
    dcp = [DCP_rule(lhs=DCP_var(-1, 0), rhs=dcp_rhs)]

    grammar.add_rule(lhs, children_nonterminals, 1, dcp)
    if nonterminal in nonterminal_counts:
        nonterminal_counts[nonterminal] += 1
    else:
        nonterminal_counts[nonterminal] = 1
    return nonterminal
    def test_stanford_unking_scheme(self):
        naming = 'child'

        def rec_part(tree):
            return left_branching_partitioning(len(tree.id_yield()))

        tree = self.tree
        tree.add_to_root("VP1")

        print(tree)

        terminal_labeling = StanfordUNKing([tree])

        grammar = fringe_extract_lcfrs(tree,
                                       rec_part(tree),
                                       naming=naming,
                                       isolate_pos=True,
                                       term_labeling=terminal_labeling)
        print(grammar)

        parser = LCFRS_parser(grammar)
        parser.set_input([token.form() for token in tree.token_yield()])
        parser.parse()
        self.assertTrue(parser.recognized())
        derivation = parser.best_derivation_tree()
        e = DCP_evaluator(derivation)
        dcp_term = e.getEvaluation()
        print(str(dcp_term[0]))
        t = ConstituentTree()
        dcp_to_hybridtree(
            t,
            dcp_term, [
                construct_constituent_token(token.form(), '--', True)
                for token in tree.token_yield()
            ],
            ignore_punctuation=False,
            construct_token=construct_constituent_token)
        print(t)
        self.assertEqual(len(tree.token_yield()), len(t.token_yield()))
        for tok1, tok2 in zip(tree.token_yield(), t.token_yield()):
            self.assertEqual(tok1.form(), tok2.form())
            self.assertEqual(tok1.pos(), tok2.pos())

        rules = terminal_labeling.create_smoothed_rules()
        print(rules)

        new_rules = {}

        for rule in grammar.rules():
            if rule.rhs() == []:
                assert len(rule.dcp()) == 1
                dcp = rule.dcp()[0]
                assert len(dcp.rhs()) == 1
                term = dcp.rhs()[0]
                head = term.head()
                pos = head.pos()

                for tag, form in rules:
                    if tag == pos:
                        lhs = LCFRS_lhs(rule.lhs().nont())
                        lhs.add_arg([form])
                        new_rules[lhs, dcp] = rules[tag, form]

        for lhs, dcp in new_rules:
            print(str(lhs), str(dcp), new_rules[(lhs, dcp)])

        tokens = [
            construct_constituent_token('hat', '--', True),
            construct_constituent_token('HAT', '--', True)
        ]
        self.assertEqual(terminal_labeling.token_label(tokens[0]), 'hat')
        self.assertEqual(terminal_labeling.token_label(tokens[1]), '_UNK')
        terminal_labeling.test_mode = True
        self.assertEqual(terminal_labeling.token_label(tokens[0]), 'hat')
        self.assertEqual(terminal_labeling.token_label(tokens[1]), 'hat')
Exemplo n.º 14
0
    def build_grammar(self):
        grammar = LCFRS("S")

        lhs1 = LCFRS_lhs("S")
        lhs1.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0)])
        rule_1 = grammar.add_rule(lhs1, ["S", "S"])

        lhs2 = LCFRS_lhs("S")
        lhs2.add_arg(["a"])
        rule_2 = grammar.add_rule(lhs2, [])

        lhs3 = LCFRS_lhs("A")
        lhs3.add_arg(["a"])
        rule_3 = grammar.add_rule(lhs3, [])

        return grammar, rule_1.get_idx(), rule_2.get_idx()
    def __test_projection(self,
                          split_weights,
                          goal_weights,
                          merge_method=False):
        grammar = LCFRS("S")
        # rule 0
        lhs = LCFRS_lhs("S")
        lhs.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0)])
        grammar.add_rule(lhs, ["A", "A"])

        # rule 1
        lhs = LCFRS_lhs("A")
        lhs.add_arg(["a"])
        grammar.add_rule(lhs, [])

        lhs = LCFRS_lhs("A")
        lhs.add_arg(["b"])
        grammar.add_rule(lhs, [], weight=2.0)

        grammar.make_proper()
        # print(grammar)

        nonterminal_map = Enumerator()
        grammarInfo = PyGrammarInfo(grammar, nonterminal_map)
        storageManager = PyStorageManager()

        la = build_PyLatentAnnotation([1, 2], [1.0], split_weights,
                                      grammarInfo, storageManager)

        # parser = LCFRS_parser(grammar)
        # parser.set_input(["a", "b"])
        # parser.parse()
        # der = parser.best_derivation_tree()

        # print(la.serialize())
        if merge_method:
            la.project_weights(grammar, grammarInfo)
        else:
            splits, _, _ = la.serialize()
            merge_sources = [[[
                split for split in range(0, splits[nont_idx])
            ]] for nont_idx in range(0, nonterminal_map.get_counter())]

            # print("Projecting to fine grammar LA", file=self.logger)
            coarse_la = la.project_annotation_by_merging(grammarInfo,
                                                         merge_sources,
                                                         debug=False)
            coarse_la.project_weights(grammar, grammarInfo)

        # print(grammar)
        for i in range(3):
            self.assertAlmostEqual(
                grammar.rule_index(i).weight(), goal_weights[i])
Exemplo n.º 16
0
    def test_projection_based_parser_k_best_hack(self):
        grammar = LCFRS("S")

        # rule 0
        lhs = LCFRS_lhs("B")
        lhs.add_arg(["a"])
        grammar.add_rule(lhs, [], 0.25)

        # rule 1
        lhs = LCFRS_lhs("A")
        lhs.add_arg(["a"])
        grammar.add_rule(lhs, [], 0.5)

        # rule 2
        lhs = LCFRS_lhs("S")
        lhs.add_arg([LCFRS_var(0, 0)])
        grammar.add_rule(lhs, ["B"], 1.0)

        # rule 3
        lhs = LCFRS_lhs("A")
        lhs.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0)])
        grammar.add_rule(lhs, ["A", "B"], 0.5)

        # rule 4
        lhs = LCFRS_lhs("B")
        lhs.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0)])
        grammar.add_rule(lhs, ["A", "B"], 0.75)

        grammar.make_proper()

        inp = ["a"] * 3
        nontMap = Enumerator()
        gi = PyGrammarInfo(grammar, nontMap)
        sm = PyStorageManager()
        la = build_PyLatentAnnotation_initial(grammar, gi, sm)

        parser = Coarse_to_fine_parser(grammar,
                                       la,
                                       gi,
                                       nontMap,
                                       base_parser_type=GFParser_k_best)
        parser.set_input(inp)
        parser.parse()
        self.assertTrue(parser.recognized())
        der = parser.max_rule_product_derivation()
        print(der)

        der = parser.best_derivation_tree()
        print(der)

        for node in der.ids():
            print(der.getRule(node), der.spanned_ranges(node))