def test_oracle(self): oracle = Oracle(self.passage) state = State(self.passage) actions_taken = [] while True: actions = oracle.get_actions(state) action = next(iter(actions)) state.transition(action) actions_taken.append("%s\n" % action) if state.finished: break with open("test_files/standard3.oracle_actions.txt") as f: self.assertSequenceEqual(actions_taken, f.readlines())
def test_oracle(self): for passage in self.load_passages(): oracle = Oracle(passage) state = State(passage) actions_taken = [] while True: actions = oracle.get_actions(state) action = next(iter(actions)) state.transition(action) actions_taken.append("%s\n" % action) if state.finished: break with open("test_files/standard3.oracle_actions.txt") as f: self.assertSequenceEqual(actions_taken, f.readlines())
def __init__(self, game): self.game = game State.__init__(self, game) # Set the menu self.menu_img = pygame.image.load( os.path.join(self.game.assets_dir, "map", "menu.png")) self.menu_rect = self.menu_img.get_rect() self.menu_rect.center = (self.game.GAME_W * .85, self.game.GAME_H * .4) # Set the cursor and menu states self.menu_options = {0: "Party", 1: "Items", 2: "Magic", 3: "Exit"} self.index = 0 self.cursor_img = pygame.image.load( os.path.join(self.game.assets_dir, "map", "cursor.png")) self.cursor_rect = self.cursor_img.get_rect() self.cursor_pos_y = self.menu_rect.y + 38 self.cursor_rect.x, self.cursor_rect.y = self.menu_rect.x + 10, self.cursor_pos_y
def execute(self, request_data) -> dict: # load context context = request_data.get('context', {}) # delete keys marked for deletion for key in self.properties.get('del_keys', []): context.pop(key, False) # update keys marked for update for key in self.properties.get('update_keys', []): context.update({key: State.contextualize(context, self.properties.get('update_keys', [])[key])}) request_data.update({'context': context}) # load next state request_data.update({'next_state': self.transitions.get('next_state', False)}) return request_data
def execute(self, request_data) -> dict: # load context context = request_data.get('context', {}) # delete keys marked for deletion for key in self.properties.get('del_keys', []): context.pop(key, False) # update keys marked for update for key in self.properties.get('update_keys', []): context.update({ key: State.contextualize( context, self.properties.get('update_keys', [])[key]) }) request_data.update({'context': context}) # load next state request_data.update( {'next_state': self.transitions.get('next_state', False)}) return request_data
def __init__(self, game): self.game = game State.__init__(self, game)
class Parser(object): """ Main class to implement transition-based UCCA parser """ def __init__(self, model_file=None, model_type=config.SPARSE_PERCEPTRON, beam=1): self.state = None # State object created at each parse self.oracle = None # Oracle object created at each parse self.scores = None # NumPy array of action scores at each action self.action_count = 0 self.correct_count = 0 self.total_actions = 0 self.total_correct = 0 self.model = Model(model_type, model_file, Actions().all) self.beam = beam self.learning_rate = Config().args.learningrate self.decay_factor = Config().args.decayfactor self.state_hash_history = None # For loop checking # Used in verify_passage to optionally ignore a mismatch in linkage nodes: self.ignore_node = lambda n: n.tag == layer1.NodeTags.Linkage if Config().args.nolinkage else None self.best_score = self.best_model = self.dev = self.iteration = self.batch = None def train(self, passages, dev=None, iterations=1, folds=None): """ Train parser on given passages :param passages: iterable of passages to train on :param dev: iterable of passages to tune on :param iterations: number of iterations to perform :param folds: whether we are inside cross-validation with this many folds :return: trained model """ if not passages: self.model.load() # Nothing to train on; pre-trained model given return self.best_score = 0 self.best_model = None self.dev = dev last = False if Config().args.devscores: with open(Config().args.devscores, "w") as f: print(",".join(["iteration"] + evaluation.Scores.field_titles()), file=f) for self.iteration in range(iterations): if last: break last = self.iteration == iterations - 1 self.batch = 0 print("Training iteration %d of %d: " % (self.iteration + 1, iterations)) passages = [passage for _, passage in self.parse(passages, mode="train")] if last: if folds is None: # Free some memory, as these are not needed any more del passages[:] else: self.learning_rate *= self.decay_factor Config().random.shuffle(passages) last = self.eval_dev_and_save_model(last) if self.dev and last and folds is None: # Free more memory del self.dev[:] print("Trained %d iterations" % iterations) self.model = self.best_model def eval_dev_and_save_model(self, last=False): model = self.model # Save non-finalize model self.model = self.model.finalize() # To evaluate finalized model on dev save_model = True if self.dev: print("Evaluating on dev passages") self.dev, scores = zip(*[(passage, evaluate_passage(predicted_passage, passage)) for predicted_passage, passage in self.parse(self.dev, mode="dev")]) self.dev = list(self.dev) scores = evaluation.Scores.aggregate(scores) score = scores.average_f1() print("Average labeled F1 score on dev: %.3f" % score) if Config().args.devscores: prefix = [self.iteration] if Config().args.saveeverybatch: prefix.append(self.batch) with open(Config().args.devscores, "a") as f: print(",".join([".".join(map(str, prefix))] + scores.fields()), file=f) if score >= self.best_score: print("Better than previous best score (%.3f)" % self.best_score) self.best_score = score save_model = True else: print("Not better than previous best score (%.3f)" % self.best_score) save_model = False if score >= 1: # Score cannot go any better, so no point in more training last = True if save_model or self.best_model is None: self.best_model = self.model # This is the finalized model self.best_model.save() if not last: self.model = model # Restore non-finalized model return last def parse(self, passages, mode="test"): """ Parse given passages :param passages: iterable of passages to parse :param mode: "train", "test" or "dev". If "train", use oracle to train on given passages. Otherwise, just parse with classifier. :return: generator of pairs of (parsed passage, original passage) """ train = mode == "train" dev = mode == "dev" test = mode == "test" assert train or dev or test, "Invalid parse mode: %s" % mode passage_word = "sentence" if Config().args.sentences else \ "paragraph" if Config().args.paragraphs else \ "passage" self.total_actions = 0 self.total_correct = 0 total_duration = 0 total_tokens = 0 num_passages = 0 for passage in passages: l0 = passage.layer(layer0.LAYER_ID) num_tokens = len(l0.all) l1 = passage.layer(layer1.LAYER_ID) labeled = len(l1.all) > 1 assert not train or labeled, "Cannot train on unannotated passage" print("%s %-7s" % (passage_word, passage.ID), end=Config().line_end, flush=True) started = time.time() self.action_count = 0 self.correct_count = 0 self.state = State(passage, callback=self.pos_tag) self.state_hash_history = set() self.oracle = Oracle(passage) if train else None failed = False try: self.parse_passage(train) # This is where the actual parsing takes place except ParserException as e: if train: raise Config().log("%s %s: %s" % (passage_word, passage.ID, e)) if not test: print("failed") failed = True predicted_passage = passage if not train or Config().args.verify: predicted_passage = self.state.create_passage(assert_proper=Config().args.verify) duration = time.time() - started total_duration += duration num_tokens -= len(self.state.buffer) total_tokens += num_tokens if train: # We have an oracle to verify by if not failed and Config().args.verify: self.verify_passage(passage, predicted_passage, train) if self.action_count: print("%-16s" % ("%d%% (%d/%d)" % (100 * self.correct_count / self.action_count, self.correct_count, self.action_count)), end=Config().line_end) print("%0.3fs" % duration, end="") print("%-15s" % ("" if failed else " (%d tokens/s)" % (num_tokens / duration)), end="") print(Config().line_end, end="") if train: print(Config().line_end, flush=True) self.model.finish(train=train) self.total_correct += self.correct_count self.total_actions += self.action_count num_passages += 1 if train and Config().args.saveeverybatch and num_passages % Config().args.batchsize == 0: self.eval_dev_and_save_model() self.batch += 1 yield predicted_passage, passage if num_passages > 1: print("Parsed %d %ss" % (num_passages, passage_word)) if self.oracle and self.total_actions: print("Overall %d%% correct transitions (%d/%d) on %s" % (100 * self.total_correct / self.total_actions, self.total_correct, self.total_actions, mode)) print("Total time: %.3fs (average time/%s: %.3fs, average tokens/s: %d)" % ( total_duration, passage_word, total_duration / num_passages, total_tokens / total_duration), flush=True) def parse_passage(self, train=False): """ Internal method to parse a single passage :param train: use oracle to train on given passages, or just parse with classifier? """ if Config().args.verbose: print(" initial state: %s" % self.state) while True: if Config().args.checkloops: self.check_loop(print_oracle=train) true_actions = [] if self.oracle is not None: try: true_actions = self.oracle.get_actions(self.state) except (AttributeError, AssertionError) as e: if train: raise ParserException("Error in oracle during training") from e features = self.model.extract_features(self.state) predicted_action = self.predict_action(features, true_actions) # sets self.scores action = predicted_action correct_action = False if not true_actions: true_actions = "?" elif predicted_action in true_actions: self.correct_count += 1 correct_action = True elif train: action = Config().random.choice(true_actions) if train and not (correct_action and self.model.update_only_on_error): best_true_action = true_actions[0] if len(true_actions) == 1 else \ true_actions[self.scores[[a.id for a in true_actions]].argmax()] rate = self.learning_rate if best_true_action.is_swap: rate *= Config().args.importance self.model.update(features, predicted_action.id, best_true_action.id, rate) self.model.advance() self.action_count += 1 try: self.state.transition(action) except AssertionError as e: raise ParserException("Invalid transition (%s): %s" % (action, e)) from e if Config().args.verbose: if self.oracle is None: print(" action: %-15s %s" % (action, self.state)) else: print(" predicted: %-15s true: %-15s taken: %-15s %s" % ( predicted_action, "|".join(map(str, true_actions)), action, self.state)) for line in self.state.log: print(" " + line) if self.state.finished or train and not correct_action and Config().args.earlyupdate: return # action is FINISH def check_loop(self, print_oracle): """ Check if the current state has already occurred, indicating a loop :param print_oracle: whether to print the oracle in case of an assertion error """ h = hash(self.state) assert h not in self.state_hash_history,\ "\n".join(["Transition loop", self.state.str("\n")] + [self.oracle.str("\n")] if print_oracle else ()) self.state_hash_history.add(h) def predict_action(self, features, true_actions): """ Choose action based on classifier :param features: extracted feature values :param true_actions: from the oracle, to copy orig_node if the same action is selected :return: valid action with maximum probability according to classifier """ self.scores = self.model.score(features) # Returns a NumPy array best_action = self.select_action(self.scores.argmax(), true_actions) if self.state.is_valid(best_action): return best_action # Usually the best action is valid, so max is enough to choose it in O(n) time # Otherwise, sort all the other scores to choose the best valid one in O(n lg n) sorted_ids = self.scores.argsort()[::-1] actions = (self.select_action(i, true_actions) for i in sorted_ids) try: return next(a for a in actions if self.state.is_valid(a)) except StopIteration as e: raise ParserException("No valid actions available\n" + ("True actions: %s" % true_actions if true_actions else self.oracle.log if self.oracle is not None else "") + "\nReturned actions: %s" % [self.select_action(i) for i in sorted_ids] + "\nScores: %s" % self.scores ) from e @staticmethod def select_action(i, true_actions=()): """ Find action with the given ID in true actions (if exists) or in all actions :param i: ID to lookup :param true_actions: preferred set of actions to look in first :return: Action with id=i """ try: return next(a for a in true_actions if a.id == i) except StopIteration: return Actions().all[i] def verify_passage(self, passage, predicted_passage, show_diff): """ Compare predicted passage to true passage and die if they differ :param passage: true passage :param predicted_passage: predicted passage to compare :param show_diff: if passages differ, show the difference between them? Depends on predicted_passage having the original node IDs annotated in the "remarks" field for each node. """ assert passage.equals(predicted_passage, ignore_node=self.ignore_node),\ "Failed to produce true passage" + \ (diffutil.diff_passages( passage, predicted_passage) if show_diff else "") @staticmethod def pos_tag(state): """ Function to pass to State to POS tag the tokens when created :param state: State object to modify """ tokens = [token for tokens in state.tokens for token in tokens] tokens, tags = zip(*pos_tag(tokens)) if Config().args.verbose: print(" ".join("%s/%s" % (token, tag) for (token, tag) in zip(tokens, tags))) for node, tag in zip(state.nodes, tags): node.pos_tag = tag
def parse(self, passages, mode="test"): """ Parse given passages :param passages: iterable of passages to parse :param mode: "train", "test" or "dev". If "train", use oracle to train on given passages. Otherwise, just parse with classifier. :return: generator of pairs of (parsed passage, original passage) """ train = mode == "train" dev = mode == "dev" test = mode == "test" assert train or dev or test, "Invalid parse mode: %s" % mode passage_word = "sentence" if Config().args.sentences else \ "paragraph" if Config().args.paragraphs else \ "passage" self.total_actions = 0 self.total_correct = 0 total_duration = 0 total_tokens = 0 num_passages = 0 for passage in passages: l0 = passage.layer(layer0.LAYER_ID) num_tokens = len(l0.all) l1 = passage.layer(layer1.LAYER_ID) labeled = len(l1.all) > 1 assert not train or labeled, "Cannot train on unannotated passage" print("%s %-7s" % (passage_word, passage.ID), end=Config().line_end, flush=True) started = time.time() self.action_count = 0 self.correct_count = 0 self.state = State(passage, callback=self.pos_tag) self.state_hash_history = set() self.oracle = Oracle(passage) if train else None failed = False try: self.parse_passage(train) # This is where the actual parsing takes place except ParserException as e: if train: raise Config().log("%s %s: %s" % (passage_word, passage.ID, e)) if not test: print("failed") failed = True predicted_passage = passage if not train or Config().args.verify: predicted_passage = self.state.create_passage(assert_proper=Config().args.verify) duration = time.time() - started total_duration += duration num_tokens -= len(self.state.buffer) total_tokens += num_tokens if train: # We have an oracle to verify by if not failed and Config().args.verify: self.verify_passage(passage, predicted_passage, train) if self.action_count: print("%-16s" % ("%d%% (%d/%d)" % (100 * self.correct_count / self.action_count, self.correct_count, self.action_count)), end=Config().line_end) print("%0.3fs" % duration, end="") print("%-15s" % ("" if failed else " (%d tokens/s)" % (num_tokens / duration)), end="") print(Config().line_end, end="") if train: print(Config().line_end, flush=True) self.model.finish(train=train) self.total_correct += self.correct_count self.total_actions += self.action_count num_passages += 1 if train and Config().args.saveeverybatch and num_passages % Config().args.batchsize == 0: self.eval_dev_and_save_model() self.batch += 1 yield predicted_passage, passage if num_passages > 1: print("Parsed %d %ss" % (num_passages, passage_word)) if self.oracle and self.total_actions: print("Overall %d%% correct transitions (%d/%d) on %s" % (100 * self.total_correct / self.total_actions, self.total_correct, self.total_actions, mode)) print("Total time: %.3fs (average time/%s: %.3fs, average tokens/s: %d)" % ( total_duration, passage_word, total_duration / num_passages, total_tokens / total_duration), flush=True)
def __init__(self, game): State.__init__(self, game)
def __init__(self, game): State.__init__(self, game) self.player = Player(self.game) self.grass_img = pygame.image.load( os.path.join(self.game.assets_dir, "map", "grass.png"))
class Parser(object): """ Main class to implement transition-based UCCA parser """ def __init__(self, model_file=None, model_type=None, beam=1): self.state = None # State object created at each parse self.oracle = None # Oracle object created at each parse self.scores = None # NumPy array of action scores at each action self.action_count = 0 self.correct_count = 0 self.total_actions = 0 self.total_correct = 0 self.model = Model(model_type, model_file, Actions().all) self.beam = beam # Currently unused self.state_hash_history = None # For loop checking # Used in verify_passage to optionally ignore a mismatch in linkage nodes: self.ignore_node = None if Config( ).args.linkage else lambda n: n.tag == layer1.NodeTags.Linkage self.best_score = self.dev = self.iteration = self.eval_index = None self.dev_scores = [] self.trained = False def train(self, passages=None, dev=None, iterations=1): """ Train parser on given passages :param passages: iterable of passages to train on :param dev: iterable of passages to tune on :param iterations: number of iterations to perform :return: trained model """ self.trained = True if passages: if ClassifierProperty.trainable_after_saving in self.model.model.get_classifier_properties( ): try: self.model.load() except FileNotFoundError: print("not found, starting from untrained model.") self.best_score = 0 self.dev = dev if Config().args.devscores: with open(Config().args.devscores, "w") as f: print(",".join(["iteration"] + evaluation.Scores.field_titles( Config().args.constructions)), file=f) for self.iteration in range(1, iterations + 1): self.eval_index = 0 print("Training iteration %d of %d: " % (self.iteration, iterations)) list(self.parse(passages, mode=ParseMode.train)) self.eval_and_save(self.iteration == iterations, finished_epoch=True) Config().random.shuffle(passages) print("Trained %d iterations" % iterations) if dev or not passages: self.model.load() def eval_and_save(self, last=False, finished_epoch=False): model = self.model self.model = self.model.finalize(finished_epoch=finished_epoch) if self.dev: print("Evaluating on dev passages") scores = [ s for _, s in self.parse( self.dev, mode=ParseMode.dev, evaluate=True) ] scores = evaluation.Scores.aggregate(scores) self.dev_scores.append(scores) score = scores.average_f1() print("Average labeled F1 score on dev: %.3f" % score) if Config().args.devscores: prefix = [self.iteration] if Config().args.save_every: prefix.append(self.eval_index) with open(Config().args.devscores, "a") as f: print(",".join([".".join(map(str, prefix))] + scores.fields()), file=f) if score >= self.best_score: print("Better than previous best score (%.3f)" % self.best_score) self.best_score = score self.model.save() else: print("Not better than previous best score (%.3f)" % self.best_score) elif last or Config().args.save_every is not None: self.model.save() if not last: self.model = model # Restore non-finalized model def parse(self, passages, mode=ParseMode.test, evaluate=False): """ Parse given passages :param passages: iterable of passages to parse :param mode: ParseMode value. If train, use oracle to train on given passages. Otherwise, just parse with classifier. :param evaluate: whether to evaluate parsed passages with respect to given ones. Only possible when given passages are annotated. :return: generator of parsed passages (or in train mode, the original ones), or, if evaluate=True, of pairs of (Passage, Scores). """ assert mode in ParseMode, "Invalid parse mode: %s" % mode train = (mode is ParseMode.train) if not train and not self.trained: self.train() passage_word = "sentence" if Config().args.sentences else \ "paragraph" if Config().args.paragraphs else \ "passage" self.total_actions = 0 self.total_correct = 0 total_duration = 0 total_tokens = 0 passage_index = 0 if not hasattr(passages, "__iter__"): # Single passage given passages = (passages, ) for passage_index, passage in enumerate(passages): l0 = passage.layer(layer0.LAYER_ID) num_tokens = len(l0.all) l1 = passage.layer(layer1.LAYER_ID) labeled = len(l1.all) > 1 assert not train or labeled, "Cannot train on unannotated passage: %s" % passage.ID assert not evaluate or labeled, "Cannot evaluate on unannotated passage: %s" % passage.ID print("%s %-7s" % (passage_word, passage.ID), end=Config().line_end, flush=True) started = time.time() self.action_count = 0 self.correct_count = 0 textutil.annotate(passage, verbose=Config().args.verbose ) # tag POS and parse dependencies self.state = State(passage) self.state_hash_history = set() self.oracle = Oracle(passage) if train else None failed = False if ClassifierProperty.require_init_features in self.model.model.get_classifier_properties( ): self.model.init_features(self.state, train) try: self.parse_passage( train) # This is where the actual parsing takes place except ParserException as e: if train: raise Config().log("%s %s: %s" % (passage_word, passage.ID, e)) failed = True predicted_passage = self.state.create_passage(assert_proper=Config().args.verify) \ if not train or Config().args.verify else passage duration = time.time() - started total_duration += duration num_tokens -= len(self.state.buffer) total_tokens += num_tokens if train: # We have an oracle to verify by if not failed and Config().args.verify: self.verify_passage(passage, predicted_passage, train) if self.action_count: print("%-16s" % ("%d%% (%d/%d)" % (100 * self.correct_count / self.action_count, self.correct_count, self.action_count)), end=Config().line_end) print("%0.3fs" % duration, end="") print("%-15s" % (" (failed)" if failed else " (%d tokens/s)" % (num_tokens / duration)), end="") print(Config().line_end, end="") if train: print(Config().line_end, flush=True) self.model.model.finished_item(train) self.total_correct += self.correct_count self.total_actions += self.action_count if train and Config().args.save_every and ( passage_index + 1) % Config().args.save_every == 0: self.eval_and_save() self.eval_index += 1 yield (predicted_passage, evaluate_passage( predicted_passage, passage)) if evaluate else predicted_passage if passages: print("Parsed %d %ss" % (passage_index + 1, passage_word)) if self.oracle and self.total_actions: print("Overall %d%% correct transitions (%d/%d) on %s" % (100 * self.total_correct / self.total_actions, self.total_correct, self.total_actions, mode.name)) print( "Total time: %.3fs (average time/%s: %.3fs, average tokens/s: %d)" % (total_duration, passage_word, total_duration / (passage_index + 1), total_tokens / total_duration), flush=True) def parse_passage(self, train): """ Internal method to parse a single passage :param train: use oracle to train on given passages, or just parse with classifier? """ if Config().args.verbose: print(" initial state: %s" % self.state) while True: if Config().args.check_loops: self.check_loop(print_oracle=train) true_actions = [] if self.oracle is not None: try: true_actions = self.oracle.get_actions(self.state) except (AttributeError, AssertionError) as e: if train: raise ParserException( "Error in oracle during training") from e features = self.model.feature_extractor.extract_features( self.state) predicted_action = self.predict_action( features, true_actions) # sets self.scores action = predicted_action correct_action = False if not true_actions: true_actions = "?" elif predicted_action in true_actions: self.correct_count += 1 correct_action = True elif train: action = Config().random.choice(true_actions) if train and not (correct_action and ClassifierProperty.update_only_on_error in self.model.model.get_classifier_properties()): best_true_action = true_actions[0] if len(true_actions) == 1 else \ true_actions[self.scores[[a.id for a in true_actions]].argmax()] self.model.model.update( features, predicted_action.id, best_true_action.id, Config().args.swap_importance if best_true_action.is_swap else 1) self.action_count += 1 self.model.model.finished_step(train) try: self.state.transition(action) except AssertionError as e: raise ParserException("Invalid transition (%s): %s" % (action, e)) from e if Config().args.verbose: if self.oracle is None: print(" action: %-15s %s" % (action, self.state)) else: print(" predicted: %-15s true: %-15s taken: %-15s %s" % (predicted_action, "|".join(map( str, true_actions)), action, self.state)) for line in self.state.log: print(" " + line) if self.state.finished or train and not correct_action and Config( ).args.early_update: return # action is FINISH def check_loop(self, print_oracle): """ Check if the current state has already occurred, indicating a loop :param print_oracle: whether to print the oracle in case of an assertion error """ h = hash(self.state) assert h not in self.state_hash_history,\ "\n".join(["Transition loop", self.state.str("\n")] + [self.oracle.str("\n")] if print_oracle else ()) self.state_hash_history.add(h) def predict_action(self, features, true_actions): """ Choose action based on classifier :param features: extracted feature values :param true_actions: from the oracle, to copy orig_node if the same action is selected :return: valid action with maximum probability according to classifier """ self.scores = self.model.model.score(features) # Returns a NumPy array if Config().args.verbose >= 2: print(" scores: " + " ".join(("%g" % s for s in self.scores))) best_action = self.select_action(self.scores.argmax(), true_actions) if self.state.is_valid(best_action): return best_action # Usually the best action is valid, so max is enough to choose it in O(n) time # Otherwise, sort all the other scores to choose the best valid one in O(n lg n) sorted_ids = self.scores.argsort()[::-1] actions = (self.select_action(i, true_actions) for i in sorted_ids) try: return next(a for a in actions if self.state.is_valid(a)) except StopIteration as e: raise ParserException( "No valid actions available\n" + ("True actions: %s" % true_actions if true_actions else self. oracle.log if self.oracle is not None else "") + "\nReturned actions: %s" % [self.select_action(i) for i in sorted_ids] + "\nScores: %s" % self.scores) from e @staticmethod def select_action(i, true_actions=()): """ Find action with the given ID in true actions (if exists) or in all actions :param i: ID to lookup :param true_actions: preferred set of actions to look in first :return: Action with id=i """ try: return next(a for a in true_actions if a.id == i) except StopIteration: return Actions().all[i] def verify_passage(self, passage, predicted_passage, show_diff): """ Compare predicted passage to true passage and die if they differ :param passage: true passage :param predicted_passage: predicted passage to compare :param show_diff: if passages differ, show the difference between them? Depends on predicted_passage having the original node IDs annotated in the "remarks" field for each node. """ assert passage.equals(predicted_passage, ignore_node=self.ignore_node),\ "Failed to produce true passage" + \ (diffutil.diff_passages( passage, predicted_passage) if show_diff else "")
def parse(self, passages, mode=ParseMode.test, evaluate=False): """ Parse given passages :param passages: iterable of passages to parse :param mode: ParseMode value. If train, use oracle to train on given passages. Otherwise, just parse with classifier. :param evaluate: whether to evaluate parsed passages with respect to given ones. Only possible when given passages are annotated. :return: generator of parsed passages (or in train mode, the original ones), or, if evaluate=True, of pairs of (Passage, Scores). """ assert mode in ParseMode, "Invalid parse mode: %s" % mode train = (mode is ParseMode.train) if not train and not self.trained: self.train() passage_word = "sentence" if Config().args.sentences else \ "paragraph" if Config().args.paragraphs else \ "passage" self.total_actions = 0 self.total_correct = 0 total_duration = 0 total_tokens = 0 passage_index = 0 if not hasattr(passages, "__iter__"): # Single passage given passages = (passages, ) for passage_index, passage in enumerate(passages): l0 = passage.layer(layer0.LAYER_ID) num_tokens = len(l0.all) l1 = passage.layer(layer1.LAYER_ID) labeled = len(l1.all) > 1 assert not train or labeled, "Cannot train on unannotated passage: %s" % passage.ID assert not evaluate or labeled, "Cannot evaluate on unannotated passage: %s" % passage.ID print("%s %-7s" % (passage_word, passage.ID), end=Config().line_end, flush=True) started = time.time() self.action_count = 0 self.correct_count = 0 textutil.annotate(passage, verbose=Config().args.verbose ) # tag POS and parse dependencies self.state = State(passage) self.state_hash_history = set() self.oracle = Oracle(passage) if train else None failed = False if ClassifierProperty.require_init_features in self.model.model.get_classifier_properties( ): self.model.init_features(self.state, train) try: self.parse_passage( train) # This is where the actual parsing takes place except ParserException as e: if train: raise Config().log("%s %s: %s" % (passage_word, passage.ID, e)) failed = True predicted_passage = self.state.create_passage(assert_proper=Config().args.verify) \ if not train or Config().args.verify else passage duration = time.time() - started total_duration += duration num_tokens -= len(self.state.buffer) total_tokens += num_tokens if train: # We have an oracle to verify by if not failed and Config().args.verify: self.verify_passage(passage, predicted_passage, train) if self.action_count: print("%-16s" % ("%d%% (%d/%d)" % (100 * self.correct_count / self.action_count, self.correct_count, self.action_count)), end=Config().line_end) print("%0.3fs" % duration, end="") print("%-15s" % (" (failed)" if failed else " (%d tokens/s)" % (num_tokens / duration)), end="") print(Config().line_end, end="") if train: print(Config().line_end, flush=True) self.model.model.finished_item(train) self.total_correct += self.correct_count self.total_actions += self.action_count if train and Config().args.save_every and ( passage_index + 1) % Config().args.save_every == 0: self.eval_and_save() self.eval_index += 1 yield (predicted_passage, evaluate_passage( predicted_passage, passage)) if evaluate else predicted_passage if passages: print("Parsed %d %ss" % (passage_index + 1, passage_word)) if self.oracle and self.total_actions: print("Overall %d%% correct transitions (%d/%d) on %s" % (100 * self.total_correct / self.total_actions, self.total_correct, self.total_actions, mode.name)) print( "Total time: %.3fs (average time/%s: %.3fs, average tokens/s: %d)" % (total_duration, passage_word, total_duration / (passage_index + 1), total_tokens / total_duration), flush=True)