def split_merge_training(grammar,
                         term_labelling,
                         corpus,
                         cycles,
                         em_epochs,
                         init="rfe",
                         tie_breaking=False,
                         sigma=0.005,
                         seed=0,
                         merge_threshold=0.5,
                         debug=False,
                         rule_pruning=exp(-100)):
    print("creating trace", file=stderr)
    trace = PySDCPTraceManager(grammar, term_labelling, debug=debug)
    print("computing reducts", file=stderr)
    trace.compute_reducts(corpus)
    print("pre em-training", file=stderr)
    emTrainer = PyEMTrainer(trace)
    emTrainer.em_training(grammar, em_epochs, init, tie_breaking, sigma, seed)
    print("starting actual split/merge training", file=stderr)
    grammarInfo = PyGrammarInfo(grammar, trace.get_nonterminal_map())
    storageManager = PyStorageManager()
    las = [
        build_PyLatentAnnotation_initial(grammar, grammarInfo, storageManager)
    ]

    trainer = PySplitMergeTrainerBuilder(trace, grammarInfo).set_em_epochs(
        em_epochs).set_threshold_merger(merge_threshold).build()

    for i in range(cycles):
        las.append(trainer.split_merge_cycle(las[i]))
        smGrammar = build_sm_grammar(las[-1], grammar, grammarInfo,
                                     rule_pruning)
        yield smGrammar
 def create_initial_la(self):
     # randomize initial weights and do em training
     la_no_splits = build_PyLatentAnnotation_initial(
         self.base_grammar, self.organizer.grammarInfo,
         self.organizer.storageManager)
     la_no_splits.add_random_noise(seed=self.organizer.seed)
     return la_no_splits
Example #3
0
    def test_la_viterbi_parsing_3(self):
        grammar = LCFRS("S")

        # rule 0
        lhs = LCFRS_lhs("B")
        lhs.add_arg(["a"])
        grammar.add_rule(lhs, [], 0.25)

        # rule 1
        lhs = LCFRS_lhs("A")
        lhs.add_arg(["a"])
        grammar.add_rule(lhs, [], 0.5)

        # rule 2
        lhs = LCFRS_lhs("S")
        lhs.add_arg([LCFRS_var(0, 0)])
        grammar.add_rule(lhs, ["B"], 1.0)

        # rule 3
        lhs = LCFRS_lhs("A")
        lhs.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0)])
        grammar.add_rule(lhs, ["A", "B"], 0.5)

        # rule 4
        lhs = LCFRS_lhs("B")
        lhs.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0)])
        grammar.add_rule(lhs, ["A", "B"], 0.75)

        grammar.make_proper()

        inp = ["a"] * 3

        nontMap = Enumerator()
        gi = PyGrammarInfo(grammar, nontMap)
        sm = PyStorageManager()
        print(nontMap.object_index("S"))
        print(nontMap.object_index("B"))

        la = build_PyLatentAnnotation_initial(grammar, gi, sm)
        parser = DiscodopKbestParser(grammar,
                                     la=la,
                                     nontMap=nontMap,
                                     grammarInfo=gi,
                                     latent_viterbi_mode=True)
        parser.set_input(inp)
        parser.parse()
        self.assertTrue(parser.recognized())
        der = parser.latent_viterbi_derivation(True)
        print(der)

        der2 = None

        for w, der_ in parser.k_best_derivation_trees():
            if der2 is None:
                der2 = der_
            print(w, der_)

        print(der2)
Example #4
0
    def test_projection_based_parser_k_best_hack(self):
        grammar = LCFRS("S")

        # rule 0
        lhs = LCFRS_lhs("B")
        lhs.add_arg(["a"])
        grammar.add_rule(lhs, [], 0.25)

        # rule 1
        lhs = LCFRS_lhs("A")
        lhs.add_arg(["a"])
        grammar.add_rule(lhs, [], 0.5)

        # rule 2
        lhs = LCFRS_lhs("S")
        lhs.add_arg([LCFRS_var(0, 0)])
        grammar.add_rule(lhs, ["B"], 1.0)

        # rule 3
        lhs = LCFRS_lhs("A")
        lhs.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0)])
        grammar.add_rule(lhs, ["A", "B"], 0.5)

        # rule 4
        lhs = LCFRS_lhs("B")
        lhs.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0)])
        grammar.add_rule(lhs, ["A", "B"], 0.75)

        grammar.make_proper()

        inp = ["a"] * 3
        nontMap = Enumerator()
        gi = PyGrammarInfo(grammar, nontMap)
        sm = PyStorageManager()
        la = build_PyLatentAnnotation_initial(grammar, gi, sm)

        parser = Coarse_to_fine_parser(grammar,
                                       la,
                                       gi,
                                       nontMap,
                                       base_parser_type=GFParser_k_best)
        parser.set_input(inp)
        parser.parse()
        self.assertTrue(parser.recognized())
        der = parser.max_rule_product_derivation()
        print(der)

        der = parser.best_derivation_tree()
        print(der)

        for node in der.ids():
            print(der.getRule(node), der.spanned_ranges(node))
Example #5
0
    def test_la_viterbi_parsing(self):
        grammar = self.build_grammar()
        inp = ["a"] * 3
        nontMap = Enumerator()
        gi = PyGrammarInfo(grammar, nontMap)
        sm = PyStorageManager()
        la = build_PyLatentAnnotation_initial(grammar, gi, sm)

        parser = DiscodopKbestParser(grammar, la=la, nontMap=nontMap, grammarInfo=gi, latent_viterbi_mode=True)
        parser.set_input(inp)
        parser.parse()
        self.assertTrue(parser.recognized())
        der = parser.best_derivation_tree()
        print(der)

        for node in der.ids():
            print(node, der.getRule(node), der.spanned_ranges(node))
