コード例 #1
0
    def _split(self, x_y_meta):
        x_all = []
        y_all = []
        meta_all = []
        for x, y, meta in x_y_meta:
            meta_all.append(meta)
            x_all.append(
                [self.tokenizer.encode(text_standardize(x_i)) for x_i in x])
            y_all.append(self.tokenizer.encode(text_standardize(y)))

        return x_all, y_all, meta_all
コード例 #2
0
 def _split(self, data):
     positive_label = set(['entailment'])
     premise = []
     hypothesis = []
     label = []
     for p, h, l in tqdm(data):
         premise.append(self.tokenizer.encode(text_standardize(p)))
         hypothesis.append(self.tokenizer.encode(text_standardize(h)))
         if l in positive_label:
             label.append(torch.tensor(1))
         else:
             label.append(torch.tensor(0))
     return premise, hypothesis, label
コード例 #3
0
 def _split(self, x_y_meta):
     x_all = []
     y_all = []
     meta_all = []
     aug_all = []
     keyword_all = []
     for x, y, meta, aug, keyword in x_y_meta:
         meta_all.append(meta)
         x_all.append(
             [self.tokenizer.encode(text_standardize(x_i)) for x_i in x])
         y_all.append(self.tokenizer.encode(text_standardize(y)))
         aug_all.append(self.tokenizer.encode(text_standardize(aug)))
         keyword_all.append(self.tokenizer.encode(
             text_standardize(keyword)))
     return x_all, y_all, meta_all, aug_all, keyword_all
コード例 #4
0
 def _split(self, x_y_meta):
     x_all = []
     y_all = []
     meta_all = []
     aug_all = []
     keyword_all = []
     for x, y, meta, aug, keyword in tqdm(x_y_meta):
         meta_all.append(meta)
         # update for the new data format
         aug = ''.join([a[1] for a in aug])
         x_all.append(
             [self.tokenizer.encode(text_standardize(x_i)) for x_i in x])
         y_all.append(self.tokenizer.encode(text_standardize(y)))
         aug_all.append(self.tokenizer.encode(text_standardize(aug)))
         keyword_all.append(self.tokenizer.encode(
             text_standardize(keyword)))
     return x_all, y_all, meta_all, aug_all, keyword_all
コード例 #5
0
    def __getitem__(self, index):
        # preprare variables
        x = []
        type_x = []
        lm_x = []
        soft_position_x = []
        attention_mask = []

        # 0. unpack needed input info
        context = self.data[index]['context']
        srl_mask = self.data[index]['srl_mask']
        comet_output = self.data[index]['comet']  # a list of dict or None
        response = self.data[index]['response']

        # 1. encode the response.
        response_encoded = self.tokenizer.encode(text_standardize(response))

        # 2. encode each utterance.
        context_encoded = []
        for i in range(10 - self.args.num_turns, 10):
            context_encoded.append(
                self.tokenizer.encode(text_standardize(context[i])))

        # 3. encode the comet output for each utterance.
        comet_encoded = []
        for i in range(len(comet_output)):
            comet_text_i = ""
            if comet_output[i] is None:
                comet_encoded.append(None)
                continue
            for rel in comet_output[i]:
                for candidate in comet_output[i][rel]['beams']:
                    if candidate != 'none':
                        comet_text_i += rel + " " + candidate + " "
                        break
            comet_encoded.append(
                self.tokenizer.encode(text_standardize(comet_text_i)))

        # 4. use the encoded seq to build the input and attention mask
        is_speaker1 = bool(self.args.num_turns % 2)
        soft_loc = 0
        for i in range(self.args.num_turns):

            # add an utterance. update x & type_x
            if is_speaker1:
                x += [self.speaker1]
                type_x += [self.speaker1] * (len(context_encoded[i]) + 1)
            else:
                x += [self.speaker2]
                type_x += [self.speaker2] * (len(context_encoded[i]) + 1)
            x += context_encoded[i]

            # update pos_x
            # concate aug part after x. but the index is from the last related token
            soft_position_x += list(
                range(soft_loc, soft_loc + (len(context_encoded[i]) + 1)))

            last_related_token_index = len(
                srl_mask[i]) - 1 - srl_mask[i][::-1].index(1)

            # add comet output
            if self.args.kbert:
                if comet_encoded[i] is not None:
                    x += [self.augment] + comet_encoded[i]
                    type_x += [self.augment] * (len(comet_encoded[i]) + 1)

                    # +2 for the special token and the requirement of one-number larger than the utterance
                    soft_position_x += list(
                        range(
                            soft_loc + 2 + last_related_token_index,
                            soft_loc + 2 + last_related_token_index +
                            (len(comet_encoded[i]) + 1)))

            soft_loc += (len(context_encoded[i]) + 1)
            is_speaker1 = not is_speaker1

        lm_x += [-100] * len(
            x)  # all position for the input is masked for loss calculation
        total_input_length = len(x)

        response_encoded = self.tokenizer.encode(text_standardize(response))
        x += [self.ref] + response_encoded + [self.eos]

        type_x += [self.ref] * (len(response_encoded) + 2)
        lm_x += [-100] + response_encoded + [self.eos]

        soft_position_x += list(
            range(soft_loc, soft_loc + len(response_encoded) + 2))

        x = x[:self.max_length]
        type_x = type_x[:self.max_length]
        lm_x = lm_x[:self.max_length]
        soft_position_x = soft_position_x[:self.max_length]

        # build attention mask
        attention_mask = torch.tril(torch.ones(len(x), len(x)))
        if self.args.kbert_mask:
            aug_start = 0  # where the aug begin
            utt_start = 0  # where the utt begin

            for turn in range(self.args.num_turns):
                aug_start += len(context_encoded[turn]) + 1
                # iter through every token in the comet output
                if comet_encoded[turn] is not None:
                    for aug_token_pos in range(
                            aug_start,
                            aug_start + len(comet_encoded[turn]) + 1):
                        # set the attention related to the aug part to be all zero
                        attention_mask[aug_token_pos, :] = torch.zeros_like(
                            attention_mask[aug_token_pos, :])

                        attention_mask[:, aug_token_pos] = torch.zeros_like(
                            attention_mask[:, aug_token_pos])
                        # set attention on related token to be one
                        for normal_token_pos in range(
                                len(context_encoded[turn])):
                            attention_mask[
                                aug_token_pos, utt_start + normal_token_pos +
                                1] += srl_mask[turn][normal_token_pos]
                        # set attention on previous aug tokens to be one
                        for previous_aug_token_poc in range(
                                aug_start, aug_token_pos + 1):
                            attention_mask[aug_token_pos,
                                           previous_aug_token_poc] += 1

                    aug_start += len(comet_encoded[turn]) + 1
                    utt_start += len(comet_encoded[turn]) + 1
                utt_start += (len(context_encoded[turn]) + 1)

        x = torch.tensor(x)
        type_x = torch.tensor(type_x)
        if not self.args.kbert_position:
            soft_position_x = list(range(len(x)))
        soft_position_x = torch.tensor(soft_position_x)
        lm_x = torch.tensor(lm_x)
        return x, type_x, soft_position_x, lm_x, total_input_length, attention_mask
