Esempio n. 1
0
    def test_la_viterbi_parsing_2(self):
        grammar = self.build_paper_grammar()
        inp = ["a"] * 3
        nontMap = Enumerator()
        gi = PyGrammarInfo(grammar, nontMap)
        sm = PyStorageManager()
        print(nontMap.object_index("S"))
        print(nontMap.object_index("B"))
        la = build_PyLatentAnnotation(
            [2, 1], [1.0], [[0.25, 1.0], [1.0, 0.0],
                            [0.0, 0.5, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0]], gi, sm)
        self.assertTrue(la.is_proper())

        parser = DiscodopKbestParser(grammar,
                                     la=la,
                                     nontMap=nontMap,
                                     grammarInfo=gi,
                                     latent_viterbi_mode=True)
        parser.set_input(inp)
        parser.parse()
        self.assertTrue(parser.recognized())
        der = parser.latent_viterbi_derivation(True)
        print(der)
        ranges = {der.spanned_ranges(idx)[0] for idx in der.ids()}
        self.assertSetEqual({(0, 3), (0, 2), (0, 1), (1, 2), (2, 3)}, ranges)
Esempio n. 2
0
    def create_initial_la(self):
        if self.induction_settings.feature_la:
            print("building initial LA from features", file=self.logger)
            nonterminal_splits, rootWeights, ruleWeights, split_id \
                = build_nont_splits_dict(self.base_grammar,
                                         self.feature_log,
                                         self.organizer.nonterminal_map,
                                         feat_function=self.induction_settings.feat_function)
            print("number of nonterminals:", len(nonterminal_splits), file=self.logger)
            print("total splits", sum(nonterminal_splits), file=self.logger)
            max_splits = max(nonterminal_splits)
            max_splits_index = nonterminal_splits.index(max_splits)
            max_splits_nont = self.organizer.nonterminal_map.index_object(max_splits_index)
            print("max. nonterminal splits", max_splits, "at index ", max_splits_index,
                  "i.e.,", max_splits_nont, file=self.logger)
            for key in split_id[max_splits_nont]:
                print(key, file=self.logger)
            print("splits for NE/1", file=self.logger)
            for key in split_id["NE/1"]:
                print(key, file=self.logger)
            for rule in self.base_grammar.lhs_nont_to_rules("NE/1"):
                print(rule, ruleWeights[rule.get_idx()], file=self.logger)
            print("number of rules", len(ruleWeights), file=self.logger)
            print("total split rules", sum(map(len, ruleWeights)), file=self.logger)
            print("number of split rules with 0 prob.",
                  sum(map(sum, map(lambda xs: map(lambda x: 1 if x == 0.0 else 0, xs), ruleWeights))),
                  file=self.logger)

            la = build_PyLatentAnnotation(nonterminal_splits, rootWeights, ruleWeights, self.organizer.grammarInfo,
                                          self.organizer.storageManager)
            la.add_random_noise(seed=self.organizer.seed)
            self.split_id = split_id
            return la
        else:
            return super(ConstituentSMExperiment, self).create_initial_la()
    def read_stage_file(self):
        # super(SplitMergeExperiment, self).read_stage_file()
        if "latent_annotations" in self.stage_dict:
            # this is a workaround
            if self.organizer.training_reducts is None:
                self.update_reducts(self.compute_reducts(
                    self.resources[TRAINING]),
                                    type=TRAINING)
                self.write_stage_file()
            # this was a workaround
            self.initialize_training_environment()

            las = self.stage_dict["latent_annotations"]
            for key in las:
                with open(las[key], "rb") as f:
                    splits, rootWeights, ruleWeights = pickle.load(f)
                    # print(key)
                    # print(len(splits), len(rootWeights), len(ruleWeights))
                    # print(len(self.base_grammar.nonts()))
                    la = build_PyLatentAnnotation(
                        splits, rootWeights, ruleWeights,
                        self.organizer.grammarInfo,
                        self.organizer.storageManager)
                    self.organizer.latent_annotations[int(key)] = la
        if "last_sm_cycle" in self.stage_dict:
            self.organizer.last_sm_cycle = int(
                self.stage_dict["last_sm_cycle"])
    def __test_projection(self,
                          split_weights,
                          goal_weights,
                          merge_method=False):
        grammar = LCFRS("S")
        # rule 0
        lhs = LCFRS_lhs("S")
        lhs.add_arg([LCFRS_var(0, 0), LCFRS_var(1, 0)])
        grammar.add_rule(lhs, ["A", "A"])

        # rule 1
        lhs = LCFRS_lhs("A")
        lhs.add_arg(["a"])
        grammar.add_rule(lhs, [])

        lhs = LCFRS_lhs("A")
        lhs.add_arg(["b"])
        grammar.add_rule(lhs, [], weight=2.0)

        grammar.make_proper()
        # print(grammar)

        nonterminal_map = Enumerator()
        grammarInfo = PyGrammarInfo(grammar, nonterminal_map)
        storageManager = PyStorageManager()

        la = build_PyLatentAnnotation([1, 2], [1.0], split_weights,
                                      grammarInfo, storageManager)

        # parser = LCFRS_parser(grammar)
        # parser.set_input(["a", "b"])
        # parser.parse()
        # der = parser.best_derivation_tree()

        # print(la.serialize())
        if merge_method:
            la.project_weights(grammar, grammarInfo)
        else:
            splits, _, _ = la.serialize()
            merge_sources = [[[
                split for split in range(0, splits[nont_idx])
            ]] for nont_idx in range(0, nonterminal_map.get_counter())]

            # print("Projecting to fine grammar LA", file=self.logger)
            coarse_la = la.project_annotation_by_merging(grammarInfo,
                                                         merge_sources,
                                                         debug=False)
            coarse_la.project_weights(grammar, grammarInfo)

        # print(grammar)
        for i in range(3):
            self.assertAlmostEqual(
                grammar.rule_index(i).weight(), goal_weights[i])
 def load_secondary_latent_annotations(self, paths):
     for path in paths:
         if path == '':
             continue
         with open(path, "rb") as f:
             splits, rootWeights, ruleWeights = pickle.load(f)
             # very basic tests to avoid incompatible LAs
             assert len(splits) == len(self.base_grammar.nonts())
             assert len(ruleWeights) == len(self.base_grammar.rules())
             la = build_PyLatentAnnotation(splits, rootWeights, ruleWeights,
                                           self.organizer.grammarInfo,
                                           self.organizer.storageManager)
             self.organizer.secondary_latent_annotations.append(la)