Example #6
0
    def test_projection_based_parser_k_best_hack(self):
        grammar = self.build_grammar()
        inp = ["a"] * 3
        nontMap = Enumerator()
        gi = PyGrammarInfo(grammar, nontMap)
        sm = PyStorageManager()
        la = build_PyLatentAnnotation_initial(grammar, gi, sm)

        parser = Coarse_to_fine_parser(grammar,
                                       la,
                                       gi,
                                       nontMap,
                                       base_parser_type=GFParser_k_best)
        parser.set_input(inp)
        parser.parse()
        self.assertTrue(parser.recognized())
        der = parser.max_rule_product_derivation()
        print(der)

        der = parser.best_derivation_tree()
        print(der)

        for node in der.ids():
            print(der.getRule(node), der.spanned_ranges(node))
def main():
    # induce grammar from a corpus
    trees = parse_conll_corpus(train, False, limit_train)
    nonterminal_labelling = the_labeling_factory(
    ).create_simple_labeling_strategy("childtop", "deprel")
    term_labelling = the_terminal_labeling_factory().get_strategy('pos')
    start = 'START'
    recursive_partitioning = [cfg]
    _, grammar = induce_grammar(trees, nonterminal_labelling,
                                term_labelling.token_label,
                                recursive_partitioning, start)

    # compute some derivations
    derivations = obtain_derivations(grammar, term_labelling)

    # create derivation manager and add derivations
    manager = PyDerivationManager(grammar)
    manager.convert_derivations_to_hypergraphs(derivations)
    manager.serialize(b"/tmp/derivations.txt")

    # build and configure split/merge trainer and supplementary objects

    rule_to_nonterminals = []
    for i in range(0, len(grammar.rule_index())):
        rule = grammar.rule_index(i)
        nonts = [
            manager.get_nonterminal_map().object_index(rule.lhs().nont())
        ] + [
            manager.get_nonterminal_map().object_index(nont)
            for nont in rule.rhs()
        ]
        rule_to_nonterminals.append(nonts)

    grammarInfo = PyGrammarInfo(grammar, manager.get_nonterminal_map())
    storageManager = PyStorageManager()
    builder = PySplitMergeTrainerBuilder(manager, grammarInfo)
    builder.set_em_epochs(20)
    builder.set_percent_merger(60.0)

    splitMergeTrainer = builder.build()

    latentAnnotation = [
        build_PyLatentAnnotation_initial(grammar, grammarInfo, storageManager)
    ]

    for i in range(max_cycles + 1):
        latentAnnotation.append(
            splitMergeTrainer.split_merge_cycle(latentAnnotation[-1]))
        # pickle.dump(map(lambda la: la.serialize(), latentAnnotation), open(sm_info_path, 'wb'))
        smGrammar = build_sm_grammar(latentAnnotation[i],
                                     grammar,
                                     grammarInfo,
                                     rule_pruning=0.0001,
                                     rule_smoothing=0.01)
        print("Cycle: ", i, "Rules: ", len(smGrammar.rules()))

        if parsing:
            parser = GFParser(smGrammar)

            trees = parse_conll_corpus(test, False, limit_test)
            for tree in trees:
                parser.set_input(
                    term_labelling.prepare_parser_input(tree.token_yield()))
                parser.parse()
                if parser.recognized():
                    print(
                        derivation_to_hybrid_tree(
                            parser.best_derivation_tree(),
                            [token.pos() for token in tree.token_yield()],
                            [token.form() for token in tree.token_yield()],
                            construct_constituent_token))