コード例 #6
0
    def __getitem__(self, index):
        x = []
        type_x = []
        lm_x = []
        soft_position_x = []

        dq = self.get_comet_aug_deque(self.data[index][3])  # the comet info

        mask_info = []

        context = self.data[index][0]
        response = self.data[index][1]

        is_speaker1 = bool(self.args.num_turns % 2)
        soft_loc = 0  # keep tract of the location of main sentences, point to the next token to be added
        utterance_start_loc = 0
        for i in range(10 - self.args.num_turns, 10):
            utternace_encoded = self.tokenizer.encode(
                text_standardize(context[i]))

            # add the prefix special token for each utterance
            if is_speaker1:
                x += [self.speaker1]
                type_x += [self.speaker1] * (len(utternace_encoded) + 1)
            else:
                x += [self.speaker2]
                type_x += [self.speaker2] * (len(utternace_encoded) + 1)
            x += utternace_encoded
            utterance_end_loc = len(x)

            soft_position_x += list(
                range(soft_loc, soft_loc + len(utternace_encoded) + 1))

            # add the aug, if it is the right place
            while len(dq) != 0 and dq[0][0] == i:
                comet_output = dq.popleft()[1]
                comet_encoded = self.tokenizer.encode(
                    text_standardize(comet_output))

                x += [self.augment] + comet_encoded
                type_x += [self.augment] * (len(comet_encoded) + 1)
                soft_position_x += list(
                    range(soft_loc, soft_loc + len(comet_encoded) + 1))
                mask_info.append([
                    utterance_start_loc, utterance_end_loc,
                    len(comet_encoded) + 1
                ])
            # update the pointer to the new seq end, add one for the delimiter token
            soft_loc += len(utternace_encoded) + 1
            is_speaker1 = not is_speaker1
            utterance_start_loc = len(x)

        lm_x += [-100] * len(
            x)  # all position for the input is masked for loss calculation
        total_input_length = len(x)

        response_encoded = self.tokenizer.encode(text_standardize(response))
        x += [self.ref_start] + response_encoded + [self.eos]

        type_x += [self.ref_start] * (len(response_encoded) + 2)
        lm_x += [-100] + response_encoded + [self.eos]

        soft_position_x += list(
            range(soft_loc, soft_loc + len(response_encoded) + 2))

        x = torch.Tensor(x)
        type_x = torch.Tensor(type_x)
        soft_position_x = torch.Tensor(soft_position_x)
        lm_x = torch.Tensor(lm_x)
        x_len = x.shape[0]

        # process the mask
        attention_mask = torch.tril(torch.ones(x_len, x_len))
        for u_start, u_end, branch_len in mask_info:
            attention_mask[
                u_end + branch_len + 1:u_end + 1:u_end + branch_len +
                1] = 0  # [1st token after branch: , 1st token in branch: last token in branch+1]
        attention_mask = attention_mask.view(1, x_len, x_len)

        return x, type_x, soft_position_x, lm_x, total_input_length, attention_mask