def __init__(self, model_file=None, model_type=None, beam=1): self.args = Config().args self.state = None # State object created at each parse self.oracle = None # Oracle object created at each parse self.action_count = self.correct_action_count = self.total_actions = self.total_correct_actions = 0 self.label_count = self.correct_label_count = self.total_labels = self.total_correct_labels = 0 self.model = Model(model_type, model_file) self.update_only_on_error = \ ClassifierProperty.update_only_on_error in self.model.model.get_classifier_properties() self.beam = beam # Currently unused self.state_hash_history = None # For loop checking # Used in verify_passage to optionally ignore a mismatch in linkage nodes: self.ignore_node = None if self.args.linkage else lambda n: n.tag == layer1.NodeTags.Linkage self.best_score = self.dev = self.iteration = self.eval_index = None self.trained = False
def __call__(self, config): config.args.vocab = self.vocab config.args.word_vectors = self.wordvectors config.args.omit_features = self.omit return SparseFeatureExtractor(omit_features=self.omit) if self.name == SPARSE else DenseFeatureExtractor( OrderedDict((p.name, p.create_from_config()) for p in Model(None, config=config).param_defs()), indexed=self.indexed, node_dropout=0, omit_features=self.omit)
def __init__(self, model_file=None, model_type=None, beam=1): self.state = None # State object created at each parse self.oracle = None # Oracle object created at each parse self.scores = None # NumPy array of action scores at each action self.action_count = 0 self.correct_count = 0 self.total_actions = 0 self.total_correct = 0 self.model = Model(model_type, model_file, Actions().all) self.beam = beam # Currently unused self.state_hash_history = None # For loop checking # Used in verify_passage to optionally ignore a mismatch in linkage nodes: self.ignore_node = None if Config( ).args.linkage else lambda n: n.tag == layer1.NodeTags.Linkage self.best_score = self.dev = self.iteration = self.eval_index = None self.dev_scores = [] self.trained = False
def test_model(model_type, formats, test_passage, iterations, omit_features, config): filename = "_".join(filter(None, [os.path.join("test_files", "models", "test"), model_type, omit_features]+formats)) remove_existing(filename) config.update(dict(classifier=model_type, copy_shared=None, omit_features=omit_features)) finalized = model = Model(filename, config=config) for i in range(iterations): parse(formats, model, test_passage, train=True) finalized = model.finalize(finished_epoch=True) assert not getattr(finalized.feature_extractor, "node_dropout", 0), finalized.feature_extractor.node_dropout parse(formats, model, test_passage, train=False) finalized.save() loaded = Model(filename, config=config) loaded.load() assert not getattr(loaded.feature_extractor, "node_dropout", 0), loaded.feature_extractor.node_dropout for key, param in sorted(model.feature_extractor.params.items()): loaded_param = loaded.feature_extractor.params[key] assert param == loaded_param assert_all_params_equal(finalized.all_params(), loaded.all_params(), decay=weight_decay(model))
def test_model(model_type, formats, test_passage, iterations, omit_features, config): filename = "_".join( filter(None, [ os.path.join("test_files", "models", "test"), model_type, omit_features ] + formats)) remove_existing(filename) config.update( dict(classifier=model_type, copy_shared=None, omit_features=omit_features)) finalized = model = Model(filename, config=config) for i in range(iterations): parse(formats, model, test_passage, train=True) finalized = model.finalize(finished_epoch=True) assert not getattr(finalized.feature_extractor, "node_dropout", 0), finalized.feature_extractor.node_dropout parse(formats, model, test_passage, train=False) finalized.save() loaded = Model(filename, config=config) loaded.load() assert not getattr(loaded.feature_extractor, "node_dropout", 0), loaded.feature_extractor.node_dropout for key, param in sorted(model.feature_extractor.params.items()): loaded_param = loaded.feature_extractor.params[key] assert param == loaded_param assert_all_params_equal(finalized.all_params(), loaded.all_params(), decay=weight_decay(model))
def load_model(filename): model = Model(filename=filename) model.load() return model
class Parser(object): """ Main class to implement transition-based UCCA parser """ def __init__(self, model_file=None, model_type=None, beam=1): self.args = Config().args self.state = None # State object created at each parse self.oracle = None # Oracle object created at each parse self.action_count = self.correct_action_count = self.total_actions = self.total_correct_actions = 0 self.label_count = self.correct_label_count = self.total_labels = self.total_correct_labels = 0 self.model = Model(model_type, model_file) self.update_only_on_error = \ ClassifierProperty.update_only_on_error in self.model.model.get_classifier_properties() self.beam = beam # Currently unused self.state_hash_history = None # For loop checking # Used in verify_passage to optionally ignore a mismatch in linkage nodes: self.ignore_node = None if self.args.linkage else lambda n: n.tag == layer1.NodeTags.Linkage self.best_score = self.dev = self.iteration = self.eval_index = None self.trained = False def train(self, passages=None, dev=None, iterations=1): """ Train parser on given passages :param passages: iterable of passages to train on :param dev: iterable of passages to tune on :param iterations: number of iterations to perform """ self.trained = True if passages: if ClassifierProperty.trainable_after_saving in self.model.model.get_classifier_properties( ): try: self.model.load() except FileNotFoundError: print("not found, starting from untrained model.") self.best_score = 0 self.dev = dev for self.iteration in range(1, iterations + 1): self.eval_index = 0 print("Training iteration %d of %d: " % (self.iteration, iterations)) Config().random.shuffle(passages) list(self.parse(passages, mode=ParseMode.train)) yield self.eval_and_save(self.iteration == iterations, finished_epoch=True) print("Trained %d iterations" % iterations) if dev or not passages: self.model.load() def eval_and_save(self, last=False, finished_epoch=False): scores = None model = self.model self.model = self.model.finalize(finished_epoch=finished_epoch) if self.dev: if not self.best_score: self.model.save() print("Evaluating on dev passages") passage_scores = [ s for _, s in self.parse( self.dev, mode=ParseMode.dev, evaluate=True) ] scores = Scores(passage_scores) average_score = scores.average_f1() print("Average labeled F1 score on dev: %.3f" % average_score) print_scores(scores, self.args.devscores, prefix_title="iteration", prefix=[self.iteration] + ([self.eval_index] if self.args.save_every else [])) if average_score >= self.best_score: print("Better than previous best score (%.3f)" % self.best_score) if self.best_score: self.model.save() self.best_score = average_score else: print("Not better than previous best score (%.3f)" % self.best_score) elif last or self.args.save_every is not None: self.model.save() if not last: self.model = model # Restore non-finalized model self.model.load_labels() return scores def parse(self, passages, mode=ParseMode.test, evaluate=False): """ Parse given passages :param passages: iterable of passages to parse :param mode: ParseMode value. If train, use oracle to train on given passages. Otherwise, just parse with classifier. :param evaluate: whether to evaluate parsed passages with respect to given ones. Only possible when given passages are annotated. :return: generator of parsed passages (or in train mode, the original ones), or, if evaluate=True, of pairs of (Passage, Scores). """ assert mode in ParseMode, "Invalid parse mode: %s" % mode train = (mode is ParseMode.train) if not train and not self.trained: list(self.train()) passage_word = "sentence" if self.args.sentences else \ "paragraph" if self.args.paragraphs else \ "passage" self.total_actions = 0 self.total_correct_actions = 0 total_duration = 0 total_tokens = 0 passage_index = 0 if not hasattr(passages, "__iter__"): # Single passage given passages = (passages, ) for passage_index, passage in enumerate(passages): labeled = any(n.outgoing or n.attrib.get(LABEL_ATTRIB) for n in passage.layer(layer1.LAYER_ID).all) assert not train or labeled, "Cannot train on unannotated passage: %s" % passage.ID assert not evaluate or labeled, "Cannot evaluate on unannotated passage: %s" % passage.ID print("%s %-7s" % (passage_word, passage.ID), end=Config().line_end, flush=True) started = time.time() self.action_count = self.correct_action_count = self.label_count = self.correct_label_count = 0 textutil.annotate(passage, verbose=self.args.verbose > 1) # tag POS and parse dependencies self.state = State(passage) self.state_hash_history = set() self.oracle = Oracle(passage) if train or ( self.args.verbose or Config().args.use_gold_node_labels ) and labeled or self.args.verify else None failed = False if ClassifierProperty.require_init_features in self.model.model.get_classifier_properties( ): self.model.init_features(self.state, train) try: self.parse_passage( train) # This is where the actual parsing takes place except ParserException as e: if train: raise Config().log("%s %s: %s" % (passage_word, passage.ID, e)) failed = True guessed = self.state.create_passage( verify=self.args.verify ) if not train or self.args.verify else passage duration = time.time() - started total_duration += duration num_tokens = len( set(self.state.terminals).difference(self.state.buffer)) total_tokens += num_tokens if self.oracle: # We have an oracle to verify by if not failed and self.args.verify: self.verify_passage(guessed, passage, train) if self.action_count: accuracy_str = "%d%% (%d/%d)" % ( 100 * self.correct_action_count / self.action_count, self.correct_action_count, self.action_count) if self.label_count: accuracy_str += " %d%% (%d/%d)" % ( 100 * self.correct_label_count / self.label_count, self.correct_label_count, self.label_count) print("%-30s" % accuracy_str, end=Config().line_end) print("%0.3fs" % duration, end="") print("%-15s" % (" (failed)" if failed else " (%d tokens/s)" % (num_tokens / duration)), end="") print(Config().line_end, end="") if self.oracle: print(Config().line_end, flush=True) self.model.model.finished_item(train) self.total_correct_actions += self.correct_action_count self.total_actions += self.action_count self.total_correct_labels += self.correct_label_count self.total_labels += self.label_count if train and self.args.save_every and ( passage_index + 1) % self.args.save_every == 0: self.eval_and_save() self.eval_index += 1 yield (guessed, self.evaluate_passage( guessed, passage)) if evaluate else guessed if passages: print("Parsed %d %ss" % (passage_index + 1, passage_word)) if self.oracle and self.total_actions: accuracy_str = "%d%% correct actions (%d/%d)" % ( 100 * self.total_correct_actions / self.total_actions, self.total_correct_actions, self.total_actions) if self.total_labels: accuracy_str += ", %d%% correct labels (%d/%d)" % ( 100 * self.total_correct_labels / self.total_labels, self.total_correct_labels, self.total_labels) print("Overall %s on %s" % (accuracy_str, mode.name)) if total_duration: print( "Total time: %.3fs (average time/%s: %.3fs, average tokens/s: %d)" % (total_duration, passage_word, total_duration / (passage_index + 1), total_tokens / total_duration), flush=True) def parse_passage(self, train): """ Internal method to parse a single passage :param train: use oracle to train on given passages, or just parse with classifier? """ if self.args.verbose > 1: print(" initial state: %s" % self.state) while True: if self.args.check_loops: self.check_loop() features = self.model.feature_extractor.extract_features( self.state) true_actions = self.get_true_actions(train) action, predicted_action = self.choose_action( features, train, true_actions) try: self.state.transition(action) except AssertionError as e: raise ParserException("Invalid transition: %s %s" % (action, self.state)) from e if self.args.verbose > 1: if self.oracle: print(" predicted: %-15s true: %-15s taken: %-15s %s" % (predicted_action, "|".join( map(str, true_actions.values())), action, self.state)) else: print(" action: %-15s %s" % (action, self.state)) if self.state.need_label: # Label action that requires a choice of label true_label = self.get_true_label(action.orig_node) label, predicted_label = self.choose_label( features, train, true_label) self.state.label_node(label) if self.args.verbose > 1: if self.oracle and not Config().args.use_gold_node_labels: print(" predicted label: %-15s true label: %-15s" % (predicted_label, true_label)) else: print(" label: %-15s" % label) self.model.model.finished_step(train) if self.args.verbose > 1: for line in self.state.log: print(" " + line) if self.state.finished: return # action is Finish (or early update is triggered) def get_true_actions(self, train): true_actions = {} if self.oracle: try: true_actions = self.oracle.get_actions(self.state, self.model.actions, create=train) except (AttributeError, AssertionError) as e: if train: raise ParserException( "Error in oracle during training") from e return true_actions def choose_action(self, features, train, true_actions): scores = self.model.model.score( features, axis=ACTION_AXIS) # Returns NumPy array if self.args.verbose > 2: print(" action scores: " + ",".join(("%s: %g" % x for x in zip(self.model.actions.all, scores)))) try: predicted_action = self.predict(scores, self.model.actions.all, self.state.is_valid_action) except StopIteration as e: raise ParserException( "No valid action available\n%s" % (self.oracle.log if self.oracle else "")) from e action = true_actions.get(predicted_action.id) is_correct = (action is not None) if is_correct: self.correct_action_count += 1 else: action = Config().random.choice(list( true_actions.values())) if train else predicted_action if train and not (is_correct and self.update_only_on_error): best_action = self.predict(scores[list(true_actions.keys())], list(true_actions.values())) self.model.model.update(features, axis=ACTION_AXIS, pred=predicted_action.id, true=best_action.id, importance=self.args.swap_importance if best_action.is_swap else 1) if train and not is_correct and self.args.early_update: self.state.finished = True self.action_count += 1 return action, predicted_action def get_true_label(self, node): true_label = None if self.oracle: if node is not None: true_label = node.attrib.get(LABEL_ATTRIB) if true_label is not None: true_label, _, _ = true_label.partition(LABEL_SEPARATOR) if not self.state.is_valid_label(true_label): raise ParserException("True label is invalid: %s %s" % (true_label, self.state)) return true_label def choose_label(self, features, train, true_label): true_id = self.model.labels[ true_label] if self.oracle else None # Needs to happen before score() if Config().args.use_gold_node_labels: return true_label, true_label scores = self.model.model.score(features, axis=LABEL_AXIS) if self.args.verbose > 2: print(" label scores: " + ",".join(("%s: %g" % x for x in zip(self.model.labels.all, scores)))) label = predicted_label = self.predict(scores, self.model.labels.all, self.state.is_valid_label) if self.oracle: is_correct = (label == true_label) if is_correct: self.correct_label_count += 1 if train and not (is_correct and self.update_only_on_error): self.model.model.update(features, axis=LABEL_AXIS, pred=self.model.labels[label], true=true_id) label = true_label self.label_count += 1 return label, predicted_label def predict(self, scores, values, is_valid=None): """ Choose action/label based on classifier Usually the best action/label is valid, so max is enough to choose it in O(n) time Otherwise, sorts all the other scores to choose the best valid one in O(n lg n) :return: valid action/label with maximum probability according to classifier """ return next( filter(is_valid, (values[i] for i in self.generate_descending(scores)))) @staticmethod def generate_descending(scores): yield scores.argmax() yield from scores.argsort( )[:: -1] # Contains the max, but otherwise items might be missed (different order) def check_loop(self): """ Check if the current state has already occurred, indicating a loop """ h = hash(self.state) assert h not in self.state_hash_history,\ "\n".join(["Transition loop", self.state.str("\n")] + [self.oracle.str("\n")] if self.oracle else ()) self.state_hash_history.add(h) def evaluate_passage(self, guessed, ref): converters = CONVERTERS.get(ref.extra.get( "format")) # returns (input converter, output converter) tuple score = EVALUATORS.get(ref.extra.get("format"), evaluation).evaluate( guessed, ref, converter=converters and converters[1], # converter output is list of lines verbose=guessed and self.args.verbose > 2, constructions=self.args.constructions) print("F1=%.3f" % score.average_f1(), flush=True) return score def verify_passage(self, guessed, ref, show_diff): """ Compare predicted passage to true passage and raise an exception if they differ :param ref: true passage :param guessed: predicted passage to compare :param show_diff: if passages differ, show the difference between them? Depends on guessed having the original node IDs annotated in the "remarks" field for each node """ assert ref.equals(guessed, ignore_node=self.ignore_node), \ "Failed to produce true passage" + (diffutil.diff_passages(ref, guessed) if show_diff else "")
class Parser(object): """ Main class to implement transition-based UCCA parser """ def __init__(self, model_file=None, model_type=None, beam=1): self.state = None # State object created at each parse self.oracle = None # Oracle object created at each parse self.scores = None # NumPy array of action scores at each action self.action_count = 0 self.correct_count = 0 self.total_actions = 0 self.total_correct = 0 self.model = Model(model_type, model_file, Actions().all) self.beam = beam # Currently unused self.state_hash_history = None # For loop checking # Used in verify_passage to optionally ignore a mismatch in linkage nodes: self.ignore_node = None if Config( ).args.linkage else lambda n: n.tag == layer1.NodeTags.Linkage self.best_score = self.dev = self.iteration = self.eval_index = None self.dev_scores = [] self.trained = False def train(self, passages=None, dev=None, iterations=1): """ Train parser on given passages :param passages: iterable of passages to train on :param dev: iterable of passages to tune on :param iterations: number of iterations to perform :return: trained model """ self.trained = True if passages: if ClassifierProperty.trainable_after_saving in self.model.model.get_classifier_properties( ): try: self.model.load() except FileNotFoundError: print("not found, starting from untrained model.") self.best_score = 0 self.dev = dev if Config().args.devscores: with open(Config().args.devscores, "w") as f: print(",".join(["iteration"] + evaluation.Scores.field_titles( Config().args.constructions)), file=f) for self.iteration in range(1, iterations + 1): self.eval_index = 0 print("Training iteration %d of %d: " % (self.iteration, iterations)) list(self.parse(passages, mode=ParseMode.train)) self.eval_and_save(self.iteration == iterations, finished_epoch=True) Config().random.shuffle(passages) print("Trained %d iterations" % iterations) if dev or not passages: self.model.load() def eval_and_save(self, last=False, finished_epoch=False): model = self.model self.model = self.model.finalize(finished_epoch=finished_epoch) if self.dev: print("Evaluating on dev passages") scores = [ s for _, s in self.parse( self.dev, mode=ParseMode.dev, evaluate=True) ] scores = evaluation.Scores.aggregate(scores) self.dev_scores.append(scores) score = scores.average_f1() print("Average labeled F1 score on dev: %.3f" % score) if Config().args.devscores: prefix = [self.iteration] if Config().args.save_every: prefix.append(self.eval_index) with open(Config().args.devscores, "a") as f: print(",".join([".".join(map(str, prefix))] + scores.fields()), file=f) if score >= self.best_score: print("Better than previous best score (%.3f)" % self.best_score) self.best_score = score self.model.save() else: print("Not better than previous best score (%.3f)" % self.best_score) elif last or Config().args.save_every is not None: self.model.save() if not last: self.model = model # Restore non-finalized model def parse(self, passages, mode=ParseMode.test, evaluate=False): """ Parse given passages :param passages: iterable of passages to parse :param mode: ParseMode value. If train, use oracle to train on given passages. Otherwise, just parse with classifier. :param evaluate: whether to evaluate parsed passages with respect to given ones. Only possible when given passages are annotated. :return: generator of parsed passages (or in train mode, the original ones), or, if evaluate=True, of pairs of (Passage, Scores). """ assert mode in ParseMode, "Invalid parse mode: %s" % mode train = (mode is ParseMode.train) if not train and not self.trained: self.train() passage_word = "sentence" if Config().args.sentences else \ "paragraph" if Config().args.paragraphs else \ "passage" self.total_actions = 0 self.total_correct = 0 total_duration = 0 total_tokens = 0 passage_index = 0 if not hasattr(passages, "__iter__"): # Single passage given passages = (passages, ) for passage_index, passage in enumerate(passages): l0 = passage.layer(layer0.LAYER_ID) num_tokens = len(l0.all) l1 = passage.layer(layer1.LAYER_ID) labeled = len(l1.all) > 1 assert not train or labeled, "Cannot train on unannotated passage: %s" % passage.ID assert not evaluate or labeled, "Cannot evaluate on unannotated passage: %s" % passage.ID print("%s %-7s" % (passage_word, passage.ID), end=Config().line_end, flush=True) started = time.time() self.action_count = 0 self.correct_count = 0 textutil.annotate(passage, verbose=Config().args.verbose ) # tag POS and parse dependencies self.state = State(passage) self.state_hash_history = set() self.oracle = Oracle(passage) if train else None failed = False if ClassifierProperty.require_init_features in self.model.model.get_classifier_properties( ): self.model.init_features(self.state, train) try: self.parse_passage( train) # This is where the actual parsing takes place except ParserException as e: if train: raise Config().log("%s %s: %s" % (passage_word, passage.ID, e)) failed = True predicted_passage = self.state.create_passage(assert_proper=Config().args.verify) \ if not train or Config().args.verify else passage duration = time.time() - started total_duration += duration num_tokens -= len(self.state.buffer) total_tokens += num_tokens if train: # We have an oracle to verify by if not failed and Config().args.verify: self.verify_passage(passage, predicted_passage, train) if self.action_count: print("%-16s" % ("%d%% (%d/%d)" % (100 * self.correct_count / self.action_count, self.correct_count, self.action_count)), end=Config().line_end) print("%0.3fs" % duration, end="") print("%-15s" % (" (failed)" if failed else " (%d tokens/s)" % (num_tokens / duration)), end="") print(Config().line_end, end="") if train: print(Config().line_end, flush=True) self.model.model.finished_item(train) self.total_correct += self.correct_count self.total_actions += self.action_count if train and Config().args.save_every and ( passage_index + 1) % Config().args.save_every == 0: self.eval_and_save() self.eval_index += 1 yield (predicted_passage, evaluate_passage( predicted_passage, passage)) if evaluate else predicted_passage if passages: print("Parsed %d %ss" % (passage_index + 1, passage_word)) if self.oracle and self.total_actions: print("Overall %d%% correct transitions (%d/%d) on %s" % (100 * self.total_correct / self.total_actions, self.total_correct, self.total_actions, mode.name)) print( "Total time: %.3fs (average time/%s: %.3fs, average tokens/s: %d)" % (total_duration, passage_word, total_duration / (passage_index + 1), total_tokens / total_duration), flush=True) def parse_passage(self, train): """ Internal method to parse a single passage :param train: use oracle to train on given passages, or just parse with classifier? """ if Config().args.verbose: print(" initial state: %s" % self.state) while True: if Config().args.check_loops: self.check_loop(print_oracle=train) true_actions = [] if self.oracle is not None: try: true_actions = self.oracle.get_actions(self.state) except (AttributeError, AssertionError) as e: if train: raise ParserException( "Error in oracle during training") from e features = self.model.feature_extractor.extract_features( self.state) predicted_action = self.predict_action( features, true_actions) # sets self.scores action = predicted_action correct_action = False if not true_actions: true_actions = "?" elif predicted_action in true_actions: self.correct_count += 1 correct_action = True elif train: action = Config().random.choice(true_actions) if train and not (correct_action and ClassifierProperty.update_only_on_error in self.model.model.get_classifier_properties()): best_true_action = true_actions[0] if len(true_actions) == 1 else \ true_actions[self.scores[[a.id for a in true_actions]].argmax()] self.model.model.update( features, predicted_action.id, best_true_action.id, Config().args.swap_importance if best_true_action.is_swap else 1) self.action_count += 1 self.model.model.finished_step(train) try: self.state.transition(action) except AssertionError as e: raise ParserException("Invalid transition (%s): %s" % (action, e)) from e if Config().args.verbose: if self.oracle is None: print(" action: %-15s %s" % (action, self.state)) else: print(" predicted: %-15s true: %-15s taken: %-15s %s" % (predicted_action, "|".join(map( str, true_actions)), action, self.state)) for line in self.state.log: print(" " + line) if self.state.finished or train and not correct_action and Config( ).args.early_update: return # action is FINISH def check_loop(self, print_oracle): """ Check if the current state has already occurred, indicating a loop :param print_oracle: whether to print the oracle in case of an assertion error """ h = hash(self.state) assert h not in self.state_hash_history,\ "\n".join(["Transition loop", self.state.str("\n")] + [self.oracle.str("\n")] if print_oracle else ()) self.state_hash_history.add(h) def predict_action(self, features, true_actions): """ Choose action based on classifier :param features: extracted feature values :param true_actions: from the oracle, to copy orig_node if the same action is selected :return: valid action with maximum probability according to classifier """ self.scores = self.model.model.score(features) # Returns a NumPy array if Config().args.verbose >= 2: print(" scores: " + " ".join(("%g" % s for s in self.scores))) best_action = self.select_action(self.scores.argmax(), true_actions) if self.state.is_valid(best_action): return best_action # Usually the best action is valid, so max is enough to choose it in O(n) time # Otherwise, sort all the other scores to choose the best valid one in O(n lg n) sorted_ids = self.scores.argsort()[::-1] actions = (self.select_action(i, true_actions) for i in sorted_ids) try: return next(a for a in actions if self.state.is_valid(a)) except StopIteration as e: raise ParserException( "No valid actions available\n" + ("True actions: %s" % true_actions if true_actions else self. oracle.log if self.oracle is not None else "") + "\nReturned actions: %s" % [self.select_action(i) for i in sorted_ids] + "\nScores: %s" % self.scores) from e @staticmethod def select_action(i, true_actions=()): """ Find action with the given ID in true actions (if exists) or in all actions :param i: ID to lookup :param true_actions: preferred set of actions to look in first :return: Action with id=i """ try: return next(a for a in true_actions if a.id == i) except StopIteration: return Actions().all[i] def verify_passage(self, passage, predicted_passage, show_diff): """ Compare predicted passage to true passage and die if they differ :param passage: true passage :param predicted_passage: predicted passage to compare :param show_diff: if passages differ, show the difference between them? Depends on predicted_passage having the original node IDs annotated in the "remarks" field for each node. """ assert passage.equals(predicted_passage, ignore_node=self.ignore_node),\ "Failed to produce true passage" + \ (diffutil.diff_passages( passage, predicted_passage) if show_diff else "")