コード例 #1
0
    def _create_examples(path: str, set_type: str) -> List[InputExample]:
        examples = []

        with open(path, encoding='utf8') as f:
            for line in f:
                example_json = json.loads(line)
                label = example_json['label'] if 'label' in example_json else None
                idx = example_json['idx']
                guid = "%s-%s" % (set_type, idx)
                text_a = example_json['premise']
                meta = {
                    'choice1': example_json['choice1'],
                    'choice2': example_json['choice2'],
                    'question': example_json['question']
                }
                example = InputExample(guid=guid, text_a=text_a, label=label, meta=meta, idx=idx)
                examples.append(example)

        if set_type == 'train' or set_type == 'unlabeled':
            mirror_examples = []
            for ex in examples:
                label = 1 if ex.label == 0 else 0
                meta = {
                    'choice1': ex.meta['choice2'],
                    'choice2': ex.meta['choice1'],
                    'question': ex.meta['question']
                }
                mirror_example = InputExample(guid=ex.guid + 'm', text_a=ex.text_a, label=label, meta=meta)
                mirror_examples.append(mirror_example)
            examples += mirror_examples
            print_rank_0(f"Added {len(mirror_examples)} mirror examples, total size is {len(examples)}...")
        return examples
コード例 #2
0
    def _create_examples(path: str, set_type: str) -> List[InputExample]:
        examples = []

        with open(path, encoding='utf8') as f:
            for line in f:
                example_json = json.loads(line)

                passage_idx = example_json['idx']
                text = punctuation_standardization(example_json['passage']['text'])
                questions = example_json['passage']['questions']
                for question_json in questions:
                    question = punctuation_standardization(question_json["question"])
                    question_idx = question_json['idx']
                    answers = question_json["answers"]
                    for answer_json in answers:
                        label = answer_json["label"] if 'label' in answer_json else None
                        answer_idx = answer_json["idx"]
                        guid = f'{set_type}-p{passage_idx}-q{question_idx}-a{answer_idx}'
                        meta = {
                            'passage_idx': passage_idx,
                            'question_idx': question_idx,
                            'answer_idx': answer_idx,
                            'answer': punctuation_standardization(answer_json["text"])
                        }
                        idx = [passage_idx, question_idx, answer_idx]
                        example = InputExample(guid=guid, text_a=text, text_b=question, label=label, meta=meta, idx=idx)
                        examples.append(example)

        question_indices = list(set(example.meta['question_idx'] for example in examples))
        label_distribution = Counter(example.label for example in examples)
        print_rank_0(
            f"Returning {len(examples)} examples corresponding to {len(question_indices)} questions with label "
            f"distribution {list(label_distribution.items())}")
        return examples
コード例 #3
0
 def create_examples(self, split):
     if split == "train":
         filename = "train.json"
     elif split == "dev":
         filename = "dev.json"
     elif split == "test":
         filename = "test.json"
     else:
         raise NotImplementedError(split)
     print_rank_0(f"Creating CMRC-{split} dataset from {self.data_dir}")
     example_list = []
     idx = 0
     with open(os.path.join(self.data_dir, filename), encoding='utf-8') as file:
         dataset = json.load(file)
         for article in dataset['data']:
             for paragraph in article['paragraphs']:
                 context = paragraph['context']
                 for qa in paragraph['qas']:
                     question = qa["question"]
                     answers = {answer['text'] for answer in qa["answers"]} if split != 'test' else {"FAKE_ANSWER"}
                     for answer in answers:
                         guid = "%s-%s" % (split, idx)
                         meta = {
                             "answer": answer,
                             "question": question,
                             "ref": self.tokenizer.DecodeIds(self.tokenizer.EncodeAsIds(answer).tokenization)}
                         example = InputExample(guid=guid, text_a=context, meta=meta)
                         if idx < 10:
                             print_rank_0(
                                 (context.encode('utf-8'), answer.encode('utf-8'), meta["ref"].encode('utf-8')))
                         example_list.append(example)
                         idx += 1
     print_rank_0(f"Creating {len(example_list)} examples for {split}")
     return example_list
