def test_oracle(self): oracle = Oracle(self.passage) state = State(self.passage) actions_taken = [] while True: actions = oracle.get_actions(state) action = next(iter(actions)) state.transition(action) actions_taken.append("%s\n" % action) if state.finished: break with open('test_files/standard3.oracle_actions.txt') as f: self.assertSequenceEqual(actions_taken, f.readlines())
def parse(self, passages, mode="test"): """ Parse given passages :param passages: iterable of passages to parse :param mode: "train", "test" or "dev". If "train", use oracle to train on given passages. Otherwise, just parse with classifier. :return: generator of pairs of (parsed passage, original passage) """ train = mode == "train" dev = mode == "dev" test = mode == "test" assert train or dev or test, "Invalid parse mode: %s" % mode passage_word = "sentence" if Config().sentences else \ "paragraph" if Config().paragraphs else \ "passage" self.total_actions = 0 self.total_correct = 0 total_duration = 0 total_tokens = 0 num_passages = 0 for passage in passages: l0 = passage.layer(layer0.LAYER_ID) num_tokens = len(l0.all) total_tokens += num_tokens l1 = passage.layer(layer1.LAYER_ID) labeled = len(l1.all) > 1 assert not train or labeled, "Cannot train on unannotated passage" print("%s %-7s" % (passage_word, passage.ID), end=Config().line_end, flush=True) started = time.time() self.action_count = 0 self.correct_count = 0 self.state = State(passage, callback=self.pos_tag) self.state_hash_history = set() self.oracle = Oracle(passage) if train else None failed = False try: self.parse_passage( train) # This is where the actual parsing takes place except ParserException as e: if train: raise Config().log("%s %s: %s" % (passage_word, passage.ID, e)) if not test: print("failed") failed = True predicted_passage = passage if not train or Config().verify: predicted_passage = self.state.create_passage( assert_proper=Config().verify) duration = time.time() - started total_duration += duration if train: # We have an oracle to verify by if not failed and Config().verify: self.verify_passage(passage, predicted_passage, train) if self.action_count: print("%-16s" % ("%d%% (%d/%d)" % (100 * self.correct_count / self.action_count, self.correct_count, self.action_count)), end=Config().line_end) print("%0.3fs" % duration, end="") print("%-15s" % ("" if failed else " (%d tokens/s)" % (num_tokens / duration)), end="") print(Config().line_end, end="") if train: print(Config().line_end, flush=True) self.total_correct += self.correct_count self.total_actions += self.action_count num_passages += 1 yield predicted_passage, passage if num_passages > 1: print("Parsed %d %ss" % (num_passages, passage_word)) if self.oracle and self.total_actions: print("Overall %d%% correct transitions (%d/%d) on %s" % (100 * self.total_correct / self.total_actions, self.total_correct, self.total_actions, mode)) print( "Total time: %.3fs (average time/%s: %.3fs, average tokens/s: %d)" % (total_duration, passage_word, total_duration / num_passages, total_tokens / total_duration), flush=True)