Example #8
0
    def test_individual_parsing_stages(self):
        grammar = self.build_grammar()

        for r in transform_grammar(grammar):
            pprint(r)

        rule_list = list(transform_grammar(grammar))
        pprint(rule_list)
        disco_grammar = Grammar(rule_list, start=grammar.start())
        print(disco_grammar)

        inp = ["a"] * 3
        estimates = 'SXlrgaps', getestimates(disco_grammar, 40, grammar.start())
        print(type(estimates))
        chart, msg = parse(inp, disco_grammar, estimates=estimates)
        print(chart)
        print(msg)
        chart.filter()
        print("filtered chart")
        print(disco_grammar.nonterminals)
        print(type(disco_grammar.nonterminals))

        print(chart)
        # print(help(chart))

        root = chart.root()
        print("root", root, type(root))
        print(chart.indices(root))
        print(chart.itemstr(root))
        print(chart.stats())
        print("root label", chart.label(root))
        print(root, chart.itemid1(chart.label(root), chart.indices(root)))
        for i in range(1, chart.numitems() + 1):
            print(i, chart.label(i), chart.indices(i), chart.numedges(i))
            if True or len(chart.indices(i)) > 1:
                for edge_num in range(chart.numedges(i)):
                    edge = chart.getEdgeForItem(i, edge_num)
                    if isinstance(edge, tuple):
                        print("\t", disco_grammar.nonterminalstr(chart.label(i)) + "[" + str(i) + "]", "->", ' '.join([disco_grammar.nonterminalstr(chart.label(j)) + "[" + str(j) + "]" for j in [edge[1], edge[2]] if j != 0]))
                    else:
                        print("\t", disco_grammar.nonterminalstr(chart.label(i)) + "[" + str(i) + "]", "->", inp[edge])
        print(chart.getEdgeForItem(root, 0))
        # print(lazykbest(chart, 5))

        manager = PyDerivationManager(grammar)
        manager.convert_chart_to_hypergraph(chart, disco_grammar, debug=True)

        file = tempfile.mktemp()
        print(file)
        manager.serialize(bytes(file, encoding="utf-8"))

        gi = PyGrammarInfo(grammar, manager.get_nonterminal_map())
        sm = PyStorageManager()
        la = build_PyLatentAnnotation_initial(grammar, gi, sm)

        vec = py_edge_weight_projection(la, manager, variational=True, debug=True, log_mode=False)
        print(vec)
        self.assertEqual([1.0, 1.0, 1.0, 0.5, 0.5, 0.5, 0.5, 0.25, 0.25, 0.25, 0.25, 1.0], vec)

        vec = py_edge_weight_projection(la, manager, variational=False, debug=True, log_mode=False)
        print(vec)
        self.assertEqual([1.0, 1.0, 1.0, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 1.0], vec)

        der = manager.viterbi_derivation(0, vec, grammar)
        print(der)

        # print(disco_grammar.rulenos)
        # print(disco_grammar.numrules)
        # print(disco_grammar.lexicalbylhs)
        # print(disco_grammar.lexicalbyword)
        # print(disco_grammar.lexicalbynum)
        # print(disco_grammar.origrules, type(disco_grammar.origrules))
        # print(disco_grammar.numbinary)
        # print(disco_grammar.numunary)
        # print(disco_grammar.toid)
        # print(disco_grammar.tolabel)
        # print(disco_grammar.bitpar)
        # striplabelre = re.compile(r'-\d+$')
        # msg = disco_grammar.getmapping(None, None)
        # disco_grammar.getrulemapping(disco_grammar, striplabelre)
        # mapping = disco_grammar.rulemapping
        # print(mapping)
        # for idx, group in enumerate(mapping):
        #     print("Index", idx)
        #     for elem in group:
        #         print(grammar.rule_index(elem))

        # for _, item in zip(range(20), chart.parseforest):
        #     edge = chart.parseforest[item]
        #     print(item, item.binrepr(), item.__repr__(), item.lexidx())
        #     print(type(edge))
        for _ in range(5):
            vec2 = py_edge_weight_projection(la, manager, debug=True, log_mode=True)
            print(vec2)
Example #9
0
def main():
    # # induce or load grammar
    # if not os.path.isfile(grammar_path):
    #     grammar = LCFRS('START')
    #     for tree in train_corpus:
    #         if not tree.complete() or tree.empty_fringe():
    #             continue
    #         part = recursive_partitioning(tree)
    #         tree_grammar = fringe_extract_lcfrs(tree, part, naming='child', term_labeling=terminal_labeling)
    #         grammar.add_gram(tree_grammar)
    #     grammar.make_proper()
    #     pickle.dump(grammar, open(grammar_path, 'wb'))
    # else:
    #     grammar = pickle.load(open(grammar_path, 'rb'))

    grammar = LCFRS('START')
    for tree in train_corpus:
        if not tree.complete() or tree.empty_fringe():
            continue
        part = recursive_partitioning(tree)
        tree_grammar = fringe_extract_lcfrs(tree,
                                            part,
                                            naming='child',
                                            term_labeling=terminal_labeling)
        grammar.add_gram(tree_grammar)
    grammar.make_proper()

    # # compute or load reducts
    # if not os.path.isfile(reduct_path):
    #     traceTrain = compute_reducts(grammar, train_corpus, terminal_labeling)
    #     traceTrain.serialize(reduct_path)
    # else:
    #     traceTrain = PySDCPTraceManager(grammar, terminal_labeling)
    #     traceTrain.load_traces_from_file(reduct_path)

    traceTrain = compute_reducts(grammar, train_corpus, terminal_labeling)
    traceValidationGenetic = compute_reducts(grammar,
                                             validation_genetic_corpus,
                                             terminal_labeling)
    traceValidation = compute_reducts(grammar, validation_corpus,
                                      terminal_labeling)

    # prepare EM training
    grammarInfo = PyGrammarInfo(grammar, traceTrain.get_nonterminal_map())
    if not grammarInfo.check_for_consistency():
        print("[Genetic] GrammarInfo is not consistent!")

    storageManager = PyStorageManager()

    em_builder = PySplitMergeTrainerBuilder(traceTrain, grammarInfo)
    em_builder.set_em_epochs(em_epochs)
    em_builder.set_simple_expector(threads=threads)
    emTrainer = em_builder.build()

    # randomize initial weights and do em training
    la_no_splits = build_PyLatentAnnotation_initial(grammar, grammarInfo,
                                                    storageManager)
    la_no_splits.add_random_noise(seed=seed)
    emTrainer.em_train(la_no_splits)
    la_no_splits.project_weights(grammar, grammarInfo)

    # emTrainerOld = PyEMTrainer(traceTrain)
    # emTrainerOld.em_training(grammar, 30, "rfe", tie_breaking=True)

    # compute parses for validation set
    baseline_parser = GFParser_k_best(grammar, k=k_best)
    validator = build_score_validator(grammar, grammarInfo,
                                      traceTrain.get_nonterminal_map(),
                                      storageManager, terminal_labeling,
                                      baseline_parser, validation_corpus,
                                      validationMethod)
    del baseline_parser

    # prepare SM training
    builder = PySplitMergeTrainerBuilder(traceTrain, grammarInfo)
    builder.set_em_epochs(em_epochs)
    builder.set_split_randomization(1.0, seed + 1)
    builder.set_simple_expector(threads=threads)
    builder.set_score_validator(validator, validationDropIterations)
    builder.set_smoothing_factor(smoothingFactor=smoothing_factor)
    builder.set_split_randomization(percent=split_randomization)
    splitMergeTrainer = builder.set_scc_merger(threshold=scc_merger_threshold,
                                               threads=threads).build()

    splitMergeTrainer.setMaxDrops(validationDropIterations, mode="smoothing")
    splitMergeTrainer.setEMepochs(em_epochs, mode="smoothing")

    # set initial latent annotation
    latentAnnotations = []
    for i in range(0, genetic_initial):
        splitMergeTrainer.reset_random_seed(seed + i + 1)
        la = splitMergeTrainer.split_merge_cycle(la_no_splits)
        if not la.check_for_validity():
            print('[Genetic] Initial LA', i,
                  'is not consistent! (See details before)')
        if not la.is_proper():
            print('[Genetic] Initial LA', i, 'is not proper!')
        heapq.heappush(
            latentAnnotations,
            (evaluate_la(grammar, grammarInfo, la, traceValidationGenetic,
                         validation_genetic_corpus), i, la))
        print('[Genetic]    added initial LA', i)
    (fBest, idBest, laBest) = min(latentAnnotations)
    validation_score = evaluate_la(grammar, grammarInfo, laBest,
                                   traceValidation, test_corpus)
    print("[Genetic] Started with best F-Score (Test) of", validation_score,
          "from Annotation ", idBest)

    geneticCount = genetic_initial
    random.seed(seed)
    for round in range(1, genetic_cycles + 1):
        print("[Genetic] Starting Recombination Round ", round)
        # newpopulation = list(latentAnnotations)
        newpopulation = []
        # Cross all candidates!
        for leftIndex in range(0, len(latentAnnotations)):
            (fLeft, idLeft, left) = latentAnnotations[leftIndex]
            # TODO: How to determine NTs to keep?

            # do SM-Training
            print("[Genetic] do SM-training on", idLeft, "and create LA",
                  geneticCount)
            la = splitMergeTrainer.split_merge_cycle(la)
            if not la.check_for_validity():
                print(
                    '[Genetic] Split/Merge introduced invalid weights into LA',
                    geneticCount)
            if not la.is_proper():
                print(
                    '[Genetic] Split/Merge introduced problems with properness of LA',
                    geneticCount)

            fscore = evaluate_la(grammar, grammarInfo, la,
                                 traceValidationGenetic,
                                 validation_genetic_corpus)
            print("[Genetic] LA", geneticCount, "has F-score: ", fscore)
            heapq.heappush(newpopulation, (fscore, geneticCount, la))
            geneticCount += 1
        heapq.heapify(newpopulation)
        latentAnnotations = heapq.nsmallest(
            genetic_population, heapq.merge(latentAnnotations, newpopulation))
        heapq.heapify(latentAnnotations)
        (fBest, idBest, laBest) = min(latentAnnotations)
        validation_score = evaluate_la(grammar, grammarInfo, laBest,
                                       traceValidation, test_corpus)
        print("[Genetic] Best LA", idBest, "has F-Score (Test) of ",
              validation_score)
