def test_trace_serialization(self):
        tree = hybrid_tree_1()
        tree2 = hybrid_tree_2()
        terminal_labeling = the_terminal_labeling_factory().get_strategy('pos')

        (_, grammar) = induce_grammar(
            [tree, tree2],
            the_labeling_factory().create_simple_labeling_strategy(
                'empty', 'pos'), terminal_labeling.token_label, [cfg], 'START')

        for rule in grammar.rules():
            print(rule, file=stderr)

        trace = compute_reducts(grammar, [tree, tree2], terminal_labeling)
        trace.serialize(b"/tmp/reducts.p")

        grammar_load = grammar
        trace2 = PySDCPTraceManager(grammar_load, terminal_labeling)
        trace2.load_traces_from_file(b"/tmp/reducts.p")
        trace2.serialize(b"/tmp/reducts2.p")

        with open(b"/tmp/reducts.p", "r") as f1, open(b"/tmp/reducts2.p",
                                                      "r") as f2:
            for e1, e2 in zip(f1, f2):
                self.assertEqual(e1, e2)
    def test_basic_em_training(self):
        tree = hybrid_tree_1()
        tree2 = hybrid_tree_2()
        terminal_labeling = the_terminal_labeling_factory().get_strategy('pos')

        (_, grammar) = induce_grammar(
            [tree, tree2],
            the_labeling_factory().create_simple_labeling_strategy(
                'empty', 'pos'), terminal_labeling.token_label, [cfg], 'START')

        for rule in grammar.rules():
            print(rule, file=stderr)

        print("compute reducts", file=stderr)

        trace = compute_reducts(grammar, [tree, tree2], terminal_labeling)

        print("call em Training", file=stderr)
        emTrainer = PyEMTrainer(trace)
        emTrainer.em_training(grammar, n_epochs=10)

        print("finished em Training", file=stderr)

        for rule in grammar.rules():
            print(rule, file=stderr)
    def test_corpus_em_training(self):
        train = 'res/dependency_conll/german/tiger/train/german_tiger_train.conll'
        limit_train = 200
        test = train
        # test = '../../res/dependency_conll/german/tiger/test/german_tiger_test.conll'
        trees = parse_conll_corpus(train, False, limit_train)
        primary_labelling = the_labeling_factory(
        ).create_simple_labeling_strategy("childtop", "deprel")
        term_labelling = the_terminal_labeling_factory().get_strategy('pos')
        start = 'START'
        recursive_partitioning = [cfg]

        (n_trees, grammar_prim) = induce_grammar(trees, primary_labelling,
                                                 term_labelling.token_label,
                                                 recursive_partitioning, start)

        # for rule in grammar.rules():
        #    print >>stderr, rule

        trees = parse_conll_corpus(train, False, limit_train)

        print("compute reducts", file=stderr)

        trace = compute_reducts(grammar_prim, trees, term_labelling)

        print("call em Training", file=stderr)
        emTrainer = PyEMTrainer(trace)
        emTrainer.em_training(grammar_prim,
                              20,
                              tie_breaking=True,
                              init="equal",
                              sigma=0.05,
                              seed=50)

        print("finished em Training", file=stderr)
예제 #4
0
    def compute_reducts(self, resource):

        # print_grammar(self.base_grammar)
        # for rule in self.base_grammar.rules():
        #     print(rule.get_idx(), rule)
        # sys.stdout.flush()

        training_corpus = list(
            filter(self.__valid_tree, self.read_corpus(resource)))
        parser = self.organizer.training_reducts.get_parser(
        ) if self.organizer.training_reducts is not None else None
        nonterminal_map = self.organizer.nonterminal_map
        frequency = self.backoff_factor if self.backoff else 1.0
        trace = compute_reducts(self.base_grammar,
                                training_corpus,
                                self.induction_settings.terminal_labeling,
                                parser=parser,
                                nont_map=nonterminal_map,
                                debug=False,
                                frequency=frequency)
        if self.backoff:
            self.terminal_labeling.backoff_mode = True
            trace.compute_reducts(training_corpus, frequency=1.0)
            self.terminal_labeling.backoff_mode = False
        print("computed trace")
        return trace
