Exemplo n.º 1
0
 def __init__(self, model_file=None, model_type=None, beam=1):
     self.args = Config().args
     self.state = None  # State object created at each parse
     self.oracle = None  # Oracle object created at each parse
     self.action_count = self.correct_action_count = self.total_actions = self.total_correct_actions = 0
     self.label_count = self.correct_label_count = self.total_labels = self.total_correct_labels = 0
     self.model = Model(model_type, model_file)
     self.update_only_on_error = \
         ClassifierProperty.update_only_on_error in self.model.model.get_classifier_properties()
     self.beam = beam  # Currently unused
     self.state_hash_history = None  # For loop checking
     # Used in verify_passage to optionally ignore a mismatch in linkage nodes:
     self.ignore_node = None if self.args.linkage else lambda n: n.tag == layer1.NodeTags.Linkage
     self.best_score = self.dev = self.iteration = self.eval_index = None
     self.trained = False
Exemplo n.º 2
0
 def __call__(self, config):
     config.args.vocab = self.vocab
     config.args.word_vectors = self.wordvectors
     config.args.omit_features = self.omit
     return SparseFeatureExtractor(omit_features=self.omit) if self.name == SPARSE else DenseFeatureExtractor(
         OrderedDict((p.name, p.create_from_config()) for p in Model(None, config=config).param_defs()),
         indexed=self.indexed, node_dropout=0, omit_features=self.omit)
Exemplo n.º 3
0
Arquivo: parse.py Projeto: ml-lab/tupa
 def __init__(self, model_file=None, model_type=None, beam=1):
     self.state = None  # State object created at each parse
     self.oracle = None  # Oracle object created at each parse
     self.scores = None  # NumPy array of action scores at each action
     self.action_count = 0
     self.correct_count = 0
     self.total_actions = 0
     self.total_correct = 0
     self.model = Model(model_type, model_file, Actions().all)
     self.beam = beam  # Currently unused
     self.state_hash_history = None  # For loop checking
     # Used in verify_passage to optionally ignore a mismatch in linkage nodes:
     self.ignore_node = None if Config(
     ).args.linkage else lambda n: n.tag == layer1.NodeTags.Linkage
     self.best_score = self.dev = self.iteration = self.eval_index = None
     self.dev_scores = []
     self.trained = False
Exemplo n.º 4
0
def test_model(model_type, formats, test_passage, iterations, omit_features, config):
    filename = "_".join(filter(None, [os.path.join("test_files", "models", "test"), model_type, omit_features]+formats))
    remove_existing(filename)
    config.update(dict(classifier=model_type, copy_shared=None, omit_features=omit_features))
    finalized = model = Model(filename, config=config)
    for i in range(iterations):
        parse(formats, model, test_passage, train=True)
        finalized = model.finalize(finished_epoch=True)
        assert not getattr(finalized.feature_extractor, "node_dropout", 0), finalized.feature_extractor.node_dropout
        parse(formats, model, test_passage, train=False)
        finalized.save()
    loaded = Model(filename, config=config)
    loaded.load()
    assert not getattr(loaded.feature_extractor, "node_dropout", 0), loaded.feature_extractor.node_dropout
    for key, param in sorted(model.feature_extractor.params.items()):
        loaded_param = loaded.feature_extractor.params[key]
        assert param == loaded_param
    assert_all_params_equal(finalized.all_params(), loaded.all_params(), decay=weight_decay(model))
