Exemple #1
0
def _test_features(config, feature_extractor_creator, filename, write_features):
    feature_extractor = feature_extractor_creator(config)
    passage = load_passage(filename, annotate=feature_extractor_creator.annotated)
    textutil.annotate(passage, as_array=True, as_extra=False, vocab=config.vocab())
    config.set_format(passage.extra.get("format") or "ucca")
    oracle = Oracle(passage)
    state = State(passage)
    actions = Actions()
    for key, param in feature_extractor.params.items():
        if not param.numeric:
            param.dropout = 0
            feature_extractor.init_param(key)
    features = [feature_extractor.init_features(state)]
    while True:
        extract_features(feature_extractor, state, features)
        action = min(oracle.get_actions(state, actions).values(), key=str)
        state.transition(action)
        if state.need_label:
            extract_features(feature_extractor, state, features)
            label, _ = oracle.get_label(state, action)
            state.label_node(label)
        if state.finished:
            break
    features = ["%s %s\n" % i for f in features if f for i in (sorted(f.items()) + [("", "")])]
    compare_file = os.path.join("test_files", "features", "-".join((basename(filename), str(feature_extractor_creator)))
                                + ".txt")
    if write_features:
        with open(compare_file, "w", encoding="utf-8") as f:
            f.writelines(features)
    with open(compare_file, encoding="utf-8") as f:
        assert f.readlines() == features, compare_file
Exemple #2
0
 def init(self):
     self.config.set_format(self.in_format)
     WIKIFIER.enabled = self.config.args.wikification
     self.state = State(self.passage)
     # Passage is considered labeled if there are any edges or node labels in it
     edges, node_labels = map(any, zip(*[(n.outgoing, n.attrib.get(LABEL_ATTRIB))
                                         for n in self.passage.layer(layer1.LAYER_ID).all]))
     self.oracle = Oracle(self.passage) if self.training or self.config.args.verify or (
             (self.config.args.verbose > 1 or self.config.args.use_gold_node_labels or self.config.args.action_stats)
             and (edges or node_labels)) else None
     for model in self.models:
         model.init_model(self.config.format, lang=self.lang if self.config.args.multilingual else None)
         if ClassifierProperty.require_init_features in model.classifier_properties:
             model.init_features(self.state, self.training)
Exemple #3
0
 def test_oracle(self):
     for passage in self.load_passages():
         oracle = Oracle(passage)
         state = State(passage)
         actions_taken = []
         while True:
             actions = oracle.get_actions(state)
             action = next(iter(actions))
             state.transition(action)
             actions_taken.append("%s\n" % action)
             if state.finished:
                 break
         with open("test_files/standard3.oracle_actions.txt") as f:
             self.assertSequenceEqual(actions_taken, f.readlines())
Exemple #4
0
def gen_actions(passage):
    oracle = Oracle(passage)
    state = State(passage)
    actions = Actions()
    while True:
        action = min(oracle.get_actions(state, actions).values(), key=str)
        state.transition(action)
        s = str(action)
        if state.need_label:
            label, _ = oracle.get_label(state, action)
            state.label_node(label)
            s += " " + str(label)
        yield s
        if state.finished:
            break
Exemple #5
0
 def test_oracle(self):
     for passage in self.load_passages():
         oracle = Oracle(passage)
         state = State(passage)
         actions = Actions()
         actions_taken = []
         while True:
             action = min(oracle.get_actions(state, actions).values(), key=str)
             state.transition(action)
             actions_taken.append("%s\n" % action)
             if state.finished:
                 break
         # with open("test_files/standard3.oracle_actions.txt", "w") as f:
         #     f.writelines(actions_taken)
         with open("test_files/standard3.oracle_actions.txt") as f:
             self.assertSequenceEqual(actions_taken, f.readlines())
def gen_actions(passage, feature_extractor):
    global envTrainingData, allLabels, allTypes, allActions
    oracle = Oracle(passage)
    state = State(passage)
    actions = Actions()
    while True:
        acts = oracle.get_actions(state, actions).values()
        type_label_maps = {
            a.type: a.tag
            for a in acts
        }  # There should be no duplicate types with different tags since there is only one golden tree
        obs = feature_extractor.extract_features(state)['numeric']
        for index in [7, 9, 11, 14, 15, 16, 17, 17, 18, 22]:
            del obs[index]
        for act in allActions:
            cur_type = act['type']
            cur_has_label = act['hasLabel']
            cur_label = act['label']
            # TODO: Double consider the reward mechanism.
            # Encourage the agent to produce less mistake VS encourage it to produce more correctness:
            # The latter will encourage an episode to go on endlessly, while the former encourage it to end as soon as possible.
            # For now, choose to be neutral: 100% correct = 0.5 reward; 100% wrong = -0.5 reward
            r = -0.5
            if cur_type in list(
                    type_label_maps.keys()):  # If action type matches
                r += 0.5
                if cur_has_label and cur_label == type_label_maps[
                        cur_type] or not cur_has_label:  # If action has no label or label matches
                    r += 0.5
            actNum = allActions.index(act)
            trainingData = {'obs': obs, 'act': actNum, 'r': r}
            envTrainingData.append(trainingData)
        action = min(acts, key=str)
        state.transition(action)
        s = str(action)
        # if state.need_label:
        #     label, _ = oracle.get_label(state, action)
        #     state.label_node(label)
        #     s += " " + str(label)
        yield s
        if state.finished:
            break
