Exemplo n.º 1
0
 def sample_to_sop(self, sop_sample: Sample) -> Sample:
     """note sop data source are different to mlm, since two sentences are needed"""
     sentences = sop_sample.inputs["text"]
     del sop_sample.inputs["text"]
     if self.data_params.segment_train:
         inputlist = sentences.split(" ")
         nowords = len(inputlist)
         # minimal word number is 10
         if nowords >= 10:
             splitindex = random.randint(4, nowords - 5)
         else:
             splitindex = 0
         textpartone = inputlist[:splitindex]
         # maximal text sequence length is 40
         textparttwo = inputlist[splitindex:]
         textpartone = " ".join(textpartone)
         textparttwo = " ".join(textparttwo)
         first_enc_sentence = self.tokenizer.encode(textpartone)
         if len(first_enc_sentence) > self.data_params.max_token_text_part:
             first_enc_sentence = first_enc_sentence[
                 len(first_enc_sentence) -
                 self.data_params.max_token_text_part:]
         sec_enc_sentence = self.tokenizer.encode(textparttwo)
         if len(sec_enc_sentence) > self.data_params.max_token_text_part:
             sec_enc_sentence = sec_enc_sentence[:self.data_params.
                                                 max_token_text_part]
     else:
         first_enc_sentence, sec_enc_sentence = self.build_two_sentence_segments(
             sentences)
     first_mask_enc_sentence, first_masked_index_list = self.mask_enc_sentence(
         first_enc_sentence)
     sec_mask_enc_sentence, sec_masked_index_list = self.mask_enc_sentence(
         sec_enc_sentence)
     # Add CLS-Tag and SEP-Tag
     if self.switch_sentences():
         text_index_list = ([self.data_params.tok_vocab_size] +
                            sec_mask_enc_sentence +
                            [self.data_params.tok_vocab_size + 1] +
                            first_mask_enc_sentence +
                            [self.data_params.tok_vocab_size + 1])
         masked_index_list = [0] + sec_masked_index_list + [
             0
         ] + first_masked_index_list + [0]
         tar_mlm = ([self.data_params.tok_vocab_size] + sec_enc_sentence +
                    [self.data_params.tok_vocab_size + 1] +
                    first_enc_sentence +
                    [self.data_params.tok_vocab_size + 1])
         tar_sop = [0]
     else:
         text_index_list = ([self.data_params.tok_vocab_size] +
                            first_mask_enc_sentence +
                            [self.data_params.tok_vocab_size + 1] +
                            sec_mask_enc_sentence +
                            [self.data_params.tok_vocab_size + 1])
         masked_index_list = [0] + first_masked_index_list + [
             0
         ] + sec_masked_index_list + [0]
         tar_mlm = ([self.data_params.tok_vocab_size] + first_enc_sentence +
                    [self.data_params.tok_vocab_size + 1] +
                    sec_enc_sentence +
                    [self.data_params.tok_vocab_size + 1])
         tar_sop = [1]
     sop_sample.inputs = {
         "text": np.asarray(text_index_list),
         "seq_length": np.asarray([len(text_index_list)])
     }
     sop_sample.inputs["seq_length"] = np.asarray([len(text_index_list)])
     sop_sample.targets = {
         "tgt_mlm": np.asarray(tar_mlm),
         "mask_mlm": np.asarray(masked_index_list),
         "tgt_sop": np.asarray(tar_sop),
     }
     if self._wwa:
         word_length_vector, segment_ids = self.build_whole_word_attention_inputs(
             tar_mlm)
         sop_sample.inputs["word_length_vector"] = np.asarray(
             word_length_vector)
         sop_sample.inputs["segment_ids"] = np.asarray(segment_ids)
     return sop_sample