Example #10
0
def main(limit=3000,
         test_limit=sys.maxint,
         max_length=sys.maxint,
         dir=dir,
         train='../res/negra-dep/negra-lower-punct-train.conll',
         test='../res/negra-dep/negra-lower-punct-test.conll',
         recursive_partitioning='cfg',
         nonterminal_labeling='childtop-deprel',
         terminal_labeling='form-unk-30/pos',
         emEpochs=20,
         emTieBreaking=True,
         emInit="rfe",
         splitRandomization=1.0,
         mergePercentage=85.0,
         smCycles=6,
         rule_pruning=0.0001,
         rule_smoothing=0.01,
         validation=True,
         validationMethod='likelihood',
         validationCorpus=None,
         validationSplit=20,
         validationDropIterations=6,
         seed=1337,
         discr=False,
         maxScaleDiscr=10,
         recompileGrammar="True",
         retrain=False,
         parsing=True,
         reparse=False,
         parser="CFG",
         k_best=50,
         minimum_risk=False,
         oracle_parse=False):

    # set various parameters
    recompileGrammar = True if recompileGrammar == "True" else False

    # print(recompileGrammar)

    def result(gram, add=None):
        if add is not None:
            return os.path.join(
                dir, gram + '_experiment_parse_results_' + add + '.conll')
        else:
            return os.path.join(dir, gram + '_experiment_parse_results.conll')

    recursive_partitioning = grammar.induction.recursive_partitioning.the_recursive_partitioning_factory(
    ).get_partitioning(recursive_partitioning)
    top_level, low_level = tuple(nonterminal_labeling.split('-'))
    nonterminal_labeling = d_l.the_labeling_factory(
    ).create_simple_labeling_strategy(top_level, low_level)

    if parser == "CFG":
        assert all([
            rp.__name__
            in ["left_branching", "right_branching", "cfg", "fanout_1"]
            for rp in recursive_partitioning
        ])
        parser = CFGParser
    elif parser == "GF":
        parser = GFParser
    elif parser == "GF-k-best":
        parser = GFParser_k_best
    elif parser == "CoarseToFine":
        parser = Coarse_to_fine_parser
    elif parser == "FST":
        if recursive_partitioning == "left_branching":
            parser = LeftBranchingFSTParser
        elif recursive_partitioning == "right_branching":
            parser = RightBranchingFSTParser
        else:
            assert False and "expect left/right branching recursive partitioning for FST parsing"

    if validation:
        if validationCorpus is not None:
            corpus_validation = Corpus(validationCorpus)
            train_limit = limit
        else:
            train_limit = int(limit * (100.0 - validationSplit) / 100.0)
            corpus_validation = Corpus(train, start=train_limit, end=limit)
    else:
        train_limit = limit

    corpus_induce = Corpus(train, end=limit)
    corpus_train = Corpus(train, end=train_limit)
    corpus_test = Corpus(test, end=test_limit)

    match = re.match(r'^form-unk-(\d+)-morph.*$', terminal_labeling)
    if match:
        unk_threshold = int(match.group(1))
        term_labelling = grammar.induction.terminal_labeling.FormPosTerminalsUnkMorph(
            corpus_induce.get_trees(),
            unk_threshold,
            pos_filter=["NE", "CARD"],
            add_morph={
                'NN': ['case', 'number', 'gender']
                # , 'NE': ['case', 'number', 'gender']
                # , 'VMFIN': ['number', 'person']
                # , 'VVFIN': ['number', 'person']
                # , 'VAFIN': ['number', 'person']
            })
    else:
        match = re.match(r'^form-unk-(\d+).*$', terminal_labeling)
        if match:
            unk_threshold = int(match.group(1))
            term_labelling = grammar.induction.terminal_labeling.FormPosTerminalsUnk(
                corpus_induce.get_trees(),
                unk_threshold,
                pos_filter=["NE", "CARD"])
        else:
            term_labelling = grammar.induction.terminal_labeling.the_terminal_labeling_factory(
            ).get_strategy(terminal_labeling)

    if not os.path.isdir(dir):
        os.makedirs(dir)

    # start actual training
    # we use the training corpus until limit for grammar induction (i.e., also the validation section)
    print("Computing baseline id: ")
    baseline_id = grammar_id(corpus_induce, nonterminal_labeling,
                             term_labelling, recursive_partitioning)
    print(baseline_id)
    baseline_path = compute_grammar_name(dir, baseline_id, "baseline")

    if recompileGrammar or not os.path.isfile(baseline_path):
        print("Inducing grammar from corpus")
        (n_trees, baseline_grammar) = d_i.induce_grammar(
            corpus_induce.get_trees(), nonterminal_labeling,
            term_labelling.token_label, recursive_partitioning, start)
        print("Induced grammar using", n_trees, ".")
        pickle.dump(baseline_grammar, open(baseline_path, 'wb'))
    else:
        print("Loading grammar from file")
        baseline_grammar = pickle.load(open(baseline_path))

    print("Rules: ", len(baseline_grammar.rules()))

    if parsing:
        parser_ = GFParser_k_best if parser == Coarse_to_fine_parser else parser
        baseline_parser = do_parsing(baseline_grammar,
                                     corpus_test,
                                     term_labelling,
                                     result,
                                     baseline_id,
                                     parser_,
                                     k_best=k_best,
                                     minimum_risk=minimum_risk,
                                     oracle_parse=oracle_parse,
                                     recompile=recompileGrammar,
                                     dir=dir,
                                     reparse=reparse)

    if True:
        em_trained = pickle.load(open(baseline_path))
        reduct_path = compute_reduct_name(dir, baseline_id, corpus_train)
        if recompileGrammar or not os.path.isfile(reduct_path):
            trace = compute_reducts(em_trained, corpus_train.get_trees(),
                                    term_labelling)
            trace.serialize(reduct_path)
        else:
            print("loading trace")
            trace = PySDCPTraceManager(em_trained, term_labelling)
            trace.load_traces_from_file(reduct_path)

        if discr:
            reduct_path_discr = compute_reduct_name(dir, baseline_id,
                                                    corpus_train, '_discr')
            if recompileGrammar or not os.path.isfile(reduct_path_discr):
                trace_discr = compute_LCFRS_reducts(
                    em_trained,
                    corpus_train.get_trees(),
                    terminal_labelling=term_labelling,
                    nonterminal_map=trace.get_nonterminal_map())
                trace_discr.serialize(reduct_path_discr)
            else:
                print("loading trace discriminative")
                trace_discr = PyLCFRSTraceManager(em_trained,
                                                  trace.get_nonterminal_map())
                trace_discr.load_traces_from_file(reduct_path_discr)

        # todo refactor EM training, to use the LA version (but without any splits)
        """
        em_trained_path_ = em_trained_path(dir, grammar_id, n_epochs=emEpochs, init=emInit, tie_breaking=emTieBreaking, seed=seed)

        if recompileGrammar or retrain or not os.path.isfile(em_trained_path_):
            emTrainer = PyEMTrainer(trace)
            emTrainer.em_training(em_trained, n_epochs=emEpochs, init=emInit, tie_breaking=emTieBreaking, seed=seed)
            pickle.dump(em_trained, open(em_trained_path_, 'wb'))
        else:
            em_trained = pickle.load(open(em_trained_path_, 'rb'))

        if parsing:
            do_parsing(em_trained, test_limit, ignore_punctuation, term_labelling, recompileGrammar or retrain, [dir, "em_trained_gf_grammar"])
        """

        grammarInfo = PyGrammarInfo(baseline_grammar,
                                    trace.get_nonterminal_map())
        storageManager = PyStorageManager()

        builder = PySplitMergeTrainerBuilder(trace, grammarInfo)
        builder.set_em_epochs(emEpochs)
        builder.set_smoothing_factor(rule_smoothing)
        builder.set_split_randomization(splitRandomization, seed + 1)
        if discr:
            builder.set_discriminative_expector(trace_discr,
                                                maxScale=maxScaleDiscr,
                                                threads=1)
        else:
            builder.set_simple_expector(threads=1)
        if validation:
            if validationMethod is "likelihood":
                reduct_path_validation = compute_reduct_name(
                    dir, baseline_id, corpus_validation)
                if recompileGrammar or not os.path.isfile(
                        reduct_path_validation):
                    validation_trace = compute_reducts(
                        em_trained, corpus_validation.get_trees(),
                        term_labelling)
                    validation_trace.serialize(reduct_path_validation)
                else:
                    print("loading trace validation")
                    validation_trace = PySDCPTraceManager(
                        em_trained, term_labelling)
                    validation_trace.load_traces_from_file(
                        reduct_path_validation)
                builder.set_simple_validator(validation_trace,
                                             maxDrops=validationDropIterations,
                                             threads=1)
            else:
                validator = build_score_validator(
                    baseline_grammar, grammarInfo, trace.get_nonterminal_map(),
                    storageManager, term_labelling, baseline_parser,
                    corpus_validation, validationMethod)
                builder.set_score_validator(validator,
                                            validationDropIterations)
        splitMergeTrainer = builder.set_percent_merger(mergePercentage).build()
        if validation:
            splitMergeTrainer.setMaxDrops(1, mode="smoothing")
            splitMergeTrainer.setEMepochs(1, mode="smoothing")

        sm_info_path = compute_sm_info_path(dir, baseline_id, emEpochs,
                                            rule_smoothing, splitRandomization,
                                            seed, discr, validation,
                                            corpus_validation, emInit)

        if (not recompileGrammar) and (
                not retrain) and os.path.isfile(sm_info_path):
            print("Loading splits and weights of LA rules")
            latentAnnotation = map(
                lambda t: build_PyLatentAnnotation(t[0], t[1], t[
                    2], grammarInfo, storageManager),
                pickle.load(open(sm_info_path, 'rb')))
        else:
            # latentAnnotation = [build_PyLatentAnnotation_initial(em_trained, grammarInfo, storageManager)]
            latentAnnotation = [
                build_PyLatentAnnotation_initial(baseline_grammar, grammarInfo,
                                                 storageManager)
            ]

        for cycle in range(smCycles + 1):
            if cycle < len(latentAnnotation):
                smGrammar = latentAnnotation[cycle].build_sm_grammar(
                    baseline_grammar,
                    grammarInfo,
                    rule_pruning=rule_pruning
                    # , rule_smoothing=rule_smoothing
                )
            else:
                # setting the seed to achieve reproducibility in case of continued training
                splitMergeTrainer.reset_random_seed(seed + cycle + 1)
                latentAnnotation.append(
                    splitMergeTrainer.split_merge_cycle(latentAnnotation[-1]))
                pickle.dump(map(lambda la: la.serialize(), latentAnnotation),
                            open(sm_info_path, 'wb'))
                smGrammar = latentAnnotation[cycle].build_sm_grammar(
                    baseline_grammar,
                    grammarInfo,
                    rule_pruning=rule_pruning
                    # , rule_smoothing=rule_smoothing
                )
            print("Cycle: ", cycle, "Rules: ", len(smGrammar.rules()))
            if parsing:
                grammar_identifier = compute_sm_grammar_id(
                    baseline_id, emEpochs, rule_smoothing, splitRandomization,
                    seed, discr, validation, corpus_validation, emInit, cycle)
                if parser == Coarse_to_fine_parser:
                    opt = {
                        'latentAnnotation':
                        latentAnnotation[:cycle + 1]  #[cycle]
                        ,
                        'grammarInfo': grammarInfo,
                        'nontMap': trace.get_nonterminal_map()
                    }
                    do_parsing(baseline_grammar,
                               corpus_test,
                               term_labelling,
                               result,
                               grammar_identifier,
                               parser,
                               k_best=k_best,
                               minimum_risk=minimum_risk,
                               oracle_parse=oracle_parse,
                               recompile=recompileGrammar,
                               dir=dir,
                               reparse=reparse,
                               opt=opt)
                else:
                    do_parsing(smGrammar,
                               corpus_test,
                               term_labelling,
                               result,
                               grammar_identifier,
                               parser,
                               k_best=k_best,
                               minimum_risk=minimum_risk,
                               oracle_parse=oracle_parse,
                               recompile=recompileGrammar,
                               dir=dir,
                               reparse=reparse)
