def test_full_tokenizer(self): tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) text = "lower newer" bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] tokens = tokenizer.tokenize(text, add_prefix_space=True) self.assertListEqual(tokens, bpe_tokens) input_tokens = tokens + [tokenizer.unk_token] input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
def convert_examples_to_features(examples: List[MultiRCExample], tokenizer: RobertaTokenizer, max_seq_length: int = 512, **kwargs): unique_id = 1000000000 features = [] for (example_index, example) in tqdm(enumerate(examples), desc='Convert examples to features', total=len(examples)): query_tokens = tokenizer.tokenize(example.question_text) if query_tokens[-1] != '?': query_tokens.append('?') # word piece index -> token index tok_to_orig_index = [] # token index -> word pieces group start index orig_to_tok_index = [] # word pieces for all doc tokens all_doc_tokens = [] for (i, token) in enumerate(example.doc_tokens): orig_to_tok_index.append(len(all_doc_tokens)) sub_tokens = tokenizer.tokenize(token) for sub_token in sub_tokens: tok_to_orig_index.append(i) all_doc_tokens.append(sub_token) # Process sentence span list sentence_spans = [] for (start, end) in example.sentence_span_list: piece_start = orig_to_tok_index[start] if end < len(example.doc_tokens) - 1: piece_end = orig_to_tok_index[end + 1] - 1 else: piece_end = len(all_doc_tokens) - 1 sentence_spans.append((piece_start, piece_end)) # Process all tokens q_op_tokens = query_tokens + tokenizer.tokenize( example.option_text) doc_tokens = all_doc_tokens[:] utils.truncate_seq_pair(q_op_tokens, doc_tokens, max_seq_length - 4) tokens = [tokenizer.cls_token] + q_op_tokens + [ tokenizer.sep_token, tokenizer.sep_token ] segment_ids = [0] * len(tokens) tokens = tokens + doc_tokens + [tokenizer.sep_token] segment_ids += [1] * (len(doc_tokens) + 1) sentence_list = [] collected_sentence_indices = [] doc_offset = len(q_op_tokens) + 3 for sentence_index, (start, end) in enumerate(sentence_spans): assert start <= end, (example_index, sentence_index, start, end) if start >= len(doc_tokens): break if end >= len(doc_tokens): end = len(doc_tokens) - 1 start = doc_offset + start end = doc_offset + end sentence_list.append((start, end)) assert start < max_seq_length and end < max_seq_length collected_sentence_indices.append(sentence_index) sentence_ids = [] for sentence_id in example.sentence_ids: if sentence_id in collected_sentence_indices: sentence_ids.append(sentence_id) # For multiple style, append 0 at last and for each sentence id, +1 # sentence_ids = [x + 1 for x in sentence_ids] # sentence_ids.append(0) input_ids = tokenizer.convert_tokens_to_ids(tokens) input_mask = [1] * len(input_ids) # Zero-pad up to the sequence length. while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length features.append( MultiRCFeature( example_index=example_index, qas_id=example.qas_id, unique_id=unique_id, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, sentence_span_list=sentence_list, answer=example.answer + 1, # In bert_hierarchical model, the output size is 3. sentence_ids=sentence_ids)) unique_id += 1 logger.info(f'Reading {len(features)} features.') return features
def convert_examples_to_features(examples: List[QAFullExample], tokenizer: RobertaTokenizer, max_seq_length, doc_stride, max_query_length): """Loads a data file into a list of `InputBatch`s.""" unique_id = 1000000000 features = [] for (example_index, example) in tqdm(enumerate(examples), desc='Convert examples to features', total=len(examples)): query_tokens = tokenizer.tokenize(example.question_text) if len(query_tokens) > max_query_length: # query_tokens = query_tokens[0:max_query_length] # Remove the tokens appended at the front of query, which may belong to last query and answer. query_tokens = query_tokens[-max_query_length:] # word piece index -> token index tok_to_orig_index = [] # token index -> word pieces group start index # BertTokenizer.tokenize(doc_tokens[i]) = all_doc_tokens[orig_to_tok_index[i]: orig_to_tok_index[i + 1]] orig_to_tok_index = [] # word pieces for all doc tokens all_doc_tokens = [] for (i, token) in enumerate(example.doc_tokens): orig_to_tok_index.append(len(all_doc_tokens)) sub_tokens = tokenizer.tokenize(token) for sub_token in sub_tokens: tok_to_orig_index.append(i) all_doc_tokens.append(sub_token) # Process sentence span list sentence_spans = [] for (start, end) in example.sentence_span_list: piece_start = orig_to_tok_index[start] if end < len(example.doc_tokens) - 1: piece_end = orig_to_tok_index[end + 1] - 1 else: piece_end = len(all_doc_tokens) - 1 sentence_spans.append((piece_start, piece_end)) # Rationale start and end position in chunk, where is calculated from the start of current chunk. # ral_start_position = None # ral_end_position = None ral_start_position = orig_to_tok_index[example.ral_start_position] if example.ral_end_position < len(example.doc_tokens) - 1: ral_end_position = orig_to_tok_index[example.ral_end_position + 1] - 1 else: ral_end_position = len(all_doc_tokens) - 1 ral_start_position, ral_end_position = utils.improve_answer_span( all_doc_tokens, ral_start_position, ral_end_position, tokenizer, example.orig_answer_text) # The -4 accounts for [CLS], [SEP] and [SEP] max_tokens_for_doc = max_seq_length - len(query_tokens) - 4 # We can have documents that are longer than the maximum sequence length. # To deal with this we do a sliding window approach, where we take chunks # of the up to our max length with a stride of `doc_stride`. _DocSpan = collections.namedtuple("DocSpan", ["start", "length"]) doc_spans = [] start_offset = 0 while start_offset < len(all_doc_tokens): length = len(all_doc_tokens) - start_offset if length > max_tokens_for_doc: length = max_tokens_for_doc doc_spans.append(_DocSpan(start=start_offset, length=length)) if start_offset + length == len(all_doc_tokens): break start_offset += min(length, doc_stride) sentence_spans_list = [] sentence_ids_list = [] for span_id, doc_span in enumerate(doc_spans): span_start = doc_span.start span_end = span_start + doc_span.length - 1 span_sentence = [] sen_ids = [] for sen_idx, (sen_start, sen_end) in enumerate(sentence_spans): if sen_end < span_start: continue if sen_start > span_end: break span_sentence.append( (max(sen_start, span_start), min(sen_end, span_end))) sen_ids.append(sen_idx) sentence_spans_list.append(span_sentence) sentence_ids_list.append(sen_ids) ini_sen_id = example.sentence_id for (doc_span_index, doc_span) in enumerate(doc_spans): # Store the input tokens to transform into input ids later. tokens = [] token_to_orig_map = {} token_is_max_context = {} segment_ids = [] tokens.append(tokenizer.cls_token) segment_ids.append(0) for token in query_tokens: tokens.append(token) segment_ids.append(0) tokens.append(tokenizer.sep_token) segment_ids.append(0) tokens.append(tokenizer.sep_token) segment_ids.append(0) doc_start = doc_span.start doc_offset = len(query_tokens) + 3 sentence_list = sentence_spans_list[doc_span_index] cur_sentence_list = [] for sen_id, sen in enumerate(sentence_list): new_sen = (sen[0] - doc_start + doc_offset, sen[1] - doc_start + doc_offset) cur_sentence_list.append(new_sen) for i in range(doc_span.length): split_token_index = doc_span.start + i # Original index of word piece in all_doc_tokens # Index of word piece in input sequence -> Original word index in doc_tokens token_to_orig_map[len( tokens)] = tok_to_orig_index[split_token_index] # Check if the word piece has the max context in all doc spans. is_max_context = utils.check_is_max_context( doc_spans, doc_span_index, split_token_index) token_is_max_context[len(tokens)] = is_max_context tokens.append(all_doc_tokens[split_token_index]) segment_ids.append(1) # tokens.append("[SEP]") tokens.append(tokenizer.sep_token) segment_ids.append(1) input_ids = tokenizer.convert_tokens_to_ids(tokens) # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. input_mask = [1] * len(input_ids) # Zero-pad up to the sequence length. while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length # ral_start = None # ral_end = None # answer_choice = None doc_start = doc_span.start doc_end = doc_span.start + doc_span.length - 1 # Process rationale out_of_span = False if not (ral_start_position >= doc_start and ral_end_position <= doc_end): out_of_span = True if out_of_span: # TODO: # Considering how to set rationale start and end positions for out of span instances. ral_start = 0 ral_end = 0 answer_choice = 0 else: doc_offset = len(query_tokens) + 2 ral_start = ral_start_position - doc_start + doc_offset ral_end = ral_end_position - doc_start + doc_offset answer_choice = example.is_impossible + 1 # Process sentence id span_sen_id = -1 for piece_sen_id, sen_id in enumerate( sentence_ids_list[doc_span_index]): if ini_sen_id == sen_id: span_sen_id = piece_sen_id # # For no sentence id feature, replace it with [] if span_sen_id == -1: span_sen_id = [] else: span_sen_id = [span_sen_id] meta_data = { 'span_sen_to_orig_sen_map': sentence_ids_list[doc_span_index] } if example_index < 0: logger.info("*** Example ***") logger.info("unique_id: %s" % unique_id) logger.info("example_index: %s" % example_index) logger.info("doc_span_index: %s" % doc_span_index) logger.info("sentence_spans_list: %s" % " ".join([(str(x[0]) + '-' + str(x[1])) for x in cur_sentence_list])) rationale_text = " ".join(tokens[ral_start:(ral_end + 1)]) logger.info("answer choice: %s" % str(answer_choice)) logger.info("rationale start position: %s" % str(ral_start)) logger.info("rationale end position: %s" % str(ral_end)) logger.info("rationale: %s" % rationale_text) features.append( QAFullInputFeatures( qas_id=example.qas_id, unique_id=unique_id, example_index=example_index, doc_span_index=doc_span_index, sentence_span_list=cur_sentence_list, tokens=tokens, token_to_orig_map=token_to_orig_map, token_is_max_context=token_is_max_context, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, is_impossible=answer_choice, sentence_id=span_sen_id, start_position=None, end_position=None, ral_start_position=ral_start, ral_end_position=ral_end, meta_data=meta_data)) unique_id += 1 return features