예제 #5
0
 def compute_reducts(self, resource):
     corpus = self.read_corpus(resource)
     if self.strip_vroot:
         for tree in corpus:
             tree.strip_vroot()
     parser = self.organizer.training_reducts.get_parser() if self.organizer.training_reducts is not None else None
     nonterminal_map = self.organizer.nonterminal_map
     frequency = self.backoff_factor if self.backoff else 1.0
     trace = compute_reducts(self.base_grammar, corpus, self.induction_settings.terminal_labeling,
                             parser=parser, nont_map=nonterminal_map, frequency=frequency)
     if self.backoff:
         self.terminal_labeling.backoff_mode = True
         trace.compute_reducts(corpus, frequency=1.0)
         self.terminal_labeling.backoff_mode = False
     return trace
예제 #6
0
def main():
    # # induce or load grammar
    # if not os.path.isfile(grammar_path):
    #     grammar = LCFRS('START')
    #     for tree in train_corpus:
    #         if not tree.complete() or tree.empty_fringe():
    #             continue
    #         part = recursive_partitioning(tree)
    #         tree_grammar = fringe_extract_lcfrs(tree, part, naming='child', term_labeling=terminal_labeling)
    #         grammar.add_gram(tree_grammar)
    #     grammar.make_proper()
    #     pickle.dump(grammar, open(grammar_path, 'wb'))
    # else:
    #     grammar = pickle.load(open(grammar_path, 'rb'))

    grammar = LCFRS('START')
    for tree in train_corpus:
        if not tree.complete() or tree.empty_fringe():
            continue
        part = recursive_partitioning(tree)
        tree_grammar = fringe_extract_lcfrs(tree,
                                            part,
                                            naming='child',
                                            term_labeling=terminal_labeling)
        grammar.add_gram(tree_grammar)
    grammar.make_proper()

    # # compute or load reducts
    # if not os.path.isfile(reduct_path):
    #     traceTrain = compute_reducts(grammar, train_corpus, terminal_labeling)
    #     traceTrain.serialize(reduct_path)
    # else:
    #     traceTrain = PySDCPTraceManager(grammar, terminal_labeling)
    #     traceTrain.load_traces_from_file(reduct_path)

    traceTrain = compute_reducts(grammar, train_corpus, terminal_labeling)
    traceValidationGenetic = compute_reducts(grammar,
                                             validation_genetic_corpus,
                                             terminal_labeling)
    traceValidation = compute_reducts(grammar, validation_corpus,
                                      terminal_labeling)

    # prepare EM training
    grammarInfo = PyGrammarInfo(grammar, traceTrain.get_nonterminal_map())
    if not grammarInfo.check_for_consistency():
        print("[Genetic] GrammarInfo is not consistent!")

    storageManager = PyStorageManager()

    em_builder = PySplitMergeTrainerBuilder(traceTrain, grammarInfo)
    em_builder.set_em_epochs(em_epochs)
    em_builder.set_simple_expector(threads=threads)
    emTrainer = em_builder.build()

    # randomize initial weights and do em training
    la_no_splits = build_PyLatentAnnotation_initial(grammar, grammarInfo,
                                                    storageManager)
    la_no_splits.add_random_noise(seed=seed)
    emTrainer.em_train(la_no_splits)
    la_no_splits.project_weights(grammar, grammarInfo)

    # emTrainerOld = PyEMTrainer(traceTrain)
    # emTrainerOld.em_training(grammar, 30, "rfe", tie_breaking=True)

    # compute parses for validation set
    baseline_parser = GFParser_k_best(grammar, k=k_best)
    validator = build_score_validator(grammar, grammarInfo,
                                      traceTrain.get_nonterminal_map(),
                                      storageManager, terminal_labeling,
                                      baseline_parser, validation_corpus,
                                      validationMethod)
    del baseline_parser

    # prepare SM training
    builder = PySplitMergeTrainerBuilder(traceTrain, grammarInfo)
    builder.set_em_epochs(em_epochs)
    builder.set_split_randomization(1.0, seed + 1)
    builder.set_simple_expector(threads=threads)
    builder.set_score_validator(validator, validationDropIterations)
    builder.set_smoothing_factor(smoothingFactor=smoothing_factor)
    builder.set_split_randomization(percent=split_randomization)
    splitMergeTrainer = builder.set_scc_merger(threshold=scc_merger_threshold,
                                               threads=threads).build()

    splitMergeTrainer.setMaxDrops(validationDropIterations, mode="smoothing")
    splitMergeTrainer.setEMepochs(em_epochs, mode="smoothing")

    # set initial latent annotation
    latentAnnotations = []
    for i in range(0, genetic_initial):
        splitMergeTrainer.reset_random_seed(seed + i + 1)
        la = splitMergeTrainer.split_merge_cycle(la_no_splits)
        if not la.check_for_validity():
            print('[Genetic] Initial LA', i,
                  'is not consistent! (See details before)')
        if not la.is_proper():
            print('[Genetic] Initial LA', i, 'is not proper!')
        heapq.heappush(
            latentAnnotations,
            (evaluate_la(grammar, grammarInfo, la, traceValidationGenetic,
                         validation_genetic_corpus), i, la))
        print('[Genetic]    added initial LA', i)
    (fBest, idBest, laBest) = min(latentAnnotations)
    validation_score = evaluate_la(grammar, grammarInfo, laBest,
                                   traceValidation, test_corpus)
    print("[Genetic] Started with best F-Score (Test) of", validation_score,
          "from Annotation ", idBest)

    geneticCount = genetic_initial
    random.seed(seed)
    for round in range(1, genetic_cycles + 1):
        print("[Genetic] Starting Recombination Round ", round)
        # newpopulation = list(latentAnnotations)
        newpopulation = []
        # Cross all candidates!
        for leftIndex in range(0, len(latentAnnotations)):
            (fLeft, idLeft, left) = latentAnnotations[leftIndex]
            # TODO: How to determine NTs to keep?

            # do SM-Training
            print("[Genetic] do SM-training on", idLeft, "and create LA",
                  geneticCount)
            la = splitMergeTrainer.split_merge_cycle(la)
            if not la.check_for_validity():
                print(
                    '[Genetic] Split/Merge introduced invalid weights into LA',
                    geneticCount)
            if not la.is_proper():
                print(
                    '[Genetic] Split/Merge introduced problems with properness of LA',
                    geneticCount)

            fscore = evaluate_la(grammar, grammarInfo, la,
                                 traceValidationGenetic,
                                 validation_genetic_corpus)
            print("[Genetic] LA", geneticCount, "has F-score: ", fscore)
            heapq.heappush(newpopulation, (fscore, geneticCount, la))
            geneticCount += 1
        heapq.heapify(newpopulation)
        latentAnnotations = heapq.nsmallest(
            genetic_population, heapq.merge(latentAnnotations, newpopulation))
        heapq.heapify(latentAnnotations)
        (fBest, idBest, laBest) = min(latentAnnotations)
        validation_score = evaluate_la(grammar, grammarInfo, laBest,
                                       traceValidation, test_corpus)
        print("[Genetic] Best LA", idBest, "has F-Score (Test) of ",
              validation_score)