Exemplo n.º 5
0
def test_model(model_type, formats, test_passage, iterations, omit_features,
               config):
    filename = "_".join(
        filter(None, [
            os.path.join("test_files", "models", "test"), model_type,
            omit_features
        ] + formats))
    remove_existing(filename)
    config.update(
        dict(classifier=model_type,
             copy_shared=None,
             omit_features=omit_features))
    finalized = model = Model(filename, config=config)
    for i in range(iterations):
        parse(formats, model, test_passage, train=True)
        finalized = model.finalize(finished_epoch=True)
        assert not getattr(finalized.feature_extractor, "node_dropout",
                           0), finalized.feature_extractor.node_dropout
        parse(formats, model, test_passage, train=False)
        finalized.save()
    loaded = Model(filename, config=config)
    loaded.load()
    assert not getattr(loaded.feature_extractor, "node_dropout",
                       0), loaded.feature_extractor.node_dropout
    for key, param in sorted(model.feature_extractor.params.items()):
        loaded_param = loaded.feature_extractor.params[key]
        assert param == loaded_param
    assert_all_params_equal(finalized.all_params(),
                            loaded.all_params(),
                            decay=weight_decay(model))
Exemplo n.º 6
0
def load_model(filename):
    model = Model(filename=filename)
    model.load()
    return model
