def prepare_split_merge_trainer(self): # prepare SM training builder = PySplitMergeTrainerBuilder(self.organizer.training_reducts, self.organizer.grammarInfo) builder.set_em_epochs(self.organizer.em_epochs_sm) builder.set_simple_expector(threads=self.organizer.threads) if self.organizer.validator_type == "SCORE": builder.set_score_validator( self.organizer.validator, self.organizer.validationDropIterations) elif self.organizer.validator_type == "SIMPLE": builder.set_simple_validator( self.organizer.validation_reducts, self.organizer.validationDropIterations) builder.set_smoothing_factor( smoothingFactor=self.organizer.smoothing_factor, smoothingFactorUnary=self.organizer.smoothing_factor_unary) builder.set_split_randomization( percent=self.organizer.split_randomization, seed=self.organizer.seed + 1) # set merger if self.organizer.merge_type == "SCC": builder.set_scc_merger(self.organizer.merge_threshold) elif self.organizer.merge_type == "THRESHOLD": builder.set_threshold_merger(self.organizer.merge_threshold) else: builder.set_percent_merger(self.organizer.merge_percentage) self.custom_sm_options(builder) self.organizer.splitMergeTrainer = builder.build() if self.organizer.validator_type in ["SCORE", "SIMPLE"]: self.organizer.splitMergeTrainer.setMaxDrops( self.organizer.validationDropIterations, mode="smoothing") self.organizer.splitMergeTrainer.setMinEpochs( self.organizer.min_epochs) self.organizer.splitMergeTrainer.setMinEpochs( self.organizer.min_epochs_smoothing, mode="smoothing") self.organizer.splitMergeTrainer.setIgnoreFailures( self.organizer.ignore_failures_smoothing, mode="smoothing") self.organizer.splitMergeTrainer.setEMepochs( self.organizer.em_epochs_sm, mode="smoothing")
def main(limit=3000, test_limit=sys.maxint, max_length=sys.maxint, dir=dir, train='../res/negra-dep/negra-lower-punct-train.conll', test='../res/negra-dep/negra-lower-punct-test.conll', recursive_partitioning='cfg', nonterminal_labeling='childtop-deprel', terminal_labeling='form-unk-30/pos', emEpochs=20, emTieBreaking=True, emInit="rfe", splitRandomization=1.0, mergePercentage=85.0, smCycles=6, rule_pruning=0.0001, rule_smoothing=0.01, validation=True, validationMethod='likelihood', validationCorpus=None, validationSplit=20, validationDropIterations=6, seed=1337, discr=False, maxScaleDiscr=10, recompileGrammar="True", retrain=False, parsing=True, reparse=False, parser="CFG", k_best=50, minimum_risk=False, oracle_parse=False): # set various parameters recompileGrammar = True if recompileGrammar == "True" else False # print(recompileGrammar) def result(gram, add=None): if add is not None: return os.path.join( dir, gram + '_experiment_parse_results_' + add + '.conll') else: return os.path.join(dir, gram + '_experiment_parse_results.conll') recursive_partitioning = grammar.induction.recursive_partitioning.the_recursive_partitioning_factory( ).get_partitioning(recursive_partitioning) top_level, low_level = tuple(nonterminal_labeling.split('-')) nonterminal_labeling = d_l.the_labeling_factory( ).create_simple_labeling_strategy(top_level, low_level) if parser == "CFG": assert all([ rp.__name__ in ["left_branching", "right_branching", "cfg", "fanout_1"] for rp in recursive_partitioning ]) parser = CFGParser elif parser == "GF": parser = GFParser elif parser == "GF-k-best": parser = GFParser_k_best elif parser == "CoarseToFine": parser = Coarse_to_fine_parser elif parser == "FST": if recursive_partitioning == "left_branching": parser = LeftBranchingFSTParser elif recursive_partitioning == "right_branching": parser = RightBranchingFSTParser else: assert False and "expect left/right branching recursive partitioning for FST parsing" if validation: if validationCorpus is not None: corpus_validation = Corpus(validationCorpus) train_limit = limit else: train_limit = int(limit * (100.0 - validationSplit) / 100.0) corpus_validation = Corpus(train, start=train_limit, end=limit) else: train_limit = limit corpus_induce = Corpus(train, end=limit) corpus_train = Corpus(train, end=train_limit) corpus_test = Corpus(test, end=test_limit) match = re.match(r'^form-unk-(\d+)-morph.*$', terminal_labeling) if match: unk_threshold = int(match.group(1)) term_labelling = grammar.induction.terminal_labeling.FormPosTerminalsUnkMorph( corpus_induce.get_trees(), unk_threshold, pos_filter=["NE", "CARD"], add_morph={ 'NN': ['case', 'number', 'gender'] # , 'NE': ['case', 'number', 'gender'] # , 'VMFIN': ['number', 'person'] # , 'VVFIN': ['number', 'person'] # , 'VAFIN': ['number', 'person'] }) else: match = re.match(r'^form-unk-(\d+).*$', terminal_labeling) if match: unk_threshold = int(match.group(1)) term_labelling = grammar.induction.terminal_labeling.FormPosTerminalsUnk( corpus_induce.get_trees(), unk_threshold, pos_filter=["NE", "CARD"]) else: term_labelling = grammar.induction.terminal_labeling.the_terminal_labeling_factory( ).get_strategy(terminal_labeling) if not os.path.isdir(dir): os.makedirs(dir) # start actual training # we use the training corpus until limit for grammar induction (i.e., also the validation section) print("Computing baseline id: ") baseline_id = grammar_id(corpus_induce, nonterminal_labeling, term_labelling, recursive_partitioning) print(baseline_id) baseline_path = compute_grammar_name(dir, baseline_id, "baseline") if recompileGrammar or not os.path.isfile(baseline_path): print("Inducing grammar from corpus") (n_trees, baseline_grammar) = d_i.induce_grammar( corpus_induce.get_trees(), nonterminal_labeling, term_labelling.token_label, recursive_partitioning, start) print("Induced grammar using", n_trees, ".") pickle.dump(baseline_grammar, open(baseline_path, 'wb')) else: print("Loading grammar from file") baseline_grammar = pickle.load(open(baseline_path)) print("Rules: ", len(baseline_grammar.rules())) if parsing: parser_ = GFParser_k_best if parser == Coarse_to_fine_parser else parser baseline_parser = do_parsing(baseline_grammar, corpus_test, term_labelling, result, baseline_id, parser_, k_best=k_best, minimum_risk=minimum_risk, oracle_parse=oracle_parse, recompile=recompileGrammar, dir=dir, reparse=reparse) if True: em_trained = pickle.load(open(baseline_path)) reduct_path = compute_reduct_name(dir, baseline_id, corpus_train) if recompileGrammar or not os.path.isfile(reduct_path): trace = compute_reducts(em_trained, corpus_train.get_trees(), term_labelling) trace.serialize(reduct_path) else: print("loading trace") trace = PySDCPTraceManager(em_trained, term_labelling) trace.load_traces_from_file(reduct_path) if discr: reduct_path_discr = compute_reduct_name(dir, baseline_id, corpus_train, '_discr') if recompileGrammar or not os.path.isfile(reduct_path_discr): trace_discr = compute_LCFRS_reducts( em_trained, corpus_train.get_trees(), terminal_labelling=term_labelling, nonterminal_map=trace.get_nonterminal_map()) trace_discr.serialize(reduct_path_discr) else: print("loading trace discriminative") trace_discr = PyLCFRSTraceManager(em_trained, trace.get_nonterminal_map()) trace_discr.load_traces_from_file(reduct_path_discr) # todo refactor EM training, to use the LA version (but without any splits) """ em_trained_path_ = em_trained_path(dir, grammar_id, n_epochs=emEpochs, init=emInit, tie_breaking=emTieBreaking, seed=seed) if recompileGrammar or retrain or not os.path.isfile(em_trained_path_): emTrainer = PyEMTrainer(trace) emTrainer.em_training(em_trained, n_epochs=emEpochs, init=emInit, tie_breaking=emTieBreaking, seed=seed) pickle.dump(em_trained, open(em_trained_path_, 'wb')) else: em_trained = pickle.load(open(em_trained_path_, 'rb')) if parsing: do_parsing(em_trained, test_limit, ignore_punctuation, term_labelling, recompileGrammar or retrain, [dir, "em_trained_gf_grammar"]) """ grammarInfo = PyGrammarInfo(baseline_grammar, trace.get_nonterminal_map()) storageManager = PyStorageManager() builder = PySplitMergeTrainerBuilder(trace, grammarInfo) builder.set_em_epochs(emEpochs) builder.set_smoothing_factor(rule_smoothing) builder.set_split_randomization(splitRandomization, seed + 1) if discr: builder.set_discriminative_expector(trace_discr, maxScale=maxScaleDiscr, threads=1) else: builder.set_simple_expector(threads=1) if validation: if validationMethod is "likelihood": reduct_path_validation = compute_reduct_name( dir, baseline_id, corpus_validation) if recompileGrammar or not os.path.isfile( reduct_path_validation): validation_trace = compute_reducts( em_trained, corpus_validation.get_trees(), term_labelling) validation_trace.serialize(reduct_path_validation) else: print("loading trace validation") validation_trace = PySDCPTraceManager( em_trained, term_labelling) validation_trace.load_traces_from_file( reduct_path_validation) builder.set_simple_validator(validation_trace, maxDrops=validationDropIterations, threads=1) else: validator = build_score_validator( baseline_grammar, grammarInfo, trace.get_nonterminal_map(), storageManager, term_labelling, baseline_parser, corpus_validation, validationMethod) builder.set_score_validator(validator, validationDropIterations) splitMergeTrainer = builder.set_percent_merger(mergePercentage).build() if validation: splitMergeTrainer.setMaxDrops(1, mode="smoothing") splitMergeTrainer.setEMepochs(1, mode="smoothing") sm_info_path = compute_sm_info_path(dir, baseline_id, emEpochs, rule_smoothing, splitRandomization, seed, discr, validation, corpus_validation, emInit) if (not recompileGrammar) and ( not retrain) and os.path.isfile(sm_info_path): print("Loading splits and weights of LA rules") latentAnnotation = map( lambda t: build_PyLatentAnnotation(t[0], t[1], t[ 2], grammarInfo, storageManager), pickle.load(open(sm_info_path, 'rb'))) else: # latentAnnotation = [build_PyLatentAnnotation_initial(em_trained, grammarInfo, storageManager)] latentAnnotation = [ build_PyLatentAnnotation_initial(baseline_grammar, grammarInfo, storageManager) ] for cycle in range(smCycles + 1): if cycle < len(latentAnnotation): smGrammar = latentAnnotation[cycle].build_sm_grammar( baseline_grammar, grammarInfo, rule_pruning=rule_pruning # , rule_smoothing=rule_smoothing ) else: # setting the seed to achieve reproducibility in case of continued training splitMergeTrainer.reset_random_seed(seed + cycle + 1) latentAnnotation.append( splitMergeTrainer.split_merge_cycle(latentAnnotation[-1])) pickle.dump(map(lambda la: la.serialize(), latentAnnotation), open(sm_info_path, 'wb')) smGrammar = latentAnnotation[cycle].build_sm_grammar( baseline_grammar, grammarInfo, rule_pruning=rule_pruning # , rule_smoothing=rule_smoothing ) print("Cycle: ", cycle, "Rules: ", len(smGrammar.rules())) if parsing: grammar_identifier = compute_sm_grammar_id( baseline_id, emEpochs, rule_smoothing, splitRandomization, seed, discr, validation, corpus_validation, emInit, cycle) if parser == Coarse_to_fine_parser: opt = { 'latentAnnotation': latentAnnotation[:cycle + 1] #[cycle] , 'grammarInfo': grammarInfo, 'nontMap': trace.get_nonterminal_map() } do_parsing(baseline_grammar, corpus_test, term_labelling, result, grammar_identifier, parser, k_best=k_best, minimum_risk=minimum_risk, oracle_parse=oracle_parse, recompile=recompileGrammar, dir=dir, reparse=reparse, opt=opt) else: do_parsing(smGrammar, corpus_test, term_labelling, result, grammar_identifier, parser, k_best=k_best, minimum_risk=minimum_risk, oracle_parse=oracle_parse, recompile=recompileGrammar, dir=dir, reparse=reparse)