예제 #7
0
def main(limit=3000,
         test_limit=sys.maxint,
         max_length=sys.maxint,
         dir=dir,
         train='../res/negra-dep/negra-lower-punct-train.conll',
         test='../res/negra-dep/negra-lower-punct-test.conll',
         recursive_partitioning='cfg',
         nonterminal_labeling='childtop-deprel',
         terminal_labeling='form-unk-30/pos',
         emEpochs=20,
         emTieBreaking=True,
         emInit="rfe",
         splitRandomization=1.0,
         mergePercentage=85.0,
         smCycles=6,
         rule_pruning=0.0001,
         rule_smoothing=0.01,
         validation=True,
         validationMethod='likelihood',
         validationCorpus=None,
         validationSplit=20,
         validationDropIterations=6,
         seed=1337,
         discr=False,
         maxScaleDiscr=10,
         recompileGrammar="True",
         retrain=False,
         parsing=True,
         reparse=False,
         parser="CFG",
         k_best=50,
         minimum_risk=False,
         oracle_parse=False):

    # set various parameters
    recompileGrammar = True if recompileGrammar == "True" else False

    # print(recompileGrammar)

    def result(gram, add=None):
        if add is not None:
            return os.path.join(
                dir, gram + '_experiment_parse_results_' + add + '.conll')
        else:
            return os.path.join(dir, gram + '_experiment_parse_results.conll')

    recursive_partitioning = grammar.induction.recursive_partitioning.the_recursive_partitioning_factory(
    ).get_partitioning(recursive_partitioning)
    top_level, low_level = tuple(nonterminal_labeling.split('-'))
    nonterminal_labeling = d_l.the_labeling_factory(
    ).create_simple_labeling_strategy(top_level, low_level)

    if parser == "CFG":
        assert all([
            rp.__name__
            in ["left_branching", "right_branching", "cfg", "fanout_1"]
            for rp in recursive_partitioning
        ])
        parser = CFGParser
    elif parser == "GF":
        parser = GFParser
    elif parser == "GF-k-best":
        parser = GFParser_k_best
    elif parser == "CoarseToFine":
        parser = Coarse_to_fine_parser
    elif parser == "FST":
        if recursive_partitioning == "left_branching":
            parser = LeftBranchingFSTParser
        elif recursive_partitioning == "right_branching":
            parser = RightBranchingFSTParser
        else:
            assert False and "expect left/right branching recursive partitioning for FST parsing"

    if validation:
        if validationCorpus is not None:
            corpus_validation = Corpus(validationCorpus)
            train_limit = limit
        else:
            train_limit = int(limit * (100.0 - validationSplit) / 100.0)
            corpus_validation = Corpus(train, start=train_limit, end=limit)
    else:
        train_limit = limit

    corpus_induce = Corpus(train, end=limit)
    corpus_train = Corpus(train, end=train_limit)
    corpus_test = Corpus(test, end=test_limit)

    match = re.match(r'^form-unk-(\d+)-morph.*$', terminal_labeling)
    if match:
        unk_threshold = int(match.group(1))
        term_labelling = grammar.induction.terminal_labeling.FormPosTerminalsUnkMorph(
            corpus_induce.get_trees(),
            unk_threshold,
            pos_filter=["NE", "CARD"],
            add_morph={
                'NN': ['case', 'number', 'gender']
                # , 'NE': ['case', 'number', 'gender']
                # , 'VMFIN': ['number', 'person']
                # , 'VVFIN': ['number', 'person']
                # , 'VAFIN': ['number', 'person']
            })
    else:
        match = re.match(r'^form-unk-(\d+).*$', terminal_labeling)
        if match:
            unk_threshold = int(match.group(1))
            term_labelling = grammar.induction.terminal_labeling.FormPosTerminalsUnk(
                corpus_induce.get_trees(),
                unk_threshold,
                pos_filter=["NE", "CARD"])
        else:
            term_labelling = grammar.induction.terminal_labeling.the_terminal_labeling_factory(
            ).get_strategy(terminal_labeling)

    if not os.path.isdir(dir):
        os.makedirs(dir)

    # start actual training
    # we use the training corpus until limit for grammar induction (i.e., also the validation section)
    print("Computing baseline id: ")
    baseline_id = grammar_id(corpus_induce, nonterminal_labeling,
                             term_labelling, recursive_partitioning)
    print(baseline_id)
    baseline_path = compute_grammar_name(dir, baseline_id, "baseline")

    if recompileGrammar or not os.path.isfile(baseline_path):
        print("Inducing grammar from corpus")
        (n_trees, baseline_grammar) = d_i.induce_grammar(
            corpus_induce.get_trees(), nonterminal_labeling,
            term_labelling.token_label, recursive_partitioning, start)
        print("Induced grammar using", n_trees, ".")
        pickle.dump(baseline_grammar, open(baseline_path, 'wb'))
    else:
        print("Loading grammar from file")
        baseline_grammar = pickle.load(open(baseline_path))

    print("Rules: ", len(baseline_grammar.rules()))

    if parsing:
        parser_ = GFParser_k_best if parser == Coarse_to_fine_parser else parser
        baseline_parser = do_parsing(baseline_grammar,
                                     corpus_test,
                                     term_labelling,
                                     result,
                                     baseline_id,
                                     parser_,
                                     k_best=k_best,
                                     minimum_risk=minimum_risk,
                                     oracle_parse=oracle_parse,
                                     recompile=recompileGrammar,
                                     dir=dir,
                                     reparse=reparse)

    if True:
        em_trained = pickle.load(open(baseline_path))
        reduct_path = compute_reduct_name(dir, baseline_id, corpus_train)
        if recompileGrammar or not os.path.isfile(reduct_path):
            trace = compute_reducts(em_trained, corpus_train.get_trees(),
                                    term_labelling)
            trace.serialize(reduct_path)
        else:
            print("loading trace")
            trace = PySDCPTraceManager(em_trained, term_labelling)
            trace.load_traces_from_file(reduct_path)

        if discr:
            reduct_path_discr = compute_reduct_name(dir, baseline_id,
                                                    corpus_train, '_discr')
            if recompileGrammar or not os.path.isfile(reduct_path_discr):
                trace_discr = compute_LCFRS_reducts(
                    em_trained,
                    corpus_train.get_trees(),
                    terminal_labelling=term_labelling,
                    nonterminal_map=trace.get_nonterminal_map())
                trace_discr.serialize(reduct_path_discr)
            else:
                print("loading trace discriminative")
                trace_discr = PyLCFRSTraceManager(em_trained,
                                                  trace.get_nonterminal_map())
                trace_discr.load_traces_from_file(reduct_path_discr)

        # todo refactor EM training, to use the LA version (but without any splits)
        """
        em_trained_path_ = em_trained_path(dir, grammar_id, n_epochs=emEpochs, init=emInit, tie_breaking=emTieBreaking, seed=seed)

        if recompileGrammar or retrain or not os.path.isfile(em_trained_path_):
            emTrainer = PyEMTrainer(trace)
            emTrainer.em_training(em_trained, n_epochs=emEpochs, init=emInit, tie_breaking=emTieBreaking, seed=seed)
            pickle.dump(em_trained, open(em_trained_path_, 'wb'))
        else:
            em_trained = pickle.load(open(em_trained_path_, 'rb'))

        if parsing:
            do_parsing(em_trained, test_limit, ignore_punctuation, term_labelling, recompileGrammar or retrain, [dir, "em_trained_gf_grammar"])
        """

        grammarInfo = PyGrammarInfo(baseline_grammar,
                                    trace.get_nonterminal_map())
        storageManager = PyStorageManager()

        builder = PySplitMergeTrainerBuilder(trace, grammarInfo)
        builder.set_em_epochs(emEpochs)
        builder.set_smoothing_factor(rule_smoothing)
        builder.set_split_randomization(splitRandomization, seed + 1)
        if discr:
            builder.set_discriminative_expector(trace_discr,
                                                maxScale=maxScaleDiscr,
                                                threads=1)
        else:
            builder.set_simple_expector(threads=1)
        if validation:
            if validationMethod is "likelihood":
                reduct_path_validation = compute_reduct_name(
                    dir, baseline_id, corpus_validation)
                if recompileGrammar or not os.path.isfile(
                        reduct_path_validation):
                    validation_trace = compute_reducts(
                        em_trained, corpus_validation.get_trees(),
                        term_labelling)
                    validation_trace.serialize(reduct_path_validation)
                else:
                    print("loading trace validation")
                    validation_trace = PySDCPTraceManager(
                        em_trained, term_labelling)
                    validation_trace.load_traces_from_file(
                        reduct_path_validation)
                builder.set_simple_validator(validation_trace,
                                             maxDrops=validationDropIterations,
                                             threads=1)
            else:
                validator = build_score_validator(
                    baseline_grammar, grammarInfo, trace.get_nonterminal_map(),
                    storageManager, term_labelling, baseline_parser,
                    corpus_validation, validationMethod)
                builder.set_score_validator(validator,
                                            validationDropIterations)
        splitMergeTrainer = builder.set_percent_merger(mergePercentage).build()
        if validation:
            splitMergeTrainer.setMaxDrops(1, mode="smoothing")
            splitMergeTrainer.setEMepochs(1, mode="smoothing")

        sm_info_path = compute_sm_info_path(dir, baseline_id, emEpochs,
                                            rule_smoothing, splitRandomization,
                                            seed, discr, validation,
                                            corpus_validation, emInit)

        if (not recompileGrammar) and (
                not retrain) and os.path.isfile(sm_info_path):
            print("Loading splits and weights of LA rules")
            latentAnnotation = map(
                lambda t: build_PyLatentAnnotation(t[0], t[1], t[
                    2], grammarInfo, storageManager),
                pickle.load(open(sm_info_path, 'rb')))
        else:
            # latentAnnotation = [build_PyLatentAnnotation_initial(em_trained, grammarInfo, storageManager)]
            latentAnnotation = [
                build_PyLatentAnnotation_initial(baseline_grammar, grammarInfo,
                                                 storageManager)
            ]

        for cycle in range(smCycles + 1):
            if cycle < len(latentAnnotation):
                smGrammar = latentAnnotation[cycle].build_sm_grammar(
                    baseline_grammar,
                    grammarInfo,
                    rule_pruning=rule_pruning
                    # , rule_smoothing=rule_smoothing
                )
            else:
                # setting the seed to achieve reproducibility in case of continued training
                splitMergeTrainer.reset_random_seed(seed + cycle + 1)
                latentAnnotation.append(
                    splitMergeTrainer.split_merge_cycle(latentAnnotation[-1]))
                pickle.dump(map(lambda la: la.serialize(), latentAnnotation),
                            open(sm_info_path, 'wb'))
                smGrammar = latentAnnotation[cycle].build_sm_grammar(
                    baseline_grammar,
                    grammarInfo,
                    rule_pruning=rule_pruning
                    # , rule_smoothing=rule_smoothing
                )
            print("Cycle: ", cycle, "Rules: ", len(smGrammar.rules()))
            if parsing:
                grammar_identifier = compute_sm_grammar_id(
                    baseline_id, emEpochs, rule_smoothing, splitRandomization,
                    seed, discr, validation, corpus_validation, emInit, cycle)
                if parser == Coarse_to_fine_parser:
                    opt = {
                        'latentAnnotation':
                        latentAnnotation[:cycle + 1]  #[cycle]
                        ,
                        'grammarInfo': grammarInfo,
                        'nontMap': trace.get_nonterminal_map()
                    }
                    do_parsing(baseline_grammar,
                               corpus_test,
                               term_labelling,
                               result,
                               grammar_identifier,
                               parser,
                               k_best=k_best,
                               minimum_risk=minimum_risk,
                               oracle_parse=oracle_parse,
                               recompile=recompileGrammar,
                               dir=dir,
                               reparse=reparse,
                               opt=opt)
                else:
                    do_parsing(smGrammar,
                               corpus_test,
                               term_labelling,
                               result,
                               grammar_identifier,
                               parser,
                               k_best=k_best,
                               minimum_risk=minimum_risk,
                               oracle_parse=oracle_parse,
                               recompile=recompileGrammar,
                               dir=dir,
                               reparse=reparse)