コード例 #4
0
    def _create_examples(self,
                         path: str,
                         set_type: str,
                         hypothesis_name: str = "hypothesis",
                         premise_name: str = "premise") -> List[InputExample]:
        examples = []

        with open(path, encoding='utf8') as f:
            for line_idx, line in enumerate(f):
                example_json = json.loads(line)
                idx = example_json['idx']
                if isinstance(idx, str):
                    try:
                        idx = int(idx)
                    except ValueError:
                        idx = line_idx
                label = example_json.get('label')
                guid = "%s-%s" % (set_type, idx)
                text_a = example_json[premise_name]
                text_b = example_json[hypothesis_name]

                example = InputExample(guid=guid,
                                       text_a=text_a,
                                       text_b=text_b,
                                       label=label,
                                       idx=idx)
                examples.append(example)

        return examples
コード例 #5
0
 def __init__(self, args, split, tokenizer):
     self.args = args
     task, data_dir = args.task.lower(), args.data_dir
     self.max_src_length, self.max_tgt_length = args.src_seq_length, args.tgt_seq_length
     self.split = split
     self.tokenizer = tokenizer
     if split == "train":
         filename = "train"
     elif split == "dev":
         filename = "val"
     elif split == "test":
         filename = "test"
     else:
         raise NotImplementedError(split)
     print_rank_0(f"Creating {task}-{split} dataset from {data_dir}")
     self.dataset_name = split
     if task == "gigaword":
         detokenizer = gigaword_detokenize
     elif task == "cnn_dm":
         detokenizer = cnndm_detokenize
     else:
         detokenizer = None
     source_texts, target_texts = [], []
     with open(os.path.join(data_dir, f"{filename}.source"),
               encoding='utf-8') as file:
         for line in file:
             line = line.strip()
             line = detokenizer(line) if detokenizer else line
             source_texts.append(line)
     with open(os.path.join(data_dir, f"{filename}.target"),
               encoding='utf-8') as file:
         for line in file:
             line = line.strip()
             line = detokenizer(line,
                                is_target=True) if detokenizer else line
             target_texts.append(line)
     assert len(source_texts) == len(target_texts)
     self.examples, self.example_list = {}, []
     for idx, (source_text,
               target_text) in enumerate(zip(source_texts, target_texts)):
         if (idx + 1) % 20000 == 0:
             print_rank_0(f"Complete {idx + 1} examples")
         guid = "%s-%s" % (split, idx)
         meta = {
             "ref":
             tokenizer.DecodeIds(
                 tokenizer.EncodeAsIds(target_text).tokenization)
         }
         example = InputExample(guid=guid,
                                text_a=source_text,
                                text_b=target_text,
                                meta=meta)
         if idx < 10:
             print_rank_0(
                 (source_text.encode('utf-8'), target_text.encode('utf-8'),
                  meta["ref"].encode('utf-8')))
         self.examples[guid] = example
         self.example_list.append(example)
     print_rank_0(f"Return {len(self.examples)} {split} examples")
