コード例 #1
0
def collate(data, tokenizer, block_size, device):
    """ Collate formats the data passed to the data loader.

    In particular we tokenize the data batch after batch to avoid keeping them
    all in memory. We output the data as a namedtuple to fit the original BertAbs's
    API.
    """
    data = [x for x in data if not len(x[1]) == 0]  # remove empty_files
    names = [name for name, _, _ in data]
    summaries = [" ".join(summary_list) for _, _, summary_list in data]

    encoded_text = [encode_for_summarization(story, summary, tokenizer) for _, story, summary in data]
    encoded_stories = torch.tensor(
        [truncate_or_pad(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text]
    )
    encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)
    encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)

    batch = Batch(
        document_names=names,
        batch_size=len(encoded_stories),
        src=encoded_stories.to(device),
        segs=encoder_token_type_ids.to(device),
        mask_src=encoder_mask.to(device),
        tgt_str=summaries,
    )

    return batch
コード例 #2
0
def collate(data, tokenizer, block_size):
    """ List of tuple as an input. """
    # remove the files with empty an story/summary, encode and fit to block
    data = filter(lambda x: not (len(x[0]) == 0 or len(x[1]) == 0), data)
    data = [
        encode_for_summarization(story, summary, tokenizer)
        for story, summary in data
    ]
    data = [(
        fit_to_block_size(story, block_size, tokenizer.pad_token_id),
        fit_to_block_size(summary, block_size, tokenizer.pad_token_id),
    ) for story, summary in data]

    stories = torch.tensor([story for story, summary in data])
    summaries = torch.tensor([summary for story, summary in data])
    encoder_token_type_ids = compute_token_type_ids(stories,
                                                    tokenizer.cls_token_id)
    encoder_mask = build_mask(stories, tokenizer.pad_token_id)
    decoder_mask = build_mask(summaries, tokenizer.pad_token_id)
    lm_labels = build_lm_labels(summaries, tokenizer.pad_token_id)

    return (
        stories,
        summaries,
        encoder_token_type_ids,
        encoder_mask,
        decoder_mask,
        lm_labels,
    )
コード例 #3
0
    def test_compute_token_type_ids(self):
        separator = 101
        batch = torch.tensor([[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6],
                              [1, 101, 3, 4, 101, 6]])
        expected = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0],
                                 [1, 0, 0, 0, 1, 1]])

        result = compute_token_type_ids(batch, separator)
        np.testing.assert_array_equal(result, expected)
コード例 #4
0
def gen_batch_data(x, y, batch_size):
    '''
    批数据生成器
    :param x:
    :param y:
    :param batch_size:
    :return:
    '''

    tokenizer = AutoTokenizer.from_pretrained(BERT_PATH)
    indices = np.arange(x.shape[0])
    random.shuffle(indices)
    x = x[indices]
    y = y[indices]
    i = 0

    x_batch, y_batch, answer = [], [], []
    while True:
        bi = i * batch_size
        ei = min(i * batch_size + batch_size, len(indices))
        if ei == len(indices):
            i = 0
        else:
            i += 1

        # for idx in range(bi,ei):
        #     # 确保编码后也不超过max_seq_len
        #     x_      = x[idx]["que_text"][:max_que_seq_len-3]
        #     y_      = y[idx]["ans_text"][:max_ans_seq_len]
        #     # 加入答案主要是为了评估进行模型选择用
        #     #answer.append(y_)
        #     x_, y_ = myToken.get_tokenizer().encode(x_, y_)
        #     x_batch.append(x_)
        #     y_batch.append(y_)

        # x_batch = padding(x_batch)
        # y_batch = padding(y_batch)
        #answer  = np.array(answer)
        # yield [x_batch, y_batch], None
        # tokenizer = AutoTokenizer.from_pretrained(BERT_PATH)
        # source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
        # x_batch, y_batch, answer = [], [], []

        # data = filter(lambda x: not (len(x[0]) == 0 or len(x[1]) == 0), data)
        data_que = [
            tokenizer.encode(que["que_text"][0:max_que_seq_len - 2])
            for que in x[bi:ei]
        ]
        data_ans = [
            tokenizer.encode(ans["ans_text"][0:max_ans_seq_len - 2])
            for ans in y[bi:ei]
        ]

        data_que = padding(data_que, tokenizer.pad_token_id)
        data_ans = padding(data_ans, tokenizer.pad_token_id)

        ques = torch.tensor(data_que, dtype=torch.long)
        anss = torch.tensor(data_ans, dtype=torch.long)
        encoder_token_type_ids = compute_token_type_ids(
            ques, tokenizer.sep_token_id)
        encoder_mask = build_mask(ques, tokenizer.pad_token_id)
        decoder_mask = build_mask(anss, tokenizer.pad_token_id)
        lm_labels = build_lm_labels(anss, tokenizer.pad_token_id)

        yield (
            ques,
            anss,
            encoder_token_type_ids,
            encoder_mask,
            decoder_mask,
            lm_labels,
        )