Exemplo n.º 2
0
    def sample_to_nsp(self, nsp_sample: Sample) -> Sample:
        """note nsp data source are different to mlm, since two sentences are needed"""
        sentences = nsp_sample.inputs["text"]
        del nsp_sample.inputs["text"]
        take_connected_parts = self.bool_decision()
        if self.data_params.segment_train:
            firstinputlist = sentences[0].split(" ")
            nofirstwords = len(firstinputlist)
            # minimal word number is 10
            if nofirstwords >= 10:
                splitindex = random.randint(4, nofirstwords - 5)
            else:
                splitindex = 0
            textpartone = firstinputlist[:splitindex]
            # maximal text sequence length is 40
            if len(textpartone) > self.data_params.max_words_text_part:
                textpartone = textpartone[len(textpartone) - self.data_params.
                                          max_words_text_part:]
            if take_connected_parts:
                textparttwo = firstinputlist[splitindex:]
                tar_nsp = [1]
            else:
                secondinputlist = sentences[1].split(" ")
                nosecondwords = len(secondinputlist)
                if nofirstwords >= 10:
                    splitindex = random.randint(0, nosecondwords - 5)
                else:
                    splitindex = 0
                textparttwo = secondinputlist[splitindex:]
                tar_nsp = [0]
            if len(textparttwo) > self.data_params.max_words_text_part:
                textparttwo = textparttwo[:self.data_params.
                                          max_words_text_part]
            textpartone = " ".join(textpartone)
            textparttwo = " ".join(textparttwo)
            first_enc_sentence = self.tokenizer.encode(textpartone)
            sec_enc_sentence = self.tokenizer.encode(textparttwo)
        else:
            first_enc_sentence, sec_enc_sentence = self.build_two_sentence_segments(
                sentences, take_connected_parts)
            if take_connected_parts:
                tar_nsp = [1]
            else:
                tar_nsp = [0]
        first_mask_enc_sentence, first_masked_index_list = self.mask_enc_sentence(
            first_enc_sentence)
        sec_mask_enc_sentence, sec_masked_index_list = self.mask_enc_sentence(
            sec_enc_sentence)
        switch_order = self.bool_decision()
        # Add CLS-Tag and SEP-Tag
        if switch_order:
            text_index_list = ([self.data_params.tok_vocab_size] +
                               sec_mask_enc_sentence +
                               [self.data_params.tok_vocab_size + 1] +
                               first_mask_enc_sentence +
                               [self.data_params.tok_vocab_size + 1])
            masked_index_list = [0] + sec_masked_index_list + [
                0
            ] + first_masked_index_list + [0]
            tar_mlm = ([self.data_params.tok_vocab_size] + sec_enc_sentence +
                       [self.data_params.tok_vocab_size + 1] +
                       first_enc_sentence +
                       [self.data_params.tok_vocab_size + 1])
        else:
            text_index_list = ([self.data_params.tok_vocab_size] +
                               first_mask_enc_sentence +
                               [self.data_params.tok_vocab_size + 1] +
                               sec_mask_enc_sentence +
                               [self.data_params.tok_vocab_size + 1])
            masked_index_list = [0] + first_masked_index_list + [
                0
            ] + sec_masked_index_list + [0]
            tar_mlm = ([self.data_params.tok_vocab_size] + first_enc_sentence +
                       [self.data_params.tok_vocab_size + 1] +
                       sec_enc_sentence +
                       [self.data_params.tok_vocab_size + 1])

        nsp_sample.inputs = {"text": np.asarray(text_index_list)}
        nsp_sample.inputs["seq_length"] = np.asarray([len(text_index_list)])
        nsp_sample.targets = {
            "tgt_mlm": np.asarray(tar_mlm),
            "mask_mlm": np.asarray(masked_index_list),
            "tgt_nsp": np.asarray(tar_nsp),
        }
        if self._wwa:
            word_length_vector, segment_ids = self.build_whole_word_attention_inputs(
                tar_mlm)
            nsp_sample.inputs["word_length_vector"] = np.asarray(
                word_length_vector)
            nsp_sample.inputs["segment_ids"] = np.asarray(segment_ids)

        return nsp_sample