Esempio n. 6
0
def main(limit=3000,
         test_limit=sys.maxint,
         max_length=sys.maxint,
         dir=dir,
         train='../res/negra-dep/negra-lower-punct-train.conll',
         test='../res/negra-dep/negra-lower-punct-test.conll',
         recursive_partitioning='cfg',
         nonterminal_labeling='childtop-deprel',
         terminal_labeling='form-unk-30/pos',
         emEpochs=20,
         emTieBreaking=True,
         emInit="rfe",
         splitRandomization=1.0,
         mergePercentage=85.0,
         smCycles=6,
         rule_pruning=0.0001,
         rule_smoothing=0.01,
         validation=True,
         validationMethod='likelihood',
         validationCorpus=None,
         validationSplit=20,
         validationDropIterations=6,
         seed=1337,
         discr=False,
         maxScaleDiscr=10,
         recompileGrammar="True",
         retrain=False,
         parsing=True,
         reparse=False,
         parser="CFG",
         k_best=50,
         minimum_risk=False,
         oracle_parse=False):

    # set various parameters
    recompileGrammar = True if recompileGrammar == "True" else False

    # print(recompileGrammar)

    def result(gram, add=None):
        if add is not None:
            return os.path.join(
                dir, gram + '_experiment_parse_results_' + add + '.conll')
        else:
            return os.path.join(dir, gram + '_experiment_parse_results.conll')

    recursive_partitioning = grammar.induction.recursive_partitioning.the_recursive_partitioning_factory(
    ).get_partitioning(recursive_partitioning)
    top_level, low_level = tuple(nonterminal_labeling.split('-'))
    nonterminal_labeling = d_l.the_labeling_factory(
    ).create_simple_labeling_strategy(top_level, low_level)

    if parser == "CFG":
        assert all([
            rp.__name__
            in ["left_branching", "right_branching", "cfg", "fanout_1"]
            for rp in recursive_partitioning
        ])
        parser = CFGParser
    elif parser == "GF":
        parser = GFParser
    elif parser == "GF-k-best":
        parser = GFParser_k_best
    elif parser == "CoarseToFine":
        parser = Coarse_to_fine_parser
    elif parser == "FST":
        if recursive_partitioning == "left_branching":
            parser = LeftBranchingFSTParser
        elif recursive_partitioning == "right_branching":
            parser = RightBranchingFSTParser
        else:
            assert False and "expect left/right branching recursive partitioning for FST parsing"

    if validation:
        if validationCorpus is not None:
            corpus_validation = Corpus(validationCorpus)
            train_limit = limit
        else:
            train_limit = int(limit * (100.0 - validationSplit) / 100.0)
            corpus_validation = Corpus(train, start=train_limit, end=limit)
    else:
        train_limit = limit

    corpus_induce = Corpus(train, end=limit)
    corpus_train = Corpus(train, end=train_limit)
    corpus_test = Corpus(test, end=test_limit)

    match = re.match(r'^form-unk-(\d+)-morph.*$', terminal_labeling)
    if match:
        unk_threshold = int(match.group(1))
        term_labelling = grammar.induction.terminal_labeling.FormPosTerminalsUnkMorph(
            corpus_induce.get_trees(),
            unk_threshold,
            pos_filter=["NE", "CARD"],
            add_morph={
                'NN': ['case', 'number', 'gender']
                # , 'NE': ['case', 'number', 'gender']
                # , 'VMFIN': ['number', 'person']
                # , 'VVFIN': ['number', 'person']
                # , 'VAFIN': ['number', 'person']
            })
    else:
        match = re.match(r'^form-unk-(\d+).*$', terminal_labeling)
        if match:
            unk_threshold = int(match.group(1))
            term_labelling = grammar.induction.terminal_labeling.FormPosTerminalsUnk(
                corpus_induce.get_trees(),
                unk_threshold,
                pos_filter=["NE", "CARD"])
        else:
            term_labelling = grammar.induction.terminal_labeling.the_terminal_labeling_factory(
            ).get_strategy(terminal_labeling)

    if not os.path.isdir(dir):
        os.makedirs(dir)

    # start actual training
    # we use the training corpus until limit for grammar induction (i.e., also the validation section)
    print("Computing baseline id: ")
    baseline_id = grammar_id(corpus_induce, nonterminal_labeling,
                             term_labelling, recursive_partitioning)
    print(baseline_id)
    baseline_path = compute_grammar_name(dir, baseline_id, "baseline")

    if recompileGrammar or not os.path.isfile(baseline_path):
        print("Inducing grammar from corpus")
        (n_trees, baseline_grammar) = d_i.induce_grammar(
            corpus_induce.get_trees(), nonterminal_labeling,
            term_labelling.token_label, recursive_partitioning, start)
        print("Induced grammar using", n_trees, ".")
        pickle.dump(baseline_grammar, open(baseline_path, 'wb'))
    else:
        print("Loading grammar from file")
        baseline_grammar = pickle.load(open(baseline_path))

    print("Rules: ", len(baseline_grammar.rules()))

    if parsing:
        parser_ = GFParser_k_best if parser == Coarse_to_fine_parser else parser
        baseline_parser = do_parsing(baseline_grammar,
                                     corpus_test,
                                     term_labelling,
                                     result,
                                     baseline_id,
                                     parser_,
                                     k_best=k_best,
                                     minimum_risk=minimum_risk,
                                     oracle_parse=oracle_parse,
                                     recompile=recompileGrammar,
                                     dir=dir,
                                     reparse=reparse)

    if True:
        em_trained = pickle.load(open(baseline_path))
        reduct_path = compute_reduct_name(dir, baseline_id, corpus_train)
        if recompileGrammar or not os.path.isfile(reduct_path):
            trace = compute_reducts(em_trained, corpus_train.get_trees(),
                                    term_labelling)
            trace.serialize(reduct_path)
        else:
            print("loading trace")
            trace = PySDCPTraceManager(em_trained, term_labelling)
            trace.load_traces_from_file(reduct_path)

        if discr:
            reduct_path_discr = compute_reduct_name(dir, baseline_id,
                                                    corpus_train, '_discr')
            if recompileGrammar or not os.path.isfile(reduct_path_discr):
                trace_discr = compute_LCFRS_reducts(
                    em_trained,
                    corpus_train.get_trees(),
                    terminal_labelling=term_labelling,
                    nonterminal_map=trace.get_nonterminal_map())
                trace_discr.serialize(reduct_path_discr)
            else:
                print("loading trace discriminative")
                trace_discr = PyLCFRSTraceManager(em_trained,
                                                  trace.get_nonterminal_map())
                trace_discr.load_traces_from_file(reduct_path_discr)

        # todo refactor EM training, to use the LA version (but without any splits)
        """
        em_trained_path_ = em_trained_path(dir, grammar_id, n_epochs=emEpochs, init=emInit, tie_breaking=emTieBreaking, seed=seed)

        if recompileGrammar or retrain or not os.path.isfile(em_trained_path_):
            emTrainer = PyEMTrainer(trace)
            emTrainer.em_training(em_trained, n_epochs=emEpochs, init=emInit, tie_breaking=emTieBreaking, seed=seed)
            pickle.dump(em_trained, open(em_trained_path_, 'wb'))
        else:
            em_trained = pickle.load(open(em_trained_path_, 'rb'))

        if parsing:
            do_parsing(em_trained, test_limit, ignore_punctuation, term_labelling, recompileGrammar or retrain, [dir, "em_trained_gf_grammar"])
        """

        grammarInfo = PyGrammarInfo(baseline_grammar,
                                    trace.get_nonterminal_map())
        storageManager = PyStorageManager()

        builder = PySplitMergeTrainerBuilder(trace, grammarInfo)
        builder.set_em_epochs(emEpochs)
        builder.set_smoothing_factor(rule_smoothing)
        builder.set_split_randomization(splitRandomization, seed + 1)
        if discr:
            builder.set_discriminative_expector(trace_discr,
                                                maxScale=maxScaleDiscr,
                                                threads=1)
        else:
            builder.set_simple_expector(threads=1)
        if validation:
            if validationMethod is "likelihood":
                reduct_path_validation = compute_reduct_name(
                    dir, baseline_id, corpus_validation)
                if recompileGrammar or not os.path.isfile(
                        reduct_path_validation):
                    validation_trace = compute_reducts(
                        em_trained, corpus_validation.get_trees(),
                        term_labelling)
                    validation_trace.serialize(reduct_path_validation)
                else:
                    print("loading trace validation")
                    validation_trace = PySDCPTraceManager(
                        em_trained, term_labelling)
                    validation_trace.load_traces_from_file(
                        reduct_path_validation)
                builder.set_simple_validator(validation_trace,
                                             maxDrops=validationDropIterations,
                                             threads=1)
            else:
                validator = build_score_validator(
                    baseline_grammar, grammarInfo, trace.get_nonterminal_map(),
                    storageManager, term_labelling, baseline_parser,
                    corpus_validation, validationMethod)
                builder.set_score_validator(validator,
                                            validationDropIterations)
        splitMergeTrainer = builder.set_percent_merger(mergePercentage).build()
        if validation:
            splitMergeTrainer.setMaxDrops(1, mode="smoothing")
            splitMergeTrainer.setEMepochs(1, mode="smoothing")

        sm_info_path = compute_sm_info_path(dir, baseline_id, emEpochs,
                                            rule_smoothing, splitRandomization,
                                            seed, discr, validation,
                                            corpus_validation, emInit)

        if (not recompileGrammar) and (
                not retrain) and os.path.isfile(sm_info_path):
            print("Loading splits and weights of LA rules")
            latentAnnotation = map(
                lambda t: build_PyLatentAnnotation(t[0], t[1], t[
                    2], grammarInfo, storageManager),
                pickle.load(open(sm_info_path, 'rb')))
        else:
            # latentAnnotation = [build_PyLatentAnnotation_initial(em_trained, grammarInfo, storageManager)]
            latentAnnotation = [
                build_PyLatentAnnotation_initial(baseline_grammar, grammarInfo,
                                                 storageManager)
            ]

        for cycle in range(smCycles + 1):
            if cycle < len(latentAnnotation):
                smGrammar = latentAnnotation[cycle].build_sm_grammar(
                    baseline_grammar,
                    grammarInfo,
                    rule_pruning=rule_pruning
                    # , rule_smoothing=rule_smoothing
                )
            else:
                # setting the seed to achieve reproducibility in case of continued training
                splitMergeTrainer.reset_random_seed(seed + cycle + 1)
                latentAnnotation.append(
                    splitMergeTrainer.split_merge_cycle(latentAnnotation[-1]))
                pickle.dump(map(lambda la: la.serialize(), latentAnnotation),
                            open(sm_info_path, 'wb'))
                smGrammar = latentAnnotation[cycle].build_sm_grammar(
                    baseline_grammar,
                    grammarInfo,
                    rule_pruning=rule_pruning
                    # , rule_smoothing=rule_smoothing
                )
            print("Cycle: ", cycle, "Rules: ", len(smGrammar.rules()))
            if parsing:
                grammar_identifier = compute_sm_grammar_id(
                    baseline_id, emEpochs, rule_smoothing, splitRandomization,
                    seed, discr, validation, corpus_validation, emInit, cycle)
                if parser == Coarse_to_fine_parser:
                    opt = {
                        'latentAnnotation':
                        latentAnnotation[:cycle + 1]  #[cycle]
                        ,
                        'grammarInfo': grammarInfo,
                        'nontMap': trace.get_nonterminal_map()
                    }
                    do_parsing(baseline_grammar,
                               corpus_test,
                               term_labelling,
                               result,
                               grammar_identifier,
                               parser,
                               k_best=k_best,
                               minimum_risk=minimum_risk,
                               oracle_parse=oracle_parse,
                               recompile=recompileGrammar,
                               dir=dir,
                               reparse=reparse,
                               opt=opt)
                else:
                    do_parsing(smGrammar,
                               corpus_test,
                               term_labelling,
                               result,
                               grammar_identifier,
                               parser,
                               k_best=k_best,
                               minimum_risk=minimum_risk,
                               oracle_parse=oracle_parse,
                               recompile=recompileGrammar,
                               dir=dir,
                               reparse=reparse)
