예제 #1
0
    def convert_examples_to_features(examples: List[MultiChoiceFullExample], tokenizer, max_seq_length: int = 512):
        unique_id = 1000000000
        features = []
        for (example_index, example) in tqdm(enumerate(examples), desc='Convert examples to features', total=len(examples)):
            query_tokens = tokenizer.tokenize(example.question_text)

            # word piece index -> token index
            tok_to_orig_index = []
            # token index -> word pieces group start index
            orig_to_tok_index = []
            # word pieces for all doc tokens
            all_doc_tokens = []
            for (i, token) in enumerate(example.doc_tokens):
                orig_to_tok_index.append(len(all_doc_tokens))
                sub_tokens = tokenizer.tokenize(token)
                for sub_token in sub_tokens:
                    tok_to_orig_index.append(i)
                    all_doc_tokens.append(sub_token)

            # Process sentence span list
            sentence_spans = []
            for (start, end) in example.sentence_span_list:
                piece_start = orig_to_tok_index[start]
                if end < len(example.doc_tokens) - 1:
                    piece_end = orig_to_tok_index[end + 1] - 1
                else:
                    piece_end = len(all_doc_tokens) - 1
                sentence_spans.append((piece_start, piece_end))

            options = example.options
            choice_features = []
            for option_id, option in enumerate(options):
                q_op_tokens = query_tokens + tokenizer.tokenize(option)
                doc_tokens = all_doc_tokens[:]
                utils.truncate_seq_pair(q_op_tokens, doc_tokens, max_seq_length - 3)

                tokens = ["[CLS]"] + q_op_tokens + ["[SEP]"]
                segment_ids = [0] * len(tokens)
                tokens = tokens + doc_tokens + ["[SEP]"]
                segment_ids += [1] * (len(doc_tokens) + 1)

                sentence_list = []
                doc_offset = len(q_op_tokens) + 2
                for (start, end) in sentence_spans:
                    assert start <= end
                    if start >= len(doc_tokens):
                        break
                    if end >= len(doc_tokens):
                        end = len(doc_tokens) - 1
                    start = doc_offset + start
                    end = doc_offset + end
                    sentence_list.append((start, end))
                    assert start < max_seq_length and end < max_seq_length

                input_ids = tokenizer.convert_tokens_to_ids(tokens)
                input_mask = [1] * len(input_ids)

                # Zero-pad up to the sequence length.
                while len(input_ids) < max_seq_length:
                    input_ids.append(0)
                    input_mask.append(0)
                    segment_ids.append(0)

                assert len(input_ids) == max_seq_length
                assert len(input_mask) == max_seq_length
                assert len(segment_ids) == max_seq_length

                choice_features.append({
                    "input_ids": input_ids,
                    "input_mask": input_mask,
                    "segment_ids": segment_ids,
                    "sentence_span_list": sentence_list
                })
            features.append(MultiChoiceFullFeature(
                example_index=example_index,
                qas_id=example.qas_id,
                unique_id=unique_id,
                choice_features=choice_features,
                answer=example.answer
            ))
            unique_id += 1

        logger.info(f'Reading {len(features)} features.')

        return features
    def convert_examples_to_features(examples: List[MultiRCExample],
                                     tokenizer,
                                     max_seq_length: int = 512):
        unique_id = 1000000000
        features = []
        for (example_index,
             example) in tqdm(enumerate(examples),
                              desc='Convert examples to features',
                              total=len(examples)):
            query_tokens = tokenizer.tokenize(example.question_text)
            if query_tokens[-1] != '?':
                query_tokens.append('?')

            # word piece index -> token index
            tok_to_orig_index = []
            # token index -> word pieces group start index
            orig_to_tok_index = []
            # word pieces for all doc tokens
            all_doc_tokens = []
            for (i, token) in enumerate(example.doc_tokens):
                orig_to_tok_index.append(len(all_doc_tokens))
                sub_tokens = tokenizer.tokenize(token)
                for sub_token in sub_tokens:
                    tok_to_orig_index.append(i)
                    all_doc_tokens.append(sub_token)

            # Process sentence span list
            sentence_spans = []
            for (start, end) in example.sentence_span_list:
                piece_start = orig_to_tok_index[start]
                if end < len(example.doc_tokens) - 1:
                    piece_end = orig_to_tok_index[end + 1] - 1
                else:
                    piece_end = len(all_doc_tokens) - 1
                sentence_spans.append((piece_start, piece_end))

            # Process all tokens
            q_op_tokens = query_tokens + tokenizer.tokenize(
                example.option_text)
            doc_tokens = all_doc_tokens[:]
            utils.truncate_seq_pair(q_op_tokens, doc_tokens,
                                    max_seq_length - 3)

            tokens = ["[CLS]"] + q_op_tokens + ["[SEP]"]
            segment_ids = [0] * len(tokens)
            tokens = tokens + doc_tokens + ["[SEP]"]
            segment_ids += [1] * (len(doc_tokens) + 1)

            sentence_list = []
            collected_sentence_indices = []
            doc_offset = len(q_op_tokens) + 2
            for sentence_index, (start, end) in enumerate(sentence_spans):
                assert start <= end, (example_index, sentence_index, start,
                                      end)
                if start >= len(doc_tokens):
                    break
                if end >= len(doc_tokens):
                    end = len(doc_tokens) - 1
                start = doc_offset + start
                end = doc_offset + end
                sentence_list.append((start, end))
                assert start < max_seq_length and end < max_seq_length
                collected_sentence_indices.append(sentence_index)

            sentence_ids = []
            for sentence_id in example.sentence_ids:
                if sentence_id in collected_sentence_indices:
                    sentence_ids.append(sentence_id)

            # For multiple style, append 0 at last and for each sentence id, +1
            sentence_ids = [x + 1 for x in sentence_ids]
            sentence_ids.append(0)

            input_ids = tokenizer.convert_tokens_to_ids(tokens)
            input_mask = [1] * len(input_ids)

            # Zero-pad up to the sequence length.
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(0)
                segment_ids.append(0)

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            features.append(
                MultiRCFeature(
                    example_index=example_index,
                    qas_id=example.qas_id,
                    unique_id=unique_id,
                    input_ids=input_ids,
                    input_mask=input_mask,
                    segment_ids=segment_ids,
                    sentence_span_list=sentence_list,
                    answer=example.answer +
                    1,  # In bert_hierarchical model, the output size is 3.
                    sentence_ids=sentence_ids))
            unique_id += 1

        logger.info(f'Reading {len(features)} features.')

        return features