コード例 #6
0
 def create_examples(self, split):
     if split == "train":
         filename = "train"
     elif split == "dev":
         filename = "val"
     elif split == "test":
         filename = "test"
     else:
         raise NotImplementedError(split)
     print_rank_0(
         f"Creating {self.task}-{split} dataset from {self.data_dir}")
     if self.task == "gigaword":
         detokenizer = gigaword_detokenize
     elif self.task == "cnn_dm":
         detokenizer = cnndm_detokenize
     else:
         detokenizer = None
     source_texts, target_texts = [], []
     with open(os.path.join(self.data_dir, f"{filename}.source"),
               encoding='utf-8') as file:
         for line in file:
             line = line.strip()
             line = punctuation_standardization(line)
             line = detokenizer(line) if detokenizer else line
             source_texts.append(line)
     with open(os.path.join(self.data_dir, f"{filename}.target"),
               encoding='utf-8') as file:
         for line in file:
             line = line.strip()
             line = punctuation_standardization(line)
             line = detokenizer(line,
                                is_target=True) if detokenizer else line
             target_texts.append(line)
     assert len(source_texts) == len(target_texts)
     example_list = []
     for idx, (source_text,
               target_text) in enumerate(zip(source_texts, target_texts)):
         if (idx + 1) % 20000 == 0:
             print_rank_0(f"Complete {idx + 1} examples")
         guid = "%s-%s" % (split, idx)
         meta = {
             "ref":
             self.tokenizer.DecodeIds(
                 self.tokenizer.EncodeAsIds(target_text).tokenization)
         }
         example = InputExample(guid=guid,
                                text_a=source_text,
                                text_b=target_text,
                                meta=meta)
         if idx < 10:
             print_rank_0(
                 (source_text.encode('utf-8'), target_text.encode('utf-8'),
                  meta["ref"].encode('utf-8')))
         example_list.append(example)
     return example_list
コード例 #7
0
    def _create_examples(path: str, set_type: str) -> List[InputExample]:
        examples = []
        df = read_tsv(path)

        for idx, row in df.iterrows():
            guid = f"{set_type}-{idx}"
            text_a = punctuation_standardization(row['sentence'])
            label = row.get('label', None)
            example = InputExample(guid=guid, text_a=text_a, label=label)
            examples.append(example)

        return examples
コード例 #8
0
 def __init__(self, args, split, tokenizer):
     self.args = args
     task, data_dir = args.task.lower(), args.data_dir
     self.max_src_length, self.max_tgt_length = args.src_seq_length, args.tgt_seq_length
     self.split = split
     assert args.tokenizer_type == "BertWordPieceTokenizer"
     self.tokenizer = tokenizer
     if split == "train":
         filename = "train"
     elif split == "dev":
         filename = "valid"
     elif split == "test":
         filename = "test"
     else:
         raise NotImplementedError(split)
     print_rank_0(f"Creating {task}-{split} dataset from {data_dir}")
     self.dataset_name = split
     detokenizer = blanklm_detokenize
     source_texts, target_texts = [], []
     with open(os.path.join(data_dir, f"{filename}.txt"),
               encoding='utf-8') as file:
         for line in file:
             line = line.strip()
             line = detokenizer(line) if detokenizer else line
             target_texts.append(line)
     if split == 'test':
         with open(os.path.join(
                 data_dir,
                 f"blank/test.maskratio{args.blank_maskratio:.1f}.blank"),
                   encoding='utf-8') as file:
             for line in file:
                 line = line.strip()
                 line = detokenizer(line) if detokenizer else line
                 source_texts.append(line)
     else:
         source_texts = target_texts
     self.examples, self.example_list = {}, []
     for idx, (source_text,
               target_text) in enumerate(zip(source_texts, target_texts)):
         # if idx > 10000:
         #     break
         if (idx + 1) % 20000 == 0:
             print_rank_0(f"Complete {idx + 1} examples")
         guid = "%s-%s" % (split, idx)
         meta = {"ref": target_text}
         example = InputExample(guid=guid,
                                text_a=source_text,
                                text_b=target_text,
                                meta=meta)
         self.examples[guid] = example
         self.example_list.append(example)
     print_rank_0(f"Return {len(self.examples)} {split} examples")
     self.random = random.Random(args.seed)
