Beispiel #1
0
def run_experiment(db_file, training_corpus, test_corpus, do_parse,
                   ignore_punctuation, length_limit, labeling,
                   terminal_labeling, partitioning, root_default_deprel,
                   disconnected_default_deprel, max_training, max_test):
    labeling_choices = labeling.split('-')
    if len(labeling_choices) == 2:
        nont_labelling = label.the_labeling_factory(
        ).create_simple_labeling_strategy(labeling_choices[0],
                                          labeling_choices[1])
    elif len(labeling_choices) > 2:
        nont_labelling = label.the_labeling_factory(
        ).create_complex_labeling_strategy(labeling_choices)
        # labeling == 'strict-pos-leaf:dep':
        # labeling == 'child-pos-leaf:dep':
    else:
        print("Error: Invalid labeling strategy: " + labeling)
        exit(1)

    rec_par = the_recursive_partitioning_factory().get_partitioning(
        partitioning)
    if rec_par is None:
        print("Error: Invalid recursive partitioning strategy: " +
              partitioning)
        exit(1)

    term_labeling_strategy = the_terminal_labeling_factory().get_strategy(
        terminal_labeling)
    if term_labeling_strategy is None:
        print("Error: Invalid recursive partitioning strategy: " +
              partitioning)
        exit(1)

    parser_type = the_parser_factory().getParser(partitioning)
    if parser_type is None:
        print("Error: Invalid parser type: " + partitioning)
        exit(1)

    connection = experiment_database.initialize_database(db_file)
    grammar, experiment = induce_grammar_from_file(
        training_corpus, connection, nont_labelling, term_labeling_strategy,
        rec_par, max_training, False, 'START', ignore_punctuation)
    if do_parse:
        parse_sentences_from_file(grammar, parser_type, experiment, connection,
                                  test_corpus,
                                  term_labeling_strategy.prepare_parser_input,
                                  length_limit, max_test, False,
                                  ignore_punctuation, root_default_deprel,
                                  disconnected_default_deprel)
    experiment_database.finalize_database(connection)
Beispiel #2
0
    def test_multiroot(self):
        tree = multi_dep_tree()
        term_pos = the_terminal_labeling_factory().get_strategy(
            'pos').token_label
        fanout_1 = the_recursive_partitioning_factory().get_partitioning(
            'fanout-1')
        for top_level_labeling_strategy in ['strict', 'child']:
            labeling_strategy = the_labeling_factory(
            ).create_simple_labeling_strategy(top_level_labeling_strategy,
                                              'pos+deprel')
            for recursive_partitioning in [[direct_extraction], fanout_1,
                                           [left_branching]]:
                (_, grammar) = induce_grammar([tree], labeling_strategy,
                                              term_pos, recursive_partitioning,
                                              'START')
                print(grammar)

                parser = LCFRS_parser(grammar, 'pA pB pC pD pE'.split(' '))
                print(parser.best_derivation_tree())

                cleaned_tokens = copy.deepcopy(tree.full_token_yield())
                for token in cleaned_tokens:
                    token.set_edge_label('_')
                hybrid_tree = HybridTree()
                hybrid_tree = parser.dcp_hybrid_tree_best_derivation(
                    hybrid_tree, cleaned_tokens, True, construct_conll_token)
                print(hybrid_tree)
                self.assertEqual(tree, hybrid_tree)
Beispiel #3
0
    def test_minimum_risk_parsing(self):
        limit_train = 20
        limit_test = 10
        train = 'res/dependency_conll/german/tiger/train/german_tiger_train.conll'
        test = train
        parser_type = GFParser_k_best
        # test = '../../res/dependency_conll/german/tiger/test/german_tiger_test.conll'
        trees = parse_conll_corpus(train, False, limit_train)
        primary_labelling = the_labeling_factory(
        ).create_simple_labeling_strategy("childtop", "deprel")
        term_labelling = the_terminal_labeling_factory().get_strategy('pos')
        start = 'START'
        recursive_partitioning = [cfg]

        (n_trees, grammar_prim) = induce_grammar(trees, primary_labelling,
                                                 term_labelling.token_label,
                                                 recursive_partitioning, start)

        parser_type.preprocess_grammar(grammar_prim)
        tree_yield = term_labelling.prepare_parser_input

        trees = parse_conll_corpus(test, False, limit_test)

        for i, tree in enumerate(trees):
            print("Parsing sentence ", i, file=stderr)

            # print >>stderr, tree

            parser = parser_type(grammar_prim,
                                 tree_yield(tree.token_yield()),
                                 k=50)

            self.assertTrue(parser.recognized())

            derivations = [der for der in parser.k_best_derivation_trees()]
            print("# derivations: ", len(derivations), file=stderr)
            h_trees = []
            current_weight = 0
            weights = []
            derivation_list = []
            for weight, der in derivations:

                self.assertTrue(not der in derivation_list)

                derivation_list.append(der)

                dcp = DCP_evaluator(der).getEvaluation()
                h_tree = HybridTree()
                cleaned_tokens = copy.deepcopy(tree.full_token_yield())
                dcp_to_hybridtree(h_tree, dcp, cleaned_tokens, False,
                                  construct_conll_token)

                h_trees.append(h_tree)
                weights.append(weight)

            if True:
                min_risk_tree = compute_minimum_risk_tree(h_trees, weights)
                if not min_risk_tree.__eq__(h_trees[0]):
                    print(h_trees[0])
                    print(min_risk_tree)
    def test_basic_split_merge(self):
        tree = hybrid_tree_1()
        tree2 = hybrid_tree_2()
        terminal_labeling = the_terminal_labeling_factory().get_strategy('pos')

        (_, grammar) = induce_grammar(
            [tree, tree2],
            the_labeling_factory().create_simple_labeling_strategy(
                'empty', 'pos'), terminal_labeling.token_label, [cfg], 'START')

        for rule in grammar.rules():
            print(rule, file=stderr)

        print("call S/M Training", file=stderr)

        new_grammars = split_merge_training(grammar,
                                            terminal_labeling, [tree, tree2],
                                            3,
                                            5,
                                            merge_threshold=0.5,
                                            debug=False)

        for new_grammar in new_grammars:
            for i, rule in enumerate(new_grammar.rules()):
                print(i, rule, file=stderr)
            print(file=stderr)

        print("finished S/M Training", file=stderr)
Beispiel #5
0
def test_conll_grammar_induction():
    db_connection = experiment_database.initialize_database(SAMPLE_DB)

    root_default_deprel = 'ROOT'
    disconnected_default_deprel = 'PUNC'

    terminal_labeling_strategy = the_terminal_labeling_factory().get_strategy(
        'pos')

    for ignore_punctuation in [True, False]:
        for top_level, node_to_string in itertools.product(['strict', 'child'],
                                                           ['pos', 'deprel']):
            nont_labelling = label.the_labeling_factory(
            ).create_simple_labeling_strategy(top_level, node_to_string)
            for rec_par_s in [
                    'direct_extraction', 'left_branching', 'right_branching',
                    'fanout-1', 'fanout_2'
            ]:
                rec_par = grammar.induction.recursive_partitioning.the_recursive_partitioning_factory(
                ).get_partitioning(rec_par_s)
                grammar, experiment = induce_grammar_from_file(
                    CONLL_TRAIN, db_connection, nont_labelling,
                    terminal_labeling_strategy.token_label, rec_par,
                    sys.maxsize, False, 'START', ignore_punctuation)
                print()
                parse_sentences_from_file(
                    grammar, experiment, db_connection, CONLL_TEST,
                    terminal_labeling_strategy.prepare_parser_input, 20,
                    sys.maxsize, False, ignore_punctuation,
                    root_default_deprel, disconnected_default_deprel)

    experiment_database.finalize_database(db_connection)
    def test_trace_serialization(self):
        tree = hybrid_tree_1()
        tree2 = hybrid_tree_2()
        terminal_labeling = the_terminal_labeling_factory().get_strategy('pos')

        (_, grammar) = induce_grammar(
            [tree, tree2],
            the_labeling_factory().create_simple_labeling_strategy(
                'empty', 'pos'), terminal_labeling.token_label, [cfg], 'START')

        for rule in grammar.rules():
            print(rule, file=stderr)

        trace = compute_reducts(grammar, [tree, tree2], terminal_labeling)
        trace.serialize(b"/tmp/reducts.p")

        grammar_load = grammar
        trace2 = PySDCPTraceManager(grammar_load, terminal_labeling)
        trace2.load_traces_from_file(b"/tmp/reducts.p")
        trace2.serialize(b"/tmp/reducts2.p")

        with open(b"/tmp/reducts.p", "r") as f1, open(b"/tmp/reducts2.p",
                                                      "r") as f2:
            for e1, e2 in zip(f1, f2):
                self.assertEqual(e1, e2)
