Пример #1
0
def _test_features(config, feature_extractor_creator, filename, write_features):
    feature_extractor = feature_extractor_creator(config)
    passage = load_passage(filename, annotate=feature_extractor_creator.annotated)
    textutil.annotate(passage, as_array=True, as_extra=False, vocab=config.vocab())
    config.set_format(passage.extra.get("format") or "ucca")
    oracle = Oracle(passage)
    state = State(passage)
    actions = Actions()
    for key, param in feature_extractor.params.items():
        if not param.numeric:
            param.dropout = 0
            feature_extractor.init_param(key)
    features = [feature_extractor.init_features(state)]
    while True:
        extract_features(feature_extractor, state, features)
        action = min(oracle.get_actions(state, actions).values(), key=str)
        state.transition(action)
        if state.need_label:
            extract_features(feature_extractor, state, features)
            label, _ = oracle.get_label(state, action)
            state.label_node(label)
        if state.finished:
            break
    features = ["%s %s\n" % i for f in features if f for i in (sorted(f.items()) + [("", "")])]
    compare_file = os.path.join("test_files", "features", "-".join((basename(filename), str(feature_extractor_creator)))
                                + ".txt")
    if write_features:
        with open(compare_file, "w", encoding="utf-8") as f:
            f.writelines(features)
    with open(compare_file, encoding="utf-8") as f:
        assert f.readlines() == features, compare_file
Пример #2
0
def gen_actions(passage):
    oracle = Oracle(passage)
    state = State(passage)
    actions = Actions()
    while True:
        action = min(oracle.get_actions(state, actions).values(), key=str)
        state.transition(action)
        s = str(action)
        if state.need_label:
            label, _ = oracle.get_label(state, action)
            state.label_node(label)
            s += " " + str(label)
        yield s
        if state.finished:
            break
