def test_full_tokenizer(self):
        tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
        text = "lower newer"
        bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"]
        tokens = tokenizer.tokenize(text, add_prefix_space=True)
        self.assertListEqual(tokens, bpe_tokens)

        input_tokens = tokens + [tokenizer.unk_token]
        input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
        self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
    def convert_examples_to_features(examples: List[MultiRCExample],
                                     tokenizer: RobertaTokenizer,
                                     max_seq_length: int = 512,
                                     **kwargs):
        unique_id = 1000000000
        features = []
        for (example_index,
             example) in tqdm(enumerate(examples),
                              desc='Convert examples to features',
                              total=len(examples)):
            query_tokens = tokenizer.tokenize(example.question_text)
            if query_tokens[-1] != '?':
                query_tokens.append('?')

            # word piece index -> token index
            tok_to_orig_index = []
            # token index -> word pieces group start index
            orig_to_tok_index = []
            # word pieces for all doc tokens
            all_doc_tokens = []
            for (i, token) in enumerate(example.doc_tokens):
                orig_to_tok_index.append(len(all_doc_tokens))
                sub_tokens = tokenizer.tokenize(token)
                for sub_token in sub_tokens:
                    tok_to_orig_index.append(i)
                    all_doc_tokens.append(sub_token)

            # Process sentence span list
            sentence_spans = []
            for (start, end) in example.sentence_span_list:
                piece_start = orig_to_tok_index[start]
                if end < len(example.doc_tokens) - 1:
                    piece_end = orig_to_tok_index[end + 1] - 1
                else:
                    piece_end = len(all_doc_tokens) - 1
                sentence_spans.append((piece_start, piece_end))

            # Process all tokens
            q_op_tokens = query_tokens + tokenizer.tokenize(
                example.option_text)
            doc_tokens = all_doc_tokens[:]
            utils.truncate_seq_pair(q_op_tokens, doc_tokens,
                                    max_seq_length - 4)

            tokens = [tokenizer.cls_token] + q_op_tokens + [
                tokenizer.sep_token, tokenizer.sep_token
            ]
            segment_ids = [0] * len(tokens)
            tokens = tokens + doc_tokens + [tokenizer.sep_token]
            segment_ids += [1] * (len(doc_tokens) + 1)

            sentence_list = []
            collected_sentence_indices = []
            doc_offset = len(q_op_tokens) + 3
            for sentence_index, (start, end) in enumerate(sentence_spans):
                assert start <= end, (example_index, sentence_index, start,
                                      end)
                if start >= len(doc_tokens):
                    break
                if end >= len(doc_tokens):
                    end = len(doc_tokens) - 1
                start = doc_offset + start
                end = doc_offset + end
                sentence_list.append((start, end))
                assert start < max_seq_length and end < max_seq_length
                collected_sentence_indices.append(sentence_index)

            sentence_ids = []
            for sentence_id in example.sentence_ids:
                if sentence_id in collected_sentence_indices:
                    sentence_ids.append(sentence_id)

            # For multiple style, append 0 at last and for each sentence id, +1
            # sentence_ids = [x + 1 for x in sentence_ids]
            # sentence_ids.append(0)

            input_ids = tokenizer.convert_tokens_to_ids(tokens)
            input_mask = [1] * len(input_ids)

            # Zero-pad up to the sequence length.
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(0)
                segment_ids.append(0)

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            features.append(
                MultiRCFeature(
                    example_index=example_index,
                    qas_id=example.qas_id,
                    unique_id=unique_id,
                    input_ids=input_ids,
                    input_mask=input_mask,
                    segment_ids=segment_ids,
                    sentence_span_list=sentence_list,
                    answer=example.answer +
                    1,  # In bert_hierarchical model, the output size is 3.
                    sentence_ids=sentence_ids))
            unique_id += 1

        logger.info(f'Reading {len(features)} features.')

        return features
