예제 #1
0
def main(limit=3000,
         test_limit=sys.maxint,
         max_length=sys.maxint,
         dir=dir,
         train='../res/negra-dep/negra-lower-punct-train.conll',
         test='../res/negra-dep/negra-lower-punct-test.conll',
         recursive_partitioning='cfg',
         nonterminal_labeling='childtop-deprel',
         terminal_labeling='form-unk-30/pos',
         emEpochs=20,
         emTieBreaking=True,
         emInit="rfe",
         splitRandomization=1.0,
         mergePercentage=85.0,
         smCycles=6,
         rule_pruning=0.0001,
         rule_smoothing=0.01,
         validation=True,
         validationMethod='likelihood',
         validationCorpus=None,
         validationSplit=20,
         validationDropIterations=6,
         seed=1337,
         discr=False,
         maxScaleDiscr=10,
         recompileGrammar="True",
         retrain=False,
         parsing=True,
         reparse=False,
         parser="CFG",
         k_best=50,
         minimum_risk=False,
         oracle_parse=False):

    # set various parameters
    recompileGrammar = True if recompileGrammar == "True" else False

    # print(recompileGrammar)

    def result(gram, add=None):
        if add is not None:
            return os.path.join(
                dir, gram + '_experiment_parse_results_' + add + '.conll')
        else:
            return os.path.join(dir, gram + '_experiment_parse_results.conll')

    recursive_partitioning = grammar.induction.recursive_partitioning.the_recursive_partitioning_factory(
    ).get_partitioning(recursive_partitioning)
    top_level, low_level = tuple(nonterminal_labeling.split('-'))
    nonterminal_labeling = d_l.the_labeling_factory(
    ).create_simple_labeling_strategy(top_level, low_level)

    if parser == "CFG":
        assert all([
            rp.__name__
            in ["left_branching", "right_branching", "cfg", "fanout_1"]
            for rp in recursive_partitioning
        ])
        parser = CFGParser
    elif parser == "GF":
        parser = GFParser
    elif parser == "GF-k-best":
        parser = GFParser_k_best
    elif parser == "CoarseToFine":
        parser = Coarse_to_fine_parser
    elif parser == "FST":
        if recursive_partitioning == "left_branching":
            parser = LeftBranchingFSTParser
        elif recursive_partitioning == "right_branching":
            parser = RightBranchingFSTParser
        else:
            assert False and "expect left/right branching recursive partitioning for FST parsing"

    if validation:
        if validationCorpus is not None:
            corpus_validation = Corpus(validationCorpus)
            train_limit = limit
        else:
            train_limit = int(limit * (100.0 - validationSplit) / 100.0)
            corpus_validation = Corpus(train, start=train_limit, end=limit)
    else:
        train_limit = limit

    corpus_induce = Corpus(train, end=limit)
    corpus_train = Corpus(train, end=train_limit)
    corpus_test = Corpus(test, end=test_limit)

    match = re.match(r'^form-unk-(\d+)-morph.*$', terminal_labeling)
    if match:
        unk_threshold = int(match.group(1))
        term_labelling = grammar.induction.terminal_labeling.FormPosTerminalsUnkMorph(
            corpus_induce.get_trees(),
            unk_threshold,
            pos_filter=["NE", "CARD"],
            add_morph={
                'NN': ['case', 'number', 'gender']
                # , 'NE': ['case', 'number', 'gender']
                # , 'VMFIN': ['number', 'person']
                # , 'VVFIN': ['number', 'person']
                # , 'VAFIN': ['number', 'person']
            })
    else:
        match = re.match(r'^form-unk-(\d+).*$', terminal_labeling)
        if match:
            unk_threshold = int(match.group(1))
            term_labelling = grammar.induction.terminal_labeling.FormPosTerminalsUnk(
                corpus_induce.get_trees(),
                unk_threshold,
                pos_filter=["NE", "CARD"])
        else:
            term_labelling = grammar.induction.terminal_labeling.the_terminal_labeling_factory(
            ).get_strategy(terminal_labeling)

    if not os.path.isdir(dir):
        os.makedirs(dir)

    # start actual training
    # we use the training corpus until limit for grammar induction (i.e., also the validation section)
    print("Computing baseline id: ")
    baseline_id = grammar_id(corpus_induce, nonterminal_labeling,
                             term_labelling, recursive_partitioning)
    print(baseline_id)
    baseline_path = compute_grammar_name(dir, baseline_id, "baseline")

    if recompileGrammar or not os.path.isfile(baseline_path):
        print("Inducing grammar from corpus")
        (n_trees, baseline_grammar) = d_i.induce_grammar(
            corpus_induce.get_trees(), nonterminal_labeling,
            term_labelling.token_label, recursive_partitioning, start)
        print("Induced grammar using", n_trees, ".")
        pickle.dump(baseline_grammar, open(baseline_path, 'wb'))
    else:
        print("Loading grammar from file")
        baseline_grammar = pickle.load(open(baseline_path))

    print("Rules: ", len(baseline_grammar.rules()))

    if parsing:
        parser_ = GFParser_k_best if parser == Coarse_to_fine_parser else parser
        baseline_parser = do_parsing(baseline_grammar,
                                     corpus_test,
                                     term_labelling,
                                     result,
                                     baseline_id,
                                     parser_,
                                     k_best=k_best,
                                     minimum_risk=minimum_risk,
                                     oracle_parse=oracle_parse,
                                     recompile=recompileGrammar,
                                     dir=dir,
                                     reparse=reparse)

    if True:
        em_trained = pickle.load(open(baseline_path))
        reduct_path = compute_reduct_name(dir, baseline_id, corpus_train)
        if recompileGrammar or not os.path.isfile(reduct_path):
            trace = compute_reducts(em_trained, corpus_train.get_trees(),
                                    term_labelling)
            trace.serialize(reduct_path)
        else:
            print("loading trace")
            trace = PySDCPTraceManager(em_trained, term_labelling)
            trace.load_traces_from_file(reduct_path)

        if discr:
            reduct_path_discr = compute_reduct_name(dir, baseline_id,
                                                    corpus_train, '_discr')
            if recompileGrammar or not os.path.isfile(reduct_path_discr):
                trace_discr = compute_LCFRS_reducts(
                    em_trained,
                    corpus_train.get_trees(),
                    terminal_labelling=term_labelling,
                    nonterminal_map=trace.get_nonterminal_map())
                trace_discr.serialize(reduct_path_discr)
            else:
                print("loading trace discriminative")
                trace_discr = PyLCFRSTraceManager(em_trained,
                                                  trace.get_nonterminal_map())
                trace_discr.load_traces_from_file(reduct_path_discr)

        # todo refactor EM training, to use the LA version (but without any splits)
        """
        em_trained_path_ = em_trained_path(dir, grammar_id, n_epochs=emEpochs, init=emInit, tie_breaking=emTieBreaking, seed=seed)

        if recompileGrammar or retrain or not os.path.isfile(em_trained_path_):
            emTrainer = PyEMTrainer(trace)
            emTrainer.em_training(em_trained, n_epochs=emEpochs, init=emInit, tie_breaking=emTieBreaking, seed=seed)
            pickle.dump(em_trained, open(em_trained_path_, 'wb'))
        else:
            em_trained = pickle.load(open(em_trained_path_, 'rb'))

        if parsing:
            do_parsing(em_trained, test_limit, ignore_punctuation, term_labelling, recompileGrammar or retrain, [dir, "em_trained_gf_grammar"])
        """

        grammarInfo = PyGrammarInfo(baseline_grammar,
                                    trace.get_nonterminal_map())
        storageManager = PyStorageManager()

        builder = PySplitMergeTrainerBuilder(trace, grammarInfo)
        builder.set_em_epochs(emEpochs)
        builder.set_smoothing_factor(rule_smoothing)
        builder.set_split_randomization(splitRandomization, seed + 1)
        if discr:
            builder.set_discriminative_expector(trace_discr,
                                                maxScale=maxScaleDiscr,
                                                threads=1)
        else:
            builder.set_simple_expector(threads=1)
        if validation:
            if validationMethod is "likelihood":
                reduct_path_validation = compute_reduct_name(
                    dir, baseline_id, corpus_validation)
                if recompileGrammar or not os.path.isfile(
                        reduct_path_validation):
                    validation_trace = compute_reducts(
                        em_trained, corpus_validation.get_trees(),
                        term_labelling)
                    validation_trace.serialize(reduct_path_validation)
                else:
                    print("loading trace validation")
                    validation_trace = PySDCPTraceManager(
                        em_trained, term_labelling)
                    validation_trace.load_traces_from_file(
                        reduct_path_validation)
                builder.set_simple_validator(validation_trace,
                                             maxDrops=validationDropIterations,
                                             threads=1)
            else:
                validator = build_score_validator(
                    baseline_grammar, grammarInfo, trace.get_nonterminal_map(),
                    storageManager, term_labelling, baseline_parser,
                    corpus_validation, validationMethod)
                builder.set_score_validator(validator,
                                            validationDropIterations)
        splitMergeTrainer = builder.set_percent_merger(mergePercentage).build()
        if validation:
            splitMergeTrainer.setMaxDrops(1, mode="smoothing")
            splitMergeTrainer.setEMepochs(1, mode="smoothing")

        sm_info_path = compute_sm_info_path(dir, baseline_id, emEpochs,
                                            rule_smoothing, splitRandomization,
                                            seed, discr, validation,
                                            corpus_validation, emInit)

        if (not recompileGrammar) and (
                not retrain) and os.path.isfile(sm_info_path):
            print("Loading splits and weights of LA rules")
            latentAnnotation = map(
                lambda t: build_PyLatentAnnotation(t[0], t[1], t[
                    2], grammarInfo, storageManager),
                pickle.load(open(sm_info_path, 'rb')))
        else:
            # latentAnnotation = [build_PyLatentAnnotation_initial(em_trained, grammarInfo, storageManager)]
            latentAnnotation = [
                build_PyLatentAnnotation_initial(baseline_grammar, grammarInfo,
                                                 storageManager)
            ]

        for cycle in range(smCycles + 1):
            if cycle < len(latentAnnotation):
                smGrammar = latentAnnotation[cycle].build_sm_grammar(
                    baseline_grammar,
                    grammarInfo,
                    rule_pruning=rule_pruning
                    # , rule_smoothing=rule_smoothing
                )
            else:
                # setting the seed to achieve reproducibility in case of continued training
                splitMergeTrainer.reset_random_seed(seed + cycle + 1)
                latentAnnotation.append(
                    splitMergeTrainer.split_merge_cycle(latentAnnotation[-1]))
                pickle.dump(map(lambda la: la.serialize(), latentAnnotation),
                            open(sm_info_path, 'wb'))
                smGrammar = latentAnnotation[cycle].build_sm_grammar(
                    baseline_grammar,
                    grammarInfo,
                    rule_pruning=rule_pruning
                    # , rule_smoothing=rule_smoothing
                )
            print("Cycle: ", cycle, "Rules: ", len(smGrammar.rules()))
            if parsing:
                grammar_identifier = compute_sm_grammar_id(
                    baseline_id, emEpochs, rule_smoothing, splitRandomization,
                    seed, discr, validation, corpus_validation, emInit, cycle)
                if parser == Coarse_to_fine_parser:
                    opt = {
                        'latentAnnotation':
                        latentAnnotation[:cycle + 1]  #[cycle]
                        ,
                        'grammarInfo': grammarInfo,
                        'nontMap': trace.get_nonterminal_map()
                    }
                    do_parsing(baseline_grammar,
                               corpus_test,
                               term_labelling,
                               result,
                               grammar_identifier,
                               parser,
                               k_best=k_best,
                               minimum_risk=minimum_risk,
                               oracle_parse=oracle_parse,
                               recompile=recompileGrammar,
                               dir=dir,
                               reparse=reparse,
                               opt=opt)
                else:
                    do_parsing(smGrammar,
                               corpus_test,
                               term_labelling,
                               result,
                               grammar_identifier,
                               parser,
                               k_best=k_best,
                               minimum_risk=minimum_risk,
                               oracle_parse=oracle_parse,
                               recompile=recompileGrammar,
                               dir=dir,
                               reparse=reparse)
