def test_la_viterbi_parsing_2(self): grammar = self.build_paper_grammar() inp = ["a"] * 3 nontMap = Enumerator() gi = PyGrammarInfo(grammar, nontMap) sm = PyStorageManager() print(nontMap.object_index("S")) print(nontMap.object_index("B")) la = build_PyLatentAnnotation( [2, 1], [1.0], [[0.25, 1.0], [1.0, 0.0], [0.0, 0.5, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0]], gi, sm) self.assertTrue(la.is_proper()) parser = DiscodopKbestParser(grammar, la=la, nontMap=nontMap, grammarInfo=gi, latent_viterbi_mode=True) parser.set_input(inp) parser.parse() self.assertTrue(parser.recognized()) der = parser.latent_viterbi_derivation(True) print(der) ranges = {der.spanned_ranges(idx)[0] for idx in der.ids()} self.assertSetEqual({(0, 3), (0, 2), (0, 1), (1, 2), (2, 3)}, ranges)
def create_initial_la(self): if self.induction_settings.feature_la: print("building initial LA from features", file=self.logger) nonterminal_splits, rootWeights, ruleWeights, split_id \ = build_nont_splits_dict(self.base_grammar, self.feature_log, self.organizer.nonterminal_map, feat_function=self.induction_settings.feat_function) print("number of nonterminals:", len(nonterminal_splits), file=self.logger) print("total splits", sum(nonterminal_splits), file=self.logger) max_splits = max(nonterminal_splits) max_splits_index = nonterminal_splits.index(max_splits) max_splits_nont = self.organizer.nonterminal_map.index_object(max_splits_index) print("max. nonterminal splits", max_splits, "at index ", max_splits_index, "i.e.,", max_splits_nont, file=self.logger) for key in split_id[max_splits_nont]: print(key, file=self.logger) print("splits for NE/1", file=self.logger) for key in split_id["NE/1"]: print(key, file=self.logger) for rule in self.base_grammar.lhs_nont_to_rules("NE/1"): print(rule, ruleWeights[rule.get_idx()], file=self.logger) print("number of rules", len(ruleWeights), file=self.logger) print("total split rules", sum(map(len, ruleWeights)), file=self.logger) print("number of split rules with 0 prob.", sum(map(sum, map(lambda xs: map(lambda x: 1 if x == 0.0 else 0, xs), ruleWeights))), file=self.logger) la = build_PyLatentAnnotation(nonterminal_splits, rootWeights, ruleWeights, self.organizer.grammarInfo, self.organizer.storageManager) la.add_random_noise(seed=self.organizer.seed) self.split_id = split_id return la else: return super(ConstituentSMExperiment, self).create_initial_la()
def read_stage_file(self): # super(SplitMergeExperiment, self).read_stage_file() if "latent_annotations" in self.stage_dict: # this is a workaround if self.organizer.training_reducts is None: self.update_reducts(self.compute_reducts( self.resources[TRAINING]), type=TRAINING) self.write_stage_file() # this was a workaround self.initialize_training_environment() las = self.stage_dict["latent_annotations"] for key in las: with open(las[key], "rb") as f: splits, rootWeights, ruleWeights = pickle.load(f) # print(key) # print(len(splits), len(rootWeights), len(ruleWeights)) # print(len(self.base_grammar.nonts())) la = build_PyLatentAnnotation( splits, rootWeights, ruleWeights, self.organizer.grammarInfo, self.organizer.storageManager) self.organizer.latent_annotations[int(key)] = la if "last_sm_cycle" in self.stage_dict: self.organizer.last_sm_cycle = int( self.stage_dict["last_sm_cycle"])
def __test_projection(self, split_weights, goal_weights, merge_method=False): grammar = LCFRS("S") # rule 0 lhs = LCFRS_lhs("S") lhs.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0)]) grammar.add_rule(lhs, ["A", "A"]) # rule 1 lhs = LCFRS_lhs("A") lhs.add_arg(["a"]) grammar.add_rule(lhs, []) lhs = LCFRS_lhs("A") lhs.add_arg(["b"]) grammar.add_rule(lhs, [], weight=2.0) grammar.make_proper() # print(grammar) nonterminal_map = Enumerator() grammarInfo = PyGrammarInfo(grammar, nonterminal_map) storageManager = PyStorageManager() la = build_PyLatentAnnotation([1, 2], [1.0], split_weights, grammarInfo, storageManager) # parser = LCFRS_parser(grammar) # parser.set_input(["a", "b"]) # parser.parse() # der = parser.best_derivation_tree() # print(la.serialize()) if merge_method: la.project_weights(grammar, grammarInfo) else: splits, _, _ = la.serialize() merge_sources = [[[ split for split in range(0, splits[nont_idx]) ]] for nont_idx in range(0, nonterminal_map.get_counter())] # print("Projecting to fine grammar LA", file=self.logger) coarse_la = la.project_annotation_by_merging(grammarInfo, merge_sources, debug=False) coarse_la.project_weights(grammar, grammarInfo) # print(grammar) for i in range(3): self.assertAlmostEqual( grammar.rule_index(i).weight(), goal_weights[i])
def load_secondary_latent_annotations(self, paths): for path in paths: if path == '': continue with open(path, "rb") as f: splits, rootWeights, ruleWeights = pickle.load(f) # very basic tests to avoid incompatible LAs assert len(splits) == len(self.base_grammar.nonts()) assert len(ruleWeights) == len(self.base_grammar.rules()) la = build_PyLatentAnnotation(splits, rootWeights, ruleWeights, self.organizer.grammarInfo, self.organizer.storageManager) self.organizer.secondary_latent_annotations.append(la)
def main(limit=3000, test_limit=sys.maxint, max_length=sys.maxint, dir=dir, train='../res/negra-dep/negra-lower-punct-train.conll', test='../res/negra-dep/negra-lower-punct-test.conll', recursive_partitioning='cfg', nonterminal_labeling='childtop-deprel', terminal_labeling='form-unk-30/pos', emEpochs=20, emTieBreaking=True, emInit="rfe", splitRandomization=1.0, mergePercentage=85.0, smCycles=6, rule_pruning=0.0001, rule_smoothing=0.01, validation=True, validationMethod='likelihood', validationCorpus=None, validationSplit=20, validationDropIterations=6, seed=1337, discr=False, maxScaleDiscr=10, recompileGrammar="True", retrain=False, parsing=True, reparse=False, parser="CFG", k_best=50, minimum_risk=False, oracle_parse=False): # set various parameters recompileGrammar = True if recompileGrammar == "True" else False # print(recompileGrammar) def result(gram, add=None): if add is not None: return os.path.join( dir, gram + '_experiment_parse_results_' + add + '.conll') else: return os.path.join(dir, gram + '_experiment_parse_results.conll') recursive_partitioning = grammar.induction.recursive_partitioning.the_recursive_partitioning_factory( ).get_partitioning(recursive_partitioning) top_level, low_level = tuple(nonterminal_labeling.split('-')) nonterminal_labeling = d_l.the_labeling_factory( ).create_simple_labeling_strategy(top_level, low_level) if parser == "CFG": assert all([ rp.__name__ in ["left_branching", "right_branching", "cfg", "fanout_1"] for rp in recursive_partitioning ]) parser = CFGParser elif parser == "GF": parser = GFParser elif parser == "GF-k-best": parser = GFParser_k_best elif parser == "CoarseToFine": parser = Coarse_to_fine_parser elif parser == "FST": if recursive_partitioning == "left_branching": parser = LeftBranchingFSTParser elif recursive_partitioning == "right_branching": parser = RightBranchingFSTParser else: assert False and "expect left/right branching recursive partitioning for FST parsing" if validation: if validationCorpus is not None: corpus_validation = Corpus(validationCorpus) train_limit = limit else: train_limit = int(limit * (100.0 - validationSplit) / 100.0) corpus_validation = Corpus(train, start=train_limit, end=limit) else: train_limit = limit corpus_induce = Corpus(train, end=limit) corpus_train = Corpus(train, end=train_limit) corpus_test = Corpus(test, end=test_limit) match = re.match(r'^form-unk-(\d+)-morph.*$', terminal_labeling) if match: unk_threshold = int(match.group(1)) term_labelling = grammar.induction.terminal_labeling.FormPosTerminalsUnkMorph( corpus_induce.get_trees(), unk_threshold, pos_filter=["NE", "CARD"], add_morph={ 'NN': ['case', 'number', 'gender'] # , 'NE': ['case', 'number', 'gender'] # , 'VMFIN': ['number', 'person'] # , 'VVFIN': ['number', 'person'] # , 'VAFIN': ['number', 'person'] }) else: match = re.match(r'^form-unk-(\d+).*$', terminal_labeling) if match: unk_threshold = int(match.group(1)) term_labelling = grammar.induction.terminal_labeling.FormPosTerminalsUnk( corpus_induce.get_trees(), unk_threshold, pos_filter=["NE", "CARD"]) else: term_labelling = grammar.induction.terminal_labeling.the_terminal_labeling_factory( ).get_strategy(terminal_labeling) if not os.path.isdir(dir): os.makedirs(dir) # start actual training # we use the training corpus until limit for grammar induction (i.e., also the validation section) print("Computing baseline id: ") baseline_id = grammar_id(corpus_induce, nonterminal_labeling, term_labelling, recursive_partitioning) print(baseline_id) baseline_path = compute_grammar_name(dir, baseline_id, "baseline") if recompileGrammar or not os.path.isfile(baseline_path): print("Inducing grammar from corpus") (n_trees, baseline_grammar) = d_i.induce_grammar( corpus_induce.get_trees(), nonterminal_labeling, term_labelling.token_label, recursive_partitioning, start) print("Induced grammar using", n_trees, ".") pickle.dump(baseline_grammar, open(baseline_path, 'wb')) else: print("Loading grammar from file") baseline_grammar = pickle.load(open(baseline_path)) print("Rules: ", len(baseline_grammar.rules())) if parsing: parser_ = GFParser_k_best if parser == Coarse_to_fine_parser else parser baseline_parser = do_parsing(baseline_grammar, corpus_test, term_labelling, result, baseline_id, parser_, k_best=k_best, minimum_risk=minimum_risk, oracle_parse=oracle_parse, recompile=recompileGrammar, dir=dir, reparse=reparse) if True: em_trained = pickle.load(open(baseline_path)) reduct_path = compute_reduct_name(dir, baseline_id, corpus_train) if recompileGrammar or not os.path.isfile(reduct_path): trace = compute_reducts(em_trained, corpus_train.get_trees(), term_labelling) trace.serialize(reduct_path) else: print("loading trace") trace = PySDCPTraceManager(em_trained, term_labelling) trace.load_traces_from_file(reduct_path) if discr: reduct_path_discr = compute_reduct_name(dir, baseline_id, corpus_train, '_discr') if recompileGrammar or not os.path.isfile(reduct_path_discr): trace_discr = compute_LCFRS_reducts( em_trained, corpus_train.get_trees(), terminal_labelling=term_labelling, nonterminal_map=trace.get_nonterminal_map()) trace_discr.serialize(reduct_path_discr) else: print("loading trace discriminative") trace_discr = PyLCFRSTraceManager(em_trained, trace.get_nonterminal_map()) trace_discr.load_traces_from_file(reduct_path_discr) # todo refactor EM training, to use the LA version (but without any splits) """ em_trained_path_ = em_trained_path(dir, grammar_id, n_epochs=emEpochs, init=emInit, tie_breaking=emTieBreaking, seed=seed) if recompileGrammar or retrain or not os.path.isfile(em_trained_path_): emTrainer = PyEMTrainer(trace) emTrainer.em_training(em_trained, n_epochs=emEpochs, init=emInit, tie_breaking=emTieBreaking, seed=seed) pickle.dump(em_trained, open(em_trained_path_, 'wb')) else: em_trained = pickle.load(open(em_trained_path_, 'rb')) if parsing: do_parsing(em_trained, test_limit, ignore_punctuation, term_labelling, recompileGrammar or retrain, [dir, "em_trained_gf_grammar"]) """ grammarInfo = PyGrammarInfo(baseline_grammar, trace.get_nonterminal_map()) storageManager = PyStorageManager() builder = PySplitMergeTrainerBuilder(trace, grammarInfo) builder.set_em_epochs(emEpochs) builder.set_smoothing_factor(rule_smoothing) builder.set_split_randomization(splitRandomization, seed + 1) if discr: builder.set_discriminative_expector(trace_discr, maxScale=maxScaleDiscr, threads=1) else: builder.set_simple_expector(threads=1) if validation: if validationMethod is "likelihood": reduct_path_validation = compute_reduct_name( dir, baseline_id, corpus_validation) if recompileGrammar or not os.path.isfile( reduct_path_validation): validation_trace = compute_reducts( em_trained, corpus_validation.get_trees(), term_labelling) validation_trace.serialize(reduct_path_validation) else: print("loading trace validation") validation_trace = PySDCPTraceManager( em_trained, term_labelling) validation_trace.load_traces_from_file( reduct_path_validation) builder.set_simple_validator(validation_trace, maxDrops=validationDropIterations, threads=1) else: validator = build_score_validator( baseline_grammar, grammarInfo, trace.get_nonterminal_map(), storageManager, term_labelling, baseline_parser, corpus_validation, validationMethod) builder.set_score_validator(validator, validationDropIterations) splitMergeTrainer = builder.set_percent_merger(mergePercentage).build() if validation: splitMergeTrainer.setMaxDrops(1, mode="smoothing") splitMergeTrainer.setEMepochs(1, mode="smoothing") sm_info_path = compute_sm_info_path(dir, baseline_id, emEpochs, rule_smoothing, splitRandomization, seed, discr, validation, corpus_validation, emInit) if (not recompileGrammar) and ( not retrain) and os.path.isfile(sm_info_path): print("Loading splits and weights of LA rules") latentAnnotation = map( lambda t: build_PyLatentAnnotation(t[0], t[1], t[ 2], grammarInfo, storageManager), pickle.load(open(sm_info_path, 'rb'))) else: # latentAnnotation = [build_PyLatentAnnotation_initial(em_trained, grammarInfo, storageManager)] latentAnnotation = [ build_PyLatentAnnotation_initial(baseline_grammar, grammarInfo, storageManager) ] for cycle in range(smCycles + 1): if cycle < len(latentAnnotation): smGrammar = latentAnnotation[cycle].build_sm_grammar( baseline_grammar, grammarInfo, rule_pruning=rule_pruning # , rule_smoothing=rule_smoothing ) else: # setting the seed to achieve reproducibility in case of continued training splitMergeTrainer.reset_random_seed(seed + cycle + 1) latentAnnotation.append( splitMergeTrainer.split_merge_cycle(latentAnnotation[-1])) pickle.dump(map(lambda la: la.serialize(), latentAnnotation), open(sm_info_path, 'wb')) smGrammar = latentAnnotation[cycle].build_sm_grammar( baseline_grammar, grammarInfo, rule_pruning=rule_pruning # , rule_smoothing=rule_smoothing ) print("Cycle: ", cycle, "Rules: ", len(smGrammar.rules())) if parsing: grammar_identifier = compute_sm_grammar_id( baseline_id, emEpochs, rule_smoothing, splitRandomization, seed, discr, validation, corpus_validation, emInit, cycle) if parser == Coarse_to_fine_parser: opt = { 'latentAnnotation': latentAnnotation[:cycle + 1] #[cycle] , 'grammarInfo': grammarInfo, 'nontMap': trace.get_nonterminal_map() } do_parsing(baseline_grammar, corpus_test, term_labelling, result, grammar_identifier, parser, k_best=k_best, minimum_risk=minimum_risk, oracle_parse=oracle_parse, recompile=recompileGrammar, dir=dir, reparse=reparse, opt=opt) else: do_parsing(smGrammar, corpus_test, term_labelling, result, grammar_identifier, parser, k_best=k_best, minimum_risk=minimum_risk, oracle_parse=oracle_parse, recompile=recompileGrammar, dir=dir, reparse=reparse)
def main(limit=300, ignore_punctuation=False, baseline_path=baseline_path, recompileGrammar=True, retrain=True, parsing=True, seed=1337): max_length = 20 trees = length_limit(parse_conll_corpus(train, False, limit), max_length) if recompileGrammar or not os.path.isfile(baseline_path): (n_trees, baseline_grammar) = d_i.induce_grammar(trees, empty_labelling, term_labelling.token_label, recursive_partitioning, start) pickle.dump(baseline_grammar, open(baseline_path, 'wb')) else: baseline_grammar = pickle.load(open(baseline_path)) test_limit = 10000 print("Rules: ", len(baseline_grammar.rules())) if parsing: do_parsing(baseline_grammar, test_limit, ignore_punctuation, recompileGrammar, [dir, "baseline_gf_grammar"]) em_trained = pickle.load(open(baseline_path)) if recompileGrammar or not os.path.isfile(reduct_path): trees = length_limit(parse_conll_corpus(train, False, limit), max_length) trace = compute_reducts(em_trained, trees, term_labelling) trace.serialize(reduct_path) else: print("loading trace") trace = PySDCPTraceManager(em_trained, term_labelling) trace.load_traces_from_file(reduct_path) discr = False if discr: if recompileGrammar or not os.path.isfile(reduct_path_discr): trees = length_limit(parse_conll_corpus(train, False, limit), max_length) trace_discr = compute_LCFRS_reducts( em_trained, trees, term_labelling, nonterminal_map=trace.get_nonterminal_map()) trace_discr.serialize(reduct_path_discr) else: print("loading trace discriminative") trace_discr = PyLCFRSTraceManager(em_trained, trace.get_nonterminal_map()) trace_discr.load_traces_from_file(reduct_path_discr) n_epochs = 20 init = "rfe" tie_breaking = True em_trained_path_ = em_trained_path(n_epochs, init, tie_breaking) if recompileGrammar or retrain or not os.path.isfile(em_trained_path_): emTrainer = PyEMTrainer(trace) emTrainer.em_training(em_trained, n_epochs=n_epochs, init=init, tie_breaking=tie_breaking, seed=seed) pickle.dump(em_trained, open(em_trained_path_, 'wb')) else: em_trained = pickle.load(open(em_trained_path_, 'rb')) if parsing: do_parsing(em_trained, test_limit, ignore_punctuation, recompileGrammar or retrain, [dir, "em_trained_gf_grammar"]) grammarInfo = PyGrammarInfo(baseline_grammar, trace.get_nonterminal_map()) storageManager = PyStorageManager() builder = PySplitMergeTrainerBuilder(trace, grammarInfo) builder.set_em_epochs(n_epochs) builder.set_split_randomization(1.0, seed + 1) if discr: builder.set_discriminative_expector(trace_discr, maxScale=10, threads=1) else: builder.set_simple_expector(threads=1) splitMergeTrainer = builder.set_percent_merger(65.0).build() if (not recompileGrammar) and ( not retrain) and os.path.isfile(sm_info_path): print("Loading splits and weights of LA rules") latentAnnotation = map( lambda t: build_PyLatentAnnotation(t[0], t[1], t[2], grammarInfo, storageManager), pickle.load(open(sm_info_path, 'rb'))) else: latentAnnotation = [ build_PyLatentAnnotation_initial(em_trained, grammarInfo, storageManager) ] max_cycles = 4 reparse = False # parsing = False for i in range(max_cycles + 1): if i < len(latentAnnotation): if reparse: smGrammar = latentAnnotation[i].build_sm_grammar( baseline_grammar, grammarInfo, rule_pruning=0.0001, rule_smoothing=0.01) print("Cycle: ", i, "Rules: ", len(smGrammar.rules())) do_parsing(smGrammar, test_limit, ignore_punctuation, recompileGrammar or retrain, [dir, "sm_cycles" + str(i) + "_gf_grammar"]) else: # setting the seed to achieve reproducibility in case of continued training splitMergeTrainer.reset_random_seed(seed + i + 1) latentAnnotation.append( splitMergeTrainer.split_merge_cycle(latentAnnotation[-1])) pickle.dump(map(lambda la: la.serialize(), latentAnnotation), open(sm_info_path, 'wb')) smGrammar = latentAnnotation[i].build_sm_grammar( baseline_grammar, grammarInfo, rule_pruning=0.0001, rule_smoothing=0.1) print("Cycle: ", i, "Rules: ", len(smGrammar.rules())) if parsing: do_parsing(smGrammar, test_limit, ignore_punctuation, recompileGrammar or retrain, [dir, "sm_cycles" + str(i) + "_gf_grammar"])