Exemplo n.º 7
0
class Parser(object):
    """
    Main class to implement transition-based UCCA parser
    """
    def __init__(self, model_file=None, model_type=None, beam=1):
        self.args = Config().args
        self.state = None  # State object created at each parse
        self.oracle = None  # Oracle object created at each parse
        self.action_count = self.correct_action_count = self.total_actions = self.total_correct_actions = 0
        self.label_count = self.correct_label_count = self.total_labels = self.total_correct_labels = 0
        self.model = Model(model_type, model_file)
        self.update_only_on_error = \
            ClassifierProperty.update_only_on_error in self.model.model.get_classifier_properties()
        self.beam = beam  # Currently unused
        self.state_hash_history = None  # For loop checking
        # Used in verify_passage to optionally ignore a mismatch in linkage nodes:
        self.ignore_node = None if self.args.linkage else lambda n: n.tag == layer1.NodeTags.Linkage
        self.best_score = self.dev = self.iteration = self.eval_index = None
        self.trained = False

    def train(self, passages=None, dev=None, iterations=1):
        """
        Train parser on given passages
        :param passages: iterable of passages to train on
        :param dev: iterable of passages to tune on
        :param iterations: number of iterations to perform
        """
        self.trained = True
        if passages:
            if ClassifierProperty.trainable_after_saving in self.model.model.get_classifier_properties(
            ):
                try:
                    self.model.load()
                except FileNotFoundError:
                    print("not found, starting from untrained model.")
            self.best_score = 0
            self.dev = dev
            for self.iteration in range(1, iterations + 1):
                self.eval_index = 0
                print("Training iteration %d of %d: " %
                      (self.iteration, iterations))
                Config().random.shuffle(passages)
                list(self.parse(passages, mode=ParseMode.train))
                yield self.eval_and_save(self.iteration == iterations,
                                         finished_epoch=True)
            print("Trained %d iterations" % iterations)
        if dev or not passages:
            self.model.load()

    def eval_and_save(self, last=False, finished_epoch=False):
        scores = None
        model = self.model
        self.model = self.model.finalize(finished_epoch=finished_epoch)
        if self.dev:
            if not self.best_score:
                self.model.save()
            print("Evaluating on dev passages")
            passage_scores = [
                s for _, s in self.parse(
                    self.dev, mode=ParseMode.dev, evaluate=True)
            ]
            scores = Scores(passage_scores)
            average_score = scores.average_f1()
            print("Average labeled F1 score on dev: %.3f" % average_score)
            print_scores(scores,
                         self.args.devscores,
                         prefix_title="iteration",
                         prefix=[self.iteration] +
                         ([self.eval_index] if self.args.save_every else []))
            if average_score >= self.best_score:
                print("Better than previous best score (%.3f)" %
                      self.best_score)
                if self.best_score:
                    self.model.save()
                self.best_score = average_score
            else:
                print("Not better than previous best score (%.3f)" %
                      self.best_score)
        elif last or self.args.save_every is not None:
            self.model.save()
        if not last:
            self.model = model  # Restore non-finalized model
            self.model.load_labels()
        return scores

    def parse(self, passages, mode=ParseMode.test, evaluate=False):
        """
        Parse given passages
        :param passages: iterable of passages to parse
        :param mode: ParseMode value.
                     If train, use oracle to train on given passages.
                     Otherwise, just parse with classifier.
        :param evaluate: whether to evaluate parsed passages with respect to given ones.
                         Only possible when given passages are annotated.
        :return: generator of parsed passages (or in train mode, the original ones),
                 or, if evaluate=True, of pairs of (Passage, Scores).
        """
        assert mode in ParseMode, "Invalid parse mode: %s" % mode
        train = (mode is ParseMode.train)
        if not train and not self.trained:
            list(self.train())
        passage_word = "sentence" if self.args.sentences else \
                       "paragraph" if self.args.paragraphs else \
                       "passage"
        self.total_actions = 0
        self.total_correct_actions = 0
        total_duration = 0
        total_tokens = 0
        passage_index = 0
        if not hasattr(passages, "__iter__"):  # Single passage given
            passages = (passages, )
        for passage_index, passage in enumerate(passages):
            labeled = any(n.outgoing or n.attrib.get(LABEL_ATTRIB)
                          for n in passage.layer(layer1.LAYER_ID).all)
            assert not train or labeled, "Cannot train on unannotated passage: %s" % passage.ID
            assert not evaluate or labeled, "Cannot evaluate on unannotated passage: %s" % passage.ID
            print("%s %-7s" % (passage_word, passage.ID),
                  end=Config().line_end,
                  flush=True)
            started = time.time()
            self.action_count = self.correct_action_count = self.label_count = self.correct_label_count = 0
            textutil.annotate(passage, verbose=self.args.verbose >
                              1)  # tag POS and parse dependencies
            self.state = State(passage)
            self.state_hash_history = set()
            self.oracle = Oracle(passage) if train or (
                self.args.verbose or Config().args.use_gold_node_labels
            ) and labeled or self.args.verify else None
            failed = False
            if ClassifierProperty.require_init_features in self.model.model.get_classifier_properties(
            ):
                self.model.init_features(self.state, train)
            try:
                self.parse_passage(
                    train)  # This is where the actual parsing takes place
            except ParserException as e:
                if train:
                    raise
                Config().log("%s %s: %s" % (passage_word, passage.ID, e))
                failed = True
            guessed = self.state.create_passage(
                verify=self.args.verify
            ) if not train or self.args.verify else passage
            duration = time.time() - started
            total_duration += duration
            num_tokens = len(
                set(self.state.terminals).difference(self.state.buffer))
            total_tokens += num_tokens
            if self.oracle:  # We have an oracle to verify by
                if not failed and self.args.verify:
                    self.verify_passage(guessed, passage, train)
                if self.action_count:
                    accuracy_str = "%d%% (%d/%d)" % (
                        100 * self.correct_action_count / self.action_count,
                        self.correct_action_count, self.action_count)
                    if self.label_count:
                        accuracy_str += " %d%% (%d/%d)" % (
                            100 * self.correct_label_count / self.label_count,
                            self.correct_label_count, self.label_count)
                    print("%-30s" % accuracy_str, end=Config().line_end)
            print("%0.3fs" % duration, end="")
            print("%-15s" % (" (failed)" if failed else " (%d tokens/s)" %
                             (num_tokens / duration)),
                  end="")
            print(Config().line_end, end="")
            if self.oracle:
                print(Config().line_end, flush=True)
            self.model.model.finished_item(train)
            self.total_correct_actions += self.correct_action_count
            self.total_actions += self.action_count
            self.total_correct_labels += self.correct_label_count
            self.total_labels += self.label_count
            if train and self.args.save_every and (
                    passage_index + 1) % self.args.save_every == 0:
                self.eval_and_save()
                self.eval_index += 1
            yield (guessed, self.evaluate_passage(
                guessed, passage)) if evaluate else guessed

        if passages:
            print("Parsed %d %ss" % (passage_index + 1, passage_word))
            if self.oracle and self.total_actions:
                accuracy_str = "%d%% correct actions (%d/%d)" % (
                    100 * self.total_correct_actions / self.total_actions,
                    self.total_correct_actions, self.total_actions)
                if self.total_labels:
                    accuracy_str += ", %d%% correct labels (%d/%d)" % (
                        100 * self.total_correct_labels / self.total_labels,
                        self.total_correct_labels, self.total_labels)
                print("Overall %s on %s" % (accuracy_str, mode.name))
            if total_duration:
                print(
                    "Total time: %.3fs (average time/%s: %.3fs, average tokens/s: %d)"
                    % (total_duration, passage_word, total_duration /
                       (passage_index + 1), total_tokens / total_duration),
                    flush=True)

    def parse_passage(self, train):
        """
        Internal method to parse a single passage
        :param train: use oracle to train on given passages, or just parse with classifier?
        """
        if self.args.verbose > 1:
            print("  initial state: %s" % self.state)
        while True:
            if self.args.check_loops:
                self.check_loop()
            features = self.model.feature_extractor.extract_features(
                self.state)
            true_actions = self.get_true_actions(train)
            action, predicted_action = self.choose_action(
                features, train, true_actions)
            try:
                self.state.transition(action)
            except AssertionError as e:
                raise ParserException("Invalid transition: %s %s" %
                                      (action, self.state)) from e
            if self.args.verbose > 1:
                if self.oracle:
                    print("  predicted: %-15s true: %-15s taken: %-15s %s" %
                          (predicted_action, "|".join(
                              map(str,
                                  true_actions.values())), action, self.state))
                else:
                    print("  action: %-15s %s" % (action, self.state))
            if self.state.need_label:  # Label action that requires a choice of label
                true_label = self.get_true_label(action.orig_node)
                label, predicted_label = self.choose_label(
                    features, train, true_label)
                self.state.label_node(label)
                if self.args.verbose > 1:
                    if self.oracle and not Config().args.use_gold_node_labels:
                        print("  predicted label: %-15s true label: %-15s" %
                              (predicted_label, true_label))
                    else:
                        print("  label: %-15s" % label)
            self.model.model.finished_step(train)
            if self.args.verbose > 1:
                for line in self.state.log:
                    print("    " + line)
            if self.state.finished:
                return  # action is Finish (or early update is triggered)

    def get_true_actions(self, train):
        true_actions = {}
        if self.oracle:
            try:
                true_actions = self.oracle.get_actions(self.state,
                                                       self.model.actions,
                                                       create=train)
            except (AttributeError, AssertionError) as e:
                if train:
                    raise ParserException(
                        "Error in oracle during training") from e
        return true_actions

    def choose_action(self, features, train, true_actions):
        scores = self.model.model.score(
            features, axis=ACTION_AXIS)  # Returns NumPy array
        if self.args.verbose > 2:
            print("  action scores: " +
                  ",".join(("%s: %g" % x
                            for x in zip(self.model.actions.all, scores))))
        try:
            predicted_action = self.predict(scores, self.model.actions.all,
                                            self.state.is_valid_action)
        except StopIteration as e:
            raise ParserException(
                "No valid action available\n%s" %
                (self.oracle.log if self.oracle else "")) from e
        action = true_actions.get(predicted_action.id)
        is_correct = (action is not None)
        if is_correct:
            self.correct_action_count += 1
        else:
            action = Config().random.choice(list(
                true_actions.values())) if train else predicted_action
        if train and not (is_correct and self.update_only_on_error):
            best_action = self.predict(scores[list(true_actions.keys())],
                                       list(true_actions.values()))
            self.model.model.update(features,
                                    axis=ACTION_AXIS,
                                    pred=predicted_action.id,
                                    true=best_action.id,
                                    importance=self.args.swap_importance
                                    if best_action.is_swap else 1)
        if train and not is_correct and self.args.early_update:
            self.state.finished = True
        self.action_count += 1
        return action, predicted_action

    def get_true_label(self, node):
        true_label = None
        if self.oracle:
            if node is not None:
                true_label = node.attrib.get(LABEL_ATTRIB)
            if true_label is not None:
                true_label, _, _ = true_label.partition(LABEL_SEPARATOR)
                if not self.state.is_valid_label(true_label):
                    raise ParserException("True label is invalid: %s %s" %
                                          (true_label, self.state))
        return true_label

    def choose_label(self, features, train, true_label):
        true_id = self.model.labels[
            true_label] if self.oracle else None  # Needs to happen before score()
        if Config().args.use_gold_node_labels:
            return true_label, true_label
        scores = self.model.model.score(features, axis=LABEL_AXIS)
        if self.args.verbose > 2:
            print("  label scores: " +
                  ",".join(("%s: %g" % x
                            for x in zip(self.model.labels.all, scores))))
        label = predicted_label = self.predict(scores, self.model.labels.all,
                                               self.state.is_valid_label)
        if self.oracle:
            is_correct = (label == true_label)
            if is_correct:
                self.correct_label_count += 1
            if train and not (is_correct and self.update_only_on_error):
                self.model.model.update(features,
                                        axis=LABEL_AXIS,
                                        pred=self.model.labels[label],
                                        true=true_id)
                label = true_label
        self.label_count += 1
        return label, predicted_label

    def predict(self, scores, values, is_valid=None):
        """
        Choose action/label based on classifier
        Usually the best action/label is valid, so max is enough to choose it in O(n) time
        Otherwise, sorts all the other scores to choose the best valid one in O(n lg n)
        :return: valid action/label with maximum probability according to classifier
        """
        return next(
            filter(is_valid,
                   (values[i] for i in self.generate_descending(scores))))

    @staticmethod
    def generate_descending(scores):
        yield scores.argmax()
        yield from scores.argsort(
        )[::
          -1]  # Contains the max, but otherwise items might be missed (different order)

    def check_loop(self):
        """
        Check if the current state has already occurred, indicating a loop
        """
        h = hash(self.state)
        assert h not in self.state_hash_history,\
            "\n".join(["Transition loop", self.state.str("\n")] + [self.oracle.str("\n")] if self.oracle else ())
        self.state_hash_history.add(h)

    def evaluate_passage(self, guessed, ref):
        converters = CONVERTERS.get(ref.extra.get(
            "format"))  # returns (input converter, output converter) tuple
        score = EVALUATORS.get(ref.extra.get("format"), evaluation).evaluate(
            guessed,
            ref,
            converter=converters
            and converters[1],  # converter output is list of lines
            verbose=guessed and self.args.verbose > 2,
            constructions=self.args.constructions)
        print("F1=%.3f" % score.average_f1(), flush=True)
        return score

    def verify_passage(self, guessed, ref, show_diff):
        """
        Compare predicted passage to true passage and raise an exception if they differ
        :param ref: true passage
        :param guessed: predicted passage to compare
        :param show_diff: if passages differ, show the difference between them?
                          Depends on guessed having the original node IDs annotated in the "remarks" field for each node
        """
        assert ref.equals(guessed, ignore_node=self.ignore_node), \
            "Failed to produce true passage" + (diffutil.diff_passages(ref, guessed) if show_diff else "")