Beispiel #7
0
    def test_corpus_split_merge_training(self):
        train = 'res/dependency_conll/german/tiger/train/german_tiger_train.conll'
        limit_train = 100
        test = train
        # test = '../../res/dependency_conll/german/tiger/test/german_tiger_test.conll'
        trees = parse_conll_corpus(train, False, limit_train)
        primary_labelling = the_labeling_factory().create_simple_labeling_strategy("childtop", "deprel")
        term_labelling = the_terminal_labeling_factory().get_strategy('pos')
        start = 'START'
        recursive_partitioning = [cfg]

        (n_trees, grammar_prim) = induce_grammar(trees, primary_labelling, term_labelling.token_label,
                                                 recursive_partitioning, start)

        # for rule in grammar.rules():
        #    print >>stderr, rule

        trees = parse_conll_corpus(train, False, limit_train)
        print("call S/M Training", file=stderr)

        new_grammars = split_merge_training(grammar_prim, term_labelling, trees, 4, 10, tie_breaking=True, init="equal",
                                            sigma=0.05, seed=50, merge_threshold=0.1)

        print("finished S/M Training", file=stderr)

        for new_grammar in new_grammars:
            for i, rule in enumerate(new_grammar.rules()):
                print(i, rule, file=stderr)
            print(file=stderr)
    def test_basic_em_training(self):
        tree = hybrid_tree_1()
        tree2 = hybrid_tree_2()
        terminal_labeling = the_terminal_labeling_factory().get_strategy('pos')

        (_, grammar) = induce_grammar(
            [tree, tree2],
            the_labeling_factory().create_simple_labeling_strategy(
                'empty', 'pos'), terminal_labeling.token_label, [cfg], 'START')

        for rule in grammar.rules():
            print(rule, file=stderr)

        print("compute reducts", file=stderr)

        trace = compute_reducts(grammar, [tree, tree2], terminal_labeling)

        print("call em Training", file=stderr)
        emTrainer = PyEMTrainer(trace)
        emTrainer.em_training(grammar, n_epochs=10)

        print("finished em Training", file=stderr)

        for rule in grammar.rules():
            print(rule, file=stderr)
    def test_corpus_em_training(self):
        train = 'res/dependency_conll/german/tiger/train/german_tiger_train.conll'
        limit_train = 200
        test = train
        # test = '../../res/dependency_conll/german/tiger/test/german_tiger_test.conll'
        trees = parse_conll_corpus(train, False, limit_train)
        primary_labelling = the_labeling_factory(
        ).create_simple_labeling_strategy("childtop", "deprel")
        term_labelling = the_terminal_labeling_factory().get_strategy('pos')
        start = 'START'
        recursive_partitioning = [cfg]

        (n_trees, grammar_prim) = induce_grammar(trees, primary_labelling,
                                                 term_labelling.token_label,
                                                 recursive_partitioning, start)

        # for rule in grammar.rules():
        #    print >>stderr, rule

        trees = parse_conll_corpus(train, False, limit_train)

        print("compute reducts", file=stderr)

        trace = compute_reducts(grammar_prim, trees, term_labelling)

        print("call em Training", file=stderr)
        emTrainer = PyEMTrainer(trace)
        emTrainer.em_training(grammar_prim,
                              20,
                              tie_breaking=True,
                              init="equal",
                              sigma=0.05,
                              seed=50)

        print("finished em Training", file=stderr)
    def test_conll_grammar_induction():
        ignore_punctuation = True
        trees = parse_conll_corpus(TEST_FILE, False)
        trees = disconnect_punctuation(trees)
        terminal_labeling = the_terminal_labeling_factory().get_strategy('pos')
        nonterminal_labeling = the_labeling_factory(
        ).create_simple_labeling_strategy('child', 'pos')
        (_, grammar) = d_i.induce_grammar(trees, nonterminal_labeling,
                                          terminal_labeling.token_label,
                                          [direct_extraction], 'START')

        trees2 = parse_conll_corpus(TEST_FILE_MODIFIED, False)
        trees2 = disconnect_punctuation(trees2)

        for tree in trees2:
            parser = LCFRS_parser(
                grammar,
                terminal_labeling.prepare_parser_input(tree.token_yield()))
            cleaned_tokens = copy.deepcopy(tree.full_token_yield())
            for token in cleaned_tokens:
                token.set_edge_label('_')
            h_tree = HybridTree()
            h_tree = parser.dcp_hybrid_tree_best_derivation(
                h_tree, cleaned_tokens, ignore_punctuation,
                construct_conll_token)
            # print h_tree
            print('input -> hybrid-tree -> output')
            print(tree_to_conll_str(tree))
            print('parsed tokens')
            print(list(map(str, h_tree.full_token_yield())))
            print('test_parser output')
            print(tree_to_conll_str(h_tree))
Beispiel #11
0
    def test_grammar_export(self):
        tree = hybrid_tree_1()
        tree2 = hybrid_tree_2()
        terminal_labeling = the_terminal_labeling_factory().get_strategy('pos')

        _, grammar = induce_grammar(
            [tree, tree2],
            the_labeling_factory().create_simple_labeling_strategy(
                'empty', 'pos'),
            # the_labeling_factory().create_simple_labeling_strategy('child', 'pos+deprel'),
            terminal_labeling.token_label,
            [direct_extraction],
            'START')
        print(max([grammar.fanout(nont) for nont in grammar.nonts()]))
        print(grammar)

        prefix = '/tmp/'
        name = 'tmpGrammar'

        name_ = export(grammar, prefix, name)

        self.assertEqual(0, compile_gf_grammar(prefix, name_))

        GFParser.preprocess_grammar(grammar)

        string = ["NP", "N", "V", "V", "V"]

        parser = GFParser(grammar, string)

        self.assertTrue(parser.recognized())

        der = parser.best_derivation_tree()
        self.assertTrue(
            der.check_integrity_recursive(der.root_id(), grammar.start()))

        print(der)

        print(
            derivation_to_hybrid_tree(der, string,
                                      "Piet Marie helpen lezen leren".split(),
                                      construct_conll_token))

        dcp = DCP_evaluator(der).getEvaluation()

        h_tree_2 = HybridTree()
        token_sequence = [
            construct_conll_token(form, lemma)
            for form, lemma in zip('Piet Marie helpen lezen leren'.split(' '),
                                   'NP N V V V'.split(' '))
        ]
        dcp_to_hybridtree(h_tree_2, dcp, token_sequence, False,
                          construct_conll_token)

        print(h_tree_2)