コード例 #9
0
 def create_examples(self, split):
     if split == "train":
         key = "train"
     elif split == "dev":
         key = "validation"
     elif split == "test":
         key = "test"
     else:
         raise NotImplementedError(split)
     print_rank_0(f"Creating XSUM-{split} dataset from {self.data_dir}")
     with open(os.path.join(self.data_dir, "XSum-TRAINING-DEV-TEST-SPLIT-90-5-5.json")) as file:
         id_list = json.load(file)
     id_list = id_list[key]
     source_texts, target_texts = [], []
     for i, idx in enumerate(id_list):
         with open(os.path.join(self.data_dir, f"{idx}.summary")) as file:
             key, sentences = None, []
             source_text, target_text = None, None
             for line in file:
                 line = line.strip()
                 if line.startswith("[SN]"):
                     if key is not None:
                         if key == "RESTBODY":
                             source_text = " ".join(sentences)
                         elif key == "FIRST-SENTENCE":
                             target_text = " ".join(sentences)
                     key = line[4:-4]
                     sentences = []
                 elif line:
                     sentences.append(line)
             if key is not None:
                 if key == "RESTBODY":
                     source_text = " ".join(sentences)
                 elif key == "FIRST-SENTENCE":
                     target_text = " ".join(sentences)
             source_texts.append(source_text)
             target_texts.append(target_text)
             if (i + 1) % 1000 == 0:
                 print_rank_0(f"Complete {i + 1} examples")
     assert len(source_texts) == len(target_texts)
     example_list = []
     for idx, (source_text, target_text) in enumerate(zip(source_texts, target_texts)):
         if (idx + 1) % 20000 == 0:
             print_rank_0(f"Complete {idx + 1} examples")
         guid = "%s-%s" % (split, idx)
         meta = {"ref": self.tokenizer.DecodeIds(self.tokenizer.EncodeAsIds(target_text).tokenization)}
         example = InputExample(guid=guid, text_a=source_text, text_b=target_text, meta=meta)
         if idx < 10:
             print_rank_0((source_text.encode('utf-8'), target_text.encode('utf-8'), meta["ref"].encode('utf-8')))
         example_list.append(example)
     return example_list
コード例 #10
0
    def _create_examples(path: str, set_type: str) -> List[InputExample]:
        examples = []

        with open(path) as f:
            reader = csv.reader(f, delimiter=',')
            for idx, row in enumerate(reader):
                label, body = row
                guid = "%s-%s" % (set_type, idx)
                text_a = body.replace('\\n', ' ').replace('\\', ' ')
                text_a = punctuation_standardization(text_a)

                example = InputExample(guid=guid, text_a=text_a, label=label)
                examples.append(example)

        return examples
コード例 #11
0
    def _create_examples(path: str, set_type: str) -> List[InputExample]:
        examples = []

        with open(path, encoding='utf8') as f:
            for line in f:
                example_json = json.loads(line)
                idx = example_json['idx']
                label = str(example_json['label']).lower() if 'label' in example_json else None
                guid = "%s-%s" % (set_type, idx)
                text_a = punctuation_standardization(example_json['passage'])
                text_b = punctuation_standardization(example_json['question'])
                example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, idx=idx)
                examples.append(example)

        return examples
コード例 #12
0
 def _create_examples(path: str, set_type: str) -> List[InputExample]:
     examples = []
     with open(path, encoding='utf8') as f:
         for line in f:
             example_json = json.loads(line)
             idx = example_json['idx']
             if isinstance(idx, str):
                 idx = int(idx)
             label = "true" if example_json.get('label') else "false"
             guid = "%s-%s" % (set_type, idx)
             text_a = punctuation_standardization(example_json['sentence1'])
             text_b = punctuation_standardization(example_json['sentence2'])
             meta = {'word': example_json['word']}
             example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, idx=idx, meta=meta)
             examples.append(example)
     return examples
