def test_full_tokenizer(self): vocab_tokens = [ "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", "," ] with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_file = vocab_writer.name tokenizer = BertTokenizer(vocab_file) os.remove(vocab_file) tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/") tokenizer.from_pretrained(vocab_file) os.remove(vocab_file) tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_full_tokenizer_raises_error_for_long_sequences(self): vocab_tokens = [ "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", "," ] with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_file = vocab_writer.name tokenizer = BertTokenizer(vocab_file, max_len=10) os.remove(vocab_file) tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time") indices = tokenizer.convert_tokens_to_ids(tokens) self.assertListEqual(indices, [0 for _ in range(10)]) tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time .") self.assertRaises(ValueError, tokenizer.convert_tokens_to_ids, tokens)
def __init__(self, tokens: [str, ...], tokenizer: BertTokenizer, max_seq_length: int): self.max_seq_length = max_seq_length if max_seq_length else 0 self.tokens = [CLS] + tokenizer.tokenize(" ".join(tokens)) + [SEP] self.token_mask_ids = [ idx for idx, token in enumerate(self.tokens) if token == MASK ] self.len = len(self.tokens) if self.max_seq_length and self.len > self.max_seq_length: logger.warning("'tokens_a' is over {}: {}".format( max_seq_length, self.len)) # raise RuntimeError("'tokens_a' is over {}: {}".format(max_seq_length, self.len)) else: self.input_ids = tokenizer.convert_tokens_to_ids( self.tokens) + [0] * max(self.max_seq_length - self.len, 0) self.attention_mask = [1] * self.len + [0] * max( self.max_seq_length - self.len, 0)
def convert_example_to_features(uid: int, text_a: str, seq_length: int, tokenizer: BertTokenizer) -> InputFeatures: tokens_a = tokenizer.tokenize(text_a) # 我们只处理一个句子,对长句截断, 所以只需要头尾附加CLS/SEP tokens_a = tokens_a[:seq_length - 2] # For single sequences: # tokens: [CLS] the dog is hairy . [SEP] # type_ids: 0 0 0 0 0 0 0 # # Where "type_ids" are used to indicate whether this is the first # sequence or the second sequence. The embedding vectors for `type=0` and # `type=1` were learned during pre-training and are added to the wordpiece # embedding vector (and position vector). This is not *strictly* necessary # since the [SEP] token unambiguously separates the sequences, but it makes # it easier for the model to learn the concept of sequences. # # For classification tasks, the first vector (corresponding to [CLS]) is # used as as the "sentence vector". Note that this only makes sense because # the entire model is fine-tuned. tokens = ["[CLS]"] + tokens_a + ["[SEP]"] input_type_ids = [0] * (len(tokens_a) + 2) input_ids = tokenizer.convert_tokens_to_ids(tokens) # Zero-pad up to the sequence length. x = [0] * (seq_length - len(input_ids)) # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. input_mask = [1] * len(input_ids) + x input_ids += x input_type_ids += x assert len(input_ids) == seq_length assert len(input_mask) == seq_length assert len(input_type_ids) == seq_length return InputFeatures( unique_id=uid, tokens=tokens, input_ids=input_ids, input_mask=input_mask, input_type_ids=input_type_ids, )
def __init__(self, text: str, tokenizer: BertTokenizer, max_seq_length: int, language: str): self.max_seq_length = max_seq_length if max_seq_length else 0 self.lang = language if self.lang == "ja": tokens = [] for v, group in groupby(text, key=lambda x: x == "M"): if v: tokens += [MASK for _ in group] else: tokens += [ morph.midasi for morph in juman.analysis("".join(group)) ] elif self.lang == "en": tokens = [ MASK if token == "M" else token for token in text.split(" ") ] else: raise ValueError("Unsupported value: {}".format(self.lang)) self.original_tokens = tokens self.original_token_mask_ids = [ idx for idx, token in enumerate(self.original_tokens) if token == MASK ] self.tokens = [CLS] + tokenizer.tokenize(" ".join(tokens)) + [SEP] self.token_mask_ids = [ idx for idx, token in enumerate(self.tokens) if token == MASK ] self.len = len(self.tokens) if self.max_seq_length and self.len > self.max_seq_length: raise RuntimeError("'tokens_a' is over {}: {}".format( max_seq_length, self.len)) self.input_ids = tokenizer.convert_tokens_to_ids( self.tokens) + [0] * max(self.max_seq_length - self.len, 0) self.attention_mask = [1] * self.len + [0] * max( self.max_seq_length - self.len, 0)
def convert_examples_to_features(examples: List[SQuADFullExample], tokenizer: BertTokenizer, max_seq_length, doc_stride, max_query_length, is_training: bool): """Loads a data file into a list of `InputBatch`s.""" unique_id = 1000000000 features = [] for (example_index, example) in tqdm(enumerate(examples), desc='Convert examples to features', total=len(examples)): query_tokens = tokenizer.tokenize(example.question_text) if len(query_tokens) > max_query_length: # query_tokens = query_tokens[0:max_query_length] # Remove the tokens appended at the front of query, which may belong to last query and answer. query_tokens = query_tokens[-max_query_length:] # word piece index -> token index tok_to_orig_index = [] # token index -> word pieces group start index # BertTokenizer.tokenize(doc_tokens[i]) = all_doc_tokens[orig_to_tok_index[i]: orig_to_tok_index[i + 1]] orig_to_tok_index = [] # word pieces for all doc tokens all_doc_tokens = [] for (i, token) in enumerate(example.doc_tokens): orig_to_tok_index.append(len(all_doc_tokens)) sub_tokens = tokenizer.tokenize(token) for sub_token in sub_tokens: tok_to_orig_index.append(i) all_doc_tokens.append(sub_token) # Process sentence span list sentence_spans = [] for (start, end) in example.sentence_span_list: piece_start = orig_to_tok_index[start] if end < len(example.doc_tokens) - 1: piece_end = orig_to_tok_index[end + 1] - 1 else: piece_end = len(all_doc_tokens) - 1 sentence_spans.append((piece_start, piece_end)) # The -3 accounts for [CLS], [SEP] and [SEP] max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 # We can have documents that are longer than the maximum sequence length. # To deal with this we do a sliding window approach, where we take chunks # of the up to our max length with a stride of `doc_stride`. _DocSpan = collections.namedtuple("DocSpan", ["start", "length"]) doc_spans = [] start_offset = 0 while start_offset < len(all_doc_tokens): length = len(all_doc_tokens) - start_offset if length > max_tokens_for_doc: length = max_tokens_for_doc doc_spans.append(_DocSpan(start=start_offset, length=length)) if start_offset + length == len(all_doc_tokens): break start_offset += min(length, doc_stride) sentence_spans_list = [] sentence_ids_list = [] for span_id, doc_span in enumerate(doc_spans): span_start = doc_span.start span_end = span_start + doc_span.length - 1 span_sentence = [] sen_ids = [] for sen_idx, (sen_start, sen_end) in enumerate(sentence_spans): if sen_end < span_start: continue if sen_start > span_end: break span_sentence.append( (max(sen_start, span_start), min(sen_end, span_end))) sen_ids.append(sen_idx) sentence_spans_list.append(span_sentence) sentence_ids_list.append(sen_ids) ini_sen_id = example.sentence_id for (doc_span_index, doc_span) in enumerate(doc_spans): # Store the input tokens to transform into input ids later. tokens = [] token_to_orig_map = {} token_is_max_context = {} segment_ids = [] tokens.append("[CLS]") segment_ids.append(0) for token in query_tokens: tokens.append(token) segment_ids.append(0) tokens.append("[SEP]") segment_ids.append(0) doc_start = doc_span.start doc_offset = len(query_tokens) + 2 sentence_list = sentence_spans_list[doc_span_index] cur_sentence_list = [] for sen_id, sen in enumerate(sentence_list): new_sen = (sen[0] - doc_start + doc_offset, sen[1] - doc_start + doc_offset) cur_sentence_list.append(new_sen) for i in range(doc_span.length): split_token_index = doc_span.start + i # Original index of word piece in all_doc_tokens # Index of word piece in input sequence -> Original word index in doc_tokens token_to_orig_map[len( tokens)] = tok_to_orig_index[split_token_index] # Check if the word piece has the max context in all doc spans. is_max_context = utils.check_is_max_context( doc_spans, doc_span_index, split_token_index) token_is_max_context[len(tokens)] = is_max_context tokens.append(all_doc_tokens[split_token_index]) segment_ids.append(1) tokens.append("[SEP]") segment_ids.append(1) input_ids = tokenizer.convert_tokens_to_ids(tokens) # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. input_mask = [1] * len(input_ids) # Zero-pad up to the sequence length. while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length # ral_start = None # ral_end = None # answer_choice = None answer_choice = -1 # Process sentence id span_sen_id = -1 for piece_sen_id, sen_id in enumerate( sentence_ids_list[doc_span_index]): if ini_sen_id == sen_id: span_sen_id = piece_sen_id meta_data = { 'span_sen_to_orig_sen_map': sentence_ids_list[doc_span_index] } if example_index < 0: logger.info("*** Example ***") logger.info("unique_id: %s" % unique_id) logger.info("example_index: %s" % example_index) logger.info("doc_span_index: %s" % doc_span_index) logger.info("sentence_spans_list: %s" % " ".join([(str(x[0]) + '-' + str(x[1])) for x in cur_sentence_list])) logger.info("answer choice: %s" % str(answer_choice)) features.append( QAFullInputFeatures( qas_id=example.qas_id, unique_id=unique_id, example_index=example_index, doc_span_index=doc_span_index, sentence_span_list=cur_sentence_list, tokens=tokens, token_to_orig_map=token_to_orig_map, token_is_max_context=token_is_max_context, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, is_impossible=answer_choice, sentence_id=span_sen_id, start_position=None, end_position=None, ral_start_position=-1, ral_end_position=-1, meta_data=meta_data)) unique_id += 1 return features
class BertCorrector(Detector): def __init__(self, bert_model_dir='', bert_model_vocab='', max_seq_length=384): super(BertCorrector, self).__init__() self.name = 'bert_corrector' self.bert_model_dir = os.path.join(pwd_path, bert_model_dir) self.bert_model_vocab = os.path.join(pwd_path, bert_model_vocab) self.max_seq_length = max_seq_length self.initialized_bert_corrector = False def check_bert_corrector_initialized(self): if not self.initialized_bert_corrector: self.initialize_bert_corrector() def initialize_bert_corrector(self): t1 = time.time() self.bert_tokenizer = BertTokenizer(self.bert_model_vocab) # Prepare model self.model = BertForMaskedLM.from_pretrained(self.bert_model_dir) print("Loaded model: %s, vocab file: %s, spend: %.3f s." % (self.bert_model_dir, self.bert_model_vocab, time.time() - t1)) self.initialized_bert_corrector = True def convert_sentence_to_features(self, sentence, tokenizer, max_seq_length, error_begin_idx=0, error_end_idx=0): """Loads a sentence into a list of `InputBatch`s.""" self.check_bert_corrector_initialized() features = [] tokens_a = list(sentence) # For single sequences: # tokens: [CLS] the dog is hairy . [SEP] # type_ids: 0 0 0 0 0 0 0 tokens = ["[CLS]"] + tokens_a + ["[SEP]"] k = error_begin_idx + 1 for i in range(error_end_idx - error_begin_idx): tokens[k] = '[MASK]' k += 1 segment_ids = [0] * len(tokens) input_ids = self.bert_tokenizer.convert_tokens_to_ids(tokens) mask_ids = [i for i, v in enumerate(input_ids) if v == MASK_ID] # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. input_mask = [1] * len(input_ids) # Zero-pad up to the sequence length. padding = [0] * (max_seq_length - len(input_ids)) input_ids += padding input_mask += padding segment_ids += padding features.append( InputFeatures(input_ids=input_ids, input_mask=input_mask, mask_ids=mask_ids, segment_ids=segment_ids, input_tokens=tokens)) return features def check_vocab_has_all_token(self, sentence): self.check_bert_corrector_initialized() flag = True for i in list(sentence): if i not in self.bert_tokenizer.vocab: flag = False break return flag def bert_lm_infer(self, sentence, error_begin_idx=0, error_end_idx=0): self.check_bert_corrector_initialized() corrected_item = sentence[error_begin_idx:error_end_idx] eval_features = self.convert_sentence_to_features( sentence=sentence, tokenizer=self.bert_tokenizer, max_seq_length=self.max_seq_length, error_begin_idx=error_begin_idx, error_end_idx=error_end_idx) for f in eval_features: input_ids = torch.tensor([f.input_ids]) segment_ids = torch.tensor([f.segment_ids]) predictions = self.model(input_ids, segment_ids) # confirm we were able to predict 'henson' masked_ids = f.mask_ids if masked_ids: for idx, i in enumerate(masked_ids): predicted_index = torch.argmax(predictions[0, i]).item() predicted_token = self.bert_tokenizer.convert_ids_to_tokens( [predicted_index])[0] print('original text is:', f.input_tokens) print('Mask predict is:', predicted_token) corrected_item = predicted_token return corrected_item def correct(self, sentence=''): """ 句子改错 :param sentence: 句子文本 :return: 改正后的句子, list(wrong, right, begin_idx, end_idx) """ detail = [] maybe_errors = self.detect(sentence) maybe_errors = sorted(maybe_errors, key=operator.itemgetter(2), reverse=False) for item, begin_idx, end_idx, err_type in maybe_errors: # 纠错,逐个处理 before_sent = sentence[:begin_idx] after_sent = sentence[end_idx:] # 困惑集中指定的词,直接取结果 if err_type == error_type["confusion"]: corrected_item = self.custom_confusion[item] elif err_type == error_type["char"]: # 对非中文的错字不做处理 if not is_chinese_string(item): continue if not self.check_vocab_has_all_token(sentence): continue # 取得所有可能正确的字 corrected_item = self.bert_lm_infer(sentence, error_begin_idx=begin_idx, error_end_idx=end_idx) elif err_type == error_type["word"]: corrected_item = item else: print('not strand error_type') # output if corrected_item != item: sentence = before_sent + corrected_item + after_sent detail_word = [item, corrected_item, begin_idx, end_idx] detail.append(detail_word) detail = sorted(detail, key=operator.itemgetter(2)) return sentence, detail
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument("--bert_model_dir", default=None, type=str, required=True, help="Bert pre-trained model config dir") parser.add_argument("--bert_model_vocab", default=None, type=str, required=True, help="Bert pre-trained model vocab path") parser.add_argument("--output_dir", default="./output", type=str, required=True, help="The output directory where the model checkpoints and predictions will be written.") # Other parameters parser.add_argument("--predict_file", default=None, type=str, help="for predictions.") parser.add_argument("--max_seq_length", default=384, type=int, help="The maximum total input sequence length after WordPiece tokenization. Sequences " "longer than this will be truncated, and sequences shorter than this will be padded.") parser.add_argument("--doc_stride", default=128, type=int, help="When splitting up a long document into chunks, how much stride to take between chunks.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--verbose_logging", default=False, action='store_true', help="If true, all of the warnings related to data processing will be printed. " "A number of warnings are expected for a normal SQuAD evaluation.") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") args = parser.parse_args() device = torch.device("cpu") random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) tokenizer = BertTokenizer(args.bert_model_vocab) # Prepare model model = BertForMaskedLM.from_pretrained(args.bert_model_dir) # Save a trained model model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") if not os.path.exists(output_model_file): torch.save(model_to_save.state_dict(), output_model_file) # Load a trained model that you have fine-tuned model_state_dict = torch.load(output_model_file) model.to(device) # Tokenized input text = "吸 烟 的 人 容 易 得 癌 症" print(text) tokenized_text = tokenizer.tokenize(text) # Mask a token that we will try to predict back with `BertForMaskedLM` masked_index = 8 tokenized_text[masked_index] = '[MASK]' print(tokenized_text) # Convert token to vocabulary indices indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) # Define sentence A and B indices associated to 1st and 2nd sentences (see paper) segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 0] # Convert inputs to PyTorch tensors tokens_tensor = torch.tensor([indexed_tokens]) segments_tensors = torch.tensor([segments_ids]) # Load pre-trained model (weights) model.eval() # Predict all tokens predictions = model(tokens_tensor, segments_tensors) # confirm we were able to predict 'henson' predicted_index = torch.argmax(predictions[0, masked_index]).item() print(predicted_index) predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0] print(predicted_token) # infer one line end if args.predict_file: eval_examples = read_lm_examples(input_file=args.predict_file) eval_features = convert_examples_to_features( examples=eval_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length) logger.info("***** Running predictions *****") logger.info(" Num orig examples = %d", len(eval_examples)) logger.info(" Num split examples = %d", len(eval_features)) logger.info("Start predict ...") for f in eval_features: input_ids = torch.tensor([f.input_ids]) segment_ids = torch.tensor([f.segment_ids]) predictions = model(input_ids, segment_ids) # confirm we were able to predict 'henson' masked_ids = f.mask_ids if masked_ids: print(masked_ids) for idx, i in enumerate(masked_ids): predicted_index = torch.argmax(predictions[0, i]).item() predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0] print('original text is:', f.input_tokens) print('Mask predict is:', predicted_token)
def convert_examples_to_features(examples: List[QAFullExample], tokenizer: BertTokenizer, max_seq_length, doc_stride, max_query_length, is_training: bool): """Loads a data file into a list of `InputBatch`s.""" unique_id = 1000000000 features = [] drop = 0 for (example_index, example) in tqdm(enumerate(examples), desc='Convert examples to features', total=len(examples)): query_tokens = tokenizer.tokenize(example.question_text) # if len(query_tokens) > max_query_length: # query_tokens = query_tokens[0:max_query_length] # Remove the tokens appended at the front of query, which may belong to last query and answer. # query_tokens = query_tokens[-max_query_length:] query_tokens = ["[CLS]"] + query_tokens + ["[SEP]"] ques_input_ids = tokenizer.convert_tokens_to_ids(query_tokens) ques_input_mask = [1] * len(ques_input_ids) assert len(ques_input_ids) <= max_query_length while len(ques_input_ids) < max_query_length: ques_input_ids.append(0) ques_input_mask.append(0) assert len(ques_input_ids) == max_query_length assert len(ques_input_mask) == max_query_length doc_sen_tokens = example.doc_tokens all_doc_tokens = [] for sentence in doc_sen_tokens: cur_sen_doc_tokens = ["[CLS]"] for token in sentence: sub_tokens = tokenizer.tokenize(token) if len(cur_sen_doc_tokens) + 1 + len( sub_tokens) > max_seq_length: drop += 1 break cur_sen_doc_tokens.extend(sub_tokens) cur_sen_doc_tokens.append("[SEP]") all_doc_tokens.append(cur_sen_doc_tokens) pass_input_ids = [] pass_input_mask = [] for sentence in all_doc_tokens: sentence_input_ids = tokenizer.convert_tokens_to_ids(sentence) sentence_input_mask = [1] * len(sentence_input_ids) assert len(sentence_input_ids) <= max_seq_length, len( sentence_input_ids) while len(sentence_input_ids) < max_seq_length: sentence_input_ids.append(0) sentence_input_mask.append(0) assert len(sentence_input_ids) == max_seq_length assert len(sentence_input_mask) == max_seq_length pass_input_ids.append(sentence_input_ids) pass_input_mask.append(sentence_input_mask) features.append( SingleSentenceFeature(qas_id=example.qas_id, unique_id=unique_id, example_index=example_index, tokens=all_doc_tokens, ques_input_ids=ques_input_ids, ques_input_mask=ques_input_mask, pass_input_ids=pass_input_ids, pass_input_mask=pass_input_mask, is_impossible=example.is_impossible, sentence_id=example.sentence_id)) unique_id += 1 logger.info( f'Read {len(features)} features and trunk {drop} sentences') return features
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument("--bert_model_dir", default='../data/bert_models/chinese_finetuned_lm/', type=str, help="Bert pre-trained model config dir") parser.add_argument( "--bert_model_vocab", default='../data/bert_models/chinese_finetuned_lm/vocab.txt', type=str, help="Bert pre-trained model vocab path") parser.add_argument( "--output_dir", default="./output", type=str, help= "The output directory where the model checkpoints and predictions will be written." ) # Other parameters parser.add_argument("--predict_file", default='../data/cn/lm_test_zh.txt', type=str, help="for predictions.") parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. Sequences " "longer than this will be truncated, and sequences shorter than this will be padded." ) parser.add_argument( "--doc_stride", default=64, type=int, help= "When splitting up a long document into chunks, how much stride to take between chunks." ) parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument( "--verbose_logging", default=False, action='store_true', help= "If true, all of the warnings related to data processing will be printed. " "A number of warnings are expected for a normal SQuAD evaluation.") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") args = parser.parse_args() device = torch.device("cpu") random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) tokenizer = BertTokenizer(args.bert_model_vocab) MASK_ID = tokenizer.convert_tokens_to_ids([MASK_TOKEN])[0] print('MASK_ID,', MASK_ID) # Prepare model model = BertForMaskedLM.from_pretrained(args.bert_model_dir) # Save a trained model model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") if not os.path.exists(output_model_file): torch.save(model_to_save.state_dict(), output_model_file) # Load a trained model that you have fine-tuned model_state_dict = torch.load(output_model_file) model.to(device) # Tokenized input text = "吸烟的人容易得癌症" tokenized_text = tokenizer.tokenize(text) print(text, '=>', tokenized_text) # Mask a token that we will try to predict back with `BertForMaskedLM` masked_index = 8 tokenized_text[masked_index] = '[MASK]' # Convert token to vocabulary indices indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) # Define sentence A and B indices associated to 1st and 2nd sentences (see paper) segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 0] # Convert inputs to PyTorch tensors print('tokens, segments_ids:', indexed_tokens, segments_ids) tokens_tensor = torch.tensor([indexed_tokens]) segments_tensors = torch.tensor([segments_ids]) # Load pre-trained model (weights) model.eval() # Predict all tokens predictions = model(tokens_tensor, segments_tensors) predicted_index = torch.argmax(predictions[0, masked_index]).item() print(predicted_index) predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0] print(predicted_token) # infer one line end # predict ppl and prob of each word text = "吸烟的人容易得癌症" tokenized_text = tokenizer.tokenize(text) indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) # Define sentence A and B indices associated to 1st and 2nd sentences (see paper) segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 0] tokens_tensor = torch.tensor([indexed_tokens]) segments_tensors = torch.tensor([segments_ids]) sentence_loss = 0.0 sentence_count = 0 for idx, label in enumerate(text): print(label) label_id = tokenizer.convert_tokens_to_ids([label])[0] lm_labels = [-1, -1, -1, -1, -1, -1, -1, -1, -1] if idx != 0: lm_labels[idx] = label_id if idx == 1: lm_labels = indexed_tokens print(lm_labels) masked_lm_labels = torch.tensor([lm_labels]) # Predict all tokens loss = model(tokens_tensor, segments_tensors, masked_lm_labels=masked_lm_labels) print('loss:', loss) prob = float(np.exp(-loss.item())) print('prob:', prob) sentence_loss += prob sentence_count += 1 ppl = float(np.exp(sentence_loss / sentence_count)) print('ppl:', ppl) # confirm we were able to predict 'henson' # infer each word with mask one text = "吸烟的人容易得癌症" for masked_index, label in enumerate(text): tokenized_text = tokenizer.tokenize(text) print(text, '=>', tokenized_text) tokenized_text[masked_index] = '[MASK]' print(tokenized_text) # Convert token to vocabulary indices indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) tokens_tensor = torch.tensor([indexed_tokens]) segments_tensors = torch.tensor([segments_ids]) predictions = model(tokens_tensor, segments_tensors) print('expected label:', label) predicted_index = torch.argmax(predictions[0, masked_index]).item() predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0] print('predict label:', predicted_token) scores = predictions[0, masked_index] # predicted_index = torch.argmax(scores).item() top_scores = torch.sort(scores, 0, True) top_score_val = top_scores[0][:5] top_score_idx = top_scores[1][:5] for j in range(len(top_score_idx)): print( 'Mask predict is:', tokenizer.convert_ids_to_tokens([top_score_idx[j].item()])[0], ' prob:', top_score_val[j].item()) print() if args.predict_file: eval_examples = read_lm_examples(input_file=args.predict_file) eval_features = convert_examples_to_features( examples=eval_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, mask_token=MASK_TOKEN, mask_id=MASK_ID) logger.info("***** Running predictions *****") logger.info(" Num orig examples = %d", len(eval_examples)) logger.info(" Num split examples = %d", len(eval_features)) logger.info("Start predict ...") for f in eval_features: input_ids = torch.tensor([f.input_ids]) segment_ids = torch.tensor([f.segment_ids]) predictions = model(input_ids, segment_ids) # confirm we were able to predict 'henson' mask_positions = f.mask_positions if mask_positions: for idx, i in enumerate(mask_positions): if not i: continue scores = predictions[0, i] # predicted_index = torch.argmax(scores).item() top_scores = torch.sort(scores, 0, True) top_score_val = top_scores[0][:5] top_score_idx = top_scores[1][:5] # predicted_prob = predictions[0, i][predicted_index].item() # predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0] print('original text is:', f.input_tokens) # print('Mask predict is:', predicted_token, ' prob:', predicted_prob) for j in range(len(top_score_idx)): print( 'Mask predict is:', tokenizer.convert_ids_to_tokens( [top_score_idx[j].item()])[0], ' prob:', top_score_val[j].item())
def convert_examples_to_features(examples: List[QAFullExample], tokenizer: BertTokenizer, max_seq_length: int, doc_stride: int, max_query_length: int, is_training: bool): unique_id = 1000000000 features = [] for (example_index, example) in tqdm(enumerate(examples), desc='Converting examples to features..', total=len(examples)): query_tokens = tokenizer.tokenize(example.question_text) if len(query_tokens) > max_query_length: query_tokens = query_tokens[-max_query_length:] tok_to_orig_index = [] orig_to_tok_index = [] all_doc_tokens = [] for (i, token) in enumerate(example.doc_tokens): orig_to_tok_index.append(len(all_doc_tokens)) sub_tokens = tokenizer.tokenize(token) for sub_token in sub_tokens: tok_to_orig_index.append(i) all_doc_tokens.append(sub_token) sentence_spans = [] for (start, end) in example.sentence_span_list: piece_start = orig_to_tok_index[start] if end < len(example.doc_tokens) - 1: piece_end = orig_to_tok_index[end + 1] - 1 else: piece_end = len(all_doc_tokens) - 1 sentence_spans.append((piece_start, piece_end)) max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 _DocSpan = collections.namedtuple("DocSpan", ["start", "length"]) doc_spans = [] start_offset = 0 while start_offset < len(all_doc_tokens): length = len(all_doc_tokens) - start_offset if length > max_tokens_for_doc: length = max_tokens_for_doc doc_spans.append(_DocSpan(start=start_offset, length=length)) if start_offset + length == len(all_doc_tokens): break start_offset += min(length, doc_stride) sentence_spans_list = [] sentence_ids_list = [] for span_id, doc_span in enumerate(doc_spans): span_start = doc_span.start span_end = span_start + doc_span.length - 1 span_sentence = [] sen_ids = [] for sen_idx, (sen_start, sen_end) in enumerate(sentence_spans): if sen_end < span_start: continue if sen_start > span_end: break span_sentence.append((max(sen_start, span_start), min(sen_end, span_end))) sen_ids.append(sen_idx) sentence_spans_list.append(span_sentence) sentence_ids_list.append(sen_ids) ini_sen_id: List[int] = example.sentence_id for (doc_span_index, doc_span) in enumerate(doc_spans): token_to_orig_map = {} token_is_max_context = {} tokens = ["[CLS]"] + query_tokens + ["[SEP]"] segment_ids = [0] * len(tokens) doc_start = doc_span.start doc_offset = len(query_tokens) + 2 sentence_list = sentence_spans_list[doc_span_index] cur_sentence_list = [] for sen_id, sen in enumerate(sentence_list): new_sen = (sen[0] - doc_start + doc_offset, sen[1] - doc_start + doc_offset) cur_sentence_list.append(new_sen) for i in range(doc_span.length): split_token_index = doc_span.start + i token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] is_max_context = utils.check_is_max_context(doc_spans, doc_span_index, split_token_index) token_is_max_context[len(tokens)] = is_max_context tokens.append(all_doc_tokens[split_token_index]) segment_ids.append(1) tokens.append("[SEP]") segment_ids.append(1) input_ids = tokenizer.convert_tokens_to_ids(tokens) input_mask = [1] * len(input_ids) while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) assert len(input_ids) == len(input_mask) == len(segment_ids) == max_seq_length doc_start = doc_span.start doc_end = doc_span.start + doc_span.length - 1 """ There are multiple evidence sentences in some examples. To avoid multi-label setting, we choose to use the evidence sentence with the max length. """ span_sen_id = -1 max_evidence_length = 0 for piece_sen_id, sen_id in enumerate(sentence_ids_list[doc_span_index]): if sen_id in ini_sen_id: evidence_length = cur_sentence_list[piece_sen_id][1] - cur_sentence_list[piece_sen_id][0] if evidence_length > max_evidence_length: max_evidence_length = evidence_length span_sen_id = piece_sen_id meta_data = { 'span_sen_to_orig_sen_map': sentence_ids_list[doc_span_index] } if span_sen_id == -1: answer_choice = 0 else: answer_choice = example.is_impossible + 1 features.append(QAFullInputFeatures( qas_id=example.qas_id, unique_id=unique_id, example_index=example_index, doc_span_index=doc_span_index, sentence_span_list=cur_sentence_list, tokens=tokens, token_to_orig_map=token_to_orig_map, token_is_max_context=token_is_max_context, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, is_impossible=answer_choice, sentence_id=span_sen_id, start_position=None, end_position=None, ral_start_position=None, ral_end_position=None, meta_data=meta_data )) unique_id += 1 return features
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument( "--max_seq_length", default=512, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument('--ffn_type', default=0, type=int, help="0: default mlp; 1: W((Wx+b) elem_prod x);") parser.add_argument('--num_qkv', default=0, type=int, help="Number of different <Q,K,V>.") parser.add_argument('--seg_emb', action='store_true', help="Using segment embedding for self-attention.") # decoding parameters parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument("--input_file", type=str, help="Input file") parser.add_argument('--subset', type=int, default=0, help="Decode a subset of the input dataset.") parser.add_argument("--output_file", type=str, help="output file") parser.add_argument("--split", type=str, default="", help="Data split (train/val/test).") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--seed', type=int, default=123, help="random seed for initialization") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--new_pos_ids', action='store_true', help="Use new position ids for LMs.") parser.add_argument('--batch_size', type=int, default=4, help="Batch size for decoding.") parser.add_argument('--beam_size', type=int, default=1, help="Beam size for searching") parser.add_argument('--length_penalty', type=float, default=0, help="Length penalty for beam search") parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument('--topk', type=int, default=10, help="Value of K.") parser.add_argument('--forbid_duplicate_ngrams', action='store_true') parser.add_argument('--forbid_ignore_word', type=str, default=None, help="Ignore the word during forbid_duplicate_ngrams") parser.add_argument("--min_len", default=None, type=int) parser.add_argument('--need_score_traces', action='store_true') parser.add_argument('--ngram_size', type=int, default=3) parser.add_argument('--mode', default="s2s", choices=["s2s", "l2r", "both"]) parser.add_argument('--max_tgt_length', type=int, default=128, help="maximum length of target sequence") parser.add_argument( '--s2s_special_token', action='store_true', help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") parser.add_argument('--s2s_add_segment', action='store_true', help="Additional segmental for the encoder of S2S.") parser.add_argument( '--s2s_share_segment', action='store_true', help= "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)." ) parser.add_argument('--pos_shift', action='store_true', help="Using position shift for fine-tuning.") parser.add_argument('--not_predict_token', type=str, default=None, help="Do not predict the tokens during decoding.") args = parser.parse_args() if args.need_score_traces and args.beam_size <= 1: raise ValueError( "Score trace is only available for beam search with beam size > 1." ) if args.max_tgt_length >= args.max_seq_length - 2: raise ValueError("Maximum tgt length exceeds max seq length - 2.") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # tokenizer = BertTokenizer.from_pretrained( # args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer( vocab_file= '/ps2/intern/clsi/BERT/bert_weights/cased_L-24_H-1024_A-16/vocab.txt', do_lower_case=args.do_lower_case) tokenizer.max_len = args.max_seq_length pair_num_relation = 0 bi_uni_pipeline = [] bi_uni_pipeline.append( seq2seq_loader.Preprocess4Seq2seqDecoder( list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode="s2s", num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift)) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model cls_num_labels = 2 type_vocab_size = 6 + \ (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]", "[S2S_SOS]"]) def _get_token_id_set(s): r = None if s: w_list = [] for w in s.split('|'): if w.startswith('[') and w.endswith(']'): w_list.append(w.upper()) else: w_list.append(w) r = set(tokenizer.convert_tokens_to_ids(w_list)) return r forbid_ignore_set = _get_token_id_set(args.forbid_ignore_word) not_predict_set = _get_token_id_set(args.not_predict_token) print(args.model_recover_path) for model_recover_path in glob.glob(args.model_recover_path.strip()): logger.info("***** Recover model: %s *****", model_recover_path) model_recover = torch.load(model_recover_path) model = BertForSeq2SeqDecoder.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=pair_num_relation, type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, search_beam_size=args.beam_size, length_penalty=args.length_penalty, eos_id=eos_word_ids, sos_id=sos_word_id, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, not_predict_set=not_predict_set, ngram_size=args.ngram_size, min_len=args.min_len, mode=args.mode, max_position_embeddings=args.max_seq_length, ffn_type=args.ffn_type, num_qkv=args.num_qkv, seg_emb=args.seg_emb, pos_shift=args.pos_shift, topk=args.topk, config_path=args.config_path) del model_recover if args.fp16: model.half() model.to(device) if n_gpu > 1: model = torch.nn.DataParallel(model) torch.cuda.empty_cache() model.eval() next_i = 0 max_src_length = args.max_seq_length - 2 - args.max_tgt_length ## for YFG style json # testset = loads_json(args.input_file, 'Load Test Set: '+args.input_file) # if args.subset > 0: # logger.info("Decoding subset: %d", args.subset) # testset = testset[:args.subset] with open(args.input_file, encoding="utf-8") as fin: data = json.load(fin) # input_lines = [x.strip() for x in fin.readlines()] # if args.subset > 0: # logger.info("Decoding subset: %d", args.subset) # input_lines = input_lines[:args.subset] # data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer # input_lines = [data_tokenizer.tokenize( # x)[:max_src_length] for x in input_lines] # input_lines = sorted(list(enumerate(input_lines)), # key=lambda x: -len(x[1])) # output_lines = [""] * len(input_lines) # score_trace_list = [None] * len(input_lines) # total_batch = math.ceil(len(input_lines) / args.batch_size) data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer PQA_dict = {} #will store the generated distractors dis_tot = 0 dis_n = 0 len_tot = 0 hypothesis = {} ##change to process one by one and store the distractors in PQA json form ##with tqdm(total=total_batch) as pbar: # for example in tqdm(testset): # question_id = str(example['id']['file_id']) + '_' + str(example['id']['question_id']) # if question_id in hypothesis: # continue # dis_n += 1 # if dis_n % 2000 == 0: # logger.info("Already processed: "+str(dis_n)) counter = 0 for race_id, example in tqdm(data.items()): counter += 1 if args.subset > 0 and counter >= args.subset: break eg_dict = {} # eg_dict["question_id"] = question_id # eg_dict["question"] = ' '.join(example['question']) # eg_dict["context"] = ' '.join(example['article']) eg_dict["question"] = example['question'] eg_dict["context"] = example['context'] label = int(example["label"]) options = example["options"] answer = options[label] #new_distractors = [] pred1 = None pred2 = None pred3 = None #while next_i < len(input_lines): #_chunk = input_lines[next_i:next_i + args.batch_size] #line = example["context"].strip() + ' ' + example["question"].strip() question = example['question'] question = question.replace('_', ' ') line = ' '.join( nltk.word_tokenize(example['context']) + nltk.word_tokenize(question)) line = [data_tokenizer.tokenize(line)[:max_src_length]] # buf_id = [x[0] for x in _chunk] # buf = [x[1] for x in _chunk] buf = line #next_i += args.batch_size max_a_len = max([len(x) for x in buf]) instances = [] for instance in [(x, max_a_len) for x in buf]: for proc in bi_uni_pipeline: instances.append(proc(instance)) with torch.no_grad(): batch = seq2seq_loader.batch_list_to_batch_tensors(instances) batch = [ t.to(device) if t is not None else None for t in batch ] input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch # for i in range(1): #try max 10 times # if len(new_distractors) >= 3: # break traces = model(input_ids, token_type_ids, position_ids, input_mask, task_idx=task_idx, mask_qkv=mask_qkv) if args.beam_size > 1: traces = {k: v.tolist() for k, v in traces.items()} output_ids = traces['pred_seq'] # print (np.array(output_ids).shape) # print (output_ids) else: output_ids = traces.tolist() # now only supports single batch decoding!!! # will keep the second and third sequence as backup for i in range(len(buf)): # print (len(buf), buf) for s in range(len(output_ids)): output_seq = output_ids[s] #w_ids = output_ids[i] #output_buf = tokenizer.convert_ids_to_tokens(w_ids) output_buf = tokenizer.convert_ids_to_tokens( output_seq) output_tokens = [] for t in output_buf: if t in ("[SEP]", "[PAD]"): break output_tokens.append(t) if s == 1: backup_1 = output_tokens if s == 2: backup_2 = output_tokens if pred1 is None: pred1 = output_tokens elif jaccard_similarity(pred1, output_tokens) < 0.5: if pred2 is None: pred2 = output_tokens elif pred3 is None: if jaccard_similarity(pred2, output_tokens) < 0.5: pred3 = output_tokens if pred1 is not None and pred2 is not None and pred3 is not None: break if pred2 is None: pred2 = backup_1 if pred3 is None: pred3 = backup_2 elif pred3 is None: pred3 = backup_1 # output_sequence = ' '.join(detokenize(output_tokens)) # print (output_sequence) # print (output_sequence) # if output_sequence.lower().strip() == answer.lower().strip(): # continue # repeated = False # for cand in new_distractors: # if output_sequence.lower().strip() == cand.lower().strip(): # repeated = True # break # if not repeated: # new_distractors.append(output_sequence.strip()) #hypothesis[question_id] = [pred1, pred2, pred3] new_distractors = [pred1, pred2, pred3] # print (new_distractors) # dis_tot += len(new_distractors) # # fill the missing ones with original distractors # for i in range(4): # if len(new_distractors) >= 3: # break # elif i == label: # continue # else: # new_distractors.append(options[i]) for dis in new_distractors: len_tot += len(dis) dis_n += 1 new_distractors = [ ' '.join(detokenize(dis)) for dis in new_distractors if dis is not None ] assert len(new_distractors) == 3, "Number of distractors WRONG" new_distractors.insert(label, answer) #eg_dict["generated_distractors"] = new_distractors eg_dict["options"] = new_distractors eg_dict["label"] = label #PQA_dict[question_id] = eg_dict PQA_dict[race_id] = eg_dict # reference = {} # for example in testset: # question_id = str(example['id']['file_id']) + '_' + str(example['id']['question_id']) # if question_id not in reference.keys(): # reference[question_id] = [example['distractor']] # else: # reference[question_id].append(example['distractor']) # _ = eval(hypothesis, reference) # assert len(PQA_dict) == len(data), "Number of examples WRONG" # logger.info("Average number of GENERATED distractor per question: "+str(dis_tot/dis_n)) logger.info("Average length of distractors: " + str(len_tot / dis_n)) with open(args.output_file, mode='w', encoding='utf-8') as f: json.dump(PQA_dict, f, indent=4)
def batch_generator_with_multi(file_path: str, tokenizer: BertTokenizer, max_seq_length: int, batch_size: int, device, data_limit: int): batch_inputs = [] batch_att_mask = [] batch_target_ids = [] batch_position = [] for n, instance in enumerate(read_instance(file_path)): if data_limit == n: logger.info( "The maximum number of rows has been reached: {}".format( data_limit)) batch_inputs = [] break tokens_with_mask = instance["surfaces"] mask_ids = [ idx for idx, token in enumerate(tokens_with_mask) if token == MASK ] tokenized_tokens = tokenizer.tokenize(" ".join(tokens_with_mask)) if n < 3: logger.debug(tokens_with_mask) logger.debug(tokenized_tokens) logger.debug(mask_ids) subword_mask_ids = [ idx for idx, subword in enumerate(tokenized_tokens) if subword == MASK ] within_mask_ids = [ idx for idx in subword_mask_ids if idx < max_seq_length - 2 ] out_mask_ids = [ idx for idx in subword_mask_ids if idx >= max_seq_length - 2 ] buffer = [] if within_mask_ids: in_tokens = [CLS] + tokenized_tokens[0:max_seq_length - 2] + [SEP] target_ids = [ idx for idx, token in enumerate(in_tokens) if token == MASK ] buffer.append((in_tokens, target_ids)) if out_mask_ids: logger.debug("exceed {}".format(max_seq_length)) in_tokens = [CLS] + tokenized_tokens[len(tokenized_tokens) - max_seq_length + 2:] + [SEP] target_ids = [ idx for idx, token in enumerate(in_tokens) if token == MASK ][-len(out_mask_ids):] assert len(in_tokens) == max_seq_length logger.debug(in_tokens) logger.debug(target_ids) buffer.append((in_tokens, target_ids)) if out_mask_ids[-1] >= (max_seq_length - 2) * 2: raise RuntimeError("Sentence is too long.") len_batch = len(batch_inputs) for in_tokens, target_ids in buffer: input_ids = tokenizer.convert_tokens_to_ids( in_tokens) + [0] * (max_seq_length - len(in_tokens)) att_mask = [1] * len(in_tokens) + [0] * (max_seq_length - len(in_tokens)) batch_inputs.append(input_ids) batch_att_mask.append(att_mask) batch_target_ids += [(len_batch, i) for i in target_ids] batch_position += [(instance["unique_id"], instance["sentence id"], instance["file name"], mask_idx) for mask_idx in mask_ids] if len(batch_inputs) >= batch_size: assert len(batch_inputs) == len(batch_att_mask) batch_inputs = torch.LongTensor(batch_inputs).to(device) batch_att_mask = torch.LongTensor(batch_att_mask).to(device) batch_target_ids = [[i[0] for i in batch_target_ids], [i[1] for i in batch_target_ids]] yield batch_inputs, batch_att_mask, batch_target_ids, batch_position batch_inputs = [] batch_att_mask = [] batch_target_ids = [] batch_position = [] if batch_inputs: assert len(batch_inputs) == len(batch_att_mask) batch_inputs = torch.LongTensor(batch_inputs).to(device) batch_att_mask = torch.LongTensor(batch_att_mask).to(device) batch_target_ids = [[i[0] for i in batch_target_ids], [i[1] for i in batch_target_ids]] yield batch_inputs, batch_att_mask, batch_target_ids, batch_position
def batch_generator_with_single(file_path: str, tokenizer: BertTokenizer, max_seq_length: int, batch_size: int, device, data_limit: int): batch_inputs = [] batch_att_mask = [] batch_target_ids = [] batch_position = [] for n, instance in enumerate(read_instance(file_path)): if data_limit == n: logger.info( "The maximum number of rows has been reached: {}".format( data_limit)) batch_inputs = [] break if n % 1000 == 0: logger.info("Unique ID: {}".format(instance["unique_id"])) tokens_with_mask = instance["surfaces"] mask_ids = [ idx for idx, token in enumerate(tokens_with_mask) if token == MASK ] for mask_idx in mask_ids: original_tokens = copy.deepcopy(instance["original_surfaces"]) original_tokens[mask_idx] = MASK tokenized_tokens = tokenizer.tokenize(" ".join(original_tokens)) if n < 3: logger.debug(original_tokens) logger.debug(tokenized_tokens) logger.debug(mask_idx) if tokenized_tokens.index(MASK) < max_seq_length - 2: in_tokens = [CLS] + tokenized_tokens[0:max_seq_length - 2] + [SEP] elif len(tokenized_tokens) < (max_seq_length - 2) * 2: logger.debug("exceed {}".format(max_seq_length)) in_tokens = [CLS ] + tokenized_tokens[len(tokenized_tokens) - max_seq_length + 2:] + [SEP] assert len(in_tokens) == max_seq_length else: raise RuntimeError("Sentence is too long.") assert MASK in in_tokens input_ids = tokenizer.convert_tokens_to_ids( in_tokens) + [0] * (max_seq_length - len(in_tokens)) att_mask = [1] * len(in_tokens) + [0] * (max_seq_length - len(in_tokens)) batch_inputs.append(input_ids) batch_att_mask.append(att_mask) batch_target_ids.append(in_tokens.index(MASK)) batch_position.append( (instance["unique_id"], instance["sentence id"], instance["file name"], mask_idx)) if len(batch_inputs) == batch_size: assert len(batch_inputs) == len(batch_att_mask) == len( batch_target_ids) == len(batch_position) batch_inputs = torch.LongTensor(batch_inputs).to(device) batch_att_mask = torch.LongTensor(batch_att_mask).to(device) batch_target_ids = [[i for i in range(batch_size)], batch_target_ids] yield batch_inputs, batch_att_mask, batch_target_ids, batch_position batch_inputs = [] batch_att_mask = [] batch_target_ids = [] batch_position = [] if batch_inputs: batch_target_ids = [[i for i in range(len(batch_inputs))], batch_target_ids] batch_inputs = torch.LongTensor(batch_inputs).to(device) batch_att_mask = torch.LongTensor(batch_att_mask).to(device) yield batch_inputs, batch_att_mask, batch_target_ids, batch_position
class CreateDataset(Dataset): def __init__(self, data_path, max_seq_len, vocab_path, example_type, seed): self.seed = seed self.max_seq_len = max_seq_len self.example_type = example_type self.data_path = data_path self.vocab_path = vocab_path self.reset() # 初始化 def reset(self): # 加载语料库,这是pretrained Bert模型自带的 self.tokenizer = BertTokenizer(vocab_file=self.vocab_path) # 构建examples self.build_examples() # 读取数据集 def read_data(self, quotechar=None): ''' 默认是以tab分割的数据 :param quotechar: :return: ''' lines = [] with open(self.data_path, 'r', encoding='utf-8') as fr: reader = csv.reader(fr, delimiter='\t', quotechar=quotechar) for line in reader: lines.append(line) return lines # 构建数据examples def build_examples(self): lines = self.read_data() self.examples = [] for i, line in enumerate(lines): guid = '%s-%d' % (self.example_type, i) label = line[0] text_a = line[1] example = InputExample(guid=guid, text_a=text_a, label=label) self.examples.append(example) del lines # 将example转化为feature def build_features(self, example): ''' # 对于两个句子: # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 # 对于单个句子: # tokens: [CLS] the dog is hairy . [SEP] # type_ids: 0 0 0 0 0 0 0 # type_ids:表示是第一个句子还是第二个句子 ''' #转化为token tokens_a = self.tokenizer.tokenize(example.text_a) # Account for [CLS] and [SEP] with "- 2" if len(tokens_a) > self.max_seq_len - 2: tokens_a = tokens_a[:(self.max_seq_len - 2)] # 句子首尾加入标示符 tokens = ['[CLS]'] + tokens_a + ['[SEP]'] segment_ids = [0] * len(tokens) # 对应type_ids # 将词转化为语料库中对应的id input_ids = self.tokenizer.convert_tokens_to_ids(tokens) # 输入mask input_mask = [1] * len(input_ids) # padding,使用0进行填充 padding = [0] * (self.max_seq_len - len(input_ids)) input_ids += padding input_mask += padding segment_ids += padding # 标签 label_id = int(example.label) feature = InputFeature(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_id=label_id) return feature def _preprocess(self, index): example = self.examples[index] feature = self.build_features(example) return np.array(feature.input_ids),np.array(feature.input_mask),\ np.array(feature.segment_ids),np.array(feature.label_id) def __getitem__(self, index): return self._preprocess(index) def __len__(self): return len(self.examples)
def get_data(filename, tokenizer: BertTokenizer, opts: DataOptions, limit=0): dataset = [] max_chunks = 0 max_sent_len = 0 qid2supportingfacts = dict() qid2sentids = dict() for features in get_features(filename, tokenizer, opts): # convert to torch tensors slen = max([len(ct) for ct in features.chunk_tokens]) chunk_token_ids = [ tokenizer.convert_tokens_to_ids(ct) + [0] * (slen - len(ct)) for ct in features.chunk_tokens ] segment_ids = [ sids + [0] * (slen - len(sids)) for sids in features.segment_ids ] if len(features.chunk_tokens) > max_chunks: max_chunks = len(features.chunk_tokens) max_sent_len = max(max_sent_len, (np.array(features.sent_ends) - np.array(features.sent_starts)).max()) sent_targets = None if features.supporting_facts is not None: qid2supportingfacts[features.id] = features.supporting_facts for sid in features.sent_ids: qid2sentids.setdefault(features.id, set()).add(sid) sent_targets = torch.zeros(len(features.sent_ids), dtype=torch.float) for sf in features.supporting_facts: if sf not in features.sent_ids: continue sent_targets[features.sent_ids.index(sf)] = 1 assert len(features.sent_starts) == len(features.sent_ends) == len( features.sent_ids) assert len(chunk_token_ids) == len(features.chunk_lengths) dataset.append((features.id, features.sent_ids, features.question_len, features.chunk_lengths, torch.tensor(chunk_token_ids, dtype=torch.long), torch.tensor(segment_ids, dtype=torch.long), torch.tensor(features.sent_starts, dtype=torch.long), torch.tensor(features.sent_ends, dtype=torch.long), sent_targets)) if 0 < limit <= len(dataset): break if len(dataset) % 5000 == 0: logger.info(f'loading dataset item {len(dataset)} from {filename}') logger.info( f'in {filename}: max_chunks = {max_chunks}, max_sent_length = {max_sent_len}' ) out_of_recall = 0 total_positives = 0 for id, sps in qid2supportingfacts.items(): total_positives += len(sps) sent_ids = qid2sentids.get(id) for sp in sps: if sp not in sent_ids: out_of_recall += 1 if len(qid2supportingfacts) > 0: logger.info( f'in {filename}, due to truncations we have lost {out_of_recall} out of {total_positives} positives' ) return dataset, qid2supportingfacts