コード例 #1
0
    def __call__(self, instance):
        input_tokens, input_pos, input_dep, target_tokens, target_pos, target_dep = instance

        # -3  for special tokens [CLS], [SEP], [SEP]

        truncate_tokens_pair(input_tokens, target_tokens, self.max_len - 3)
        truncate_tokens_pair(input_pos, target_pos, self.max_len - 3)
        truncate_tokens_pair(input_dep, target_dep, self.max_len - 3)
        target_tokens = truncate_tokens(target_tokens, self.max_len)
        target_pos = truncate_tokens(target_pos, self.max_len)
        target_dep = truncate_tokens(target_dep, self.max_len)

        # Add Special Tokens
        origin_word_tokens = ['[CLS]'] + input_tokens + [
            '[SEP]'
        ] + target_tokens + ['[SEP]']
        if rand() < 0.5:
            word_tokens = origin_word_tokens
        else:
            word_tokens = ['[CLS]'] + input_tokens + ['[SEP]'] + (
                ['[MASK]'] * len(target_tokens)) + ['[SEP]']

        #word_tokens = ['[CLS]'] + input_tokens + ['[SEP]'] + target_tokens + ['[SEP]']
        pos_tokens = ['[CLS]'] + input_pos + ['[SEP]'] + target_pos + ['[SEP]']
        dep_tokens = ['[CLS]'] + input_dep + ['[SEP]'] + target_dep + ['[SEP]']
        input_segment_ids = [0] * (len(input_tokens) +
                                   2) + [1] * (len(target_tokens) + 1)
        input_mask = [1] * len(word_tokens)
        target_mask = [1] * (len(target_tokens) + 1)
        input_len = len(input_tokens) + 2
        target_len = len(target_tokens) + 1

        input_word_ids, input_pos_ids, input_dep_ids = self.indexer(
            word_tokens, pos_tokens, dep_tokens)
        origin_input_word_ids, _, _ = self.indexer(origin_word_tokens, [], [])
        target_word_ids, target_pos_ids, target_dep_ids = self.indexer(
            target_tokens + ['[SEP]'], target_pos + ['[SEP]'],
            target_dep + ['[SEP]'])

        # Zero Padding
        input_n_pad = self.max_len - len(input_word_ids)
        origin_input_word_ids.extend([0] * input_n_pad)
        input_word_ids.extend([0] * input_n_pad)
        input_pos_ids.extend([0] * input_n_pad)
        input_dep_ids.extend([0] * input_n_pad)
        input_segment_ids.extend([0] * input_n_pad)
        input_mask.extend([0] * input_n_pad)

        target_n_pad = self.max_len - len(target_word_ids)
        target_word_ids.extend([0] * target_n_pad)
        target_pos_ids.extend([0] * target_n_pad)
        target_dep_ids.extend([0] * target_n_pad)
        target_mask.extend([0] * target_n_pad)

        return (origin_input_word_ids, input_word_ids, input_pos_ids,
                input_dep_ids, input_segment_ids, input_mask, target_word_ids,
                target_pos_ids, target_dep_ids, target_mask, input_len,
                target_len)