Пример #3
0
class PassageParser(AbstractParser):
    """ Parser for a single passage, has a state and optionally an oracle """
    def __init__(self, passage, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.passage = self.out = passage
        self.format = self.passage.extra.get("format") if self.training or self.evaluation else \
            sorted(set.intersection(*map(set, filter(None, (self.model.formats, self.config.args.formats)))) or
                   self.model.formats)[0]
        if self.training and self.config.args.verify:
            errors = list(validate(self.passage))
            assert not errors, errors
        self.in_format = self.format or "ucca"
        self.out_format = "ucca" if self.format in (None, "text") else self.format
        if self.config.args.use_bert and self.config.args.bert_multilingual is not None:
            self.lang = self.passage.attrib.get("lang")
            assert self.lang, "Attribute 'lang' is required per passage when using multilingual BERT"
        else:
            self.lang = self.passage.attrib.get("lang", self.config.args.lang)
        # Used in verify_passage to optionally ignore a mismatch in linkage nodes:
        self.ignore_node = None if self.config.args.linkage else lambda n: n.tag == layer1.NodeTags.Linkage
        self.state_hash_history = set()
        self.state = self.oracle = self.eval_type = None

    def init(self):
        self.config.set_format(self.in_format)
        WIKIFIER.enabled = self.config.args.wikification
        self.state = State(self.passage)
        # Passage is considered labeled if there are any edges or node labels in it
        edges, node_labels = map(any, zip(*[(n.outgoing, n.attrib.get(LABEL_ATTRIB))
                                            for n in self.passage.layer(layer1.LAYER_ID).all]))
        self.oracle = Oracle(self.passage) if self.training or self.config.args.verify or (
                (self.config.args.verbose > 1 or self.config.args.use_gold_node_labels or self.config.args.action_stats)
                and (edges or node_labels)) else None
        for model in self.models:
            model.init_model(self.config.format, lang=self.lang if self.config.args.multilingual else None)
            if ClassifierProperty.require_init_features in model.classifier_properties:
                model.init_features(self.state, self.training)

    def parse(self, display=True, write=False, accuracies=None):
        self.init()
        passage_id = self.passage.ID
        try:
            with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
                executor.submit(self.parse_internal).result(self.config.args.timeout)
            status = "(%d tokens/s)" % self.tokens_per_second()
        except ParserException as e:
            if self.training:
                raise
            self.config.log("%s %s: %s" % (self.config.passage_word, passage_id, e))
            status = "(failed)"
        except concurrent.futures.TimeoutError:
            self.config.log("%s %s: timeout (%fs)" % (self.config.passage_word, passage_id, self.config.args.timeout))
            status = "(timeout)"
        return self.finish(status, display=display, write=write, accuracies=accuracies)

    def parse_internal(self):
        """
        Internal method to parse a single passage.
        If training, use oracle to train on given passages. Otherwise just parse with classifier.
        """
        self.config.print("  initial state: %s" % self.state)
        while True:
            if self.config.args.check_loops:
                self.check_loop()
            self.label_node()  # In case root node needs labeling
            true_actions = self.get_true_actions()
            action, predicted_action = self.choose(true_actions)
            self.state.transition(action)
            need_label, label, predicted_label, true_label = self.label_node(action)
            if self.config.args.action_stats:
                try:
                    with open(self.config.args.action_stats, "a") as f:
                        print(",".join(map(str, [predicted_action, action] + list(true_actions.values()))), file=f)
                except OSError:
                    pass
            self.config.print(lambda: "\n".join(["  predicted: %-15s true: %-15s taken: %-15s %s" % (
                predicted_action, "|".join(map(str, true_actions.values())), action, self.state) if self.oracle else
                                          "  action: %-15s %s" % (action, self.state)] + (
                ["  predicted label: %-9s true label: %s" % (predicted_label, true_label) if self.oracle and not
                 self.config.args.use_gold_node_labels else "  label: %s" % label] if need_label else []) + [
                "    " + l for l in self.state.log]))
            if self.state.finished:
                return  # action is Finish (or early update is triggered)

    def get_true_actions(self):
        true_actions = {}
        if self.oracle:
            try:
                true_actions = self.oracle.get_actions(self.state, self.model.actions, create=self.training)
            except (AttributeError, AssertionError) as e:
                if self.training:
                    raise ParserException("Error in getting action from oracle during training") from e
        return true_actions

    def get_true_label(self, node):
        try:
            return self.oracle.get_label(self.state, node) if self.oracle else (None, None)
        except AssertionError as e:
            if self.training:
                raise ParserException("Error in getting label from oracle during training") from e
            return None, None

    def label_node(self, action=None):
        true_label = label = predicted_label = None
        need_label = self.state.need_label  # Label action that requires a choice of label
        if need_label:
            true_label, raw_true_label = self.get_true_label(action or need_label)
            label, predicted_label = self.choose(true_label, NODE_LABEL_KEY, "node label")
            self.state.label_node(raw_true_label if label == true_label else label)
        return need_label, label, predicted_label, true_label

    def choose(self, true, axis=None, name="action"):
        if axis is None:
            axis = self.model.axis
        elif axis == NODE_LABEL_KEY and self.config.args.use_gold_node_labels:
            return true, true
        labels = self.model.classifier.labels[axis]
        if axis == NODE_LABEL_KEY:
            true_keys = (labels[true],) if self.oracle else ()  # Must be before score()
            is_valid = self.state.is_valid_label
        else:
            true_keys = None
            is_valid = self.state.is_valid_action
        scores, features = self.model.score(self.state, axis)
        for model in self.models[1:]:  # Ensemble if given more than one model; align label order and add scores
            label_scores = dict(zip(model.classifier.labels[axis].all, self.model.score(self.state, axis)[0]))
            scores += [label_scores.get(a, 0) for a in labels.all]  # Product of Experts, assuming log(softmax)
        self.config.print(lambda: "  %s scores: %s" % (name, tuple(zip(labels.all, scores))), level=4)
        try:
            label = pred = self.predict(scores, labels.all, is_valid)
        except StopIteration as e:
            raise ParserException("No valid %s available\n%s" % (name, self.oracle.log if self.oracle else "")) from e
        label, is_correct, true_keys, true_values = self.correct(axis, label, pred, scores, true, true_keys)
        if self.training:
            if not (is_correct and ClassifierProperty.update_only_on_error in self.model.classifier_properties):
                assert not self.model.is_finalized, "Updating finalized model"
                self.model.classifier.update(
                    features, axis=axis, true=true_keys, pred=labels[pred] if axis == NODE_LABEL_KEY else pred.id,
                    importance=[self.config.args.swap_importance if a.is_swap else 1 for a in true_values] or None)
            if not is_correct and self.config.args.early_update:
                self.state.finished = True
        for model in self.models:
            model.classifier.finished_step(self.training)
            if axis != NODE_LABEL_KEY:
                model.classifier.transition(label, axis=axis)
        return label, pred

    def correct(self, axis, label, pred, scores, true, true_keys):
        true_values = is_correct = ()
        if axis == NODE_LABEL_KEY:
            if self.oracle:
                is_correct = (label == true)
                if is_correct:
                    self.correct_label_count += 1
                elif self.training:
                    label = true
            self.label_count += 1
        else:  # action
            true_keys, true_values = map(list, zip(*true.items())) if true else (None, None)
            label = true.get(pred.id)
            is_correct = (label is not None)
            if is_correct:
                self.correct_action_count += 1
            else:
                label = true_values[scores[true_keys].argmax()] if self.training else pred
            self.action_count += 1
        return label, is_correct, true_keys, true_values

    @staticmethod
    def predict(scores, values, is_valid=None):
        """
        Choose action/label based on classifier
        Usually the best action/label is valid, so max is enough to choose it in O(n) time
        Otherwise, sorts all the other scores to choose the best valid one in O(n lg n)
        :return: valid action/label with maximum probability according to classifier
        """
        return next(filter(is_valid, (values[i] for i in PassageParser.generate_descending(scores))))

    @staticmethod
    def generate_descending(scores):
        yield scores.argmax()
        yield from scores.argsort()[::-1]  # Contains the max, but otherwise items might be missed (different order)

    def finish(self, status, display=True, write=False, accuracies=None):
        self.model.classifier.finished_item(self.training)
        for model in self.models[1:]:
            model.classifier.finished_item(renew=False)  # So that dynet.renew_cg happens only once
        if not self.training or self.config.args.verify:
            self.out = self.state.create_passage(verify=self.config.args.verify, format=self.out_format)
        if write:
            for out_format in self.config.args.formats or [self.out_format]:
                if self.config.args.normalize and out_format == "ucca":
                    normalize(self.out)
                ioutil.write_passage(self.out, output_format=out_format, binary=out_format == "pickle",
                                     outdir=self.config.args.outdir, prefix=self.config.args.prefix,
                                     converter=get_output_converter(out_format), verbose=self.config.args.verbose,
                                     append=self.config.args.join, basename=self.config.args.join)
        if self.oracle and self.config.args.verify:
            self.verify(self.out, self.passage)
        ret = (self.out,)
        if self.evaluation:
            ret += (self.evaluate(self.evaluation),)
            status = "%-14s %s F1=%.3f" % (status, self.eval_type, self.f1)
        if display:
            self.config.print("%s%.3fs %s" % (self.accuracy_str, self.duration, status), level=1)
        if accuracies is not None:
            accuracies[self.passage.ID] = self.correct_action_count / self.action_count if self.action_count else 0
        return ret

    @property
    def accuracy_str(self):
        if self.oracle and self.action_count:
            accuracy_str = "a=%-14s" % percents_str(self.correct_action_count, self.action_count)
            if self.label_count:
                accuracy_str += " l=%-14s" % percents_str(self.correct_label_count, self.label_count)
            return "%-33s" % accuracy_str
        return ""

    def evaluate(self, mode=ParseMode.test):
        if self.format:
            self.config.print("Converting to %s and evaluating..." % self.format)
        self.eval_type = UNLABELED if self.config.is_unlabeled(self.in_format) else LABELED
        evaluator = EVALUATORS.get(self.format, evaluate_ucca)
        score = evaluator(self.out, self.passage, converter=get_output_converter(self.format),
                          verbose=self.out and self.config.args.verbose > 3,
                          constructions=self.config.args.constructions,
                          eval_types=(self.eval_type,) if mode is ParseMode.dev else (LABELED, UNLABELED))
        self.f1 = average_f1(score, self.eval_type)
        score.lang = self.lang
        return score

    def check_loop(self):
        """
        Check if the current state has already occurred, indicating a loop
        """
        h = hash(self.state)
        assert h not in self.state_hash_history, \
            "\n".join(["Transition loop", self.state.str("\n")] + [self.oracle.str("\n")] if self.oracle else ())
        self.state_hash_history.add(h)

    def verify(self, guessed, ref):
        """
        Compare predicted passage to true passage and raise an exception if they differ
        :param ref: true passage
        :param guessed: predicted passage to compare
        """
        assert ref.equals(guessed, ignore_node=self.ignore_node), \
            "Failed to produce true passage" + (diffutil.diff_passages(ref, guessed) if self.training else "")

    @property
    def num_tokens(self):
        return len(set(self.state.terminals).difference(self.state.buffer))  # To count even incomplete parses

    @num_tokens.setter
    def num_tokens(self, _):
        pass
Пример #4
0
class PassageParser(AbstractParser):
    """ Parser for a single passage, has a state and optionally an oracle """
    def __init__(self, passage, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.passage = self.out = passage
        self.format = self.passage.extra.get("format") if self.training or self.evaluation else \
            sorted(set.intersection(*map(set, filter(None, (self.model.formats, self.config.args.formats)))) or
                   self.model.formats)[0]
        if self.training and self.config.args.verify:
            errors = list(validate(self.passage))
            assert not errors, errors
        self.in_format = self.format or "ucca"
        self.out_format = "ucca" if self.format in (None, "text") else self.format
        self.lang = self.passage.attrib.get("lang", self.config.args.lang)
        # Used in verify_passage to optionally ignore a mismatch in linkage nodes:
        self.ignore_node = None if self.config.args.linkage else lambda n: n.tag == layer1.NodeTags.Linkage
        self.state_hash_history = set()
        self.state = self.oracle = self.eval_type = None

    def init(self):
        self.config.set_format(self.in_format)
        WIKIFIER.enabled = self.config.args.wikification
        self.state = State(self.passage)
        # Passage is considered labeled if there are any edges or node labels in it
        edges, node_labels = map(any, zip(*[(n.outgoing, n.attrib.get(LABEL_ATTRIB))
                                            for n in self.passage.layer(layer1.LAYER_ID).all]))
        self.oracle = Oracle(self.passage) if self.training or self.config.args.verify or (
                (self.config.args.verbose > 1 or self.config.args.use_gold_node_labels or self.config.args.action_stats)
                and (edges or node_labels)) else None
        for model in self.models:
            model.init_model(self.config.format, lang=self.lang if self.config.args.multilingual else None)
            if ClassifierProperty.require_init_features in model.classifier_properties:
                model.init_features(self.state, self.training)

    def parse(self, display=True, write=False, accuracies=None):
        self.init()
        passage_id = self.passage.ID
        try:
            with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
                executor.submit(self.parse_internal).result(self.config.args.timeout)
            status = "(%d tokens/s)" % self.tokens_per_second()
        except ParserException as e:
            if self.training:
                raise
            self.config.log("%s %s: %s" % (self.config.passage_word, passage_id, e))
            status = "(failed)"
        except concurrent.futures.TimeoutError:
            self.config.log("%s %s: timeout (%fs)" % (self.config.passage_word, passage_id, self.config.args.timeout))
            status = "(timeout)"
        return self.finish(status, display=display, write=write, accuracies=accuracies)

    def parse_internal(self):
        """
        Internal method to parse a single passage.
        If training, use oracle to train on given passages. Otherwise just parse with classifier.
        """
        self.config.print("  initial state: %s" % self.state)
        while True:
            if self.config.args.check_loops:
                self.check_loop()
            self.label_node()  # In case root node needs labeling
            true_actions = self.get_true_actions()
            action, predicted_action = self.choose(true_actions)
            self.state.transition(action)
            need_label, label, predicted_label, true_label = self.label_node(action)
            if self.config.args.action_stats:
                try:
                    with open(self.config.args.action_stats, "a") as f:
                        print(",".join(map(str, [predicted_action, action] + list(true_actions.values()))), file=f)
                except OSError:
                    pass
            self.config.print(lambda: "\n".join(["  predicted: %-15s true: %-15s taken: %-15s %s" % (
                predicted_action, "|".join(map(str, true_actions.values())), action, self.state) if self.oracle else
                                          "  action: %-15s %s" % (action, self.state)] + (
                ["  predicted label: %-9s true label: %s" % (predicted_label, true_label) if self.oracle and not
                 self.config.args.use_gold_node_labels else "  label: %s" % label] if need_label else []) + [
                "    " + l for l in self.state.log]))
            if self.state.finished:
                return  # action is Finish (or early update is triggered)

    def get_true_actions(self):
        true_actions = {}
        if self.oracle:
            try:
                true_actions = self.oracle.get_actions(self.state, self.model.actions, create=self.training)
            except (AttributeError, AssertionError) as e:
                if self.training:
                    raise ParserException("Error in getting action from oracle during training") from e
        return true_actions

    def get_true_label(self, node):
        try:
            return self.oracle.get_label(self.state, node) if self.oracle else (None, None)
        except AssertionError as e:
            if self.training:
                raise ParserException("Error in getting label from oracle during training") from e
            return None, None

    def label_node(self, action=None):
        true_label = label = predicted_label = None
        need_label = self.state.need_label  # Label action that requires a choice of label
        if need_label:
            true_label, raw_true_label = self.get_true_label(action or need_label)
            label, predicted_label = self.choose(true_label, NODE_LABEL_KEY, "node label")
            self.state.label_node(raw_true_label if label == true_label else label)
        return need_label, label, predicted_label, true_label

    def choose(self, true, axis=None, name="action"):
        if axis is None:
            axis = self.model.axis
        elif axis == NODE_LABEL_KEY and self.config.args.use_gold_node_labels:
            return true, true
        labels = self.model.classifier.labels[axis]
        if axis == NODE_LABEL_KEY:
            true_keys = (labels[true],) if self.oracle else ()  # Must be before score()
            is_valid = self.state.is_valid_label
        else:
            true_keys = None
            is_valid = self.state.is_valid_action
        scores, features = self.model.score(self.state, axis)
        for model in self.models[1:]:  # Ensemble if given more than one model; align label order and add scores
            label_scores = dict(zip(model.classifier.labels[axis].all, self.model.score(self.state, axis)[0]))
            scores += [label_scores.get(a, 0) for a in labels.all]  # Product of Experts, assuming log(softmax)
        self.config.print(lambda: "  %s scores: %s" % (name, tuple(zip(labels.all, scores))), level=4)
        try:
            label = pred = self.predict(scores, labels.all, is_valid)
        except StopIteration as e:
            raise ParserException("No valid %s available\n%s" % (name, self.oracle.log if self.oracle else "")) from e
        label, is_correct, true_keys, true_values = self.correct(axis, label, pred, scores, true, true_keys)
        if self.training:
            if not (is_correct and ClassifierProperty.update_only_on_error in self.model.classifier_properties):
                assert not self.model.is_finalized, "Updating finalized model"
                self.model.classifier.update(
                    features, axis=axis, true=true_keys, pred=labels[pred] if axis == NODE_LABEL_KEY else pred.id,
                    importance=[self.config.args.swap_importance if a.is_swap else 1 for a in true_values] or None)
            if not is_correct and self.config.args.early_update:
                self.state.finished = True
        for model in self.models:
            model.classifier.finished_step(self.training)
            if axis != NODE_LABEL_KEY:
                model.classifier.transition(label, axis=axis)
        return label, pred

    def correct(self, axis, label, pred, scores, true, true_keys):
        true_values = is_correct = ()
        if axis == NODE_LABEL_KEY:
            if self.oracle:
                is_correct = (label == true)
                if is_correct:
                    self.correct_label_count += 1
                elif self.training:
                    label = true
            self.label_count += 1
        else:  # action
            true_keys, true_values = map(list, zip(*true.items())) if true else (None, None)
            label = true.get(pred.id)
            is_correct = (label is not None)
            if is_correct:
                self.correct_action_count += 1
            else:
                label = true_values[scores[true_keys].argmax()] if self.training else pred
            self.action_count += 1
        return label, is_correct, true_keys, true_values

    @staticmethod
    def predict(scores, values, is_valid=None):
        """
        Choose action/label based on classifier
        Usually the best action/label is valid, so max is enough to choose it in O(n) time
        Otherwise, sorts all the other scores to choose the best valid one in O(n lg n)
        :return: valid action/label with maximum probability according to classifier
        """
        return next(filter(is_valid, (values[i] for i in PassageParser.generate_descending(scores))))

    @staticmethod
    def generate_descending(scores):
        yield scores.argmax()
        yield from scores.argsort()[::-1]  # Contains the max, but otherwise items might be missed (different order)

    def finish(self, status, display=True, write=False, accuracies=None):
        self.model.classifier.finished_item(self.training)
        for model in self.models[1:]:
            model.classifier.finished_item(renew=False)  # So that dynet.renew_cg happens only once
        if not self.training or self.config.args.verify:
            self.out = self.state.create_passage(verify=self.config.args.verify, format=self.out_format)
        if write:
            for out_format in self.config.args.formats or [self.out_format]:
                if self.config.args.normalize and out_format == "ucca":
                    normalize(self.out)
                ioutil.write_passage(self.out, output_format=out_format, binary=out_format == "pickle",
                                     outdir=self.config.args.outdir, prefix=self.config.args.prefix,
                                     converter=get_output_converter(out_format), verbose=self.config.args.verbose,
                                     append=self.config.args.join, basename=self.config.args.join)
        if self.oracle and self.config.args.verify:
            self.verify(self.out, self.passage)
        ret = (self.out,)
        if self.evaluation:
            ret += (self.evaluate(self.evaluation),)
            status = "%-14s %s F1=%.3f" % (status, self.eval_type, self.f1)
        if display:
            self.config.print("%s%.3fs %s" % (self.accuracy_str, self.duration, status), level=1)
        if accuracies is not None:
            accuracies[self.passage.ID] = self.correct_action_count / self.action_count if self.action_count else 0
        return ret

    @property
    def accuracy_str(self):
        if self.oracle and self.action_count:
            accuracy_str = "a=%-14s" % percents_str(self.correct_action_count, self.action_count)
            if self.label_count:
                accuracy_str += " l=%-14s" % percents_str(self.correct_label_count, self.label_count)
            return "%-33s" % accuracy_str
        return ""

    def evaluate(self, mode=ParseMode.test):
        if self.format:
            self.config.print("Converting to %s and evaluating..." % self.format)
        self.eval_type = UNLABELED if self.config.is_unlabeled(self.in_format) else LABELED
        evaluator = EVALUATORS.get(self.format, evaluate_ucca)
        score = evaluator(self.out, self.passage, converter=get_output_converter(self.format),
                          verbose=self.out and self.config.args.verbose > 3,
                          constructions=self.config.args.constructions,
                          eval_types=(self.eval_type,) if mode is ParseMode.dev else (LABELED, UNLABELED))
        self.f1 = average_f1(score, self.eval_type)
        score.lang = self.lang
        return score

    def check_loop(self):
        """
        Check if the current state has already occurred, indicating a loop
        """
        h = hash(self.state)
        assert h not in self.state_hash_history, \
            "\n".join(["Transition loop", self.state.str("\n")] + [self.oracle.str("\n")] if self.oracle else ())
        self.state_hash_history.add(h)

    def verify(self, guessed, ref):
        """
        Compare predicted passage to true passage and raise an exception if they differ
        :param ref: true passage
        :param guessed: predicted passage to compare
        """
        assert ref.equals(guessed, ignore_node=self.ignore_node), \
            "Failed to produce true passage" + (diffutil.diff_passages(ref, guessed) if self.training else "")

    @property
    def num_tokens(self):
        return len(set(self.state.terminals).difference(self.state.buffer))  # To count even incomplete parses

    @num_tokens.setter
    def num_tokens(self, _):
        pass
Пример #5
0
class Parser(object):
    """
    Main class to implement transition-based UCCA parser
    """
    def __init__(self, model_file=None, model_type=None, beam=1):
        self.args = Config().args
        self.state = None  # State object created at each parse
        self.oracle = None  # Oracle object created at each parse
        self.action_count = self.correct_action_count = self.total_actions = self.total_correct_actions = 0
        self.label_count = self.correct_label_count = self.total_labels = self.total_correct_labels = 0
        self.model = Model(model_type, model_file)
        self.update_only_on_error = \
            ClassifierProperty.update_only_on_error in self.model.model.get_classifier_properties()
        self.beam = beam  # Currently unused
        self.state_hash_history = None  # For loop checking
        # Used in verify_passage to optionally ignore a mismatch in linkage nodes:
        self.ignore_node = None if self.args.linkage else lambda n: n.tag == layer1.NodeTags.Linkage
        self.best_score = self.dev = self.iteration = self.eval_index = None
        self.trained = False

    def train(self, passages=None, dev=None, iterations=1):
        """
        Train parser on given passages
        :param passages: iterable of passages to train on
        :param dev: iterable of passages to tune on
        :param iterations: number of iterations to perform
        """
        self.trained = True
        if passages:
            if ClassifierProperty.trainable_after_saving in self.model.model.get_classifier_properties(
            ):
                try:
                    self.model.load()
                except FileNotFoundError:
                    print("not found, starting from untrained model.")
            self.best_score = 0
            self.dev = dev
            for self.iteration in range(1, iterations + 1):
                self.eval_index = 0
                print("Training iteration %d of %d: " %
                      (self.iteration, iterations))
                Config().random.shuffle(passages)
                list(self.parse(passages, mode=ParseMode.train))
                yield self.eval_and_save(self.iteration == iterations,
                                         finished_epoch=True)
            print("Trained %d iterations" % iterations)
        if dev or not passages:
            self.model.load()

    def eval_and_save(self, last=False, finished_epoch=False):
        scores = None
        model = self.model
        self.model = self.model.finalize(finished_epoch=finished_epoch)
        if self.dev:
            if not self.best_score:
                self.model.save()
            print("Evaluating on dev passages")
            passage_scores = [
                s for _, s in self.parse(
                    self.dev, mode=ParseMode.dev, evaluate=True)
            ]
            scores = Scores(passage_scores)
            average_score = scores.average_f1()
            print("Average labeled F1 score on dev: %.3f" % average_score)
            print_scores(scores,
                         self.args.devscores,
                         prefix_title="iteration",
                         prefix=[self.iteration] +
                         ([self.eval_index] if self.args.save_every else []))
            if average_score >= self.best_score:
                print("Better than previous best score (%.3f)" %
                      self.best_score)
                if self.best_score:
                    self.model.save()
                self.best_score = average_score
            else:
                print("Not better than previous best score (%.3f)" %
                      self.best_score)
        elif last or self.args.save_every is not None:
            self.model.save()
        if not last:
            self.model = model  # Restore non-finalized model
            self.model.load_labels()
        return scores

    def parse(self, passages, mode=ParseMode.test, evaluate=False):
        """
        Parse given passages
        :param passages: iterable of passages to parse
        :param mode: ParseMode value.
                     If train, use oracle to train on given passages.
                     Otherwise, just parse with classifier.
        :param evaluate: whether to evaluate parsed passages with respect to given ones.
                         Only possible when given passages are annotated.
        :return: generator of parsed passages (or in train mode, the original ones),
                 or, if evaluate=True, of pairs of (Passage, Scores).
        """
        assert mode in ParseMode, "Invalid parse mode: %s" % mode
        train = (mode is ParseMode.train)
        if not train and not self.trained:
            list(self.train())
        passage_word = "sentence" if self.args.sentences else \
                       "paragraph" if self.args.paragraphs else \
                       "passage"
        self.total_actions = 0
        self.total_correct_actions = 0
        total_duration = 0
        total_tokens = 0
        passage_index = 0
        if not hasattr(passages, "__iter__"):  # Single passage given
            passages = (passages, )
        for passage_index, passage in enumerate(passages):
            labeled = any(n.outgoing or n.attrib.get(LABEL_ATTRIB)
                          for n in passage.layer(layer1.LAYER_ID).all)
            assert not train or labeled, "Cannot train on unannotated passage: %s" % passage.ID
            assert not evaluate or labeled, "Cannot evaluate on unannotated passage: %s" % passage.ID
            print("%s %-7s" % (passage_word, passage.ID),
                  end=Config().line_end,
                  flush=True)
            started = time.time()
            self.action_count = self.correct_action_count = self.label_count = self.correct_label_count = 0
            textutil.annotate(passage, verbose=self.args.verbose >
                              1)  # tag POS and parse dependencies
            self.state = State(passage)
            self.state_hash_history = set()
            self.oracle = Oracle(passage) if train or (
                self.args.verbose or Config().args.use_gold_node_labels
            ) and labeled or self.args.verify else None
            failed = False
            if ClassifierProperty.require_init_features in self.model.model.get_classifier_properties(
            ):
                self.model.init_features(self.state, train)
            try:
                self.parse_passage(
                    train)  # This is where the actual parsing takes place
            except ParserException as e:
                if train:
                    raise
                Config().log("%s %s: %s" % (passage_word, passage.ID, e))
                failed = True
            guessed = self.state.create_passage(
                verify=self.args.verify
            ) if not train or self.args.verify else passage
            duration = time.time() - started
            total_duration += duration
            num_tokens = len(
                set(self.state.terminals).difference(self.state.buffer))
            total_tokens += num_tokens
            if self.oracle:  # We have an oracle to verify by
                if not failed and self.args.verify:
                    self.verify_passage(guessed, passage, train)
                if self.action_count:
                    accuracy_str = "%d%% (%d/%d)" % (
                        100 * self.correct_action_count / self.action_count,
                        self.correct_action_count, self.action_count)
                    if self.label_count:
                        accuracy_str += " %d%% (%d/%d)" % (
                            100 * self.correct_label_count / self.label_count,
                            self.correct_label_count, self.label_count)
                    print("%-30s" % accuracy_str, end=Config().line_end)
            print("%0.3fs" % duration, end="")
            print("%-15s" % (" (failed)" if failed else " (%d tokens/s)" %
                             (num_tokens / duration)),
                  end="")
            print(Config().line_end, end="")
            if self.oracle:
                print(Config().line_end, flush=True)
            self.model.model.finished_item(train)
            self.total_correct_actions += self.correct_action_count
            self.total_actions += self.action_count
            self.total_correct_labels += self.correct_label_count
            self.total_labels += self.label_count
            if train and self.args.save_every and (
                    passage_index + 1) % self.args.save_every == 0:
                self.eval_and_save()
                self.eval_index += 1
            yield (guessed, self.evaluate_passage(
                guessed, passage)) if evaluate else guessed

        if passages:
            print("Parsed %d %ss" % (passage_index + 1, passage_word))
            if self.oracle and self.total_actions:
                accuracy_str = "%d%% correct actions (%d/%d)" % (
                    100 * self.total_correct_actions / self.total_actions,
                    self.total_correct_actions, self.total_actions)
                if self.total_labels:
                    accuracy_str += ", %d%% correct labels (%d/%d)" % (
                        100 * self.total_correct_labels / self.total_labels,
                        self.total_correct_labels, self.total_labels)
                print("Overall %s on %s" % (accuracy_str, mode.name))
            if total_duration:
                print(
                    "Total time: %.3fs (average time/%s: %.3fs, average tokens/s: %d)"
                    % (total_duration, passage_word, total_duration /
                       (passage_index + 1), total_tokens / total_duration),
                    flush=True)

    def parse_passage(self, train):
        """
        Internal method to parse a single passage
        :param train: use oracle to train on given passages, or just parse with classifier?
        """
        if self.args.verbose > 1:
            print("  initial state: %s" % self.state)
        while True:
            if self.args.check_loops:
                self.check_loop()
            features = self.model.feature_extractor.extract_features(
                self.state)
            true_actions = self.get_true_actions(train)
            action, predicted_action = self.choose_action(
                features, train, true_actions)
            try:
                self.state.transition(action)
            except AssertionError as e:
                raise ParserException("Invalid transition: %s %s" %
                                      (action, self.state)) from e
            if self.args.verbose > 1:
                if self.oracle:
                    print("  predicted: %-15s true: %-15s taken: %-15s %s" %
                          (predicted_action, "|".join(
                              map(str,
                                  true_actions.values())), action, self.state))
                else:
                    print("  action: %-15s %s" % (action, self.state))
            if self.state.need_label:  # Label action that requires a choice of label
                true_label = self.get_true_label(action.orig_node)
                label, predicted_label = self.choose_label(
                    features, train, true_label)
                self.state.label_node(label)
                if self.args.verbose > 1:
                    if self.oracle and not Config().args.use_gold_node_labels:
                        print("  predicted label: %-15s true label: %-15s" %
                              (predicted_label, true_label))
                    else:
                        print("  label: %-15s" % label)
            self.model.model.finished_step(train)
            if self.args.verbose > 1:
                for line in self.state.log:
                    print("    " + line)
            if self.state.finished:
                return  # action is Finish (or early update is triggered)

    def get_true_actions(self, train):
        true_actions = {}
        if self.oracle:
            try:
                true_actions = self.oracle.get_actions(self.state,
                                                       self.model.actions,
                                                       create=train)
            except (AttributeError, AssertionError) as e:
                if train:
                    raise ParserException(
                        "Error in oracle during training") from e
        return true_actions

    def choose_action(self, features, train, true_actions):
        scores = self.model.model.score(
            features, axis=ACTION_AXIS)  # Returns NumPy array
        if self.args.verbose > 2:
            print("  action scores: " +
                  ",".join(("%s: %g" % x
                            for x in zip(self.model.actions.all, scores))))
        try:
            predicted_action = self.predict(scores, self.model.actions.all,
                                            self.state.is_valid_action)
        except StopIteration as e:
            raise ParserException(
                "No valid action available\n%s" %
                (self.oracle.log if self.oracle else "")) from e
        action = true_actions.get(predicted_action.id)
        is_correct = (action is not None)
        if is_correct:
            self.correct_action_count += 1
        else:
            action = Config().random.choice(list(
                true_actions.values())) if train else predicted_action
        if train and not (is_correct and self.update_only_on_error):
            best_action = self.predict(scores[list(true_actions.keys())],
                                       list(true_actions.values()))
            self.model.model.update(features,
                                    axis=ACTION_AXIS,
                                    pred=predicted_action.id,
                                    true=best_action.id,
                                    importance=self.args.swap_importance
                                    if best_action.is_swap else 1)
        if train and not is_correct and self.args.early_update:
            self.state.finished = True
        self.action_count += 1
        return action, predicted_action

    def get_true_label(self, node):
        true_label = None
        if self.oracle:
            if node is not None:
                true_label = node.attrib.get(LABEL_ATTRIB)
            if true_label is not None:
                true_label, _, _ = true_label.partition(LABEL_SEPARATOR)
                if not self.state.is_valid_label(true_label):
                    raise ParserException("True label is invalid: %s %s" %
                                          (true_label, self.state))
        return true_label

    def choose_label(self, features, train, true_label):
        true_id = self.model.labels[
            true_label] if self.oracle else None  # Needs to happen before score()
        if Config().args.use_gold_node_labels:
            return true_label, true_label
        scores = self.model.model.score(features, axis=LABEL_AXIS)
        if self.args.verbose > 2:
            print("  label scores: " +
                  ",".join(("%s: %g" % x
                            for x in zip(self.model.labels.all, scores))))
        label = predicted_label = self.predict(scores, self.model.labels.all,
                                               self.state.is_valid_label)
        if self.oracle:
            is_correct = (label == true_label)
            if is_correct:
                self.correct_label_count += 1
            if train and not (is_correct and self.update_only_on_error):
                self.model.model.update(features,
                                        axis=LABEL_AXIS,
                                        pred=self.model.labels[label],
                                        true=true_id)
                label = true_label
        self.label_count += 1
        return label, predicted_label

    def predict(self, scores, values, is_valid=None):
        """
        Choose action/label based on classifier
        Usually the best action/label is valid, so max is enough to choose it in O(n) time
        Otherwise, sorts all the other scores to choose the best valid one in O(n lg n)
        :return: valid action/label with maximum probability according to classifier
        """
        return next(
            filter(is_valid,
                   (values[i] for i in self.generate_descending(scores))))

    @staticmethod
    def generate_descending(scores):
        yield scores.argmax()
        yield from scores.argsort(
        )[::
          -1]  # Contains the max, but otherwise items might be missed (different order)

    def check_loop(self):
        """
        Check if the current state has already occurred, indicating a loop
        """
        h = hash(self.state)
        assert h not in self.state_hash_history,\
            "\n".join(["Transition loop", self.state.str("\n")] + [self.oracle.str("\n")] if self.oracle else ())
        self.state_hash_history.add(h)

    def evaluate_passage(self, guessed, ref):
        converters = CONVERTERS.get(ref.extra.get(
            "format"))  # returns (input converter, output converter) tuple
        score = EVALUATORS.get(ref.extra.get("format"), evaluation).evaluate(
            guessed,
            ref,
            converter=converters
            and converters[1],  # converter output is list of lines
            verbose=guessed and self.args.verbose > 2,
            constructions=self.args.constructions)
        print("F1=%.3f" % score.average_f1(), flush=True)
        return score

    def verify_passage(self, guessed, ref, show_diff):
        """
        Compare predicted passage to true passage and raise an exception if they differ
        :param ref: true passage
        :param guessed: predicted passage to compare
        :param show_diff: if passages differ, show the difference between them?
                          Depends on guessed having the original node IDs annotated in the "remarks" field for each node
        """
        assert ref.equals(guessed, ignore_node=self.ignore_node), \
            "Failed to produce true passage" + (diffutil.diff_passages(ref, guessed) if show_diff else "")