예제 #3
0
    def convert_examples_to_features(examples: List[QAFullExample],
                                     tokenizer: RobertaTokenizer,
                                     max_seq_length, doc_stride,
                                     max_query_length):
        """Loads a data file into a list of `InputBatch`s."""

        unique_id = 1000000000
        features = []
        for (example_index,
             example) in tqdm(enumerate(examples),
                              desc='Convert examples to features',
                              total=len(examples)):
            query_tokens = tokenizer.tokenize(example.question_text)

            if len(query_tokens) > max_query_length:
                # query_tokens = query_tokens[0:max_query_length]
                # Remove the tokens appended at the front of query, which may belong to last query and answer.
                query_tokens = query_tokens[-max_query_length:]

            # word piece index -> token index
            tok_to_orig_index = []
            # token index -> word pieces group start index
            # BertTokenizer.tokenize(doc_tokens[i]) = all_doc_tokens[orig_to_tok_index[i]: orig_to_tok_index[i + 1]]
            orig_to_tok_index = []
            # word pieces for all doc tokens
            all_doc_tokens = []
            for (i, token) in enumerate(example.doc_tokens):
                orig_to_tok_index.append(len(all_doc_tokens))
                sub_tokens = tokenizer.tokenize(token)
                for sub_token in sub_tokens:
                    tok_to_orig_index.append(i)
                    all_doc_tokens.append(sub_token)

            # Process sentence span list
            sentence_spans = []
            for (start, end) in example.sentence_span_list:
                piece_start = orig_to_tok_index[start]
                if end < len(example.doc_tokens) - 1:
                    piece_end = orig_to_tok_index[end + 1] - 1
                else:
                    piece_end = len(all_doc_tokens) - 1
                sentence_spans.append((piece_start, piece_end))

            # Rationale start and end position in chunk, where is calculated from the start of current chunk.
            # ral_start_position = None
            # ral_end_position = None

            ral_start_position = orig_to_tok_index[example.ral_start_position]
            if example.ral_end_position < len(example.doc_tokens) - 1:
                ral_end_position = orig_to_tok_index[example.ral_end_position +
                                                     1] - 1
            else:
                ral_end_position = len(all_doc_tokens) - 1
            ral_start_position, ral_end_position = utils.improve_answer_span(
                all_doc_tokens, ral_start_position, ral_end_position,
                tokenizer, example.orig_answer_text)

            # The -4 accounts for [CLS], [SEP] and [SEP]
            max_tokens_for_doc = max_seq_length - len(query_tokens) - 4

            # We can have documents that are longer than the maximum sequence length.
            # To deal with this we do a sliding window approach, where we take chunks
            # of the up to our max length with a stride of `doc_stride`.
            _DocSpan = collections.namedtuple("DocSpan", ["start", "length"])
            doc_spans = []
            start_offset = 0
            while start_offset < len(all_doc_tokens):
                length = len(all_doc_tokens) - start_offset
                if length > max_tokens_for_doc:
                    length = max_tokens_for_doc
                doc_spans.append(_DocSpan(start=start_offset, length=length))
                if start_offset + length == len(all_doc_tokens):
                    break
                start_offset += min(length, doc_stride)

            sentence_spans_list = []
            sentence_ids_list = []
            for span_id, doc_span in enumerate(doc_spans):
                span_start = doc_span.start
                span_end = span_start + doc_span.length - 1

                span_sentence = []
                sen_ids = []

                for sen_idx, (sen_start, sen_end) in enumerate(sentence_spans):
                    if sen_end < span_start:
                        continue
                    if sen_start > span_end:
                        break
                    span_sentence.append(
                        (max(sen_start, span_start), min(sen_end, span_end)))
                    sen_ids.append(sen_idx)

                sentence_spans_list.append(span_sentence)
                sentence_ids_list.append(sen_ids)

            ini_sen_id = example.sentence_id
            for (doc_span_index, doc_span) in enumerate(doc_spans):
                # Store the input tokens to transform into input ids later.
                tokens = []
                token_to_orig_map = {}
                token_is_max_context = {}
                segment_ids = []
                tokens.append(tokenizer.cls_token)
                segment_ids.append(0)
                for token in query_tokens:
                    tokens.append(token)
                    segment_ids.append(0)
                tokens.append(tokenizer.sep_token)
                segment_ids.append(0)
                tokens.append(tokenizer.sep_token)
                segment_ids.append(0)

                doc_start = doc_span.start
                doc_offset = len(query_tokens) + 3
                sentence_list = sentence_spans_list[doc_span_index]
                cur_sentence_list = []
                for sen_id, sen in enumerate(sentence_list):
                    new_sen = (sen[0] - doc_start + doc_offset,
                               sen[1] - doc_start + doc_offset)
                    cur_sentence_list.append(new_sen)

                for i in range(doc_span.length):
                    split_token_index = doc_span.start + i  # Original index of word piece in all_doc_tokens
                    # Index of word piece in input sequence -> Original word index in doc_tokens
                    token_to_orig_map[len(
                        tokens)] = tok_to_orig_index[split_token_index]
                    # Check if the word piece has the max context in all doc spans.
                    is_max_context = utils.check_is_max_context(
                        doc_spans, doc_span_index, split_token_index)

                    token_is_max_context[len(tokens)] = is_max_context
                    tokens.append(all_doc_tokens[split_token_index])
                    segment_ids.append(1)
                # tokens.append("[SEP]")
                tokens.append(tokenizer.sep_token)
                segment_ids.append(1)

                input_ids = tokenizer.convert_tokens_to_ids(tokens)

                # The mask has 1 for real tokens and 0 for padding tokens. Only real
                # tokens are attended to.
                input_mask = [1] * len(input_ids)

                # Zero-pad up to the sequence length.
                while len(input_ids) < max_seq_length:
                    input_ids.append(0)
                    input_mask.append(0)
                    segment_ids.append(0)

                assert len(input_ids) == max_seq_length
                assert len(input_mask) == max_seq_length
                assert len(segment_ids) == max_seq_length

                # ral_start = None
                # ral_end = None
                # answer_choice = None

                doc_start = doc_span.start
                doc_end = doc_span.start + doc_span.length - 1

                # Process rationale
                out_of_span = False
                if not (ral_start_position >= doc_start
                        and ral_end_position <= doc_end):
                    out_of_span = True
                if out_of_span:
                    # TODO:
                    #  Considering how to set rationale start and end positions for out of span instances.
                    ral_start = 0
                    ral_end = 0
                    answer_choice = 0
                else:
                    doc_offset = len(query_tokens) + 2
                    ral_start = ral_start_position - doc_start + doc_offset
                    ral_end = ral_end_position - doc_start + doc_offset
                    answer_choice = example.is_impossible + 1

                # Process sentence id
                span_sen_id = -1
                for piece_sen_id, sen_id in enumerate(
                        sentence_ids_list[doc_span_index]):
                    if ini_sen_id == sen_id:
                        span_sen_id = piece_sen_id
                # # For no sentence id feature, replace it with []
                if span_sen_id == -1:
                    span_sen_id = []
                else:
                    span_sen_id = [span_sen_id]
                meta_data = {
                    'span_sen_to_orig_sen_map':
                    sentence_ids_list[doc_span_index]
                }

                if example_index < 0:
                    logger.info("*** Example ***")
                    logger.info("unique_id: %s" % unique_id)
                    logger.info("example_index: %s" % example_index)
                    logger.info("doc_span_index: %s" % doc_span_index)
                    logger.info("sentence_spans_list: %s" %
                                " ".join([(str(x[0]) + '-' + str(x[1]))
                                          for x in cur_sentence_list]))

                    rationale_text = " ".join(tokens[ral_start:(ral_end + 1)])
                    logger.info("answer choice: %s" % str(answer_choice))
                    logger.info("rationale start position: %s" %
                                str(ral_start))
                    logger.info("rationale end position: %s" % str(ral_end))
                    logger.info("rationale: %s" % rationale_text)

                features.append(
                    QAFullInputFeatures(
                        qas_id=example.qas_id,
                        unique_id=unique_id,
                        example_index=example_index,
                        doc_span_index=doc_span_index,
                        sentence_span_list=cur_sentence_list,
                        tokens=tokens,
                        token_to_orig_map=token_to_orig_map,
                        token_is_max_context=token_is_max_context,
                        input_ids=input_ids,
                        input_mask=input_mask,
                        segment_ids=segment_ids,
                        is_impossible=answer_choice,
                        sentence_id=span_sen_id,
                        start_position=None,
                        end_position=None,
                        ral_start_position=ral_start,
                        ral_end_position=ral_end,
                        meta_data=meta_data))

                unique_id += 1

        return features