Exemple #7
0
    def parse(self, passages, mode=ParseMode.test, evaluate=False):
        """
        Parse given passages
        :param passages: iterable of passages to parse
        :param mode: ParseMode value.
                     If train, use oracle to train on given passages.
                     Otherwise, just parse with classifier.
        :param evaluate: whether to evaluate parsed passages with respect to given ones.
                         Only possible when given passages are annotated.
        :return: generator of parsed passages (or in train mode, the original ones),
                 or, if evaluate=True, of pairs of (Passage, Scores).
        """
        assert mode in ParseMode, "Invalid parse mode: %s" % mode
        train = (mode is ParseMode.train)
        if not train and not self.trained:
            list(self.train())
        passage_word = "sentence" if self.args.sentences else \
                       "paragraph" if self.args.paragraphs else \
                       "passage"
        self.total_actions = 0
        self.total_correct_actions = 0
        total_duration = 0
        total_tokens = 0
        passage_index = 0
        if not hasattr(passages, "__iter__"):  # Single passage given
            passages = (passages, )
        for passage_index, passage in enumerate(passages):
            labeled = any(n.outgoing or n.attrib.get(LABEL_ATTRIB)
                          for n in passage.layer(layer1.LAYER_ID).all)
            assert not train or labeled, "Cannot train on unannotated passage: %s" % passage.ID
            assert not evaluate or labeled, "Cannot evaluate on unannotated passage: %s" % passage.ID
            print("%s %-7s" % (passage_word, passage.ID),
                  end=Config().line_end,
                  flush=True)
            started = time.time()
            self.action_count = self.correct_action_count = self.label_count = self.correct_label_count = 0
            textutil.annotate(passage, verbose=self.args.verbose >
                              1)  # tag POS and parse dependencies
            self.state = State(passage)
            self.state_hash_history = set()
            self.oracle = Oracle(passage) if train or (
                self.args.verbose or Config().args.use_gold_node_labels
            ) and labeled or self.args.verify else None
            failed = False
            if ClassifierProperty.require_init_features in self.model.model.get_classifier_properties(
            ):
                self.model.init_features(self.state, train)
            try:
                self.parse_passage(
                    train)  # This is where the actual parsing takes place
            except ParserException as e:
                if train:
                    raise
                Config().log("%s %s: %s" % (passage_word, passage.ID, e))
                failed = True
            guessed = self.state.create_passage(
                verify=self.args.verify
            ) if not train or self.args.verify else passage
            duration = time.time() - started
            total_duration += duration
            num_tokens = len(
                set(self.state.terminals).difference(self.state.buffer))
            total_tokens += num_tokens
            if self.oracle:  # We have an oracle to verify by
                if not failed and self.args.verify:
                    self.verify_passage(guessed, passage, train)
                if self.action_count:
                    accuracy_str = "%d%% (%d/%d)" % (
                        100 * self.correct_action_count / self.action_count,
                        self.correct_action_count, self.action_count)
                    if self.label_count:
                        accuracy_str += " %d%% (%d/%d)" % (
                            100 * self.correct_label_count / self.label_count,
                            self.correct_label_count, self.label_count)
                    print("%-30s" % accuracy_str, end=Config().line_end)
            print("%0.3fs" % duration, end="")
            print("%-15s" % (" (failed)" if failed else " (%d tokens/s)" %
                             (num_tokens / duration)),
                  end="")
            print(Config().line_end, end="")
            if self.oracle:
                print(Config().line_end, flush=True)
            self.model.model.finished_item(train)
            self.total_correct_actions += self.correct_action_count
            self.total_actions += self.action_count
            self.total_correct_labels += self.correct_label_count
            self.total_labels += self.label_count
            if train and self.args.save_every and (
                    passage_index + 1) % self.args.save_every == 0:
                self.eval_and_save()
                self.eval_index += 1
            yield (guessed, self.evaluate_passage(
                guessed, passage)) if evaluate else guessed

        if passages:
            print("Parsed %d %ss" % (passage_index + 1, passage_word))
            if self.oracle and self.total_actions:
                accuracy_str = "%d%% correct actions (%d/%d)" % (
                    100 * self.total_correct_actions / self.total_actions,
                    self.total_correct_actions, self.total_actions)
                if self.total_labels:
                    accuracy_str += ", %d%% correct labels (%d/%d)" % (
                        100 * self.total_correct_labels / self.total_labels,
                        self.total_correct_labels, self.total_labels)
                print("Overall %s on %s" % (accuracy_str, mode.name))
            if total_duration:
                print(
                    "Total time: %.3fs (average time/%s: %.3fs, average tokens/s: %d)"
                    % (total_duration, passage_word, total_duration /
                       (passage_index + 1), total_tokens / total_duration),
                    flush=True)