コード例 #1
0
 def test_oracle(self):
     oracle = Oracle(self.passage)
     state = State(self.passage)
     actions_taken = []
     while True:
         actions = oracle.get_actions(state)
         action = next(iter(actions))
         state.transition(action)
         actions_taken.append("%s\n" % action)
         if state.finished:
             break
     with open('test_files/standard3.oracle_actions.txt') as f:
         self.assertSequenceEqual(actions_taken, f.readlines())
コード例 #2
0
ファイル: parse.py プロジェクト: viksit/ucca
    def parse(self, passages, mode="test"):
        """
        Parse given passages
        :param passages: iterable of passages to parse
        :param mode: "train", "test" or "dev".
                     If "train", use oracle to train on given passages.
                     Otherwise, just parse with classifier.
        :return: generator of pairs of (parsed passage, original passage)
        """
        train = mode == "train"
        dev = mode == "dev"
        test = mode == "test"
        assert train or dev or test, "Invalid parse mode: %s" % mode
        passage_word = "sentence" if Config().sentences else \
                       "paragraph" if Config().paragraphs else \
                       "passage"
        self.total_actions = 0
        self.total_correct = 0
        total_duration = 0
        total_tokens = 0
        num_passages = 0
        for passage in passages:
            l0 = passage.layer(layer0.LAYER_ID)
            num_tokens = len(l0.all)
            total_tokens += num_tokens
            l1 = passage.layer(layer1.LAYER_ID)
            labeled = len(l1.all) > 1
            assert not train or labeled, "Cannot train on unannotated passage"
            print("%s %-7s" % (passage_word, passage.ID),
                  end=Config().line_end,
                  flush=True)
            started = time.time()
            self.action_count = 0
            self.correct_count = 0
            self.state = State(passage, callback=self.pos_tag)
            self.state_hash_history = set()
            self.oracle = Oracle(passage) if train else None
            failed = False
            try:
                self.parse_passage(
                    train)  # This is where the actual parsing takes place
            except ParserException as e:
                if train:
                    raise
                Config().log("%s %s: %s" % (passage_word, passage.ID, e))
                if not test:
                    print("failed")
                failed = True
            predicted_passage = passage
            if not train or Config().verify:
                predicted_passage = self.state.create_passage(
                    assert_proper=Config().verify)
            duration = time.time() - started
            total_duration += duration
            if train:  # We have an oracle to verify by
                if not failed and Config().verify:
                    self.verify_passage(passage, predicted_passage, train)
                if self.action_count:
                    print("%-16s" %
                          ("%d%% (%d/%d)" %
                           (100 * self.correct_count / self.action_count,
                            self.correct_count, self.action_count)),
                          end=Config().line_end)
            print("%0.3fs" % duration, end="")
            print("%-15s" % ("" if failed else " (%d tokens/s)" %
                             (num_tokens / duration)),
                  end="")
            print(Config().line_end, end="")
            if train:
                print(Config().line_end, flush=True)
            self.total_correct += self.correct_count
            self.total_actions += self.action_count
            num_passages += 1
            yield predicted_passage, passage

        if num_passages > 1:
            print("Parsed %d %ss" % (num_passages, passage_word))
            if self.oracle and self.total_actions:
                print("Overall %d%% correct transitions (%d/%d) on %s" %
                      (100 * self.total_correct / self.total_actions,
                       self.total_correct, self.total_actions, mode))
            print(
                "Total time: %.3fs (average time/%s: %.3fs, average tokens/s: %d)"
                % (total_duration, passage_word, total_duration / num_passages,
                   total_tokens / total_duration),
                flush=True)