コード例 #13
0
    def _create_examples(path: str, set_type: str) -> List[InputExample]:
        examples = []
        with open(path) as f:
            data = json.load(f)['data']

        for idx, passage in enumerate(data):
            for pid, paragraph in enumerate(passage['paragraphs']):
                context = paragraph['context']
                for qid, qas in enumerate(paragraph['qas']):
                    if len(qas['answers']) == 0:
                        continue
                    guid = f"{set_type}-{idx}-{pid}-{qid}"
                    example = InputExample(guid=guid, text_a=context, text_b=qas['question'], label='0',
                                           meta={'answer': qas['answers'][0]})
                    examples.append(example)

        return examples
コード例 #14
0
    def _create_examples(path: str, set_type: str) -> List[InputExample]:
        examples = []

        df = pd.read_table(path)
        for idx, row in df.iterrows():
            label = str(row['prefix'])
            guid = "%s-%s" % (set_type, idx)
            text_a = str(row['input_text'])
            text_b = str(row['target_text'])
            example = InputExample(guid=guid,
                                   text_a=text_a,
                                   text_b=text_b,
                                   label=label,
                                   idx=idx)
            examples.append(example)

        return examples
コード例 #15
0
    def _create_examples(path: str, set_type: str) -> List[InputExample]:
        examples = []

        with open(path, encoding='utf8') as f:
            reader = csv.reader(f, delimiter=',')
            for idx, row in enumerate(reader):
                label, question_title, question_body, answer = row
                guid = "%s-%s" % (set_type, idx)
                text_a = ' '.join([question_title.replace('\\n', ' ').replace('\\', ' '),
                                   question_body.replace('\\n', ' ').replace('\\', ' ')])
                text_a = punctuation_standardization(text_a)
                text_b = answer.replace('\\n', ' ').replace('\\', ' ')
                text_b = punctuation_standardization(text_b)

                example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
                examples.append(example)

        return examples
コード例 #16
0
    def _create_examples(path: str, set_type: str) -> List[InputExample]:
        examples = []
        if set_type != 'test':
            df = read_tsv(path, header=None)
        else:
            df = read_tsv(path)

        for idx, row in df.iterrows():
            guid = f"{set_type}-{idx}"
            if set_type != 'test':
                text_a = punctuation_standardization(row[3])
                label = row[1]
            else:
                text_a = punctuation_standardization(row['sentence'])
                label = None
            example = InputExample(guid=guid, text_a=text_a, label=label)
            examples.append(example)

        return examples
コード例 #17
0
    def _create_examples(self, path: str) -> List[InputExample]:
        examples = []

        with open(path, encoding='utf8') as f:
            for line in f:
                example_json = json.loads(line)
                label = example_json['label']
                id_ = example_json['id']
                text_a = punctuation_standardization(example_json['question'])
                text_b = punctuation_standardization(example_json['comment'])
                language = example_json['language']

                if self.language is not None and language != self.language:
                    continue

                example = InputExample(guid=id_, text_a=text_a, text_b=text_b, label=label)
                examples.append(example)

        return examples
コード例 #18
0
 def __init__(self, args, split, tokenizer):
     self.args = args
     task, data_dir = args.task.lower(), args.data_dir
     self.max_src_length, self.max_tgt_length = args.src_seq_length, args.tgt_seq_length
     self.split = split
     self.tokenizer = tokenizer
     if split == "train":
         filename = "train"
     elif split == "dev":
         filename = "valid"
     elif split == "test":
         filename = "test"
     else:
         raise NotImplementedError(split)
     print_rank_0(f"Creating {task}-{split} dataset from {data_dir}")
     self.dataset_name = split
     source_texts, target_texts = [], []
     with open(os.path.join(data_dir, f"{filename}.source"),
               encoding='utf-8') as file:
         for line in file:
             line = line.strip()
             source_texts.append(line)
     with open(os.path.join(data_dir, f"{filename}.target"),
               encoding='utf-8') as file:
         for line in file:
             line = line.strip()
             target_texts.append(line)
     self.examples, self.example_list = {}, []
     for idx, (source_text,
               target_text) in enumerate(zip(source_texts, target_texts)):
         if (idx + 1) % 20000 == 0:
             print_rank_0(f"Complete {idx + 1} examples")
         guid = "%s-%s" % (split, idx)
         meta = {"ref": target_text}
         example = InputExample(guid=guid,
                                text_a=source_text,
                                text_b=target_text,
                                meta=meta)
         self.examples[guid] = example
         self.example_list.append(example)
     print_rank_0(f"Return {len(self.examples)} {split} examples")