예제 #8
0
def main(limit=300,
         ignore_punctuation=False,
         baseline_path=baseline_path,
         recompileGrammar=True,
         retrain=True,
         parsing=True,
         seed=1337):
    max_length = 20
    trees = length_limit(parse_conll_corpus(train, False, limit), max_length)

    if recompileGrammar or not os.path.isfile(baseline_path):
        (n_trees,
         baseline_grammar) = d_i.induce_grammar(trees, empty_labelling,
                                                term_labelling.token_label,
                                                recursive_partitioning, start)
        pickle.dump(baseline_grammar, open(baseline_path, 'wb'))
    else:
        baseline_grammar = pickle.load(open(baseline_path))

    test_limit = 10000
    print("Rules: ", len(baseline_grammar.rules()))

    if parsing:
        do_parsing(baseline_grammar, test_limit, ignore_punctuation,
                   recompileGrammar, [dir, "baseline_gf_grammar"])

    em_trained = pickle.load(open(baseline_path))
    if recompileGrammar or not os.path.isfile(reduct_path):
        trees = length_limit(parse_conll_corpus(train, False, limit),
                             max_length)
        trace = compute_reducts(em_trained, trees, term_labelling)
        trace.serialize(reduct_path)
    else:
        print("loading trace")
        trace = PySDCPTraceManager(em_trained, term_labelling)
        trace.load_traces_from_file(reduct_path)

    discr = False
    if discr:
        if recompileGrammar or not os.path.isfile(reduct_path_discr):
            trees = length_limit(parse_conll_corpus(train, False, limit),
                                 max_length)
            trace_discr = compute_LCFRS_reducts(
                em_trained,
                trees,
                term_labelling,
                nonterminal_map=trace.get_nonterminal_map())
            trace_discr.serialize(reduct_path_discr)
        else:
            print("loading trace discriminative")
            trace_discr = PyLCFRSTraceManager(em_trained,
                                              trace.get_nonterminal_map())
            trace_discr.load_traces_from_file(reduct_path_discr)

    n_epochs = 20
    init = "rfe"
    tie_breaking = True
    em_trained_path_ = em_trained_path(n_epochs, init, tie_breaking)

    if recompileGrammar or retrain or not os.path.isfile(em_trained_path_):
        emTrainer = PyEMTrainer(trace)
        emTrainer.em_training(em_trained,
                              n_epochs=n_epochs,
                              init=init,
                              tie_breaking=tie_breaking,
                              seed=seed)
        pickle.dump(em_trained, open(em_trained_path_, 'wb'))
    else:
        em_trained = pickle.load(open(em_trained_path_, 'rb'))

    if parsing:
        do_parsing(em_trained, test_limit, ignore_punctuation, recompileGrammar
                   or retrain, [dir, "em_trained_gf_grammar"])

    grammarInfo = PyGrammarInfo(baseline_grammar, trace.get_nonterminal_map())
    storageManager = PyStorageManager()

    builder = PySplitMergeTrainerBuilder(trace, grammarInfo)
    builder.set_em_epochs(n_epochs)
    builder.set_split_randomization(1.0, seed + 1)
    if discr:
        builder.set_discriminative_expector(trace_discr,
                                            maxScale=10,
                                            threads=1)
    else:
        builder.set_simple_expector(threads=1)
    splitMergeTrainer = builder.set_percent_merger(65.0).build()

    if (not recompileGrammar) and (
            not retrain) and os.path.isfile(sm_info_path):
        print("Loading splits and weights of LA rules")
        latentAnnotation = map(
            lambda t: build_PyLatentAnnotation(t[0], t[1], t[2], grammarInfo,
                                               storageManager),
            pickle.load(open(sm_info_path, 'rb')))
    else:
        latentAnnotation = [
            build_PyLatentAnnotation_initial(em_trained, grammarInfo,
                                             storageManager)
        ]

    max_cycles = 4
    reparse = False
    # parsing = False
    for i in range(max_cycles + 1):
        if i < len(latentAnnotation):
            if reparse:
                smGrammar = latentAnnotation[i].build_sm_grammar(
                    baseline_grammar,
                    grammarInfo,
                    rule_pruning=0.0001,
                    rule_smoothing=0.01)
                print("Cycle: ", i, "Rules: ", len(smGrammar.rules()))
                do_parsing(smGrammar, test_limit, ignore_punctuation,
                           recompileGrammar or retrain,
                           [dir, "sm_cycles" + str(i) + "_gf_grammar"])
        else:
            # setting the seed to achieve reproducibility in case of continued training
            splitMergeTrainer.reset_random_seed(seed + i + 1)
            latentAnnotation.append(
                splitMergeTrainer.split_merge_cycle(latentAnnotation[-1]))
            pickle.dump(map(lambda la: la.serialize(), latentAnnotation),
                        open(sm_info_path, 'wb'))
            smGrammar = latentAnnotation[i].build_sm_grammar(
                baseline_grammar,
                grammarInfo,
                rule_pruning=0.0001,
                rule_smoothing=0.1)
            print("Cycle: ", i, "Rules: ", len(smGrammar.rules()))
            if parsing:
                do_parsing(smGrammar, test_limit, ignore_punctuation,
                           recompileGrammar or retrain,
                           [dir, "sm_cycles" + str(i) + "_gf_grammar"])