Beispiel #12
0
 def _test_dependency_induction(self):
     tree = self.get_single_tree()
     grammar = dependency_induce_grammar(
         trees=[tree],
         nont_labelling=the_labeling_factory(
         ).create_simple_labeling_strategy('empty', 'pos'),
         term_labelling=the_terminal_labeling_factory().get_strategy(
             'form').token_label,
         recursive_partitioning=[
             lambda tree: PartitionBuilder(choice_function=min,
                                           split_function=monadic_split)
             (tree)
         ])
     print(grammar)
    def test_dcp_evaluation_with_induced_dependency_grammar(self):
        tree = hybrid_tree_1()

        print(tree)

        tree2 = hybrid_tree_2()

        print(tree2)
        # print tree.recursive_partitioning()

        labeling = the_labeling_factory().create_simple_labeling_strategy(
            'child', 'pos')
        term_pos = the_terminal_labeling_factory().get_strategy(
            'pos').token_label
        (_, grammar) = induce_grammar([tree, tree2], labeling, term_pos,
                                      [direct_extraction], 'START')

        # print grammar

        self.assertEqual(grammar.well_formed(), None)
        self.assertEqual(grammar.ordered()[0], True)
        # print max([grammar.fanout(nont) for nont in grammar.nonts()])
        print(grammar)

        parser = Parser(grammar, 'NP N V V'.split(' '))

        self.assertEqual(parser.recognized(), True)

        for item in parser.successful_root_items():
            der = Derivation()
            derivation_tree(der, item, None)
            print(der)

            hybrid_tree = derivation_to_hybrid_tree(
                der, 'NP N V V'.split(' '),
                'Piet Marie helpen lezen'.split(' '),
                construct_constituent_token)
            print(hybrid_tree)

            dcp = DCP_evaluator(der).getEvaluation()
            h_tree_2 = HybridTree()
            token_sequence = [
                construct_conll_token(form, lemma)
                for form, lemma in zip('Piet Marie helpen lezen'.split(' '),
                                       'NP N V V'.split(' '))
            ]
            dcp_to_hybridtree(h_tree_2, dcp, token_sequence, False,
                              construct_conll_token)
Beispiel #14
0
    def test_cfg_parser(self):
        tree = hybrid_tree_1()
        tree2 = hybrid_tree_2()
        terminal_labeling = the_terminal_labeling_factory().get_strategy('pos')

        (_, grammar) = induce_grammar(
            [tree, tree2],
            the_labeling_factory().create_simple_labeling_strategy(
                'empty', 'pos'), terminal_labeling.token_label, [cfg], 'START')

        for parser_class in [LCFRS_parser, CFGParser]:

            parser_class.preprocess_grammar(grammar)

            string = ["NP", "N", "V", "V", "V"]

            parser = parser_class(grammar, string)

            self.assertTrue(parser.recognized())

            der = parser.best_derivation_tree()
            self.assertTrue(
                der.check_integrity_recursive(der.root_id(), grammar.start()))

            print(der)

            print(
                derivation_to_hybrid_tree(
                    der, string, "Piet Marie helpen lezen leren".split(),
                    construct_conll_token))

            dcp = DCP_evaluator(der).getEvaluation()

            h_tree_2 = HybridTree()
            token_sequence = [
                construct_conll_token(form, lemma) for form, lemma in zip(
                    'Piet Marie helpen lezen leren'.split(' '),
                    'NP N V V V'.split(' '))
            ]
            dcp_to_hybridtree(h_tree_2, dcp, token_sequence, False,
                              construct_conll_token)

            print(h_tree_2)
    def test_basic_sdcp_parsing_dependency(self):
        tree1 = hybrid_tree_1()
        tree2 = hybrid_tree_2()

        terminal_labeling = the_terminal_labeling_factory().get_strategy('pos')

        (_, grammar) = induce_grammar(
            [tree1, tree2],
            the_labeling_factory().create_simple_labeling_strategy(
                'empty', 'pos'), terminal_labeling.token_label, [cfg], 'START')

        print("grammar induced. Printing rules...", file=stderr)

        for rule in grammar.rules():
            print(rule, file=stderr)

        parser_type = LCFRS_sDCP_Parser

        print("preprocessing grammar", file=stderr)

        parser_type.preprocess_grammar(grammar, terminal_labeling)

        print("invoking parser", file=stderr)

        parser = parser_type(grammar, tree1)

        print("listing derivations", file=stderr)

        for der in parser.all_derivation_trees():
            print(der)
            output_tree = HybridTree()
            tokens = tree1.token_yield()
            dcp_to_hybridtree(output_tree,
                              DCP_evaluator(der).getEvaluation(), tokens,
                              False, construct_conll_token)
            print(tree1)
            print(output_tree)

        print("completed test", file=stderr)
    def generic_parsing_test(self, parser_type, limit_train, limit_test,
                             compare_order):
        def filter_by_id(n, trees):
            j = 0
            for tree in trees:
                if j in n:
                    yield tree
                j += 1

        #params
        train = 'res/dependency_conll/german/tiger/train/german_tiger_train.conll'
        test = train
        # test = 'res/dependency_conll/german/tiger/test/german_tiger_test.conll'
        trees = parse_conll_corpus(train, False, limit_train)
        primary_labelling = the_labeling_factory(
        ).create_simple_labeling_strategy("childtop", "deprel")
        term_labelling = the_terminal_labeling_factory().get_strategy('pos')
        start = 'START'
        recursive_partitioning = [cfg]

        (n_trees, grammar_prim) = induce_grammar(trees, primary_labelling,
                                                 term_labelling.token_label,
                                                 recursive_partitioning, start)

        parser_type.preprocess_grammar(grammar_prim, term_labelling)

        trees = parse_conll_corpus(test, False, limit_test)

        count_derivs = {}
        no_complete_match = 0

        for i, tree in enumerate(trees):
            print("Parsing tree for ", i, file=stderr)

            print(tree, file=stderr)

            parser = parser_type(grammar_prim, tree)
            self.assertTrue(parser.recognized())
            count_derivs[i] = 0

            print("Found derivations for ", i, file=stderr)
            j = 0

            derivations = []

            for der in parser.all_derivation_trees():
                self.assertTrue(
                    der.check_integrity_recursive(der.root_id(), start))

                print(count_derivs[i], file=stderr)
                print(der, file=stderr)

                output_tree = HybridTree()
                tokens = tree.token_yield()

                the_yield = der.compute_yield()
                # print >>stderr, the_yield
                tokens2 = list(
                    map(lambda pos: construct_conll_token('_', pos),
                        the_yield))

                dcp_to_hybridtree(output_tree,
                                  DCP_evaluator(der).getEvaluation(),
                                  tokens2,
                                  False,
                                  construct_conll_token,
                                  reorder=False)
                print(tree, file=stderr)
                print(output_tree, file=stderr)

                self.compare_hybrid_trees(tree, output_tree, compare_order)
                count_derivs[i] += 1
                derivations.append(der)

            self.assertTrue(
                sDCPParserTest.pairwise_different(
                    derivations, sDCPParserTest.compare_derivations))
            self.assertEqual(len(derivations), count_derivs[i])

            if count_derivs[i] == 0:
                no_complete_match += 1

        for key in count_derivs:
            print(key, count_derivs[key])

        print("# trees with no complete match:", no_complete_match)
