def _get_results(self): self._n_correct, self._n_predicted, self._n_gold = 0, 0, 0 for labels, preds in zip(self._labels, self._preds): sent_spans = set( tagging_utils.get_span_labels(labels, self._inv_label_mapping)) span_preds = set( tagging_utils.get_span_labels(preds, self._inv_label_mapping)) self._n_correct += len(sent_spans & span_preds) self._n_gold += len(sent_spans) self._n_predicted += len(span_preds) return super(EntityLevelF1Scorer, self)._get_results()
def _get_improved_span_labels(self, generate_labels, generate_preds): eid_to_idx_dict = {eid: idx for idx, eid in enumerate(self._eids)} all_results = [] for example_index, example in enumerate(self._eval_examples): result_spans = [] features = self._task.featurize(example, False, for_eval=True) for (feature_index, feature) in enumerate(features): idx = eid_to_idx_dict[feature[self._name + "_eid"]] if generate_labels: sentence_tags = self._labels[idx] if generate_preds: sentence_tags = self._preds[idx] labeled_positions = self._labeled_positions[idx] for (s, e, l) in tagging_utils.get_span_labels( sentence_tags, self._inv_label_mapping): start_index = labeled_positions[s] end_index = labeled_positions[e] s = s + feature[self._name + "_doc_span_orig_start"] e = e + feature[self._name + "_doc_span_orig_start"] start_index = start_index - feature[self._name + "_doc_span_start"] + 1 end_index = end_index - feature[self._name + "_doc_span_start"] + 1 if start_index not in feature[self._name + "_token_to_orig_map"]: utils.log(example.orig_id, generate_labels, "".join(example.words[s:e + 1]), l, "error", "4") continue if end_index not in feature[self._name + "_token_to_orig_map"]: utils.log(example.orig_id, generate_labels, "".join(example.words[s:e + 1]), l, "error", "5") continue if not feature[self._name + "_token_is_max_context"].get( start_index, False): utils.log(example.orig_id, generate_labels, "".join(example.words[s:e + 1]), l, "error", "6") continue if end_index < start_index: utils.log(example.orig_id, generate_labels, "".join(example.words[s:e + 1]), l, "error", "7") continue length = end_index - start_index + 1 if length > self._config.max_answer_length: utils.log(example.orig_id, generate_labels, "".join(example.words[s:e + 1]), l, "error", "8") continue result_spans.append((s, e, l)) all_results.append(set(result_spans)) return all_results
def __init__(self, eid, task_name, words, tags, is_token_level, label_mapping): super(TaggingExample, self).__init__(task_name) self.eid = eid self.words = words if is_token_level: labels = tags else: span_labels = tagging_utils.get_span_labels(tags) labels = tagging_utils.get_tags(span_labels, len(words), LABEL_ENCODING) self.labels = [label_mapping[l] for l in labels]
def _get_label_mapping(self, provided_split=None, provided_sentences=None): # import pdb; pdb.set_trace() # IBO if self._label_mapping is not None: return self._label_mapping if tf.io.gfile.exists(self._label_mapping_path): self._label_mapping = utils.load_pickle(self._label_mapping_path) return self._label_mapping utils.log("Writing label mapping for task", self.name) tag_counts = collections.Counter() train_tags = set() for split in ["train", "dev", "test"]: if not tf.io.gfile.exists( os.path.join(self.config.raw_data_dir(self.name), split + ".txt")): continue if split == provided_split: split_sentences = provided_sentences else: split_sentences = self._get_labeled_sentences(split) for _, tags in split_sentences: if not self._is_token_level: span_labels = tagging_utils.get_span_labels(tags) tags = tagging_utils.get_tags(span_labels, len(tags), LABEL_ENCODING) for tag in tags: tag_counts[tag] += 1 if provided_split == "train": train_tags.add(tag) if self.name == "ccg": infrequent_tags = [] for tag in tag_counts: if tag not in train_tags: infrequent_tags.append(tag) label_mapping = { label: i for i, label in enumerate( sorted( filter(lambda t: t not in infrequent_tags, tag_counts.keys()))) } n = len(label_mapping) for tag in infrequent_tags: label_mapping[tag] = n else: labels = sorted(tag_counts.keys()) label_mapping = {label: i for i, label in enumerate(labels)} utils.write_pickle(label_mapping, self._label_mapping_path) self._label_mapping = label_mapping return label_mapping