예제 #2
0
def main(limit=300,
         ignore_punctuation=False,
         baseline_path=baseline_path,
         recompileGrammar=True,
         retrain=True,
         parsing=True,
         seed=1337):
    max_length = 20
    trees = length_limit(parse_conll_corpus(train, False, limit), max_length)

    if recompileGrammar or not os.path.isfile(baseline_path):
        (n_trees,
         baseline_grammar) = d_i.induce_grammar(trees, empty_labelling,
                                                term_labelling.token_label,
                                                recursive_partitioning, start)
        pickle.dump(baseline_grammar, open(baseline_path, 'wb'))
    else:
        baseline_grammar = pickle.load(open(baseline_path))

    test_limit = 10000
    print("Rules: ", len(baseline_grammar.rules()))

    if parsing:
        do_parsing(baseline_grammar, test_limit, ignore_punctuation,
                   recompileGrammar, [dir, "baseline_gf_grammar"])

    em_trained = pickle.load(open(baseline_path))
    if recompileGrammar or not os.path.isfile(reduct_path):
        trees = length_limit(parse_conll_corpus(train, False, limit),
                             max_length)
        trace = compute_reducts(em_trained, trees, term_labelling)
        trace.serialize(reduct_path)
    else:
        print("loading trace")
        trace = PySDCPTraceManager(em_trained, term_labelling)
        trace.load_traces_from_file(reduct_path)

    discr = False
    if discr:
        if recompileGrammar or not os.path.isfile(reduct_path_discr):
            trees = length_limit(parse_conll_corpus(train, False, limit),
                                 max_length)
            trace_discr = compute_LCFRS_reducts(
                em_trained,
                trees,
                term_labelling,
                nonterminal_map=trace.get_nonterminal_map())
            trace_discr.serialize(reduct_path_discr)
        else:
            print("loading trace discriminative")
            trace_discr = PyLCFRSTraceManager(em_trained,
                                              trace.get_nonterminal_map())
            trace_discr.load_traces_from_file(reduct_path_discr)

    n_epochs = 20
    init = "rfe"
    tie_breaking = True
    em_trained_path_ = em_trained_path(n_epochs, init, tie_breaking)

    if recompileGrammar or retrain or not os.path.isfile(em_trained_path_):
        emTrainer = PyEMTrainer(trace)
        emTrainer.em_training(em_trained,
                              n_epochs=n_epochs,
                              init=init,
                              tie_breaking=tie_breaking,
                              seed=seed)
        pickle.dump(em_trained, open(em_trained_path_, 'wb'))
    else:
        em_trained = pickle.load(open(em_trained_path_, 'rb'))

    if parsing:
        do_parsing(em_trained, test_limit, ignore_punctuation, recompileGrammar
                   or retrain, [dir, "em_trained_gf_grammar"])

    grammarInfo = PyGrammarInfo(baseline_grammar, trace.get_nonterminal_map())
    storageManager = PyStorageManager()

    builder = PySplitMergeTrainerBuilder(trace, grammarInfo)
    builder.set_em_epochs(n_epochs)
    builder.set_split_randomization(1.0, seed + 1)
    if discr:
        builder.set_discriminative_expector(trace_discr,
                                            maxScale=10,
                                            threads=1)
    else:
        builder.set_simple_expector(threads=1)
    splitMergeTrainer = builder.set_percent_merger(65.0).build()

    if (not recompileGrammar) and (
            not retrain) and os.path.isfile(sm_info_path):
        print("Loading splits and weights of LA rules")
        latentAnnotation = map(
            lambda t: build_PyLatentAnnotation(t[0], t[1], t[2], grammarInfo,
                                               storageManager),
            pickle.load(open(sm_info_path, 'rb')))
    else:
        latentAnnotation = [
            build_PyLatentAnnotation_initial(em_trained, grammarInfo,
                                             storageManager)
        ]

    max_cycles = 4
    reparse = False
    # parsing = False
    for i in range(max_cycles + 1):
        if i < len(latentAnnotation):
            if reparse:
                smGrammar = latentAnnotation[i].build_sm_grammar(
                    baseline_grammar,
                    grammarInfo,
                    rule_pruning=0.0001,
                    rule_smoothing=0.01)
                print("Cycle: ", i, "Rules: ", len(smGrammar.rules()))
                do_parsing(smGrammar, test_limit, ignore_punctuation,
                           recompileGrammar or retrain,
                           [dir, "sm_cycles" + str(i) + "_gf_grammar"])
        else:
            # setting the seed to achieve reproducibility in case of continued training
            splitMergeTrainer.reset_random_seed(seed + i + 1)
            latentAnnotation.append(
                splitMergeTrainer.split_merge_cycle(latentAnnotation[-1]))
            pickle.dump(map(lambda la: la.serialize(), latentAnnotation),
                        open(sm_info_path, 'wb'))
            smGrammar = latentAnnotation[i].build_sm_grammar(
                baseline_grammar,
                grammarInfo,
                rule_pruning=0.0001,
                rule_smoothing=0.1)
            print("Cycle: ", i, "Rules: ", len(smGrammar.rules()))
            if parsing:
                do_parsing(smGrammar, test_limit, ignore_punctuation,
                           recompileGrammar or retrain,
                           [dir, "sm_cycles" + str(i) + "_gf_grammar"])