コード例 #19
0
    def _create_examples(path, set_type, for_train=False) -> List[InputExample]:
        examples = []

        def clean_text(text):
            """Remove new lines and multiple spaces and adjust end of sentence dot."""

            text = text.replace("\n", " ")
            text = re.sub(r'\s+', ' ', text)
            for _ in range(3):
                text = text.replace(' . ', '. ')

            return text

        filenames = glob.glob(os.path.join(path, "middle", '*.txt')) + glob.glob(os.path.join(path, "high", "*.txt"))
        for filename in filenames:
            with open(filename, 'r') as f:
                for line in f:
                    data = json.loads(line)
                    idx = data["id"]
                    context = data["article"]
                    questions = data["questions"]
                    choices = data["options"]
                    answers = data["answers"]
                    # Check the length.
                    assert len(questions) == len(answers)
                    assert len(questions) == len(choices)

                    context = clean_text(context)
                    for question_idx, question in enumerate(questions):
                        answer = answers[question_idx]
                        choice = choices[question_idx]
                        guid = f'{set_type}-p{idx}-q{question_idx}'
                        ex_idx = [set_type, idx, question_idx]
                        meta = {
                            "choices": choice
                        }
                        example = InputExample(guid=guid, text_a=context, text_b=question, label=answer, meta=meta,
                                               idx=ex_idx)
                        examples.append(example)
        return examples
コード例 #20
0
    def _create_examples(path, set_type, seed=42, max_train_candidates_per_question: int = 10, for_train=False) -> List[
        InputExample]:
        examples = []

        entity_shuffler = random.Random(seed)

        with open(path, encoding='utf8') as f:
            for idx, line in enumerate(f):
                example_json = json.loads(line)

                idx = example_json['idx']
                text = punctuation_standardization(example_json['passage']['text'])
                entities = set()

                for entity_json in example_json['passage']['entities']:
                    start = entity_json['start']
                    end = entity_json['end']
                    entity = punctuation_standardization(text[start:end + 1])
                    entities.add(entity)

                entities = list(entities)
                entities.sort()

                text = text.replace("@highlight\n", "- ")  # we follow the GPT-3 paper wrt @highlight annotations
                questions = example_json['qas']

                for question_json in questions:
                    question = punctuation_standardization(question_json['query'])
                    question_idx = question_json['idx']
                    answers = set()

                    for answer_json in question_json.get('answers', []):
                        answer = punctuation_standardization(answer_json['text'])
                        answers.add(answer)

                    answers = list(answers)

                    if set_type == 'train' or for_train:
                        # create a single example per *correct* answer
                        for answer_idx, answer in enumerate(answers):
                            candidates = [ent for ent in entities if ent not in answers]
                            if len(candidates) > max_train_candidates_per_question - 1:
                                entity_shuffler.shuffle(candidates)
                                candidates = candidates[:max_train_candidates_per_question - 1]

                            guid = f'{set_type}-p{idx}-q{question_idx}-a{answer_idx}'
                            meta = {
                                'passage_idx': idx,
                                'question_idx': question_idx,
                                'candidates': [answer] + candidates,
                                'answers': [answer]
                            }
                            ex_idx = [idx, question_idx, answer_idx]
                            example = InputExample(guid=guid, text_a=text, text_b=question, label="0", meta=meta,
                                                   idx=ex_idx, num_choices=len(candidates) + 1)
                            examples.append(example)

                    else:
                        # create just one example with *all* correct answers and *all* answer candidates
                        guid = f'{set_type}-p{idx}-q{question_idx}'
                        meta = {
                            'passage_idx': idx,
                            'question_idx': question_idx,
                            'candidates': entities,
                            'answers': answers
                        }
                        example = InputExample(guid=guid, text_a=text, text_b=question, label="1", meta=meta,
                                               idx=question_idx, num_choices=len(entities))
                        examples.append(example)

        question_indices = list(set(example.meta['question_idx'] for example in examples))
        label_distribution = Counter(example.label for example in examples)
        print_rank_0(
            f"Returning {len(examples)} examples corresponding to {len(question_indices)} questions with label "
            f"distribution {list(label_distribution.items())}")
        return examples