Exemplo n.º 8
0
Arquivo: parse.py Projeto: ml-lab/tupa
class Parser(object):
    """
    Main class to implement transition-based UCCA parser
    """
    def __init__(self, model_file=None, model_type=None, beam=1):
        self.state = None  # State object created at each parse
        self.oracle = None  # Oracle object created at each parse
        self.scores = None  # NumPy array of action scores at each action
        self.action_count = 0
        self.correct_count = 0
        self.total_actions = 0
        self.total_correct = 0
        self.model = Model(model_type, model_file, Actions().all)
        self.beam = beam  # Currently unused
        self.state_hash_history = None  # For loop checking
        # Used in verify_passage to optionally ignore a mismatch in linkage nodes:
        self.ignore_node = None if Config(
        ).args.linkage else lambda n: n.tag == layer1.NodeTags.Linkage
        self.best_score = self.dev = self.iteration = self.eval_index = None
        self.dev_scores = []
        self.trained = False

    def train(self, passages=None, dev=None, iterations=1):
        """
        Train parser on given passages
        :param passages: iterable of passages to train on
        :param dev: iterable of passages to tune on
        :param iterations: number of iterations to perform
        :return: trained model
        """
        self.trained = True
        if passages:
            if ClassifierProperty.trainable_after_saving in self.model.model.get_classifier_properties(
            ):
                try:
                    self.model.load()
                except FileNotFoundError:
                    print("not found, starting from untrained model.")
            self.best_score = 0
            self.dev = dev
            if Config().args.devscores:
                with open(Config().args.devscores, "w") as f:
                    print(",".join(["iteration"] +
                                   evaluation.Scores.field_titles(
                                       Config().args.constructions)),
                          file=f)
            for self.iteration in range(1, iterations + 1):
                self.eval_index = 0
                print("Training iteration %d of %d: " %
                      (self.iteration, iterations))
                list(self.parse(passages, mode=ParseMode.train))
                self.eval_and_save(self.iteration == iterations,
                                   finished_epoch=True)
                Config().random.shuffle(passages)
            print("Trained %d iterations" % iterations)
        if dev or not passages:
            self.model.load()

    def eval_and_save(self, last=False, finished_epoch=False):
        model = self.model
        self.model = self.model.finalize(finished_epoch=finished_epoch)
        if self.dev:
            print("Evaluating on dev passages")
            scores = [
                s for _, s in self.parse(
                    self.dev, mode=ParseMode.dev, evaluate=True)
            ]
            scores = evaluation.Scores.aggregate(scores)
            self.dev_scores.append(scores)
            score = scores.average_f1()
            print("Average labeled F1 score on dev: %.3f" % score)
            if Config().args.devscores:
                prefix = [self.iteration]
                if Config().args.save_every:
                    prefix.append(self.eval_index)
                with open(Config().args.devscores, "a") as f:
                    print(",".join([".".join(map(str, prefix))] +
                                   scores.fields()),
                          file=f)
            if score >= self.best_score:
                print("Better than previous best score (%.3f)" %
                      self.best_score)
                self.best_score = score
                self.model.save()
            else:
                print("Not better than previous best score (%.3f)" %
                      self.best_score)
        elif last or Config().args.save_every is not None:
            self.model.save()
        if not last:
            self.model = model  # Restore non-finalized model

    def parse(self, passages, mode=ParseMode.test, evaluate=False):
        """
        Parse given passages
        :param passages: iterable of passages to parse
        :param mode: ParseMode value.
                     If train, use oracle to train on given passages.
                     Otherwise, just parse with classifier.
        :param evaluate: whether to evaluate parsed passages with respect to given ones.
                         Only possible when given passages are annotated.
        :return: generator of parsed passages (or in train mode, the original ones),
                 or, if evaluate=True, of pairs of (Passage, Scores).
        """
        assert mode in ParseMode, "Invalid parse mode: %s" % mode
        train = (mode is ParseMode.train)
        if not train and not self.trained:
            self.train()
        passage_word = "sentence" if Config().args.sentences else \
                       "paragraph" if Config().args.paragraphs else \
                       "passage"
        self.total_actions = 0
        self.total_correct = 0
        total_duration = 0
        total_tokens = 0
        passage_index = 0
        if not hasattr(passages, "__iter__"):  # Single passage given
            passages = (passages, )
        for passage_index, passage in enumerate(passages):
            l0 = passage.layer(layer0.LAYER_ID)
            num_tokens = len(l0.all)
            l1 = passage.layer(layer1.LAYER_ID)
            labeled = len(l1.all) > 1
            assert not train or labeled, "Cannot train on unannotated passage: %s" % passage.ID
            assert not evaluate or labeled, "Cannot evaluate on unannotated passage: %s" % passage.ID
            print("%s %-7s" % (passage_word, passage.ID),
                  end=Config().line_end,
                  flush=True)
            started = time.time()
            self.action_count = 0
            self.correct_count = 0
            textutil.annotate(passage, verbose=Config().args.verbose
                              )  # tag POS and parse dependencies
            self.state = State(passage)
            self.state_hash_history = set()
            self.oracle = Oracle(passage) if train else None
            failed = False
            if ClassifierProperty.require_init_features in self.model.model.get_classifier_properties(
            ):
                self.model.init_features(self.state, train)
            try:
                self.parse_passage(
                    train)  # This is where the actual parsing takes place
            except ParserException as e:
                if train:
                    raise
                Config().log("%s %s: %s" % (passage_word, passage.ID, e))
                failed = True
            predicted_passage = self.state.create_passage(assert_proper=Config().args.verify) \
                if not train or Config().args.verify else passage
            duration = time.time() - started
            total_duration += duration
            num_tokens -= len(self.state.buffer)
            total_tokens += num_tokens
            if train:  # We have an oracle to verify by
                if not failed and Config().args.verify:
                    self.verify_passage(passage, predicted_passage, train)
                if self.action_count:
                    print("%-16s" %
                          ("%d%% (%d/%d)" %
                           (100 * self.correct_count / self.action_count,
                            self.correct_count, self.action_count)),
                          end=Config().line_end)
            print("%0.3fs" % duration, end="")
            print("%-15s" % (" (failed)" if failed else " (%d tokens/s)" %
                             (num_tokens / duration)),
                  end="")
            print(Config().line_end, end="")
            if train:
                print(Config().line_end, flush=True)
            self.model.model.finished_item(train)
            self.total_correct += self.correct_count
            self.total_actions += self.action_count
            if train and Config().args.save_every and (
                    passage_index + 1) % Config().args.save_every == 0:
                self.eval_and_save()
                self.eval_index += 1
            yield (predicted_passage,
                   evaluate_passage(
                       predicted_passage,
                       passage)) if evaluate else predicted_passage

        if passages:
            print("Parsed %d %ss" % (passage_index + 1, passage_word))
            if self.oracle and self.total_actions:
                print("Overall %d%% correct transitions (%d/%d) on %s" %
                      (100 * self.total_correct / self.total_actions,
                       self.total_correct, self.total_actions, mode.name))
            print(
                "Total time: %.3fs (average time/%s: %.3fs, average tokens/s: %d)"
                % (total_duration, passage_word, total_duration /
                   (passage_index + 1), total_tokens / total_duration),
                flush=True)

    def parse_passage(self, train):
        """
        Internal method to parse a single passage
        :param train: use oracle to train on given passages, or just parse with classifier?
        """
        if Config().args.verbose:
            print("  initial state: %s" % self.state)
        while True:
            if Config().args.check_loops:
                self.check_loop(print_oracle=train)

            true_actions = []
            if self.oracle is not None:
                try:
                    true_actions = self.oracle.get_actions(self.state)
                except (AttributeError, AssertionError) as e:
                    if train:
                        raise ParserException(
                            "Error in oracle during training") from e

            features = self.model.feature_extractor.extract_features(
                self.state)
            predicted_action = self.predict_action(
                features, true_actions)  # sets self.scores
            action = predicted_action
            correct_action = False
            if not true_actions:
                true_actions = "?"
            elif predicted_action in true_actions:
                self.correct_count += 1
                correct_action = True
            elif train:
                action = Config().random.choice(true_actions)
            if train and not (correct_action
                              and ClassifierProperty.update_only_on_error
                              in self.model.model.get_classifier_properties()):
                best_true_action = true_actions[0] if len(true_actions) == 1 else \
                    true_actions[self.scores[[a.id for a in true_actions]].argmax()]
                self.model.model.update(
                    features, predicted_action.id, best_true_action.id,
                    Config().args.swap_importance
                    if best_true_action.is_swap else 1)
            self.action_count += 1
            self.model.model.finished_step(train)
            try:
                self.state.transition(action)
            except AssertionError as e:
                raise ParserException("Invalid transition (%s): %s" %
                                      (action, e)) from e
            if Config().args.verbose:
                if self.oracle is None:
                    print("  action: %-15s %s" % (action, self.state))
                else:
                    print("  predicted: %-15s true: %-15s taken: %-15s %s" %
                          (predicted_action, "|".join(map(
                              str, true_actions)), action, self.state))
                for line in self.state.log:
                    print("    " + line)
            if self.state.finished or train and not correct_action and Config(
            ).args.early_update:
                return  # action is FINISH

    def check_loop(self, print_oracle):
        """
        Check if the current state has already occurred, indicating a loop
        :param print_oracle: whether to print the oracle in case of an assertion error
        """
        h = hash(self.state)
        assert h not in self.state_hash_history,\
            "\n".join(["Transition loop", self.state.str("\n")] +
                      [self.oracle.str("\n")] if print_oracle else ())
        self.state_hash_history.add(h)

    def predict_action(self, features, true_actions):
        """
        Choose action based on classifier
        :param features: extracted feature values
        :param true_actions: from the oracle, to copy orig_node if the same action is selected
        :return: valid action with maximum probability according to classifier
        """
        self.scores = self.model.model.score(features)  # Returns a NumPy array
        if Config().args.verbose >= 2:
            print("  scores: " + " ".join(("%g" % s for s in self.scores)))
        best_action = self.select_action(self.scores.argmax(), true_actions)
        if self.state.is_valid(best_action):
            return best_action
        # Usually the best action is valid, so max is enough to choose it in O(n) time
        # Otherwise, sort all the other scores to choose the best valid one in O(n lg n)
        sorted_ids = self.scores.argsort()[::-1]
        actions = (self.select_action(i, true_actions) for i in sorted_ids)
        try:
            return next(a for a in actions if self.state.is_valid(a))
        except StopIteration as e:
            raise ParserException(
                "No valid actions available\n" +
                ("True actions: %s" % true_actions if true_actions else self.
                 oracle.log if self.oracle is not None else "") +
                "\nReturned actions: %s" %
                [self.select_action(i)
                 for i in sorted_ids] + "\nScores: %s" % self.scores) from e

    @staticmethod
    def select_action(i, true_actions=()):
        """
        Find action with the given ID in true actions (if exists) or in all actions
        :param i: ID to lookup
        :param true_actions: preferred set of actions to look in first
        :return: Action with id=i
        """
        try:
            return next(a for a in true_actions if a.id == i)
        except StopIteration:
            return Actions().all[i]

    def verify_passage(self, passage, predicted_passage, show_diff):
        """
        Compare predicted passage to true passage and die if they differ
        :param passage: true passage
        :param predicted_passage: predicted passage to compare
        :param show_diff: if passages differ, show the difference between them?
                          Depends on predicted_passage having the original node IDs annotated
                          in the "remarks" field for each node.
        """
        assert passage.equals(predicted_passage, ignore_node=self.ignore_node),\
            "Failed to produce true passage" + \
            (diffutil.diff_passages(
                    passage, predicted_passage) if show_diff else "")