コード例 #2
0
    def __call__(self, instance):
        input_tokens, input_pos, input_dep, target_tokens, target_pos, target_dep = instance

        # -3  for special tokens [CLS], [SEP], [SEP]
        truncate_tokens_pair(input_tokens, target_tokens, self.max_len - 3)
        truncate_tokens_pair(input_pos, target_pos, self.max_len - 3)
        truncate_tokens_pair(input_dep, target_dep, self.max_len - 3)
        target_tokens = truncate_tokens(target_tokens, self.max_len)
        target_pos = truncate_tokens(target_pos, self.max_len)
        target_dep = truncate_tokens(target_dep, self.max_len)

        # Add Special Tokens
        word_tokens = ['[CLS]'] + input_tokens + ['[SEP]'] + target_tokens + [
            '[SEP]'
        ]
        pos_tokens = ['[CLS]'] + input_pos + ['[SEP]'] + target_pos + ['[SEP]']
        dep_tokens = ['[CLS]'] + input_dep + ['[SEP]'] + target_dep + ['[SEP]']
        input_segment_ids = [0] * (len(input_tokens) +
                                   2) + [1] * (len(target_tokens) + 1)
        input_mask = [1] * len(word_tokens)

        target_mask = [1] * len(target_tokens)

        # For masked Language Models
        masked_word_tokens, masked_pos_tokens, masked_dep_tokens, masked_pos = [], [], [], []
        # the number of prediction is sometimes less than max_pred when sequence is short
        n_pred = min(self.max_pred,
                     max(1, int(round(len(word_tokens) * self.mask_prob))))
        # candidate positions of masked tokens
        #
        #cand_pos = [i for i, token in enumerate(word_tokens)
        #            if word_tokens != '[CLS]' and word_tokens != '[SEP]']
        #Detect SEP for summary
        cand_pos = [
            i for i, token in enumerate(word_tokens) if word_tokens != '[CLS]'
        ]
        shuffle(cand_pos)
        for pos in cand_pos[:n_pred]:
            masked_word_tokens.append(word_tokens[pos])
            masked_pos_tokens.append(pos_tokens[pos])
            masked_dep_tokens.append(dep_tokens[pos])
            masked_pos.append(pos)
            if rand() < 0.8:  # 80%
                word_tokens[pos] = '[MASK]'
                pos_tokens[pos] = '[MASK]'
                dep_tokens[pos] = '[MASK]'
            #elif rand() < 0.5: # 10%
            #    word_tokens[pos] = get_random_word(self.vocab_words)
            #    pos_tokens[pos] = get_random_word(self.vocab_pos)
            #    dep_tokens[pos] = get_random_word(self.vocab_dep)
        # when n_pred < max_pred, we only calculate loss within n_pred
        masked_weights = [1] * len(masked_pos_tokens)

        #replace right as mask for summary
        #if rand() < 0.1:
        #    word_tokens = word_tokens[:len(input_tokens)+2] + ['[MASK]']*len(self.max_len - len(input_tokens)+2)
        #    pos_tokens = pos_tokens[:len(input_pos)+2] + ['[MASK]']*len(self.max_len - len(input_pos)+2)
        #    dep_tokens = dep_tokens[:len(input_dep)+2] + ['[MASK]']*len(self.max_len - len(input_dep)+2)

        input_word_ids, input_pos_ids, input_dep_ids = self.indexer(
            word_tokens, pos_tokens, dep_tokens)
        masked_word_ids, masked_pos_ids, masked_dep_ids = self.indexer(
            masked_word_tokens, masked_pos_tokens, masked_dep_tokens)
        target_word_ids, target_pos_ids, target_dep_ids = self.indexer(
            target_tokens, target_pos, target_dep)

        # Zero Padding
        input_n_pad = self.max_len - len(input_word_ids)
        input_word_ids.extend([0] * input_n_pad)
        input_pos_ids.extend([0] * input_n_pad)
        input_dep_ids.extend([0] * input_n_pad)
        input_segment_ids.extend([0] * input_n_pad)
        input_mask.extend([0] * input_n_pad)

        target_n_pad = self.max_len - len(target_word_ids)
        target_word_ids.extend([0] * target_n_pad)
        target_pos_ids.extend([0] * target_n_pad)
        target_dep_ids.extend([0] * target_n_pad)
        target_mask.extend([0] * target_n_pad)

        # Zero Padding for masked target
        if self.max_pred > n_pred:
            n_pad = self.max_pred - n_pred
            masked_word_ids.extend([0] * n_pad)
            masked_pos_ids.extend([0] * n_pad)
            masked_dep_ids.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)
            masked_weights.extend([0] * n_pad)

        return (input_word_ids, input_pos_ids, input_dep_ids,
                input_segment_ids, input_mask, masked_word_ids, masked_pos_ids,
                masked_dep_ids, masked_pos, masked_weights, target_word_ids,
                target_pos_ids, target_dep_ids, target_mask)