def main(limit=300,
         ignore_punctuation=False,
         baseline_path=baseline_path,
         recompileGrammar=True,
         retrain=True,
         parsing=True,
         seed=1337):
    max_length = 20
    trees = length_limit(parse_conll_corpus(train, False, limit), max_length)

    if recompileGrammar or not os.path.isfile(baseline_path):
        (n_trees,
         baseline_grammar) = d_i.induce_grammar(trees, empty_labelling,
                                                term_labelling.token_label,
                                                recursive_partitioning, start)
        pickle.dump(baseline_grammar, open(baseline_path, 'wb'))
    else:
        baseline_grammar = pickle.load(open(baseline_path))

    test_limit = 10000
    print("Rules: ", len(baseline_grammar.rules()))

    if parsing:
        do_parsing(baseline_grammar, test_limit, ignore_punctuation,
                   recompileGrammar, [dir, "baseline_gf_grammar"])

    em_trained = pickle.load(open(baseline_path))
    if recompileGrammar or not os.path.isfile(reduct_path):
        trees = length_limit(parse_conll_corpus(train, False, limit),
                             max_length)
        trace = compute_reducts(em_trained, trees, term_labelling)
        trace.serialize(reduct_path)
    else:
        print("loading trace")
        trace = PySDCPTraceManager(em_trained, term_labelling)
        trace.load_traces_from_file(reduct_path)

    discr = False
    if discr:
        if recompileGrammar or not os.path.isfile(reduct_path_discr):
            trees = length_limit(parse_conll_corpus(train, False, limit),
                                 max_length)
            trace_discr = compute_LCFRS_reducts(
                em_trained,
                trees,
                term_labelling,
                nonterminal_map=trace.get_nonterminal_map())
            trace_discr.serialize(reduct_path_discr)
        else:
            print("loading trace discriminative")
            trace_discr = PyLCFRSTraceManager(em_trained,
                                              trace.get_nonterminal_map())
            trace_discr.load_traces_from_file(reduct_path_discr)

    n_epochs = 20
    init = "rfe"
    tie_breaking = True
    em_trained_path_ = em_trained_path(n_epochs, init, tie_breaking)

    if recompileGrammar or retrain or not os.path.isfile(em_trained_path_):
        emTrainer = PyEMTrainer(trace)
        emTrainer.em_training(em_trained,
                              n_epochs=n_epochs,
                              init=init,
                              tie_breaking=tie_breaking,
                              seed=seed)
        pickle.dump(em_trained, open(em_trained_path_, 'wb'))
    else:
        em_trained = pickle.load(open(em_trained_path_, 'rb'))

    if parsing:
        do_parsing(em_trained, test_limit, ignore_punctuation, recompileGrammar
                   or retrain, [dir, "em_trained_gf_grammar"])

    grammarInfo = PyGrammarInfo(baseline_grammar, trace.get_nonterminal_map())
    storageManager = PyStorageManager()

    builder = PySplitMergeTrainerBuilder(trace, grammarInfo)
    builder.set_em_epochs(n_epochs)
    builder.set_split_randomization(1.0, seed + 1)
    if discr:
        builder.set_discriminative_expector(trace_discr,
                                            maxScale=10,
                                            threads=1)
    else:
        builder.set_simple_expector(threads=1)
    splitMergeTrainer = builder.set_percent_merger(65.0).build()

    if (not recompileGrammar) and (
            not retrain) and os.path.isfile(sm_info_path):
        print("Loading splits and weights of LA rules")
        latentAnnotation = map(
            lambda t: build_PyLatentAnnotation(t[0], t[1], t[2], grammarInfo,
                                               storageManager),
            pickle.load(open(sm_info_path, 'rb')))
    else:
        latentAnnotation = [
            build_PyLatentAnnotation_initial(em_trained, grammarInfo,
                                             storageManager)
        ]

    max_cycles = 4
    reparse = False
    # parsing = False
    for i in range(max_cycles + 1):
        if i < len(latentAnnotation):
            if reparse:
                smGrammar = latentAnnotation[i].build_sm_grammar(
                    baseline_grammar,
                    grammarInfo,
                    rule_pruning=0.0001,
                    rule_smoothing=0.01)
                print("Cycle: ", i, "Rules: ", len(smGrammar.rules()))
                do_parsing(smGrammar, test_limit, ignore_punctuation,
                           recompileGrammar or retrain,
                           [dir, "sm_cycles" + str(i) + "_gf_grammar"])
        else:
            # setting the seed to achieve reproducibility in case of continued training
            splitMergeTrainer.reset_random_seed(seed + i + 1)
            latentAnnotation.append(
                splitMergeTrainer.split_merge_cycle(latentAnnotation[-1]))
            pickle.dump(map(lambda la: la.serialize(), latentAnnotation),
                        open(sm_info_path, 'wb'))
            smGrammar = latentAnnotation[i].build_sm_grammar(
                baseline_grammar,
                grammarInfo,
                rule_pruning=0.0001,
                rule_smoothing=0.1)
            print("Cycle: ", i, "Rules: ", len(smGrammar.rules()))
            if parsing:
                do_parsing(smGrammar, test_limit, ignore_punctuation,
                           recompileGrammar or retrain,
                           [dir, "sm_cycles" + str(i) + "_gf_grammar"])
