Exemplo n.º 1
0
 def __init__(self, eid, task_name, words, tags, is_token_level,
              label_mapping):
     super(TaggingExample, self).__init__(task_name)
     self.eid = eid
     self.words = words
     if is_token_level:
         labels = tags
     else:
         span_labels = tagging_utils.get_span_labels(tags)
         labels = tagging_utils.get_tags(span_labels, len(words),
                                         LABEL_ENCODING)
     self.labels = [label_mapping[l] for l in labels]
Exemplo n.º 2
0
 def _get_label_mapping(self, provided_split=None, provided_sentences=None):
     # import pdb; pdb.set_trace() # IBO
     if self._label_mapping is not None:
         return self._label_mapping
     if tf.io.gfile.exists(self._label_mapping_path):
         self._label_mapping = utils.load_pickle(self._label_mapping_path)
         return self._label_mapping
     utils.log("Writing label mapping for task", self.name)
     tag_counts = collections.Counter()
     train_tags = set()
     for split in ["train", "dev", "test"]:
         if not tf.io.gfile.exists(
                 os.path.join(self.config.raw_data_dir(self.name),
                              split + ".txt")):
             continue
         if split == provided_split:
             split_sentences = provided_sentences
         else:
             split_sentences = self._get_labeled_sentences(split)
         for _, tags in split_sentences:
             if not self._is_token_level:
                 span_labels = tagging_utils.get_span_labels(tags)
                 tags = tagging_utils.get_tags(span_labels, len(tags),
                                               LABEL_ENCODING)
             for tag in tags:
                 tag_counts[tag] += 1
                 if provided_split == "train":
                     train_tags.add(tag)
     if self.name == "ccg":
         infrequent_tags = []
         for tag in tag_counts:
             if tag not in train_tags:
                 infrequent_tags.append(tag)
         label_mapping = {
             label: i
             for i, label in enumerate(
                 sorted(
                     filter(lambda t: t not in infrequent_tags,
                            tag_counts.keys())))
         }
         n = len(label_mapping)
         for tag in infrequent_tags:
             label_mapping[tag] = n
     else:
         labels = sorted(tag_counts.keys())
         label_mapping = {label: i for i, label in enumerate(labels)}
     utils.write_pickle(label_mapping, self._label_mapping_path)
     self._label_mapping = label_mapping
     return label_mapping
Exemplo n.º 3
0
    def _get_labeled_sentences(self, split):
        sentences = []
        entry_ids = []
        if split not in self._word_to_char_mapping:
            self._word_to_char_mapping[split] = collections.OrderedDict()
        with tf.io.gfile.GFile(
                os.path.join(
                    self.config.raw_data_dir(self.name),
                    split + ("-debug" if self.config.debug else "") + ".json"),
                "r") as f:
            input_data = json.load(f)["data"]

        for entry in input_data:
            entry_ids.append(entry["id"])
            for paragraph in entry["paragraphs"]:
                paragraph_text = paragraph["context"]
                doc_tokens = []
                char_to_word_offset = []
                span_labels = []
                text_b_texts = []

                prev_is_whitespace = True
                prev_is_chinese = True
                for c in paragraph_text:
                    if tagging_utils.is_whitespace(c):
                        prev_is_whitespace = True
                    else:
                        if prev_is_whitespace or prev_is_chinese or tagging_utils.is_chinese_char(
                                c):
                            doc_tokens.append(c)
                            prev_is_chinese = True if tagging_utils.is_chinese_char(
                                c) else False
                        else:
                            doc_tokens[-1] += c
                            prev_is_chinese = False
                        prev_is_whitespace = False
                    char_to_word_offset.append(len(doc_tokens) - 1)

                for qa in paragraph["qas"]:
                    question_text = qa["question"]
                    text_b_text = tagging_utils.get_event_type(question_text)
                    label_text = tagging_utils.get_question_text(question_text)
                    text_b_texts.append(text_b_text)
                    if split == "train" or split == "dev":
                        is_impossible = qa["is_impossible"]
                        if not is_impossible:
                            answer = qa["answers"][0]
                            answer_offset = answer["answer_start"]
                            answer_length = len(answer["text"])
                            start_position = char_to_word_offset[answer_offset]
                            if answer_offset + answer_length - 1 >= len(
                                    char_to_word_offset):
                                utils.log("End position is out of document!")
                                continue
                            end_position = char_to_word_offset[answer_offset +
                                                               answer_length -
                                                               1]
                            span_labels.append(
                                (start_position, end_position, label_text))
                assert len(set(text_b_texts)) == 1
                tags = tagging_utils.get_tags(span_labels, len(doc_tokens),
                                              LABEL_ENCODING)

                sentence = []
                for word, tag in zip(doc_tokens, tags):
                    sentence.append((word, tag))

                words, tags = zip(*sentence)
                sentences.append((words, tags, text_b_texts[0]))

                self._word_to_char_mapping[split][entry["id"]] = {
                    w: c
                    for c, w in enumerate(char_to_word_offset)
                }
        assert len(sentences) == len(entry_ids)
        return sentences, entry_ids