コード例 #21
0
    def create_examples(self, split):

        if split == "train":
            filename = "train-v1.1.json" if self.task == "squad_v1" else "train-v2.0.json"
        elif split == "dev":
            filename = "dev-v1.1.json" if self.task == "squad_v1" else "dev-v2.0.json"
        elif split == "test":
            filename = "dev-v1.1.json" if self.task == "squad_v1" else "dev-v2.0.json"
        else:
            raise NotImplementedError(split)
        print_rank_0(f"Creating SQuAD-{split} dataset from {self.data_dir}")
        example_list = []
        idx = 0
        total_qas = 0
        total_na = 0
        with open(os.path.join(self.data_dir, filename),
                  encoding='utf-8') as file:
            dataset = json.load(file)['data']
            for paragraphs in dataset:
                for paragraph in paragraphs['paragraphs']:
                    context = paragraph['context']
                    context_tokens = self.tokenizer.EncodeAsIds(
                        context).tokenization
                    transformer_encode = self.transformer_tokenizer(
                        context,
                        return_offsets_mapping=True,
                        add_special_tokens=False,
                        verbose=False)
                    assert transformer_encode['input_ids'] == context_tokens
                    token_to_char = transformer_encode['offset_mapping']
                    # if self.tokenizer_type == 'BertWordPieceTokenizer':
                    #     token_to_char = generate_token_to_char_map(context_tokens, context, self.tokenizer)
                    # else:
                    #     token_to_char = None
                    for qa in paragraph['qas']:
                        total_qas += 1
                        question = qa["question"]
                        question_tokens = self.tokenizer.EncodeAsIds(
                            " " + question).tokenization
                        answers = [answer["text"] for answer in qa["answers"]]
                        if len(qa['answers']) == 0:
                            answers = ['N/A']
                        for start in range(0, len(context_tokens),
                                           self.max_src_length // 2):
                            length = self.max_src_length - 3 - len(
                                question_tokens)
                            tokens = context_tokens[start:start + length]
                            new_context = self.tokenizer.DecodeIds(tokens)
                            answer = answers[0]
                            answer_tokens_text = self.tokenizer.DecodeIds(
                                self.tokenizer.EncodeAsIds(
                                    answer).tokenization)
                            if answer_tokens_text and answer_tokens_text in new_context:
                                # new_context = new_context.replace(answer_tokens_text, answer)
                                pass
                            else:
                                answer = 'N/A'
                            if self.task == 'squad_v1' and answer == 'N/A':
                                continue
                            guid = "%s-%s" % (split, idx)
                            meta = {
                                "context": context,
                                "context_tokens": context_tokens,
                                "token_to_char": token_to_char,
                                "answer": answer,
                                "answers": answers,
                                "question": question,
                                "ref": answer
                            }
                            example = InputExample(guid=guid,
                                                   text_a=new_context,
                                                   meta=meta,
                                                   idx=qa['id'])
                            example_list.append(example)
                            idx += 1
                            total_na += (answer == 'N/A')
                            if len(tokens) < length:
                                break
        print_rank_0(
            f"Creating {len(example_list)} / {total_qas} examples for {split}. {total_na} N/A"
        )
        return example_list
コード例 #22
0
    def _create_examples(self, path: str, set_type: str, cloze_eval=True) -> List[InputExample]:
        examples = []

        with open(path, encoding='utf8') as f:
            for line in f:
                example_json = json.loads(line)
                idx = example_json['idx']
                label = str(example_json['label']) if 'label' in example_json else None
                guid = "%s-%s" % (set_type, idx)
                text_a = punctuation_standardization(example_json['text'])
                meta = {
                    'span1_text': example_json['target']['span1_text'],
                    'span2_text': example_json['target']['span2_text'],
                    'span1_index': example_json['target']['span1_index'],
                    'span2_index': example_json['target']['span2_index']
                }
                if 'candidates' in example_json:
                    candidates = [cand['text'] for cand in example_json['candidates']]
                    # candidates = list(set(candidates))
                    filtered = []
                    for i, cand in enumerate(candidates):
                        if not cand in candidates[:i]:
                            filtered.append(cand)
                    candidates = filtered

                # the indices in the dataset are wrong for some examples, so we manually fix them
                span1_index, span1_text = meta['span1_index'], meta['span1_text']
                span2_index, span2_text = meta['span2_index'], meta['span2_text']
                words_a = text_a.split()
                words_a_lower = text_a.lower().split()
                words_span1_text = span1_text.lower().split()
                span1_len = len(words_span1_text)

                if words_a_lower[span1_index:span1_index + span1_len] != words_span1_text:
                    for offset in [-1, +1]:
                        if words_a_lower[span1_index + offset:span1_index + span1_len + offset] == words_span1_text:
                            span1_index += offset

                # if words_a_lower[span1_index:span1_index + span1_len] != words_span1_text:
                #     print_rank_0(f"Got '{words_a_lower[span1_index:span1_index + span1_len]}' but expected "
                #                  f"'{words_span1_text}' at index {span1_index} for '{words_a}'")

                if words_a[span2_index] != span2_text:
                    for offset in [-1, +1]:
                        if words_a[span2_index + offset] == span2_text:
                            span2_index += offset

                    if words_a[span2_index] != span2_text and words_a[span2_index].startswith(span2_text):
                        words_a = words_a[:span2_index] \
                                  + [words_a[span2_index][:len(span2_text)], words_a[span2_index][len(span2_text):]] \
                                  + words_a[span2_index + 1:]

                assert words_a[span2_index] == span2_text, \
                    f"Got '{words_a[span2_index]}' but expected '{span2_text}' at index {span2_index} for '{words_a}'"

                text_a = ' '.join(words_a)
                meta['span1_index'], meta['span2_index'] = span1_index, span2_index

                if self.args.task == 'wsc1':
                    example = InputExample(guid=guid, text_a=text_a, text_b=span1_text,
                                           label=label, meta=meta, idx=idx)
                    examples.append(example)
                    if set_type == 'train' and label == 'True':
                        for cand in candidates:
                            example = InputExample(guid=guid, text_a=text_a, text_b=cand,
                                                   label='False', meta=meta, idx=idx)
                            examples.append(example)
                    continue

                if cloze_eval and set_type == 'train' and label != 'True':
                    continue
                if set_type == 'train' and 'candidates' in example_json and len(candidates) > 9:
                    for i in range(0, len(candidates), 9):
                        _meta = copy.deepcopy(meta)
                        _meta['candidates'] = candidates[i:i + 9]
                        if len(_meta['candidates']) < 9:
                            _meta['candidates'] += candidates[:9 - len(_meta['candidates'])]
                        example = InputExample(guid=guid, text_a=text_a, label=label, meta=_meta, idx=idx)
                        examples.append(example)
                else:
                    if 'candidates' in example_json:
                        meta['candidates'] = candidates
                    example = InputExample(guid=guid, text_a=text_a, label=label, meta=meta, idx=idx)
                    examples.append(example)

        return examples