def run_experiment(rec_part_strategy,
                   nonterminal_labeling,
                   exp,
                   reorder_children,
                   binarize=True):
    start = 1
    stop = 7000

    test_start = 7001
    test_stop = 7200

    # path = "res/tiger/tiger_release_aug07.corrected.16012013.utf8.xml"
    corpus_path = "res/tiger/tiger_8000.xml"
    exclude = []
    train_dsgs = sentence_names_to_deep_syntax_graphs(
        ['s' + str(i) for i in range(start, stop + 1) if i not in exclude],
        corpus_path,
        hold=False,
        reorder_children=reorder_children)
    test_dsgs = sentence_names_to_deep_syntax_graphs(
        [
            's' + str(i)
            for i in range(test_start, test_stop + 1) if i not in exclude
        ],
        corpus_path,
        hold=False,
        reorder_children=reorder_children)

    # Grammar induction
    term_labeling_token = PosTerminals()

    def term_labeling(token):
        if isinstance(token, ConstituentTerminal):
            return term_labeling_token.token_label(token)
        else:
            return token

    if binarize:

        def modify_token(token):
            if isinstance(token, ConstituentCategory):
                token_new = deepcopy(token)
                token_new.set_category(token.category() + '-BAR')
                return token_new
            elif isinstance(token, str):
                return token + '-BAR'
            else:
                assert False

        train_dsgs = [
            dsg.binarize(bin_modifier=modify_token) for dsg in train_dsgs
        ]

        def is_bin(token):
            if isinstance(token, ConstituentCategory):
                if token.category().endswith('-BAR'):
                    return True
            elif isinstance(token, str):
                if token.endswith('-BAR'):
                    return True
            return False

        def debinarize(dsg):
            return dsg.debinarize(is_bin=is_bin)

    else:
        debinarize = id

    grammar = induction_on_a_corpus(train_dsgs, rec_part_strategy,
                                    nonterminal_labeling, term_labeling)
    grammar.make_proper()

    print("Nonterminals", len(grammar.nonts()), "Rules", len(grammar.rules()))

    parser = GFParser_k_best(grammar, k=500)
    return do_parsing(parser,
                      test_dsgs,
                      term_labeling_token,
                      oracle=True,
                      debinarize=debinarize)

    # Compute reducts, i.e., intersect grammar with each training dsg
    basedir = path.join('/tmp/dog_experiments', 'exp' + str(exp))
    reduct_dir = path.join(basedir, 'reduct_grammars')

    terminal_map = Enumerator()
    if not os.path.isdir(basedir):
        os.makedirs(basedir)
    data = export_dog_grammar_to_json(grammar, terminal_map)
    grammar_path = path.join(basedir, 'grammar.json')
    with open(grammar_path, 'w') as file:
        json.dump(data, file)

    corpus_path = path.join(basedir, 'corpus.json')
    with open(corpus_path, 'w') as file:
        json.dump(
            export_corpus_to_json(train_dsgs,
                                  terminal_map,
                                  terminal_labeling=term_labeling), file)

    with open(path.join(basedir, 'enumerator.enum'), 'w') as file:
        terminal_map.print_index(file)

    if os.path.isdir(reduct_dir):
        shutil.rmtree(reduct_dir)
    os.makedirs(reduct_dir)
    p = subprocess.Popen([
        ' '.join([
            "java", "-jar",
            os.path.join("util", SCHICK_PARSER_JAR), 'dog-reduct', '-g',
            grammar_path, '-t', corpus_path, "-o", reduct_dir
        ])
    ],
                         shell=True,
                         stdout=subprocess.PIPE,
                         stderr=subprocess.STDOUT)

    while True:
        nextline = p.stdout.readline()
        if nextline == '' and p.poll() is not None:
            break
        sys.stdout.write(nextline)
        sys.stdout.flush()

    p.wait()
    p.stdout.close()

    rtgs = []
    for i in range(1, len(train_dsgs) + 1):
        rtgs.append(read_rtg(path.join(reduct_dir, str(i) + '.gra')))

    derivation_manager = PyDerivationManager(grammar)
    derivation_manager.convert_rtgs_to_hypergraphs(rtgs)
    derivation_manager.serialize(path.join(basedir, 'reduct_manager.trace'))

    # Training
    ## prepare EM training
    em_epochs = 20
    seed = 0
    smoothing_factor = 0.01
    split_randomization = 0.01
    sm_cycles = 2
    merge_percentage = 50.0
    grammarInfo = PyGrammarInfo(grammar,
                                derivation_manager.get_nonterminal_map())
    storageManager = PyStorageManager()

    em_builder = PySplitMergeTrainerBuilder(derivation_manager, grammarInfo)
    em_builder.set_em_epochs(em_epochs)
    em_builder.set_simple_expector(threads=THREADS)
    emTrainer = em_builder.build()

    # randomize initial weights and do em training
    la_no_splits = build_PyLatentAnnotation_initial(grammar, grammarInfo,
                                                    storageManager)
    la_no_splits.add_random_noise(seed=seed)
    emTrainer.em_train(la_no_splits)
    la_no_splits.project_weights(grammar, grammarInfo)

    do_parsing(CFGParser(grammar), test_dsgs, term_labeling_token)
    return
    ## prepare SM training
    builder = PySplitMergeTrainerBuilder(derivation_manager, grammarInfo)
    builder.set_em_epochs(em_epochs)
    builder.set_split_randomization(1.0, seed + 1)
    builder.set_simple_expector(threads=THREADS)
    builder.set_smoothing_factor(smoothingFactor=smoothing_factor)
    builder.set_split_randomization(percent=split_randomization)
    # builder.set_scc_merger(-0.2)
    builder.set_percent_merger(merge_percentage)
    splitMergeTrainer = builder.build()

    # splitMergeTrainer.setMaxDrops(validationDropIterations, mode="smoothing")
    splitMergeTrainer.setEMepochs(em_epochs, mode="smoothing")

    # set initial latent annotation
    latentAnnotation = [la_no_splits]

    # carry out split/merge training and do parsing
    parsing_method = "filter-ctf"
    # parsing_method = "single-best-annotation"
    k_best = 50
    for i in range(1, sm_cycles + 1):
        splitMergeTrainer.reset_random_seed(seed + i + 1)
        latentAnnotation.append(
            splitMergeTrainer.split_merge_cycle(latentAnnotation[-1]))
        print("Cycle: ", i)
        if parsing_method == "single-best-annotation":
            smGrammar = latentAnnotation[i].build_sm_grammar(
                grammar, grammarInfo, rule_pruning=0.0001, rule_smoothing=0.1)
            print("Rules in smoothed grammar: ", len(smGrammar.rules()))
            parser = GFParser(smGrammar)
        elif parsing_method == "filter-ctf":
            latentAnnotation[-1].project_weights(grammar, grammarInfo)
            parser = Coarse_to_fine_parser(
                grammar,
                latentAnnotation[-1],
                grammarInfo,
                derivation_manager.get_nonterminal_map(),
                base_parser_type=GFParser_k_best,
                k=k_best)
        else:
            raise (Exception())
        do_parsing(parser, test_dsgs, term_labeling_token)
        del parser