コード例 #5
0
    def decode_sequence(self, encoder_seq, max_decode_seq_length=max_ans_seq_len_predict,topk=2):
        tokenizer = self.tokenizer
        token_dict = tokenizer.vocab
        IGNORE_WORD_IDX = token_dict[FIRST_VALIDED_TOKEN]
        encoder_token_ids = [tokenizer.encode(encoder_seq[:max_que_seq_len-2])]
        # 扩充为 topk个样本
        encoder_token_ids = torch.tensor(encoder_token_ids).repeat((topk,1)).to(getDevive()[0])
        encoder_token_type_ids = compute_token_type_ids(encoder_token_ids, tokenizer.sep_token_id)
        encoder_mask = build_mask(encoder_token_ids, tokenizer.pad_token_id,)

        # 编码
        self.QA_model.eval()
        encoder_outputs = self.QA_model.encoder(encoder_token_ids, token_type_ids= encoder_token_type_ids,attention_mask = encoder_mask)
        encoder_hidden_states = encoder_outputs[0]

        # 解码
        input_seq = torch.tensor([[tokenizer.cls_token_id,]]).repeat((topk,1)).to(getDevive()[0])
        pre_score = np.array([[0.]*topk]*topk)


        stop_condition = False
        target_seq = None
        target_seq_output = None
        while not stop_condition:
            # 准备数据
            decoder_mask = build_mask(input_seq, tokenizer.pad_token_id).to(getDevive()[0])
            #  预测时无法计算标签损失,所以不提供 lm_labels
            outputs = self.QA_model.decoder(input_seq,
                                            encoder_hidden_states=encoder_hidden_states,
                                            encoder_attention_mask=encoder_mask,
                                            attention_mask=decoder_mask,
            )
            output_tokens = outputs[0][:,-1,IGNORE_WORD_IDX:].data.numpy()

            #首次输出,三个样本一样,所以取第一个样本topk就行
            if target_seq is None:
                arg_topk = output_tokens.argsort(axis=-1)[:, -topk:]  # 每一项选出topk
                target_seq = arg_topk[0, :].reshape((topk,1)) + IGNORE_WORD_IDX
                tmp_cur_score = np.log(np.sort(output_tokens[0,:], axis=-1)[-topk:])
                tmp_cur_score = np.tile(tmp_cur_score.reshape((topk,1)),(1,topk))
                cur_score = pre_score + tmp_cur_score
                pre_score = cur_score
                target_seq_output = target_seq

            else:
                # 当上次输出中有结束符 ‘[SEP]’ 时,将该样本输出'[SEP]' _sentence_end_token的概率置为最大1.0
                for i, word in enumerate(target_seq[:, 0]):
                    if word == token_dict[sentence_end_token]:
                        output_tokens[i, token_dict[sentence_end_token] - IGNORE_WORD_IDX] = 1.0
                # pre_score
                # 取对数防止向下溢出
                # 利用对数计算,乘法改+法
                arg_topk = output_tokens.argsort(axis=-1)[:, -topk:]  # 每一项选出topk
                tmp_cur_score = np.log(arg_topk)
                cur_score = pre_score + tmp_cur_score
                maxIdx = np.unravel_index(np.argsort(cur_score, axis=None)[-topk:],cur_score.shape)
                pre_score = np.tile(cur_score[maxIdx].reshape((topk,1)),(1,topk))
                target_seq  = arg_topk[maxIdx].reshape((topk,1)) + IGNORE_WORD_IDX

                target_seq_output = np.concatenate((target_seq_output[maxIdx[0],:],target_seq),axis=-1)

            if (target_seq_output.shape[1] >= max_decode_seq_length
                    or (target_seq == token_dict[sentence_end_token] * np.ones((topk,1))).all()):
                stop_condition = True

            input_seq = torch.cat((input_seq,torch.from_numpy(target_seq)),axis=-1)

        # print("==")
        # 最后一行,概率最大
        # maxIdx为元组,维数为 pre_score维度值
        # maxIdx = np.unravel_index(np.argmax(pre_score, axis=None), pre_score.shape)
        # print(maxIdx[0])
        target_seq_output = target_seq_output[-1,:].reshape(1,-1)
        for i, word in enumerate(target_seq_output[0,:]):
            if word == token_dict[sentence_end_token]:
                break
        target_seq_output = target_seq_output[:,0:i]
        return target_seq_output