def sample_to_sop(self, sop_sample: Sample) -> Sample: """note sop data source are different to mlm, since two sentences are needed""" sentences = sop_sample.inputs["text"] del sop_sample.inputs["text"] if self.data_params.segment_train: inputlist = sentences.split(" ") nowords = len(inputlist) # minimal word number is 10 if nowords >= 10: splitindex = random.randint(4, nowords - 5) else: splitindex = 0 textpartone = inputlist[:splitindex] # maximal text sequence length is 40 textparttwo = inputlist[splitindex:] textpartone = " ".join(textpartone) textparttwo = " ".join(textparttwo) first_enc_sentence = self.tokenizer.encode(textpartone) if len(first_enc_sentence) > self.data_params.max_token_text_part: first_enc_sentence = first_enc_sentence[ len(first_enc_sentence) - self.data_params.max_token_text_part:] sec_enc_sentence = self.tokenizer.encode(textparttwo) if len(sec_enc_sentence) > self.data_params.max_token_text_part: sec_enc_sentence = sec_enc_sentence[:self.data_params. max_token_text_part] else: first_enc_sentence, sec_enc_sentence = self.build_two_sentence_segments( sentences) first_mask_enc_sentence, first_masked_index_list = self.mask_enc_sentence( first_enc_sentence) sec_mask_enc_sentence, sec_masked_index_list = self.mask_enc_sentence( sec_enc_sentence) # Add CLS-Tag and SEP-Tag if self.switch_sentences(): text_index_list = ([self.data_params.tok_vocab_size] + sec_mask_enc_sentence + [self.data_params.tok_vocab_size + 1] + first_mask_enc_sentence + [self.data_params.tok_vocab_size + 1]) masked_index_list = [0] + sec_masked_index_list + [ 0 ] + first_masked_index_list + [0] tar_mlm = ([self.data_params.tok_vocab_size] + sec_enc_sentence + [self.data_params.tok_vocab_size + 1] + first_enc_sentence + [self.data_params.tok_vocab_size + 1]) tar_sop = [0] else: text_index_list = ([self.data_params.tok_vocab_size] + first_mask_enc_sentence + [self.data_params.tok_vocab_size + 1] + sec_mask_enc_sentence + [self.data_params.tok_vocab_size + 1]) masked_index_list = [0] + first_masked_index_list + [ 0 ] + sec_masked_index_list + [0] tar_mlm = ([self.data_params.tok_vocab_size] + first_enc_sentence + [self.data_params.tok_vocab_size + 1] + sec_enc_sentence + [self.data_params.tok_vocab_size + 1]) tar_sop = [1] sop_sample.inputs = { "text": np.asarray(text_index_list), "seq_length": np.asarray([len(text_index_list)]) } sop_sample.inputs["seq_length"] = np.asarray([len(text_index_list)]) sop_sample.targets = { "tgt_mlm": np.asarray(tar_mlm), "mask_mlm": np.asarray(masked_index_list), "tgt_sop": np.asarray(tar_sop), } if self._wwa: word_length_vector, segment_ids = self.build_whole_word_attention_inputs( tar_mlm) sop_sample.inputs["word_length_vector"] = np.asarray( word_length_vector) sop_sample.inputs["segment_ids"] = np.asarray(segment_ids) return sop_sample
def sample_to_nsp(self, nsp_sample: Sample) -> Sample: """note nsp data source are different to mlm, since two sentences are needed""" sentences = nsp_sample.inputs["text"] del nsp_sample.inputs["text"] take_connected_parts = self.bool_decision() if self.data_params.segment_train: firstinputlist = sentences[0].split(" ") nofirstwords = len(firstinputlist) # minimal word number is 10 if nofirstwords >= 10: splitindex = random.randint(4, nofirstwords - 5) else: splitindex = 0 textpartone = firstinputlist[:splitindex] # maximal text sequence length is 40 if len(textpartone) > self.data_params.max_words_text_part: textpartone = textpartone[len(textpartone) - self.data_params. max_words_text_part:] if take_connected_parts: textparttwo = firstinputlist[splitindex:] tar_nsp = [1] else: secondinputlist = sentences[1].split(" ") nosecondwords = len(secondinputlist) if nofirstwords >= 10: splitindex = random.randint(0, nosecondwords - 5) else: splitindex = 0 textparttwo = secondinputlist[splitindex:] tar_nsp = [0] if len(textparttwo) > self.data_params.max_words_text_part: textparttwo = textparttwo[:self.data_params. max_words_text_part] textpartone = " ".join(textpartone) textparttwo = " ".join(textparttwo) first_enc_sentence = self.tokenizer.encode(textpartone) sec_enc_sentence = self.tokenizer.encode(textparttwo) else: first_enc_sentence, sec_enc_sentence = self.build_two_sentence_segments( sentences, take_connected_parts) if take_connected_parts: tar_nsp = [1] else: tar_nsp = [0] first_mask_enc_sentence, first_masked_index_list = self.mask_enc_sentence( first_enc_sentence) sec_mask_enc_sentence, sec_masked_index_list = self.mask_enc_sentence( sec_enc_sentence) switch_order = self.bool_decision() # Add CLS-Tag and SEP-Tag if switch_order: text_index_list = ([self.data_params.tok_vocab_size] + sec_mask_enc_sentence + [self.data_params.tok_vocab_size + 1] + first_mask_enc_sentence + [self.data_params.tok_vocab_size + 1]) masked_index_list = [0] + sec_masked_index_list + [ 0 ] + first_masked_index_list + [0] tar_mlm = ([self.data_params.tok_vocab_size] + sec_enc_sentence + [self.data_params.tok_vocab_size + 1] + first_enc_sentence + [self.data_params.tok_vocab_size + 1]) else: text_index_list = ([self.data_params.tok_vocab_size] + first_mask_enc_sentence + [self.data_params.tok_vocab_size + 1] + sec_mask_enc_sentence + [self.data_params.tok_vocab_size + 1]) masked_index_list = [0] + first_masked_index_list + [ 0 ] + sec_masked_index_list + [0] tar_mlm = ([self.data_params.tok_vocab_size] + first_enc_sentence + [self.data_params.tok_vocab_size + 1] + sec_enc_sentence + [self.data_params.tok_vocab_size + 1]) nsp_sample.inputs = {"text": np.asarray(text_index_list)} nsp_sample.inputs["seq_length"] = np.asarray([len(text_index_list)]) nsp_sample.targets = { "tgt_mlm": np.asarray(tar_mlm), "mask_mlm": np.asarray(masked_index_list), "tgt_nsp": np.asarray(tar_nsp), } if self._wwa: word_length_vector, segment_ids = self.build_whole_word_attention_inputs( tar_mlm) nsp_sample.inputs["word_length_vector"] = np.asarray( word_length_vector) nsp_sample.inputs["segment_ids"] = np.asarray(segment_ids) return nsp_sample