Beispiel #17
0
    def test_single_root_induction(self):
        tree = hybrid_tree_1()
        # print tree.children("v")
        # print tree
        #
        # for id_set in ['v v1 v2 v21'.split(' '), 'v1 v2'.split(' '),
        # 'v v21'.split(' '), ['v'], ['v1'], ['v2'], ['v21']]:
        # print id_set, 'top:', top(tree, id_set), 'bottom:', bottom(tree, id_set)
        # print id_set, 'top_max:', max(tree, top(tree, id_set)), 'bottom_max:', max(tree, bottom(tree, id_set))
        #
        # print "some rule"
        # for mem, arg in [(-1, 0), (0,0), (1,0)]:
        # print create_DCP_rule(mem, arg, top_max(tree, ['v','v1','v2','v21']), bottom_max(tree, ['v','v1','v2','v21']),
        # [(top_max(tree, l), bottom_max(tree, l)) for l in [['v1', 'v2'], ['v', 'v21']]])
        #
        #
        # print "some other rule"
        # for mem, arg in [(-1,1),(1,0)]:
        # print create_DCP_rule(mem, arg, top_max(tree, ['v1','v2']), bottom_max(tree, ['v1','v2']),
        # [(top_max(tree, l), bottom_max(tree, l)) for l in [['v1'], ['v2']]])
        #
        # print 'strict:' , strict_labeling(tree, top_max(tree, ['v','v21']), bottom_max(tree, ['v','v21']))
        # print 'child:' , child_labeling(tree, top_max(tree, ['v','v21']), bottom_max(tree, ['v','v21']))
        # print '---'
        # print 'strict: ', strict_labeling(tree, top_max(tree, ['v1','v21']), bottom_max(tree, ['v1','v21']))
        # print 'child: ', child_labeling(tree, top_max(tree, ['v1','v21']), bottom_max(tree, ['v1','v21']))
        # print '---'
        # print 'strict:' , strict_labeling(tree, top_max(tree, ['v','v1', 'v21']), bottom_max(tree, ['v','v1', 'v21']))
        # print 'child:' , child_labeling(tree, top_max(tree, ['v','v1', 'v21']), bottom_max(tree, ['v','v1', 'v21']))

        tree2 = hybrid_tree_2()

        # print tree2.children("v")
        # print tree2
        #
        # print 'siblings v211', tree2.siblings('v211')
        # print top(tree2, ['v','v1', 'v211'])
        # print top_max(tree2, ['v','v1', 'v211'])
        #
        # print '---'
        # print 'strict:' , strict_labeling(tree2, top_max(tree2, ['v','v1', 'v211']), bottom_max(tree2, ['v','v11', 'v211']))
        # print 'child:' , child_labeling(tree2, top_max(tree2, ['v','v1', 'v211']), bottom_max(tree2, ['v','v11', 'v211']))

        # rec_par = ('v v1 v2 v21'.split(' '),
        # [('v1 v2'.split(' '), [(['v1'],[]), (['v2'],[])])
        #                ,('v v21'.split(' '), [(['v'],[]), (['v21'],[])])
        #            ])
        #
        # grammar = LCFRS(nonterminal_str(tree, top_max(tree, rec_par[0]), bottom_max(tree, rec_par[0]), 'strict'))
        #
        # add_rules_to_grammar_rec(tree, rec_par, grammar, 'child')
        #
        # grammar.make_proper()
        # print grammar

        print(tree.recursive_partitioning())

        terminal_labeling = the_terminal_labeling_factory().get_strategy('pos')

        (_, grammar) = induce_grammar(
            [tree, tree2],
            the_labeling_factory().create_simple_labeling_strategy(
                'empty', 'pos'),
            # the_labeling_factory().create_simple_labeling_strategy('child', 'pos+deprel'),
            terminal_labeling.token_label,
            [direct_extraction],
            'START')
        print(max([grammar.fanout(nont) for nont in grammar.nonts()]))
        print(grammar)

        parser = LCFRS_parser(grammar, 'NP N V V'.split(' '))
        print(parser.best_derivation_tree())

        tokens = [
            construct_conll_token(form, pos) for form, pos in zip(
                'Piet Marie helpen lezen'.split(' '), 'NP N V V'.split(' '))
        ]
        hybrid_tree = HybridTree()
        hybrid_tree = parser.dcp_hybrid_tree_best_derivation(
            hybrid_tree, tokens, True, construct_conll_token)
        print(list(map(str, hybrid_tree.full_token_yield())))
        print(hybrid_tree)

        string = "foo"
        dcp_string = DCP_string(string)
        dcp_string.set_edge_label("bar")
        print(dcp_string, dcp_string.edge_label())

        linearize(
            grammar,
            the_labeling_factory().create_simple_labeling_strategy(
                'child', 'pos+deprel'),
            the_terminal_labeling_factory().get_strategy('pos'), sys.stdout)
def trainAndEval(strategy,
                 labelling1,
                 labelling2,
                 fanout,
                 parser_type,
                 train,
                 test,
                 cDT,
                 parseStrings,
                 ignore_punctuation=False):
    file = open('results.txt', 'a')
    term_labelling = the_terminal_labeling_factory().get_strategy('pos')
    recursive_partitioning = d_i.the_recursive_partitioning_factory(
    ).get_partitioning('fanout-' + str(fanout) + strategy)
    primary_labelling = d_l.the_labeling_factory(
    ).create_simple_labeling_strategy(labelling1, labelling2)

    trees = parse_conll_corpus(train, False, train_limit)
    if ignore_punctuation:
        trees = disconnect_punctuation(trees)
    (n_trees, grammar) = d_i.induce_grammar(trees, primary_labelling,
                                            term_labelling.token_label,
                                            recursive_partitioning, start)

    # write current transformation strategy and hyperparameters to results.txt
    if strategy == '':
        file.write('rtl ' + labelling1 + ' ' + labelling2 +
                   '    maximal fanout:' + fanout)
    else:
        splitList = strategy.split('-')
        if splitList[1] == 'left':
            file.write('ltr ' + labelling1 + ' ' + labelling2 +
                       '    maximal fanout:' + fanout)
        elif splitList[1] == 'random':
            file.write('random seed:' + splitList[2] + ' ' + labelling1 + ' ' +
                       labelling2 + ' maximal fanout:' + fanout)
        elif splitList[1] == 'no':
            if splitList[4] == 'random':
                file.write('nnont fallback:random seed:' + splitList[5] + ' ' +
                           labelling1 + ' ' + labelling2 + ' maximal fanout:' +
                           fanout)
            elif splitList[4] == 'ltr':
                file.write('nnont fallback:ltr' + ' ' + labelling1 + ' ' +
                           labelling2 + ' maximal fanout:' + fanout)
            elif splitList[4] == 'rtl':
                file.write('nnont fallback:rtl' + ' ' + labelling1 + ' ' +
                           labelling2 + ' maximal fanout:' + fanout)
            else:
                file.write('nnont fallback:argmax' + ' ' + labelling1 + ' ' +
                           labelling2 + ' maximal fanout:' + fanout)
        else:  #argmax
            file.write('argmax ' + labelling1 + ' ' + labelling2 +
                       ' maximal fanout:' + fanout)
    file.write('\n')

    res = ''

    res += '#nonts:' + str(len(grammar.nonts()))
    res += ' #rules:' + str(len(grammar.rules()))

    file.write(res)
    res = ''

    # The following code is to count the number of derivations for a hypergraph (tree parser required)
    if cDT == True:
        tree_parser.preprocess_grammar(grammar, term_labelling)

        trees = parse_conll_corpus(train, False, train_limit)
        if ignore_punctuation:
            trees = disconnect_punctuation(trees)

        derCount = 0
        derMax = 0
        for tree in trees:
            parser = tree_parser(grammar, tree)  # if tree parser is used
            der = parser.count_derivation_trees()
            if der > derMax:
                derMax = der
            derCount += der

        res += "\n#derivation trees:  average: " + str(
            1.0 * derCount / n_trees)
        res += " maximal: " + str(derMax)
    file.write(res)

    res = ''
    total_time = 0.0

    # The following code works for string parsers for evaluating
    if parseStrings == True:
        parser_type.preprocess_grammar(grammar)

        trees = parse_conll_corpus(test, False, test_limit)
        if ignore_punctuation:
            trees = disconnect_punctuation(trees)

        i = 0
        with open(result, 'w') as result_file:
            failures = 0
            for tree in trees:
                time_stamp = time.clock()
                i += i
                #if (i % 100 == 0):
                #print '.',
                #sys.stdout.flush()

                parser = parser_type(grammar, tree_yield(tree.token_yield()))

                time_stamp = time.clock() - time_stamp
                total_time += time_stamp

                cleaned_tokens = copy.deepcopy(tree.full_token_yield())
                for token in cleaned_tokens:
                    token.set_edge_label('_')
                h_tree = HybridTree(tree.sent_label())
                h_tree = parser.dcp_hybrid_tree_best_derivation(
                    h_tree, cleaned_tokens, ignore_punctuation,
                    construct_conll_token)

                if h_tree:
                    result_file.write(tree_to_conll_str(h_tree))
                    result_file.write('\n\n')
                else:
                    failures += 1
                    forms = [token.form() for token in tree.full_token_yield()]
                    poss = [token.pos() for token in tree.full_token_yield()]
                    result_file.write(
                        tree_to_conll_str(
                            fall_back_left_branching_token(cleaned_tokens)))
                    result_file.write('\n\n')

        res += "\nattachment scores:\nno punctuation: "
        out = subprocess.check_output(
            ["perl", "../util/eval.pl", "-g", test, "-s", result, "-q"])
        match = re.search(r'[^=]*= (\d+\.\d+)[^=]*= (\d+.\d+).*', out)
        res += ' labelled:' + match.group(1)  #labeled attachment score
        res += ' unlabelled:' + match.group(2)  #unlabeled attachment score
        res += "\npunctation: "
        out = subprocess.check_output(
            ["perl", "../util/eval.pl", "-g", test, "-s", result, "-q", "-p"])
        match = re.search(r'[^=]*= (\d+\.\d+)[^=]*= (\d+.\d+).*', out)
        res += ' labelled:' + match.group(1)
        res += ' unlabelled:' + match.group(2)

        res += "\nparse time: " + str(total_time)

    file.write(res)
    file.write('\n\n\n')
    file.close()
                yield tree
        else:
            self._trees = []
            for tree in length_limit(
                    parse_conll_corpus(self._path,
                                       False,
                                       limit=self._end,
                                       start=self._start), self._max_length):
                self._trees.append(tree)
                yield tree