Esempio n. 7
0
def main(limit=300,
         ignore_punctuation=False,
         baseline_path=baseline_path,
         recompileGrammar=True,
         retrain=True,
         parsing=True,
         seed=1337):
    max_length = 20
    trees = length_limit(parse_conll_corpus(train, False, limit), max_length)

    if recompileGrammar or not os.path.isfile(baseline_path):
        (n_trees,
         baseline_grammar) = d_i.induce_grammar(trees, empty_labelling,
                                                term_labelling.token_label,
                                                recursive_partitioning, start)
        pickle.dump(baseline_grammar, open(baseline_path, 'wb'))
    else:
        baseline_grammar = pickle.load(open(baseline_path))

    test_limit = 10000
    print("Rules: ", len(baseline_grammar.rules()))

    if parsing:
        do_parsing(baseline_grammar, test_limit, ignore_punctuation,
                   recompileGrammar, [dir, "baseline_gf_grammar"])

    em_trained = pickle.load(open(baseline_path))
    if recompileGrammar or not os.path.isfile(reduct_path):
        trees = length_limit(parse_conll_corpus(train, False, limit),
                             max_length)
        trace = compute_reducts(em_trained, trees, term_labelling)
        trace.serialize(reduct_path)
    else:
        print("loading trace")
        trace = PySDCPTraceManager(em_trained, term_labelling)
        trace.load_traces_from_file(reduct_path)

    discr = False
    if discr:
        if recompileGrammar or not os.path.isfile(reduct_path_discr):
            trees = length_limit(parse_conll_corpus(train, False, limit),
                                 max_length)
            trace_discr = compute_LCFRS_reducts(
                em_trained,
                trees,
                term_labelling,
                nonterminal_map=trace.get_nonterminal_map())
            trace_discr.serialize(reduct_path_discr)
        else:
            print("loading trace discriminative")
            trace_discr = PyLCFRSTraceManager(em_trained,
                                              trace.get_nonterminal_map())
            trace_discr.load_traces_from_file(reduct_path_discr)

    n_epochs = 20
    init = "rfe"
    tie_breaking = True
    em_trained_path_ = em_trained_path(n_epochs, init, tie_breaking)

    if recompileGrammar or retrain or not os.path.isfile(em_trained_path_):
        emTrainer = PyEMTrainer(trace)
        emTrainer.em_training(em_trained,
                              n_epochs=n_epochs,
                              init=init,
                              tie_breaking=tie_breaking,
                              seed=seed)
        pickle.dump(em_trained, open(em_trained_path_, 'wb'))
    else:
        em_trained = pickle.load(open(em_trained_path_, 'rb'))

    if parsing:
        do_parsing(em_trained, test_limit, ignore_punctuation, recompileGrammar
                   or retrain, [dir, "em_trained_gf_grammar"])

    grammarInfo = PyGrammarInfo(baseline_grammar, trace.get_nonterminal_map())
    storageManager = PyStorageManager()

    builder = PySplitMergeTrainerBuilder(trace, grammarInfo)
    builder.set_em_epochs(n_epochs)
    builder.set_split_randomization(1.0, seed + 1)
    if discr:
        builder.set_discriminative_expector(trace_discr,
                                            maxScale=10,
                                            threads=1)
    else:
        builder.set_simple_expector(threads=1)
    splitMergeTrainer = builder.set_percent_merger(65.0).build()

    if (not recompileGrammar) and (
            not retrain) and os.path.isfile(sm_info_path):
        print("Loading splits and weights of LA rules")
        latentAnnotation = map(
            lambda t: build_PyLatentAnnotation(t[0], t[1], t[2], grammarInfo,
                                               storageManager),
            pickle.load(open(sm_info_path, 'rb')))
    else:
        latentAnnotation = [
            build_PyLatentAnnotation_initial(em_trained, grammarInfo,
                                             storageManager)
        ]

    max_cycles = 4
    reparse = False
    # parsing = False
    for i in range(max_cycles + 1):
        if i < len(latentAnnotation):
            if reparse:
                smGrammar = latentAnnotation[i].build_sm_grammar(
                    baseline_grammar,
                    grammarInfo,
                    rule_pruning=0.0001,
                    rule_smoothing=0.01)
                print("Cycle: ", i, "Rules: ", len(smGrammar.rules()))
                do_parsing(smGrammar, test_limit, ignore_punctuation,
                           recompileGrammar or retrain,
                           [dir, "sm_cycles" + str(i) + "_gf_grammar"])
        else:
            # setting the seed to achieve reproducibility in case of continued training
            splitMergeTrainer.reset_random_seed(seed + i + 1)
            latentAnnotation.append(
                splitMergeTrainer.split_merge_cycle(latentAnnotation[-1]))
            pickle.dump(map(lambda la: la.serialize(), latentAnnotation),
                        open(sm_info_path, 'wb'))
            smGrammar = latentAnnotation[i].build_sm_grammar(
                baseline_grammar,
                grammarInfo,
                rule_pruning=0.0001,
                rule_smoothing=0.1)
            print("Cycle: ", i, "Rules: ", len(smGrammar.rules()))
            if parsing:
                do_parsing(smGrammar, test_limit, ignore_punctuation,
                           recompileGrammar or retrain,
                           [dir, "sm_cycles" + str(i) + "_gf_grammar"])