def _test_features(config, feature_extractor_creator, filename, write_features): feature_extractor = feature_extractor_creator(config) passage = load_passage(filename, annotate=feature_extractor_creator.annotated) textutil.annotate(passage, as_array=True, as_extra=False, vocab=config.vocab()) config.set_format(passage.extra.get("format") or "ucca") oracle = Oracle(passage) state = State(passage) actions = Actions() for key, param in feature_extractor.params.items(): if not param.numeric: param.dropout = 0 feature_extractor.init_param(key) features = [feature_extractor.init_features(state)] while True: extract_features(feature_extractor, state, features) action = min(oracle.get_actions(state, actions).values(), key=str) state.transition(action) if state.need_label: extract_features(feature_extractor, state, features) label, _ = oracle.get_label(state, action) state.label_node(label) if state.finished: break features = ["%s %s\n" % i for f in features if f for i in (sorted(f.items()) + [("", "")])] compare_file = os.path.join("test_files", "features", "-".join((basename(filename), str(feature_extractor_creator))) + ".txt") if write_features: with open(compare_file, "w", encoding="utf-8") as f: f.writelines(features) with open(compare_file, encoding="utf-8") as f: assert f.readlines() == features, compare_file
def gen_actions(passage): oracle = Oracle(passage) state = State(passage) actions = Actions() while True: action = min(oracle.get_actions(state, actions).values(), key=str) state.transition(action) s = str(action) if state.need_label: label, _ = oracle.get_label(state, action) state.label_node(label) s += " " + str(label) yield s if state.finished: break
class PassageParser(AbstractParser): """ Parser for a single passage, has a state and optionally an oracle """ def __init__(self, passage, *args, **kwargs): super().__init__(*args, **kwargs) self.passage = self.out = passage self.format = self.passage.extra.get("format") if self.training or self.evaluation else \ sorted(set.intersection(*map(set, filter(None, (self.model.formats, self.config.args.formats)))) or self.model.formats)[0] if self.training and self.config.args.verify: errors = list(validate(self.passage)) assert not errors, errors self.in_format = self.format or "ucca" self.out_format = "ucca" if self.format in (None, "text") else self.format if self.config.args.use_bert and self.config.args.bert_multilingual is not None: self.lang = self.passage.attrib.get("lang") assert self.lang, "Attribute 'lang' is required per passage when using multilingual BERT" else: self.lang = self.passage.attrib.get("lang", self.config.args.lang) # Used in verify_passage to optionally ignore a mismatch in linkage nodes: self.ignore_node = None if self.config.args.linkage else lambda n: n.tag == layer1.NodeTags.Linkage self.state_hash_history = set() self.state = self.oracle = self.eval_type = None def init(self): self.config.set_format(self.in_format) WIKIFIER.enabled = self.config.args.wikification self.state = State(self.passage) # Passage is considered labeled if there are any edges or node labels in it edges, node_labels = map(any, zip(*[(n.outgoing, n.attrib.get(LABEL_ATTRIB)) for n in self.passage.layer(layer1.LAYER_ID).all])) self.oracle = Oracle(self.passage) if self.training or self.config.args.verify or ( (self.config.args.verbose > 1 or self.config.args.use_gold_node_labels or self.config.args.action_stats) and (edges or node_labels)) else None for model in self.models: model.init_model(self.config.format, lang=self.lang if self.config.args.multilingual else None) if ClassifierProperty.require_init_features in model.classifier_properties: model.init_features(self.state, self.training) def parse(self, display=True, write=False, accuracies=None): self.init() passage_id = self.passage.ID try: with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: executor.submit(self.parse_internal).result(self.config.args.timeout) status = "(%d tokens/s)" % self.tokens_per_second() except ParserException as e: if self.training: raise self.config.log("%s %s: %s" % (self.config.passage_word, passage_id, e)) status = "(failed)" except concurrent.futures.TimeoutError: self.config.log("%s %s: timeout (%fs)" % (self.config.passage_word, passage_id, self.config.args.timeout)) status = "(timeout)" return self.finish(status, display=display, write=write, accuracies=accuracies) def parse_internal(self): """ Internal method to parse a single passage. If training, use oracle to train on given passages. Otherwise just parse with classifier. """ self.config.print(" initial state: %s" % self.state) while True: if self.config.args.check_loops: self.check_loop() self.label_node() # In case root node needs labeling true_actions = self.get_true_actions() action, predicted_action = self.choose(true_actions) self.state.transition(action) need_label, label, predicted_label, true_label = self.label_node(action) if self.config.args.action_stats: try: with open(self.config.args.action_stats, "a") as f: print(",".join(map(str, [predicted_action, action] + list(true_actions.values()))), file=f) except OSError: pass self.config.print(lambda: "\n".join([" predicted: %-15s true: %-15s taken: %-15s %s" % ( predicted_action, "|".join(map(str, true_actions.values())), action, self.state) if self.oracle else " action: %-15s %s" % (action, self.state)] + ( [" predicted label: %-9s true label: %s" % (predicted_label, true_label) if self.oracle and not self.config.args.use_gold_node_labels else " label: %s" % label] if need_label else []) + [ " " + l for l in self.state.log])) if self.state.finished: return # action is Finish (or early update is triggered) def get_true_actions(self): true_actions = {} if self.oracle: try: true_actions = self.oracle.get_actions(self.state, self.model.actions, create=self.training) except (AttributeError, AssertionError) as e: if self.training: raise ParserException("Error in getting action from oracle during training") from e return true_actions def get_true_label(self, node): try: return self.oracle.get_label(self.state, node) if self.oracle else (None, None) except AssertionError as e: if self.training: raise ParserException("Error in getting label from oracle during training") from e return None, None def label_node(self, action=None): true_label = label = predicted_label = None need_label = self.state.need_label # Label action that requires a choice of label if need_label: true_label, raw_true_label = self.get_true_label(action or need_label) label, predicted_label = self.choose(true_label, NODE_LABEL_KEY, "node label") self.state.label_node(raw_true_label if label == true_label else label) return need_label, label, predicted_label, true_label def choose(self, true, axis=None, name="action"): if axis is None: axis = self.model.axis elif axis == NODE_LABEL_KEY and self.config.args.use_gold_node_labels: return true, true labels = self.model.classifier.labels[axis] if axis == NODE_LABEL_KEY: true_keys = (labels[true],) if self.oracle else () # Must be before score() is_valid = self.state.is_valid_label else: true_keys = None is_valid = self.state.is_valid_action scores, features = self.model.score(self.state, axis) for model in self.models[1:]: # Ensemble if given more than one model; align label order and add scores label_scores = dict(zip(model.classifier.labels[axis].all, self.model.score(self.state, axis)[0])) scores += [label_scores.get(a, 0) for a in labels.all] # Product of Experts, assuming log(softmax) self.config.print(lambda: " %s scores: %s" % (name, tuple(zip(labels.all, scores))), level=4) try: label = pred = self.predict(scores, labels.all, is_valid) except StopIteration as e: raise ParserException("No valid %s available\n%s" % (name, self.oracle.log if self.oracle else "")) from e label, is_correct, true_keys, true_values = self.correct(axis, label, pred, scores, true, true_keys) if self.training: if not (is_correct and ClassifierProperty.update_only_on_error in self.model.classifier_properties): assert not self.model.is_finalized, "Updating finalized model" self.model.classifier.update( features, axis=axis, true=true_keys, pred=labels[pred] if axis == NODE_LABEL_KEY else pred.id, importance=[self.config.args.swap_importance if a.is_swap else 1 for a in true_values] or None) if not is_correct and self.config.args.early_update: self.state.finished = True for model in self.models: model.classifier.finished_step(self.training) if axis != NODE_LABEL_KEY: model.classifier.transition(label, axis=axis) return label, pred def correct(self, axis, label, pred, scores, true, true_keys): true_values = is_correct = () if axis == NODE_LABEL_KEY: if self.oracle: is_correct = (label == true) if is_correct: self.correct_label_count += 1 elif self.training: label = true self.label_count += 1 else: # action true_keys, true_values = map(list, zip(*true.items())) if true else (None, None) label = true.get(pred.id) is_correct = (label is not None) if is_correct: self.correct_action_count += 1 else: label = true_values[scores[true_keys].argmax()] if self.training else pred self.action_count += 1 return label, is_correct, true_keys, true_values @staticmethod def predict(scores, values, is_valid=None): """ Choose action/label based on classifier Usually the best action/label is valid, so max is enough to choose it in O(n) time Otherwise, sorts all the other scores to choose the best valid one in O(n lg n) :return: valid action/label with maximum probability according to classifier """ return next(filter(is_valid, (values[i] for i in PassageParser.generate_descending(scores)))) @staticmethod def generate_descending(scores): yield scores.argmax() yield from scores.argsort()[::-1] # Contains the max, but otherwise items might be missed (different order) def finish(self, status, display=True, write=False, accuracies=None): self.model.classifier.finished_item(self.training) for model in self.models[1:]: model.classifier.finished_item(renew=False) # So that dynet.renew_cg happens only once if not self.training or self.config.args.verify: self.out = self.state.create_passage(verify=self.config.args.verify, format=self.out_format) if write: for out_format in self.config.args.formats or [self.out_format]: if self.config.args.normalize and out_format == "ucca": normalize(self.out) ioutil.write_passage(self.out, output_format=out_format, binary=out_format == "pickle", outdir=self.config.args.outdir, prefix=self.config.args.prefix, converter=get_output_converter(out_format), verbose=self.config.args.verbose, append=self.config.args.join, basename=self.config.args.join) if self.oracle and self.config.args.verify: self.verify(self.out, self.passage) ret = (self.out,) if self.evaluation: ret += (self.evaluate(self.evaluation),) status = "%-14s %s F1=%.3f" % (status, self.eval_type, self.f1) if display: self.config.print("%s%.3fs %s" % (self.accuracy_str, self.duration, status), level=1) if accuracies is not None: accuracies[self.passage.ID] = self.correct_action_count / self.action_count if self.action_count else 0 return ret @property def accuracy_str(self): if self.oracle and self.action_count: accuracy_str = "a=%-14s" % percents_str(self.correct_action_count, self.action_count) if self.label_count: accuracy_str += " l=%-14s" % percents_str(self.correct_label_count, self.label_count) return "%-33s" % accuracy_str return "" def evaluate(self, mode=ParseMode.test): if self.format: self.config.print("Converting to %s and evaluating..." % self.format) self.eval_type = UNLABELED if self.config.is_unlabeled(self.in_format) else LABELED evaluator = EVALUATORS.get(self.format, evaluate_ucca) score = evaluator(self.out, self.passage, converter=get_output_converter(self.format), verbose=self.out and self.config.args.verbose > 3, constructions=self.config.args.constructions, eval_types=(self.eval_type,) if mode is ParseMode.dev else (LABELED, UNLABELED)) self.f1 = average_f1(score, self.eval_type) score.lang = self.lang return score def check_loop(self): """ Check if the current state has already occurred, indicating a loop """ h = hash(self.state) assert h not in self.state_hash_history, \ "\n".join(["Transition loop", self.state.str("\n")] + [self.oracle.str("\n")] if self.oracle else ()) self.state_hash_history.add(h) def verify(self, guessed, ref): """ Compare predicted passage to true passage and raise an exception if they differ :param ref: true passage :param guessed: predicted passage to compare """ assert ref.equals(guessed, ignore_node=self.ignore_node), \ "Failed to produce true passage" + (diffutil.diff_passages(ref, guessed) if self.training else "") @property def num_tokens(self): return len(set(self.state.terminals).difference(self.state.buffer)) # To count even incomplete parses @num_tokens.setter def num_tokens(self, _): pass
class PassageParser(AbstractParser): """ Parser for a single passage, has a state and optionally an oracle """ def __init__(self, passage, *args, **kwargs): super().__init__(*args, **kwargs) self.passage = self.out = passage self.format = self.passage.extra.get("format") if self.training or self.evaluation else \ sorted(set.intersection(*map(set, filter(None, (self.model.formats, self.config.args.formats)))) or self.model.formats)[0] if self.training and self.config.args.verify: errors = list(validate(self.passage)) assert not errors, errors self.in_format = self.format or "ucca" self.out_format = "ucca" if self.format in (None, "text") else self.format self.lang = self.passage.attrib.get("lang", self.config.args.lang) # Used in verify_passage to optionally ignore a mismatch in linkage nodes: self.ignore_node = None if self.config.args.linkage else lambda n: n.tag == layer1.NodeTags.Linkage self.state_hash_history = set() self.state = self.oracle = self.eval_type = None def init(self): self.config.set_format(self.in_format) WIKIFIER.enabled = self.config.args.wikification self.state = State(self.passage) # Passage is considered labeled if there are any edges or node labels in it edges, node_labels = map(any, zip(*[(n.outgoing, n.attrib.get(LABEL_ATTRIB)) for n in self.passage.layer(layer1.LAYER_ID).all])) self.oracle = Oracle(self.passage) if self.training or self.config.args.verify or ( (self.config.args.verbose > 1 or self.config.args.use_gold_node_labels or self.config.args.action_stats) and (edges or node_labels)) else None for model in self.models: model.init_model(self.config.format, lang=self.lang if self.config.args.multilingual else None) if ClassifierProperty.require_init_features in model.classifier_properties: model.init_features(self.state, self.training) def parse(self, display=True, write=False, accuracies=None): self.init() passage_id = self.passage.ID try: with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: executor.submit(self.parse_internal).result(self.config.args.timeout) status = "(%d tokens/s)" % self.tokens_per_second() except ParserException as e: if self.training: raise self.config.log("%s %s: %s" % (self.config.passage_word, passage_id, e)) status = "(failed)" except concurrent.futures.TimeoutError: self.config.log("%s %s: timeout (%fs)" % (self.config.passage_word, passage_id, self.config.args.timeout)) status = "(timeout)" return self.finish(status, display=display, write=write, accuracies=accuracies) def parse_internal(self): """ Internal method to parse a single passage. If training, use oracle to train on given passages. Otherwise just parse with classifier. """ self.config.print(" initial state: %s" % self.state) while True: if self.config.args.check_loops: self.check_loop() self.label_node() # In case root node needs labeling true_actions = self.get_true_actions() action, predicted_action = self.choose(true_actions) self.state.transition(action) need_label, label, predicted_label, true_label = self.label_node(action) if self.config.args.action_stats: try: with open(self.config.args.action_stats, "a") as f: print(",".join(map(str, [predicted_action, action] + list(true_actions.values()))), file=f) except OSError: pass self.config.print(lambda: "\n".join([" predicted: %-15s true: %-15s taken: %-15s %s" % ( predicted_action, "|".join(map(str, true_actions.values())), action, self.state) if self.oracle else " action: %-15s %s" % (action, self.state)] + ( [" predicted label: %-9s true label: %s" % (predicted_label, true_label) if self.oracle and not self.config.args.use_gold_node_labels else " label: %s" % label] if need_label else []) + [ " " + l for l in self.state.log])) if self.state.finished: return # action is Finish (or early update is triggered) def get_true_actions(self): true_actions = {} if self.oracle: try: true_actions = self.oracle.get_actions(self.state, self.model.actions, create=self.training) except (AttributeError, AssertionError) as e: if self.training: raise ParserException("Error in getting action from oracle during training") from e return true_actions def get_true_label(self, node): try: return self.oracle.get_label(self.state, node) if self.oracle else (None, None) except AssertionError as e: if self.training: raise ParserException("Error in getting label from oracle during training") from e return None, None def label_node(self, action=None): true_label = label = predicted_label = None need_label = self.state.need_label # Label action that requires a choice of label if need_label: true_label, raw_true_label = self.get_true_label(action or need_label) label, predicted_label = self.choose(true_label, NODE_LABEL_KEY, "node label") self.state.label_node(raw_true_label if label == true_label else label) return need_label, label, predicted_label, true_label def choose(self, true, axis=None, name="action"): if axis is None: axis = self.model.axis elif axis == NODE_LABEL_KEY and self.config.args.use_gold_node_labels: return true, true labels = self.model.classifier.labels[axis] if axis == NODE_LABEL_KEY: true_keys = (labels[true],) if self.oracle else () # Must be before score() is_valid = self.state.is_valid_label else: true_keys = None is_valid = self.state.is_valid_action scores, features = self.model.score(self.state, axis) for model in self.models[1:]: # Ensemble if given more than one model; align label order and add scores label_scores = dict(zip(model.classifier.labels[axis].all, self.model.score(self.state, axis)[0])) scores += [label_scores.get(a, 0) for a in labels.all] # Product of Experts, assuming log(softmax) self.config.print(lambda: " %s scores: %s" % (name, tuple(zip(labels.all, scores))), level=4) try: label = pred = self.predict(scores, labels.all, is_valid) except StopIteration as e: raise ParserException("No valid %s available\n%s" % (name, self.oracle.log if self.oracle else "")) from e label, is_correct, true_keys, true_values = self.correct(axis, label, pred, scores, true, true_keys) if self.training: if not (is_correct and ClassifierProperty.update_only_on_error in self.model.classifier_properties): assert not self.model.is_finalized, "Updating finalized model" self.model.classifier.update( features, axis=axis, true=true_keys, pred=labels[pred] if axis == NODE_LABEL_KEY else pred.id, importance=[self.config.args.swap_importance if a.is_swap else 1 for a in true_values] or None) if not is_correct and self.config.args.early_update: self.state.finished = True for model in self.models: model.classifier.finished_step(self.training) if axis != NODE_LABEL_KEY: model.classifier.transition(label, axis=axis) return label, pred def correct(self, axis, label, pred, scores, true, true_keys): true_values = is_correct = () if axis == NODE_LABEL_KEY: if self.oracle: is_correct = (label == true) if is_correct: self.correct_label_count += 1 elif self.training: label = true self.label_count += 1 else: # action true_keys, true_values = map(list, zip(*true.items())) if true else (None, None) label = true.get(pred.id) is_correct = (label is not None) if is_correct: self.correct_action_count += 1 else: label = true_values[scores[true_keys].argmax()] if self.training else pred self.action_count += 1 return label, is_correct, true_keys, true_values @staticmethod def predict(scores, values, is_valid=None): """ Choose action/label based on classifier Usually the best action/label is valid, so max is enough to choose it in O(n) time Otherwise, sorts all the other scores to choose the best valid one in O(n lg n) :return: valid action/label with maximum probability according to classifier """ return next(filter(is_valid, (values[i] for i in PassageParser.generate_descending(scores)))) @staticmethod def generate_descending(scores): yield scores.argmax() yield from scores.argsort()[::-1] # Contains the max, but otherwise items might be missed (different order) def finish(self, status, display=True, write=False, accuracies=None): self.model.classifier.finished_item(self.training) for model in self.models[1:]: model.classifier.finished_item(renew=False) # So that dynet.renew_cg happens only once if not self.training or self.config.args.verify: self.out = self.state.create_passage(verify=self.config.args.verify, format=self.out_format) if write: for out_format in self.config.args.formats or [self.out_format]: if self.config.args.normalize and out_format == "ucca": normalize(self.out) ioutil.write_passage(self.out, output_format=out_format, binary=out_format == "pickle", outdir=self.config.args.outdir, prefix=self.config.args.prefix, converter=get_output_converter(out_format), verbose=self.config.args.verbose, append=self.config.args.join, basename=self.config.args.join) if self.oracle and self.config.args.verify: self.verify(self.out, self.passage) ret = (self.out,) if self.evaluation: ret += (self.evaluate(self.evaluation),) status = "%-14s %s F1=%.3f" % (status, self.eval_type, self.f1) if display: self.config.print("%s%.3fs %s" % (self.accuracy_str, self.duration, status), level=1) if accuracies is not None: accuracies[self.passage.ID] = self.correct_action_count / self.action_count if self.action_count else 0 return ret @property def accuracy_str(self): if self.oracle and self.action_count: accuracy_str = "a=%-14s" % percents_str(self.correct_action_count, self.action_count) if self.label_count: accuracy_str += " l=%-14s" % percents_str(self.correct_label_count, self.label_count) return "%-33s" % accuracy_str return "" def evaluate(self, mode=ParseMode.test): if self.format: self.config.print("Converting to %s and evaluating..." % self.format) self.eval_type = UNLABELED if self.config.is_unlabeled(self.in_format) else LABELED evaluator = EVALUATORS.get(self.format, evaluate_ucca) score = evaluator(self.out, self.passage, converter=get_output_converter(self.format), verbose=self.out and self.config.args.verbose > 3, constructions=self.config.args.constructions, eval_types=(self.eval_type,) if mode is ParseMode.dev else (LABELED, UNLABELED)) self.f1 = average_f1(score, self.eval_type) score.lang = self.lang return score def check_loop(self): """ Check if the current state has already occurred, indicating a loop """ h = hash(self.state) assert h not in self.state_hash_history, \ "\n".join(["Transition loop", self.state.str("\n")] + [self.oracle.str("\n")] if self.oracle else ()) self.state_hash_history.add(h) def verify(self, guessed, ref): """ Compare predicted passage to true passage and raise an exception if they differ :param ref: true passage :param guessed: predicted passage to compare """ assert ref.equals(guessed, ignore_node=self.ignore_node), \ "Failed to produce true passage" + (diffutil.diff_passages(ref, guessed) if self.training else "") @property def num_tokens(self): return len(set(self.state.terminals).difference(self.state.buffer)) # To count even incomplete parses @num_tokens.setter def num_tokens(self, _): pass
class Parser(object): """ Main class to implement transition-based UCCA parser """ def __init__(self, model_file=None, model_type=None, beam=1): self.args = Config().args self.state = None # State object created at each parse self.oracle = None # Oracle object created at each parse self.action_count = self.correct_action_count = self.total_actions = self.total_correct_actions = 0 self.label_count = self.correct_label_count = self.total_labels = self.total_correct_labels = 0 self.model = Model(model_type, model_file) self.update_only_on_error = \ ClassifierProperty.update_only_on_error in self.model.model.get_classifier_properties() self.beam = beam # Currently unused self.state_hash_history = None # For loop checking # Used in verify_passage to optionally ignore a mismatch in linkage nodes: self.ignore_node = None if self.args.linkage else lambda n: n.tag == layer1.NodeTags.Linkage self.best_score = self.dev = self.iteration = self.eval_index = None self.trained = False def train(self, passages=None, dev=None, iterations=1): """ Train parser on given passages :param passages: iterable of passages to train on :param dev: iterable of passages to tune on :param iterations: number of iterations to perform """ self.trained = True if passages: if ClassifierProperty.trainable_after_saving in self.model.model.get_classifier_properties( ): try: self.model.load() except FileNotFoundError: print("not found, starting from untrained model.") self.best_score = 0 self.dev = dev for self.iteration in range(1, iterations + 1): self.eval_index = 0 print("Training iteration %d of %d: " % (self.iteration, iterations)) Config().random.shuffle(passages) list(self.parse(passages, mode=ParseMode.train)) yield self.eval_and_save(self.iteration == iterations, finished_epoch=True) print("Trained %d iterations" % iterations) if dev or not passages: self.model.load() def eval_and_save(self, last=False, finished_epoch=False): scores = None model = self.model self.model = self.model.finalize(finished_epoch=finished_epoch) if self.dev: if not self.best_score: self.model.save() print("Evaluating on dev passages") passage_scores = [ s for _, s in self.parse( self.dev, mode=ParseMode.dev, evaluate=True) ] scores = Scores(passage_scores) average_score = scores.average_f1() print("Average labeled F1 score on dev: %.3f" % average_score) print_scores(scores, self.args.devscores, prefix_title="iteration", prefix=[self.iteration] + ([self.eval_index] if self.args.save_every else [])) if average_score >= self.best_score: print("Better than previous best score (%.3f)" % self.best_score) if self.best_score: self.model.save() self.best_score = average_score else: print("Not better than previous best score (%.3f)" % self.best_score) elif last or self.args.save_every is not None: self.model.save() if not last: self.model = model # Restore non-finalized model self.model.load_labels() return scores def parse(self, passages, mode=ParseMode.test, evaluate=False): """ Parse given passages :param passages: iterable of passages to parse :param mode: ParseMode value. If train, use oracle to train on given passages. Otherwise, just parse with classifier. :param evaluate: whether to evaluate parsed passages with respect to given ones. Only possible when given passages are annotated. :return: generator of parsed passages (or in train mode, the original ones), or, if evaluate=True, of pairs of (Passage, Scores). """ assert mode in ParseMode, "Invalid parse mode: %s" % mode train = (mode is ParseMode.train) if not train and not self.trained: list(self.train()) passage_word = "sentence" if self.args.sentences else \ "paragraph" if self.args.paragraphs else \ "passage" self.total_actions = 0 self.total_correct_actions = 0 total_duration = 0 total_tokens = 0 passage_index = 0 if not hasattr(passages, "__iter__"): # Single passage given passages = (passages, ) for passage_index, passage in enumerate(passages): labeled = any(n.outgoing or n.attrib.get(LABEL_ATTRIB) for n in passage.layer(layer1.LAYER_ID).all) assert not train or labeled, "Cannot train on unannotated passage: %s" % passage.ID assert not evaluate or labeled, "Cannot evaluate on unannotated passage: %s" % passage.ID print("%s %-7s" % (passage_word, passage.ID), end=Config().line_end, flush=True) started = time.time() self.action_count = self.correct_action_count = self.label_count = self.correct_label_count = 0 textutil.annotate(passage, verbose=self.args.verbose > 1) # tag POS and parse dependencies self.state = State(passage) self.state_hash_history = set() self.oracle = Oracle(passage) if train or ( self.args.verbose or Config().args.use_gold_node_labels ) and labeled or self.args.verify else None failed = False if ClassifierProperty.require_init_features in self.model.model.get_classifier_properties( ): self.model.init_features(self.state, train) try: self.parse_passage( train) # This is where the actual parsing takes place except ParserException as e: if train: raise Config().log("%s %s: %s" % (passage_word, passage.ID, e)) failed = True guessed = self.state.create_passage( verify=self.args.verify ) if not train or self.args.verify else passage duration = time.time() - started total_duration += duration num_tokens = len( set(self.state.terminals).difference(self.state.buffer)) total_tokens += num_tokens if self.oracle: # We have an oracle to verify by if not failed and self.args.verify: self.verify_passage(guessed, passage, train) if self.action_count: accuracy_str = "%d%% (%d/%d)" % ( 100 * self.correct_action_count / self.action_count, self.correct_action_count, self.action_count) if self.label_count: accuracy_str += " %d%% (%d/%d)" % ( 100 * self.correct_label_count / self.label_count, self.correct_label_count, self.label_count) print("%-30s" % accuracy_str, end=Config().line_end) print("%0.3fs" % duration, end="") print("%-15s" % (" (failed)" if failed else " (%d tokens/s)" % (num_tokens / duration)), end="") print(Config().line_end, end="") if self.oracle: print(Config().line_end, flush=True) self.model.model.finished_item(train) self.total_correct_actions += self.correct_action_count self.total_actions += self.action_count self.total_correct_labels += self.correct_label_count self.total_labels += self.label_count if train and self.args.save_every and ( passage_index + 1) % self.args.save_every == 0: self.eval_and_save() self.eval_index += 1 yield (guessed, self.evaluate_passage( guessed, passage)) if evaluate else guessed if passages: print("Parsed %d %ss" % (passage_index + 1, passage_word)) if self.oracle and self.total_actions: accuracy_str = "%d%% correct actions (%d/%d)" % ( 100 * self.total_correct_actions / self.total_actions, self.total_correct_actions, self.total_actions) if self.total_labels: accuracy_str += ", %d%% correct labels (%d/%d)" % ( 100 * self.total_correct_labels / self.total_labels, self.total_correct_labels, self.total_labels) print("Overall %s on %s" % (accuracy_str, mode.name)) if total_duration: print( "Total time: %.3fs (average time/%s: %.3fs, average tokens/s: %d)" % (total_duration, passage_word, total_duration / (passage_index + 1), total_tokens / total_duration), flush=True) def parse_passage(self, train): """ Internal method to parse a single passage :param train: use oracle to train on given passages, or just parse with classifier? """ if self.args.verbose > 1: print(" initial state: %s" % self.state) while True: if self.args.check_loops: self.check_loop() features = self.model.feature_extractor.extract_features( self.state) true_actions = self.get_true_actions(train) action, predicted_action = self.choose_action( features, train, true_actions) try: self.state.transition(action) except AssertionError as e: raise ParserException("Invalid transition: %s %s" % (action, self.state)) from e if self.args.verbose > 1: if self.oracle: print(" predicted: %-15s true: %-15s taken: %-15s %s" % (predicted_action, "|".join( map(str, true_actions.values())), action, self.state)) else: print(" action: %-15s %s" % (action, self.state)) if self.state.need_label: # Label action that requires a choice of label true_label = self.get_true_label(action.orig_node) label, predicted_label = self.choose_label( features, train, true_label) self.state.label_node(label) if self.args.verbose > 1: if self.oracle and not Config().args.use_gold_node_labels: print(" predicted label: %-15s true label: %-15s" % (predicted_label, true_label)) else: print(" label: %-15s" % label) self.model.model.finished_step(train) if self.args.verbose > 1: for line in self.state.log: print(" " + line) if self.state.finished: return # action is Finish (or early update is triggered) def get_true_actions(self, train): true_actions = {} if self.oracle: try: true_actions = self.oracle.get_actions(self.state, self.model.actions, create=train) except (AttributeError, AssertionError) as e: if train: raise ParserException( "Error in oracle during training") from e return true_actions def choose_action(self, features, train, true_actions): scores = self.model.model.score( features, axis=ACTION_AXIS) # Returns NumPy array if self.args.verbose > 2: print(" action scores: " + ",".join(("%s: %g" % x for x in zip(self.model.actions.all, scores)))) try: predicted_action = self.predict(scores, self.model.actions.all, self.state.is_valid_action) except StopIteration as e: raise ParserException( "No valid action available\n%s" % (self.oracle.log if self.oracle else "")) from e action = true_actions.get(predicted_action.id) is_correct = (action is not None) if is_correct: self.correct_action_count += 1 else: action = Config().random.choice(list( true_actions.values())) if train else predicted_action if train and not (is_correct and self.update_only_on_error): best_action = self.predict(scores[list(true_actions.keys())], list(true_actions.values())) self.model.model.update(features, axis=ACTION_AXIS, pred=predicted_action.id, true=best_action.id, importance=self.args.swap_importance if best_action.is_swap else 1) if train and not is_correct and self.args.early_update: self.state.finished = True self.action_count += 1 return action, predicted_action def get_true_label(self, node): true_label = None if self.oracle: if node is not None: true_label = node.attrib.get(LABEL_ATTRIB) if true_label is not None: true_label, _, _ = true_label.partition(LABEL_SEPARATOR) if not self.state.is_valid_label(true_label): raise ParserException("True label is invalid: %s %s" % (true_label, self.state)) return true_label def choose_label(self, features, train, true_label): true_id = self.model.labels[ true_label] if self.oracle else None # Needs to happen before score() if Config().args.use_gold_node_labels: return true_label, true_label scores = self.model.model.score(features, axis=LABEL_AXIS) if self.args.verbose > 2: print(" label scores: " + ",".join(("%s: %g" % x for x in zip(self.model.labels.all, scores)))) label = predicted_label = self.predict(scores, self.model.labels.all, self.state.is_valid_label) if self.oracle: is_correct = (label == true_label) if is_correct: self.correct_label_count += 1 if train and not (is_correct and self.update_only_on_error): self.model.model.update(features, axis=LABEL_AXIS, pred=self.model.labels[label], true=true_id) label = true_label self.label_count += 1 return label, predicted_label def predict(self, scores, values, is_valid=None): """ Choose action/label based on classifier Usually the best action/label is valid, so max is enough to choose it in O(n) time Otherwise, sorts all the other scores to choose the best valid one in O(n lg n) :return: valid action/label with maximum probability according to classifier """ return next( filter(is_valid, (values[i] for i in self.generate_descending(scores)))) @staticmethod def generate_descending(scores): yield scores.argmax() yield from scores.argsort( )[:: -1] # Contains the max, but otherwise items might be missed (different order) def check_loop(self): """ Check if the current state has already occurred, indicating a loop """ h = hash(self.state) assert h not in self.state_hash_history,\ "\n".join(["Transition loop", self.state.str("\n")] + [self.oracle.str("\n")] if self.oracle else ()) self.state_hash_history.add(h) def evaluate_passage(self, guessed, ref): converters = CONVERTERS.get(ref.extra.get( "format")) # returns (input converter, output converter) tuple score = EVALUATORS.get(ref.extra.get("format"), evaluation).evaluate( guessed, ref, converter=converters and converters[1], # converter output is list of lines verbose=guessed and self.args.verbose > 2, constructions=self.args.constructions) print("F1=%.3f" % score.average_f1(), flush=True) return score def verify_passage(self, guessed, ref, show_diff): """ Compare predicted passage to true passage and raise an exception if they differ :param ref: true passage :param guessed: predicted passage to compare :param show_diff: if passages differ, show the difference between them? Depends on guessed having the original node IDs annotated in the "remarks" field for each node """ assert ref.equals(guessed, ignore_node=self.ignore_node), \ "Failed to produce true passage" + (diffutil.diff_passages(ref, guessed) if show_diff else "")