# term_labelling =  #d_i.the_terminal_labeling_factory().get_strategy('pos')
# recursive_partitioning = d_i.the_recursive_partitioning_factory().getPartitioning('fanout-1')
# recursive_partitioning = d_i.the_recursive_partitioning_factory().getPartitioning('left-branching')
primary_labelling = d_l.the_labeling_factory().create_simple_labeling_strategy(
    'child', 'pos+deprel')
secondary_labelling = d_l.the_labeling_factory(
).create_simple_labeling_strategy('strict', 'deprel')
ternary_labelling = d_l.the_labeling_factory().create_simple_labeling_strategy(
    'child', 'deprel')
child_top_labelling = d_l.the_labeling_factory(
).create_simple_labeling_strategy('childtop', 'deprel')
empty_labelling = d_l.the_labeling_factory().create_simple_labeling_strategy(
    'empty', 'pos')

ignore_punctuation = False


@plac.annotations(
    recompileGrammar=('force (repeated) grammar induction and compilation',
                      'option', None, str),
from hybridtree.monadic_tokens import construct_conll_token
import dependency.induction as d_i
import dependency.labeling as d_l
import time
import parser.parser_factory
import copy
import subprocess

TEST = 'res/negra-dep/negra-lower-punct-test.conll'
TRAIN = 'res/negra-dep/negra-lower-punct-train.conll'
RESULT = '.tmp/cascade-parse-results.conll'
START = 'START'
TERMINAL_LABELLING = grammar.induction.terminal_labeling.the_terminal_labeling_factory().get_strategy('pos')
RECURSIVE_PARTITIONING = grammar.induction.recursive_partitioning.the_recursive_partitioning_factory(). \
    get_partitioning('fanout-1')
PRIMARY_LABELLING = d_l.the_labeling_factory().create_simple_labeling_strategy('child', 'pos+deprel')
SECONDARY_LABELLING = d_l.the_labeling_factory().create_simple_labeling_strategy('child', 'pos')
TERNARY_LABELLING = d_l.the_labeling_factory().create_simple_labeling_strategy('child', 'deprel')

PARSER_TYPE = parser.parser_factory.the_parser_factory().getParser("fanout-1")
TREE_YIELD = TERMINAL_LABELLING.prepare_parser_input


def main(limit=100000, ignore_punctuation=False):
    if PARSER_TYPE.__name__ != 'GFParser':
        print('GFParser not found, using', PARSER_TYPE.__name__, 'instead!')
        print('Please install grammatical framework to reproduce experiments.')

    test_limit = 10000
    trees = parse_conll_corpus(TRAIN, False, limit)
    if ignore_punctuation:
def main():
    # induce grammar from a corpus
    trees = parse_conll_corpus(train, False, limit_train)
    nonterminal_labelling = the_labeling_factory(
    ).create_simple_labeling_strategy("childtop", "deprel")
    term_labelling = the_terminal_labeling_factory().get_strategy('pos')
    start = 'START'
    recursive_partitioning = [cfg]
    _, grammar = induce_grammar(trees, nonterminal_labelling,
                                term_labelling.token_label,
                                recursive_partitioning, start)

    # compute some derivations
    derivations = obtain_derivations(grammar, term_labelling)

    # create derivation manager and add derivations
    manager = PyDerivationManager(grammar)
    manager.convert_derivations_to_hypergraphs(derivations)
    manager.serialize(b"/tmp/derivations.txt")

    # build and configure split/merge trainer and supplementary objects

    rule_to_nonterminals = []
    for i in range(0, len(grammar.rule_index())):
        rule = grammar.rule_index(i)
        nonts = [
            manager.get_nonterminal_map().object_index(rule.lhs().nont())
        ] + [
            manager.get_nonterminal_map().object_index(nont)
            for nont in rule.rhs()
        ]
        rule_to_nonterminals.append(nonts)

    grammarInfo = PyGrammarInfo(grammar, manager.get_nonterminal_map())
    storageManager = PyStorageManager()
    builder = PySplitMergeTrainerBuilder(manager, grammarInfo)
    builder.set_em_epochs(20)
    builder.set_percent_merger(60.0)

    splitMergeTrainer = builder.build()

    latentAnnotation = [
        build_PyLatentAnnotation_initial(grammar, grammarInfo, storageManager)
    ]

    for i in range(max_cycles + 1):
        latentAnnotation.append(
            splitMergeTrainer.split_merge_cycle(latentAnnotation[-1]))
        # pickle.dump(map(lambda la: la.serialize(), latentAnnotation), open(sm_info_path, 'wb'))
        smGrammar = build_sm_grammar(latentAnnotation[i],
                                     grammar,
                                     grammarInfo,
                                     rule_pruning=0.0001,
                                     rule_smoothing=0.01)
        print("Cycle: ", i, "Rules: ", len(smGrammar.rules()))

        if parsing:
            parser = GFParser(smGrammar)

            trees = parse_conll_corpus(test, False, limit_test)
            for tree in trees:
                parser.set_input(
                    term_labelling.prepare_parser_input(tree.token_yield()))
                parser.parse()
                if parser.recognized():
                    print(
                        derivation_to_hybrid_tree(
                            parser.best_derivation_tree(),
                            [token.pos() for token in tree.token_yield()],
                            [token.form() for token in tree.token_yield()],
                            construct_constituent_token))
Beispiel #22
0
    def test_k_best_parsing(self):
        limit_train = 20
        limit_test = 10
        train = 'res/dependency_conll/german/tiger/train/german_tiger_train.conll'
        test = train
        parser_type = GFParser_k_best
        # test = '../../res/dependency_conll/german/tiger/test/german_tiger_test.conll'
        trees = parse_conll_corpus(train, False, limit_train)
        primary_labelling = the_labeling_factory(
        ).create_simple_labeling_strategy("childtop", "deprel")
        term_labelling = the_terminal_labeling_factory().get_strategy('pos')
        start = 'START'
        recursive_partitioning = [cfg]

        (n_trees, grammar_prim) = induce_grammar(trees, primary_labelling,
                                                 term_labelling.token_label,
                                                 recursive_partitioning, start)

        parser_type.preprocess_grammar(grammar_prim)
        tree_yield = term_labelling.prepare_parser_input

        trees = parse_conll_corpus(test, False, limit_test)

        for i, tree in enumerate(trees):
            print("Parsing sentence ", i, file=stderr)

            # print >>stderr, tree

            parser = parser_type(grammar_prim,
                                 tree_yield(tree.token_yield()),
                                 k=50)

            self.assertTrue(parser.recognized())

            derivations = [der for der in parser.k_best_derivation_trees()]
            print("# derivations: ", len(derivations), file=stderr)
            h_trees = []
            current_weight = 0
            weights = []
            derivation_list = []
            for weight, der in derivations:
                # print >>stderr, exp(-weight)
                # print >>stderr, der

                self.assertTrue(not der in derivation_list)

                derivation_list.append(der)

                # TODO this should hold, but it looks like a GF bug!
                # self.assertGreaterEqual(weight, current_weight)
                current_weight = weight

                dcp = DCP_evaluator(der).getEvaluation()
                h_tree = HybridTree()
                cleaned_tokens = copy.deepcopy(tree.full_token_yield())
                dcp_to_hybridtree(h_tree, dcp, cleaned_tokens, False,
                                  construct_conll_token)

                h_trees.append(h_tree)
                weights.append(weight)

                # print >>stderr, h_tree

            # print a matrix indicating which derivations result
            # in the same hybrid tree
            if True:
                for i, h_tree1 in enumerate(h_trees):
                    for h_tree2 in h_trees:
                        if h_tree1 == h_tree2:
                            print("x", end=' ', file=stderr)
                        else:
                            print("", end=' ', file=stderr)
                    print(weights[i], file=stderr)
                print(file=stderr)
def main(limit=3000,
         test_limit=sys.maxint,
         max_length=sys.maxint,
         dir=dir,
         train='../res/negra-dep/negra-lower-punct-train.conll',
         test='../res/negra-dep/negra-lower-punct-test.conll',
         recursive_partitioning='cfg',
         nonterminal_labeling='childtop-deprel',
         terminal_labeling='form-unk-30/pos',
         emEpochs=20,
         emTieBreaking=True,
         emInit="rfe",
         splitRandomization=1.0,
         mergePercentage=85.0,
         smCycles=6,
         rule_pruning=0.0001,
         rule_smoothing=0.01,
         validation=True,
         validationMethod='likelihood',
         validationCorpus=None,
         validationSplit=20,
         validationDropIterations=6,
         seed=1337,
         discr=False,
         maxScaleDiscr=10,
         recompileGrammar="True",
         retrain=False,
         parsing=True,
         reparse=False,
         parser="CFG",
         k_best=50,
         minimum_risk=False,
         oracle_parse=False):

    # set various parameters
    recompileGrammar = True if recompileGrammar == "True" else False

    # print(recompileGrammar)

    def result(gram, add=None):
        if add is not None:
            return os.path.join(
                dir, gram + '_experiment_parse_results_' + add + '.conll')
        else:
            return os.path.join(dir, gram + '_experiment_parse_results.conll')

    recursive_partitioning = grammar.induction.recursive_partitioning.the_recursive_partitioning_factory(
    ).get_partitioning(recursive_partitioning)
    top_level, low_level = tuple(nonterminal_labeling.split('-'))
    nonterminal_labeling = d_l.the_labeling_factory(
    ).create_simple_labeling_strategy(top_level, low_level)

    if parser == "CFG":
        assert all([
            rp.__name__
            in ["left_branching", "right_branching", "cfg", "fanout_1"]
            for rp in recursive_partitioning
        ])
        parser = CFGParser
    elif parser == "GF":
        parser = GFParser
    elif parser == "GF-k-best":
        parser = GFParser_k_best
    elif parser == "CoarseToFine":
        parser = Coarse_to_fine_parser
    elif parser == "FST":
        if recursive_partitioning == "left_branching":
            parser = LeftBranchingFSTParser
        elif recursive_partitioning == "right_branching":
            parser = RightBranchingFSTParser
        else:
            assert False and "expect left/right branching recursive partitioning for FST parsing"

    if validation:
        if validationCorpus is not None:
            corpus_validation = Corpus(validationCorpus)
            train_limit = limit
        else:
            train_limit = int(limit * (100.0 - validationSplit) / 100.0)
            corpus_validation = Corpus(train, start=train_limit, end=limit)
    else:
        train_limit = limit

    corpus_induce = Corpus(train, end=limit)
    corpus_train = Corpus(train, end=train_limit)
    corpus_test = Corpus(test, end=test_limit)

    match = re.match(r'^form-unk-(\d+)-morph.*$', terminal_labeling)
    if match:
        unk_threshold = int(match.group(1))
        term_labelling = grammar.induction.terminal_labeling.FormPosTerminalsUnkMorph(
            corpus_induce.get_trees(),
            unk_threshold,
            pos_filter=["NE", "CARD"],
            add_morph={
                'NN': ['case', 'number', 'gender']
                # , 'NE': ['case', 'number', 'gender']
                # , 'VMFIN': ['number', 'person']
                # , 'VVFIN': ['number', 'person']
                # , 'VAFIN': ['number', 'person']
            })
    else:
        match = re.match(r'^form-unk-(\d+).*$', terminal_labeling)
        if match:
            unk_threshold = int(match.group(1))
            term_labelling = grammar.induction.terminal_labeling.FormPosTerminalsUnk(
                corpus_induce.get_trees(),
                unk_threshold,
                pos_filter=["NE", "CARD"])
        else:
            term_labelling = grammar.induction.terminal_labeling.the_terminal_labeling_factory(
            ).get_strategy(terminal_labeling)

    if not os.path.isdir(dir):
        os.makedirs(dir)

    # start actual training
    # we use the training corpus until limit for grammar induction (i.e., also the validation section)
    print("Computing baseline id: ")
    baseline_id = grammar_id(corpus_induce, nonterminal_labeling,
                             term_labelling, recursive_partitioning)
    print(baseline_id)
    baseline_path = compute_grammar_name(dir, baseline_id, "baseline")

    if recompileGrammar or not os.path.isfile(baseline_path):
        print("Inducing grammar from corpus")
        (n_trees, baseline_grammar) = d_i.induce_grammar(
            corpus_induce.get_trees(), nonterminal_labeling,
            term_labelling.token_label, recursive_partitioning, start)
        print("Induced grammar using", n_trees, ".")
        pickle.dump(baseline_grammar, open(baseline_path, 'wb'))
    else:
        print("Loading grammar from file")
        baseline_grammar = pickle.load(open(baseline_path))

    print("Rules: ", len(baseline_grammar.rules()))

    if parsing:
        parser_ = GFParser_k_best if parser == Coarse_to_fine_parser else parser
        baseline_parser = do_parsing(baseline_grammar,
                                     corpus_test,
                                     term_labelling,
                                     result,
                                     baseline_id,
                                     parser_,
                                     k_best=k_best,
                                     minimum_risk=minimum_risk,
                                     oracle_parse=oracle_parse,
                                     recompile=recompileGrammar,
                                     dir=dir,
                                     reparse=reparse)

    if True:
        em_trained = pickle.load(open(baseline_path))
        reduct_path = compute_reduct_name(dir, baseline_id, corpus_train)
        if recompileGrammar or not os.path.isfile(reduct_path):
            trace = compute_reducts(em_trained, corpus_train.get_trees(),
                                    term_labelling)
            trace.serialize(reduct_path)
        else:
            print("loading trace")
            trace = PySDCPTraceManager(em_trained, term_labelling)
            trace.load_traces_from_file(reduct_path)

        if discr:
            reduct_path_discr = compute_reduct_name(dir, baseline_id,
                                                    corpus_train, '_discr')
            if recompileGrammar or not os.path.isfile(reduct_path_discr):
                trace_discr = compute_LCFRS_reducts(
                    em_trained,
                    corpus_train.get_trees(),
                    terminal_labelling=term_labelling,
                    nonterminal_map=trace.get_nonterminal_map())
                trace_discr.serialize(reduct_path_discr)
            else:
                print("loading trace discriminative")
                trace_discr = PyLCFRSTraceManager(em_trained,
                                                  trace.get_nonterminal_map())
                trace_discr.load_traces_from_file(reduct_path_discr)

        # todo refactor EM training, to use the LA version (but without any splits)
        """
        em_trained_path_ = em_trained_path(dir, grammar_id, n_epochs=emEpochs, init=emInit, tie_breaking=emTieBreaking, seed=seed)

        if recompileGrammar or retrain or not os.path.isfile(em_trained_path_):
            emTrainer = PyEMTrainer(trace)
            emTrainer.em_training(em_trained, n_epochs=emEpochs, init=emInit, tie_breaking=emTieBreaking, seed=seed)
            pickle.dump(em_trained, open(em_trained_path_, 'wb'))
        else:
            em_trained = pickle.load(open(em_trained_path_, 'rb'))

        if parsing:
            do_parsing(em_trained, test_limit, ignore_punctuation, term_labelling, recompileGrammar or retrain, [dir, "em_trained_gf_grammar"])
        """

        grammarInfo = PyGrammarInfo(baseline_grammar,
                                    trace.get_nonterminal_map())
        storageManager = PyStorageManager()

        builder = PySplitMergeTrainerBuilder(trace, grammarInfo)
        builder.set_em_epochs(emEpochs)
        builder.set_smoothing_factor(rule_smoothing)
        builder.set_split_randomization(splitRandomization, seed + 1)
        if discr:
            builder.set_discriminative_expector(trace_discr,
                                                maxScale=maxScaleDiscr,
                                                threads=1)
        else:
            builder.set_simple_expector(threads=1)
        if validation:
            if validationMethod is "likelihood":
                reduct_path_validation = compute_reduct_name(
                    dir, baseline_id, corpus_validation)
                if recompileGrammar or not os.path.isfile(
                        reduct_path_validation):
                    validation_trace = compute_reducts(
                        em_trained, corpus_validation.get_trees(),
                        term_labelling)
                    validation_trace.serialize(reduct_path_validation)
                else:
                    print("loading trace validation")
                    validation_trace = PySDCPTraceManager(
                        em_trained, term_labelling)
                    validation_trace.load_traces_from_file(
                        reduct_path_validation)
                builder.set_simple_validator(validation_trace,
                                             maxDrops=validationDropIterations,
                                             threads=1)
            else:
                validator = build_score_validator(
                    baseline_grammar, grammarInfo, trace.get_nonterminal_map(),
                    storageManager, term_labelling, baseline_parser,
                    corpus_validation, validationMethod)
                builder.set_score_validator(validator,
                                            validationDropIterations)
        splitMergeTrainer = builder.set_percent_merger(mergePercentage).build()
        if validation:
            splitMergeTrainer.setMaxDrops(1, mode="smoothing")
            splitMergeTrainer.setEMepochs(1, mode="smoothing")

        sm_info_path = compute_sm_info_path(dir, baseline_id, emEpochs,
                                            rule_smoothing, splitRandomization,
                                            seed, discr, validation,
                                            corpus_validation, emInit)

        if (not recompileGrammar) and (
                not retrain) and os.path.isfile(sm_info_path):
            print("Loading splits and weights of LA rules")
            latentAnnotation = map(
                lambda t: build_PyLatentAnnotation(t[0], t[1], t[
                    2], grammarInfo, storageManager),
                pickle.load(open(sm_info_path, 'rb')))
        else:
            # latentAnnotation = [build_PyLatentAnnotation_initial(em_trained, grammarInfo, storageManager)]
            latentAnnotation = [
                build_PyLatentAnnotation_initial(baseline_grammar, grammarInfo,
                                                 storageManager)
            ]

        for cycle in range(smCycles + 1):
            if cycle < len(latentAnnotation):
                smGrammar = latentAnnotation[cycle].build_sm_grammar(
                    baseline_grammar,
                    grammarInfo,
                    rule_pruning=rule_pruning
                    # , rule_smoothing=rule_smoothing
                )
            else:
                # setting the seed to achieve reproducibility in case of continued training
                splitMergeTrainer.reset_random_seed(seed + cycle + 1)
                latentAnnotation.append(
                    splitMergeTrainer.split_merge_cycle(latentAnnotation[-1]))
                pickle.dump(map(lambda la: la.serialize(), latentAnnotation),
                            open(sm_info_path, 'wb'))
                smGrammar = latentAnnotation[cycle].build_sm_grammar(
                    baseline_grammar,
                    grammarInfo,
                    rule_pruning=rule_pruning
                    # , rule_smoothing=rule_smoothing
                )
            print("Cycle: ", cycle, "Rules: ", len(smGrammar.rules()))
            if parsing:
                grammar_identifier = compute_sm_grammar_id(
                    baseline_id, emEpochs, rule_smoothing, splitRandomization,
                    seed, discr, validation, corpus_validation, emInit, cycle)
                if parser == Coarse_to_fine_parser:
                    opt = {
                        'latentAnnotation':
                        latentAnnotation[:cycle + 1]  #[cycle]
                        ,
                        'grammarInfo': grammarInfo,
                        'nontMap': trace.get_nonterminal_map()
                    }
                    do_parsing(baseline_grammar,
                               corpus_test,
                               term_labelling,
                               result,
                               grammar_identifier,
                               parser,
                               k_best=k_best,
                               minimum_risk=minimum_risk,
                               oracle_parse=oracle_parse,
                               recompile=recompileGrammar,
                               dir=dir,
                               reparse=reparse,
                               opt=opt)
                else:
                    do_parsing(smGrammar,
                               corpus_test,
                               term_labelling,
                               result,
                               grammar_identifier,
                               parser,
                               k_best=k_best,
                               minimum_risk=minimum_risk,
                               oracle_parse=oracle_parse,
                               recompile=recompileGrammar,
                               dir=dir,
                               reparse=reparse)
Beispiel #24
0
    def test_fst_compilation_right(self):
        if not test_pynini:
            return
        tree = hybrid_tree_1()
        tree2 = hybrid_tree_2()
        terminal_labeling = the_terminal_labeling_factory().get_strategy('pos')

        (_, grammar) = induce_grammar(
            [tree, tree2],
            the_labeling_factory().create_simple_labeling_strategy(
                'empty', 'pos'), terminal_labeling.token_label,
            [right_branching], 'START')

        a, rules = compile_wfst_from_right_branching_grammar(grammar)

        print(repr(a))

        symboltable = a.input_symbols()

        string = 'NP N V V V'.split(' ')

        token_sequence = [
            construct_conll_token(form, lemma) for form, lemma in zip(
                'Piet Marie helpen leren lezen'.split(' '), string)
        ]

        fsa = fsa_from_list_of_symbols(string, symboltable)
        self.assertEqual(
            '0\t1\tNP\tNP\n1\t2\tN\tN\n2\t3\tV\tV\n3\t4\tV\tV\n4\t5\tV\tV\n5\n',
            fsa.text().decode('utf-8'))

        b = compose(fsa, a)

        print(b.input_symbols())
        for i in b.input_symbols():
            print(i)

        print("Input Composition")
        print(b.text(symboltable, symboltable).decode('utf-8'))

        i = 0
        for path in paths(b):
            print(i, "th path:", path, end=' ')
            r = list(map(rules.index_object, path))
            d = PolishDerivation(r[1::])
            dcp = DCP_evaluator(d).getEvaluation()
            h = HybridTree()
            dcp_to_hybridtree(h, dcp, token_sequence, False,
                              construct_conll_token)
            h.reorder()
            if h == tree2:
                print("correct")
            else:
                print("incorrect")
            i += 1

        stats = defaultdict(lambda: 0)
        local_rule_stats(b, stats, 15)

        print(stats)

        print("Shortest path probability")
        best = shortestpath(b)
        best.topsort()
        self.assertAlmostEqual(1.80844898756e-05,
                               pow(e, -float(shortestdistance(best)[-1])))
        print(best.text())

        polish_rules = retrieve_rules(best)
        self.assertSequenceEqual(polish_rules, [8, 7, 1, 6, 2, 5, 3, 10, 3, 3])

        polish_rules = list(map(rules.index_object, polish_rules))

        print(polish_rules)

        der = PolishDerivation(polish_rules[1::])

        print(der)

        print(
            derivation_to_hybrid_tree(der, string,
                                      "Piet Marie helpen lezen leren".split(),
                                      construct_conll_token))

        dcp = DCP_evaluator(der).getEvaluation()

        h_tree_2 = HybridTree()
        dcp_to_hybridtree(h_tree_2, dcp, token_sequence, False,
                          construct_conll_token)

        print(h_tree_2)
Beispiel #25
0
    def test_best_trees(self):
        limit_train = 5000
        limit_test = 100
        train = 'res/dependency_conll/german/tiger/train/german_tiger_train.conll'
        test = train
        parser_type = GFParser_k_best
        # test = '../../res/dependency_conll/german/tiger/test/german_tiger_test.conll'
        trees = parse_conll_corpus(train, False, limit_train)
        primary_labelling = the_labeling_factory(
        ).create_simple_labeling_strategy("child", "pos+deprel")
        term_labelling = the_terminal_labeling_factory().get_strategy('pos')
        start = 'START'
        recursive_partitioning = [cfg]

        (n_trees, grammar_prim) = induce_grammar(trees, primary_labelling,
                                                 term_labelling.token_label,
                                                 recursive_partitioning, start)

        parser_type.preprocess_grammar(grammar_prim)
        tree_yield = term_labelling.prepare_parser_input

        trees = parse_conll_corpus(test, False, limit_test)

        for i, tree in enumerate(trees):
            print("Parsing sentence ", i, file=stderr)

            parser = parser_type(grammar_prim,
                                 tree_yield(tree.token_yield()),
                                 k=200)

            self.assertTrue(parser.recognized())

            viterbi_weight = parser.viterbi_weight()
            viterbi_deriv = parser.viterbi_derivation()

            der_to_tree = lambda der: dcp_to_hybridtree(
                HybridTree(),
                DCP_evaluator(der).getEvaluation(),
                copy.deepcopy(tree.full_token_yield()), False,
                construct_conll_token)

            viterbi_tree = der_to_tree(viterbi_deriv)

            ordered_parse_trees = parser.best_trees(der_to_tree)

            best_tree, best_weight, best_witnesses = ordered_parse_trees[0]

            for i, (parsed_tree, _, _) in enumerate(ordered_parse_trees):
                if parsed_tree.__eq__(tree):
                    print("Gold tree is ",
                          i + 1,
                          " in best tree list",
                          file=stderr)
                    break

            if (not viterbi_tree.__eq__(best_tree)
                    and viterbi_weight != best_weight):
                print("viterbi and k-best tree differ", file=stderr)
                print("viterbi: ", viterbi_weight, file=stderr)
                print("k-best: ", best_weight, best_witnesses, file=stderr)
                if False:
                    print(viterbi_tree, file=stderr)
                    print(tree_to_conll_str(viterbi_tree), file=stderr)
                    print(best_tree, file=stderr)
                    print(tree_to_conll_str(best_tree), file=stderr)
                    print("gold tree", file=stderr)
                    print(tree, file=stderr)
                    print(tree_to_conll_str(tree), file=stderr)
Beispiel #26
0
    def test_fst_compilation_left(self):
        if not test_pynini:
            return
        tree = hybrid_tree_1()
        tree2 = hybrid_tree_2()
        terminal_labeling = the_terminal_labeling_factory().get_strategy('pos')

        (_, grammar) = induce_grammar(
            [tree, tree2],
            the_labeling_factory().create_simple_labeling_strategy(
                'empty', 'pos'), terminal_labeling.token_label,
            [left_branching], 'START')

        fst, rules = compile_wfst_from_left_branching_grammar(grammar)

        print(repr(fst))

        symboltable = fst.input_symbols()

        string = ["NP", "N", "V", "V", "V"]

        fsa = fsa_from_list_of_symbols(string, symboltable)
        self.assertEqual(
            fsa.text().decode('utf-8'),
            '0\t1\tNP\tNP\n1\t2\tN\tN\n2\t3\tV\tV\n3\t4\tV\tV\n4\t5\tV\tV\n5\n'
        )

        b = compose(fsa, fst)

        print(b.text(symboltable, symboltable))

        print("Shortest path probability", end=' ')
        best = shortestpath(b)
        best.topsort()
        # self.assertAlmostEquals(pow(e, -float(shortestdistance(best)[-1])), 1.80844898756e-05)
        print(best.text())

        polish_rules = retrieve_rules(best)
        self.assertSequenceEqual(polish_rules, [1, 2, 3, 4, 5, 4, 9, 4, 7, 8])

        polish_rules = list(map(rules.index_object, polish_rules))

        for rule in polish_rules:
            print(rule)
        print()

        der = ReversePolishDerivation(polish_rules[0:-1])
        self.assertTrue(der.check_integrity_recursive(der.root_id()))

        print(der)

        LeftBranchingFSTParser.preprocess_grammar(grammar)
        parser = LeftBranchingFSTParser(grammar, string)
        der_ = parser.best_derivation_tree()

        print(der_)
        self.assertTrue(der_.check_integrity_recursive(der_.root_id()))

        print(
            derivation_to_hybrid_tree(der, string,
                                      "Piet Marie helpen lezen leren".split(),
                                      construct_conll_token))

        print(
            derivation_to_hybrid_tree(der_, string,
                                      "Piet Marie helpen lezen leren".split(),
                                      construct_conll_token))

        dcp = DCP_evaluator(der).getEvaluation()

        h_tree_2 = HybridTree()
        token_sequence = [
            construct_conll_token(form, lemma)
            for form, lemma in zip('Piet Marie helpen lezen leren'.split(' '),
                                   'NP N V V V'.split(' '))
        ]
        dcp_to_hybridtree(h_tree_2, dcp, token_sequence, False,
                          construct_conll_token)

        print(h_tree_2)