Exemple #1
0
 def get_ck_local(self, hop, story, story_size, domains):
     embed = _cuda(torch.zeros((story_size + (self.embedding_dim, ))))
     for i, domain in enumerate(domains):
         embed[i] = self.__getattribute__('C_{}_'.format(domain))[hop](
             story.contiguous()[i])
     embed = torch.sum(embed, 2).squeeze(2)
     return embed
Exemple #2
0
    def forward(self, input_seqs,
                input_lengths):  # input_lengths是该batch中,每个故事的长度
        # 感觉不将维度压缩更符合推测,因为需要相邻元素需要计算相似度,如果将MEM_TOKEN_SIZE压缩进去反而会导致词的语义被分割
        embeddings = self.embedding(
            input_seqs
        )  # [batch_size, story_length, MEM_TOKEN_SIZE, hidden_size]
        # 保持batch_size维度不变,另外两个维度合并,然后在embedding
        # embeddings = self.embedding(input_seqs.contiguous().view(input_seqs.size(0), -1).long())
        # embeddings = embeddings.view(input_seqs.size() + (embeddings.size(-1),))  # 添加一个维度
        embeddings = torch.sum(embeddings,
                               2)  # [batch_size, story_length, hidden_size]
        embeddings = self.dropout_layer(embeddings)  # 为什么要使用dropout
        # 随机丢弃一些embedding的特征(embedding的每一维就是一个特征),防止过拟合

        hidden_init = _cuda(
            torch.zeros(
                2 * self.n_layers, input_seqs.size(0),
                self.hidden_size))  # [2, batch_size, hidden_size]隐含状态的初始值
        if input_lengths:
            embeddings = nn.utils.rnn.pack_padded_sequence(embeddings,
                                                           input_lengths,
                                                           batch_first=True)
        output, hidden = self.gru(
            embeddings,
            hidden_init)  # output [] hidden [2, batch_size, hidden_size]
        # outputs   (seq_len, batch, num_directions * hidden_size) hidden [2 8 128](num_layers * num_directions, batch, hidden_size)

        if input_lengths:  # 消除pack_padded_sequence的填充
            output, _ = nn.utils.rnn.pad_packed_sequence(output,
                                                         batch_first=True)  #
        hidden = self.W(torch.cat((hidden[0], hidden[1]), dim=1))
        output = self.W(output)  # [batch_size, story_length, hidden_size]

        return output, hidden
Exemple #3
0
    def forward(self, input_seqs, input_lengths):
        embedded = self.embedding(input_seqs.contiguous().view(
            input_seqs.size(0), -1).long())
        embedded = embedded.view(input_seqs.size() + (embedded.size(-1), ))
        embedded = torch.sum(embedded, 2).squeeze(2)
        embedded = self.dropout_layer(embedded.transpose(0, 1))
        global_outputs, global_hidden = self.global_gru(
            embedded, input_lengths)
        local_outputs = []

        mask = _cuda(torch.zeros((len(input_lengths), input_lengths[0])))
        for i, length in enumerate(input_lengths):
            mask[i, :length] = 1

        for domain in self.domains:
            local_rnn = getattr(self, '{}_gru'.format(domain))
            local_output, _ = local_rnn(embedded, input_lengths)
            local_outputs.append(local_output)

        local_outputs, scores = self.mix_attention(
            torch.stack(local_outputs, dim=-1), mask)
        outputs = self.MLP_H(
            torch.cat((F.dropout(local_outputs, self.dropout, self.training),
                       F.dropout(global_outputs, self.dropout, self.training)),
                      dim=-1))

        hidden = self.selfatten(outputs, input_lengths)
        outputs_ = self.W(outputs)
        hidden_ = self.W(hidden)
        label = self.global_classifier(global_outputs)
        return outputs_, hidden_, label, scores
Exemple #4
0
 def get_state(self, bsz):
     """Get cell states and hidden states."""
     return _cuda(torch.zeros(2, bsz, self.hidden_size))
Exemple #5
0
    def forward(self, extKnow, story_size, story_lengths, copy_list,
                encode_hidden, target_batches, max_target_length, batch_size,
                use_teacher_forcing, get_decoded_words, global_pointer):
        # Initialize variables for vocab and pointer
        all_decoder_outputs_vocab = _cuda(
            torch.zeros(max_target_length, batch_size, self.num_vocab))
        all_decoder_outputs_ptr = _cuda(
            torch.zeros(max_target_length, batch_size, story_size[1]))
        decoder_input = _cuda(torch.LongTensor([SOS_token] * batch_size))
        memory_mask_for_step = _cuda(torch.ones(story_size[0], story_size[1]))
        decoded_fine, decoded_coarse = [], []

        hidden = self.relu(self.projector(encode_hidden)).unsqueeze(0)

        # Start to generate word-by-word
        for t in range(max_target_length):
            embed_q = self.dropout_layer(self.C(decoder_input))  # b * e
            if len(embed_q.size()) == 1: embed_q = embed_q.unsqueeze(0)
            _, hidden = self.sketch_rnn(embed_q.unsqueeze(0), hidden)
            query_vector = hidden[0]

            p_vocab = self.attend_vocab(self.C.weight, hidden.squeeze(0))
            all_decoder_outputs_vocab[t] = p_vocab
            _, topvi = p_vocab.data.topk(1)

            # query the external konwledge using the hidden state of sketch RNN
            prob_soft, prob_logits = extKnow(query_vector, global_pointer)
            all_decoder_outputs_ptr[t] = prob_logits

            if use_teacher_forcing:
                decoder_input = target_batches[:, t]
            else:
                decoder_input = topvi.squeeze()

            if get_decoded_words:

                search_len = min(5, min(story_lengths))
                prob_soft = prob_soft * memory_mask_for_step
                _, toppi = prob_soft.data.topk(search_len)
                temp_f, temp_c = [], []

                for bi in range(batch_size):
                    token = topvi[bi].item()  #topvi[:,0][bi].item()
                    temp_c.append(self.lang.index2word[token])

                    if '@' in self.lang.index2word[token]:
                        cw = 'UNK'
                        for i in range(search_len):
                            if toppi[:, i][bi] < story_lengths[bi] - 1:
                                cw = copy_list[bi][toppi[:, i][bi].item()]
                                break
                        temp_f.append(cw)

                        if args['record']:
                            memory_mask_for_step[bi, toppi[:,
                                                           i][bi].item()] = 0
                    else:
                        temp_f.append(self.lang.index2word[token])

                decoded_fine.append(temp_f)
                decoded_coarse.append(temp_c)

        return all_decoder_outputs_vocab, all_decoder_outputs_ptr, decoded_fine, decoded_coarse
Exemple #6
0
    def forward(self, context_hidden, context_outputs, context_lengths, context_mask, \
                context_entity, context_entity_lengths, context_entity_mask, context_entity_id, \
                kb_entity, kb_entity_id, kb_entity_row, kb_entity_lengths, kb_entity_mask, \
                entity, entity_lengths, entity_mask, entity_plain, entity_type, \
                target_batches, max_target_length, schedule_sampling, get_decoded_words):

        batch_size, entity_set_length = entity.size(0), entity.size(1)
        #context_entity_id = context_entity_id + context_entity_mask.long() * (entity_set_length-1)
        #kb_entity_id = kb_entity_id + kb_entity_mask.long() * (entity_set_length-1)

        # Initialize variables for vocab and pointer
        all_decoder_outputs_vocab = _cuda(
            torch.zeros(max_target_length, batch_size, self.num_vocab))
        all_decoder_outputs_ptr = _cuda(
            torch.zeros(max_target_length, batch_size, entity_set_length))

        decoder_input = _cuda(torch.LongTensor([SOS_token] * batch_size))
        memory_mask_for_step = _cuda(torch.ones(batch_size, entity_set_length))
        decoded_fine, decoded_coarse = [], []

        dec_hidden = self.relu(self.projector(context_hidden))

        # Start to generate word-by-word
        for t in range(max_target_length):
            pre_emb = self.dropout_layer(self.embedder(decoder_input))  # b * e
            if len(pre_emb.size()) == 1: pre_emb = pre_emb.unsqueeze(0)

            _, dec_hidden = self.gru(pre_emb.unsqueeze(0), dec_hidden)

            # For context distribution
            p_entity_context = _cuda(torch.zeros(batch_size,
                                                 entity_set_length))

            context_entity_hidden, context_entity_pro = self.context_entity_attention(dec_hidden.transpose(0,1), \
                                                            context_entity, mask=context_entity_mask, return_weights=True)
            p_entity_context.scatter_add_(1, context_entity_id,
                                          context_entity_pro.squeeze(1))

            # For KB distribution
            p_entity_kb = _cuda(torch.zeros(batch_size, entity_set_length))

            ## Row-level
            kb_entity_row_onehot = to_onehot(kb_entity_row,
                                             mask=kb_entity_mask).transpose(
                                                 1, 2)  # B x maxR x maxE
            kb_entity_row_hidden = torch.bmm(kb_entity_row_onehot,
                                             kb_entity)  # B x maxR x h
            kb_entity_row_sum = kb_entity_row_onehot.sum(
                2, keepdim=True, dtype=torch.float)  # B x maxR x 1
            kb_entity_row_mask = kb_entity_row_sum.squeeze(2).eq(0)
            kb_entity_row_sum = torch.clamp(kb_entity_row_sum, min=1)
            kb_entity_row_hidden = kb_entity_row_hidden / kb_entity_row_sum
            kb_entity_hidden, kb_entity_row_pro = self.kb_entity_attention(dec_hidden.transpose(0,1), \
                                                kb_entity_row_hidden, mask=kb_entity_row_mask, return_weights=True)
            kb_entity_row_pro = torch.bmm(
                kb_entity_row_pro, kb_entity_row_onehot).squeeze(1)  # B x maxE

            ## Entity-level
            kb_entity_logit = self.kb_entity_attention(dec_hidden.transpose(0,1), \
                                                            kb_entity, return_weights_only=True) # B x maxE x 1
            kb_entity_logit = kb_entity_logit * kb_entity_row_onehot  # B x maxR x maxE
            #kb_entity_logit.masked_fill_(torch.logical_not(kb_entity_row_onehot.bool()), -1e9)
            kb_entity_logit.masked_fill_(1 - kb_entity_row_onehot.byte(), -1e9)
            #kb_entity_logit = kb_entity_logit - (1 - kb_entity_row_onehot) * 1e10
            kb_entity_pro = F.softmax(kb_entity_logit, dim=2)
            kb_entity_pro = torch.gather(kb_entity_pro, 1,
                                         kb_entity_row.unsqueeze(1)).squeeze(1)
            #kb_entity_pro = kb_entity_pro.sum(1)
            kb_entity_pro = kb_entity_pro * kb_entity_row_pro
            p_entity_kb.scatter_add_(1, kb_entity_id, kb_entity_pro)
            """
            kb_entity_hidden, kb_entity_logit = self.kb_entity_attention(dec_hidden.transpose(0,1), \
                                                            kb_entity, mask=kb_entity_mask, return_weights=True)
            kb_entity_logit = kb_entity_logit.squeeze(1)
            p_entity_kb.scatter_add_(1, kb_entity_id, kb_entity_logit)
            """

            switch_input = self.switch(dec_hidden.squeeze(0))
            #pro_switch = self.softmax(switch_input)

            #if not get_decoded_words:
            #    pro_switch = nn.functional.gumbel_softmax(switch_input, tau=1.0 - (epoch / 15.0), hard=False)
            #else:
            #    pro_switch = nn.functional.gumbel_softmax(switch_input, tau=1.0 - (epoch / 15.0), hard=True)

            #p_entity = torch.cat((p_entity_context.unsqueeze(2), p_entity_kb.unsqueeze(2)), dim=2)
            #p_entity = torch.bmm(p_entity, pro_switch.unsqueeze(2)).squeeze(2)
            pro_switch = self.sigmoid(switch_input)
            p_entity = (
                1.0 - pro_switch) * p_entity_context + pro_switch * p_entity_kb

            # For Vocab
            vocab_attn = self.context_attention(dec_hidden.transpose(0, 1),
                                                context_outputs,
                                                mask=context_mask)
            #entity_hidden = torch.cat((context_entity_hidden, kb_entity_hidden), dim=1)
            #entity_hidden = torch.bmm(pro_switch.unsqueeze(1), entity_hidden)
            entity_hidden = context_entity_hidden.squeeze(1) * (
                1 - pro_switch) + kb_entity_hidden.squeeze(1) * pro_switch
            #concat_input = torch.cat((dec_hidden.squeeze(0), vocab_attn.squeeze(1)), dim=1)
            concat_input = torch.cat(
                (dec_hidden.squeeze(0), vocab_attn.squeeze(1),
                 entity_hidden.squeeze(1)),
                dim=1)
            concat_output = torch.tanh(self.concat(concat_input))
            #p_vocab = self.attend_vocab(self.embedder.weight, concat_output)
            p_vocab = self.vocab_matrix(concat_output)

            all_decoder_outputs_vocab[t] = p_vocab
            all_decoder_outputs_ptr[t] = p_entity

            use_teacher_forcing = random.random() < schedule_sampling
            if use_teacher_forcing:
                decoder_input = target_batches[:, t]
            else:
                _, topvi = p_vocab.data.topk(1)
                decoder_input = topvi.squeeze()

            if get_decoded_words:
                prob_soft = self.softmax(p_entity)
                search_len = min(5, min(entity_lengths))
                prob_soft = prob_soft * memory_mask_for_step
                _, toppi = prob_soft.data.topk(search_len)
                temp_f, temp_c = [], []

                for bi in range(batch_size):
                    token = topvi[bi].item()  #topvi[:,0][bi].item()
                    temp_c.append(self.vocab.index2word[token])

                    if '@' in self.vocab.index2word[token]:
                        slot = self.vocab.index2word[token]
                        cw = 'UNK'
                        for i in range(search_len):
                            top_index = toppi[:, i][bi].item()
                            #if top_index < entity_lengths[bi]-1 and entity_type[bi][top_index] == slot:
                            if top_index < entity_lengths[bi] - 1:
                                cw = entity_plain[bi][toppi[:, i][bi].item()]
                                break
                        temp_f.append(cw)

                        if args['record']:
                            memory_mask_for_step[bi, toppi[:,
                                                           i][bi].item()] = 0
                    else:
                        temp_f.append(self.vocab.index2word[token])

                decoded_fine.append(temp_f)
                decoded_coarse.append(temp_c)

        return all_decoder_outputs_vocab, all_decoder_outputs_ptr, decoded_fine, decoded_coarse
Exemple #7
0
    def forward(self, kb_ent, extKnow, context, context_mask, copy_list,
                encode_hidden, target_batches, max_target_length,
                schedule_sampling, get_decoded_words):
        batch_size = len(copy_list)
        story_size = max([len(seq) for seq in copy_list])
        extKnow_mask, _ = mask_and_length(kb_ent, PAD_token)

        # Initialize variables for vocab and pointer
        all_decoder_outputs_vocab = _cuda(
            torch.zeros(max_target_length, batch_size, self.num_vocab))
        all_decoder_outputs_ptr = _cuda(
            torch.zeros(max_target_length, batch_size, story_size))
        decoder_input = _cuda(torch.LongTensor([SOS_token] * batch_size))
        decoded_fine, decoded_coarse = [], []

        #hidden = self.relu(self.projector(encode_hidden)).unsqueeze(0)
        hidden = self.tanh(self.projector(encode_hidden)).unsqueeze(0)

        # Start to generate word-by-word
        for t in range(max_target_length + 1):
            rnn_input_list, concat_input_list = [], []

            embed_q = self.dropout_layer(self.embedder(decoder_input))  # b * e
            if len(embed_q.size()) == 1: embed_q = embed_q.unsqueeze(0)
            rnn_input_list.append(embed_q)

            rnn_input = torch.cat(rnn_input_list, dim=1)
            _, hidden = self.gru(rnn_input.unsqueeze(0), hidden)
            concat_input_list.append(hidden.squeeze(0))

            #get knowledge attention
            knowledge_outputs = self.knowledge_attention(hidden.transpose(
                0, 1),
                                                         extKnow,
                                                         mask=extKnow_mask)
            concat_input_list.append(knowledge_outputs.squeeze(1))

            #get context attention
            context_outputs = self.context_attention(hidden.transpose(0, 1),
                                                     context,
                                                     mask=context_mask)
            concat_input_list.append(context_outputs.squeeze(1))

            #concat_input = torch.cat((hidden.squeeze(0), context_outputs.squeeze(1), knowledge_outputs.squeeze(1)), dim=1)
            concat_input = torch.cat(concat_input_list, dim=1)
            concat_output = torch.tanh(self.concat(concat_input))

            if t < max_target_length:
                #p_vocab = self.attend_vocab(self.C.weight, concat_output)
                p_vocab = self.vocab_matrix(concat_output)
                all_decoder_outputs_vocab[t] = p_vocab

            if t > 0:
                p_entity = self.entity_ranking(concat_output.unsqueeze(1),
                                               extKnow,
                                               mask=extKnow_mask).squeeze(1)
                all_decoder_outputs_ptr[t - 1] = p_entity

            if t < max_target_length:
                use_teacher_forcing = random.random() < schedule_sampling
                if use_teacher_forcing:
                    decoder_input = target_batches[:, t]
                else:
                    _, topvi = p_vocab.data.topk(1)
                    decoder_input = topvi.squeeze()

        # Start to generate word-by-word
        if get_decoded_words:
            for t in range(max_target_length):
                p_vocab = all_decoder_outputs_vocab[t]
                p_entity = all_decoder_outputs_ptr[t]
                _, topvi = p_vocab.data.topk(1)

                search_len = min(5, story_size)
                _, toppi = p_entity.data.topk(search_len)
                temp_f, temp_c = [], []

                for bi in range(batch_size):
                    token = topvi[bi].item()  #topvi[:,0][bi].item()
                    temp_c.append(self.lang.index2word[token])

                    if '@' in self.lang.index2word[token]:
                        cw = 'UNK'
                        for i in range(search_len):
                            #if toppi[:,i][bi] < story_lengths[bi]-1:
                            if toppi[:, i][bi] > 0 and toppi[:, i][bi] < len(
                                    copy_list[bi]):
                                cw = copy_list[bi][toppi[:, i][bi].item()]
                                break

                        temp_f.append(cw)
                    else:
                        temp_f.append(self.lang.index2word[token])

                decoded_fine.append(temp_f)
                decoded_coarse.append(temp_c)

        return all_decoder_outputs_vocab, all_decoder_outputs_ptr, decoded_fine, decoded_coarse
Exemple #8
0
 def flipByLength(self, input, lengths):
     output = _cuda(torch.zeros_like(input))
     for i, l in enumerate(lengths):
         output[i, :l, :] = torch.flip(input[i, :l, :], (0, ))
     return output
Exemple #9
0
    def forward(self,
                extKnow,
                story_size,
                story_lengths,
                copy_list,
                encode_hidden,
                target_batches,
                max_target_length,
                batch_size,
                use_teacher_forcing,
                get_decoded_words,
                global_pointer,
                H=None,
                global_entity_type=None,
                domains=None):
        # Initialize variables for vocab and pointer
        all_decoder_outputs_vocab = _cuda(
            torch.zeros(max_target_length, batch_size, self.num_vocab))
        all_decoder_outputs_ptr = _cuda(
            torch.zeros(max_target_length, batch_size, story_size[1]))
        decoder_input = _cuda(self.domain_emb(domains.view(-1, ))) + self.C(
            _cuda(torch.LongTensor([SOS_token] * batch_size)))
        memory_mask_for_step = _cuda(torch.ones(story_size[0], story_size[1]))
        decoded_fine, decoded_coarse = [], []
        hidden = self.relu(self.projector(encode_hidden)).unsqueeze(0)
        hidden_locals = []
        for i in range(len(self.domains)):
            hidden_locals.append(hidden.clone())

        mask = _cuda(torch.ones((len(story_lengths), 1)))
        global_hiddens = []
        local_hiddens = []
        scores = []
        # Start to generate word-by-word
        for t in range(max_target_length):
            if t != 0:
                decoder_input = self.C(decoder_input)
            embed_q = self.dropout_layer(decoder_input)
            if len(embed_q.size()) == 2: embed_q = embed_q.unsqueeze(0)

            _, hidden = self.sketch_rnn_global(embed_q, hidden)
            hidden_locals_ = []
            for domain in self.domains.values():
                hidden_locals_.append(self.sketch_rnn_local[domain](
                    embed_q, hidden_locals[domain])[1])
            hidden_locals = hidden_locals_
            hidden_local, score = self.mix_attention(
                torch.stack(hidden_locals, dim=-1).transpose(0, 1), mask)
            hidden_local, score = hidden_local.transpose(0,
                                                         1), score.transpose(
                                                             0, 1)
            scores.append(score)
            query_vector = self.MLP(
                torch.cat(
                    (F.dropout(hidden, self.dropout, self.training),
                     F.dropout(hidden_local, self.dropout, self.training)),
                    dim=-1))
            global_hiddens.append(hidden)
            local_hiddens.append(hidden_local)

            p_vocab, context = self.get_p_vocab(query_vector[0], H)

            all_decoder_outputs_vocab[t] = p_vocab
            _, topvi = p_vocab.data.topk(1)

            # query the external konwledge using the hidden state of sketch RNN
            prob_soft, prob_logits = extKnow(context[0], global_pointer)
            all_decoder_outputs_ptr[t] = prob_logits

            if use_teacher_forcing:
                decoder_input = target_batches[:, t]
            else:
                decoder_input = topvi.squeeze()

            if get_decoded_words:

                search_len = min(5, min(story_lengths))
                prob_soft = prob_soft * memory_mask_for_step
                _, toppi = prob_soft.data.topk(search_len)
                temp_f, temp_c = [], []

                for bi in range(batch_size):
                    token = topvi[bi].item()
                    temp_c.append(self.lang.index2word[token])

                    if '@' in self.lang.index2word[token]:
                        gold_type = self.lang.index2word[token]
                        cw = 'UNK'
                        for i in range(search_len):
                            if toppi[:, i][bi] < story_lengths[bi] - 1:
                                cw = copy_list[bi][toppi[:, i][bi].item()]
                                break
                        temp_f.append(cw)

                        if args['record']:
                            memory_mask_for_step[bi, toppi[:,
                                                           i][bi].item()] = 0
                    else:
                        temp_f.append(self.lang.index2word[token])

                decoded_fine.append(temp_f)
                decoded_coarse.append(temp_c)

        label = self.global_classifier(
            torch.cat(global_hiddens, dim=0).transpose(0, 1))
        scores = torch.cat(scores, dim=0).transpose(0, 1).contiguous()
        return all_decoder_outputs_vocab, all_decoder_outputs_ptr, decoded_fine, decoded_coarse, label, scores
Exemple #10
0
    def forward(self, story_size, ext_know, global_ptr, story_length,
                max_target_length, batch_size, encoded_hidden, evaluating,
                copy_list, use_teacher_forcing, response_target):
        record = _cuda(torch.ones(story_size[0],
                                  story_size[1]))  # [batch_size, story_length]
        # all_decoder_output_ptr输出的是局部指针,是针对当前对话来说的
        all_decoder_output_ptr = _cuda(
            torch.zeros(max_target_length, batch_size,
                        story_size[1]))  # 针对当前对话
        # all_decoder_output_vocab 针对的是词汇表
        all_decoder_output_vocab = _cuda(
            torch.zeros(max_target_length, batch_size,
                        self.num_vocab))  # 针对词汇表
        decoder_input = _cuda(torch.LongTensor(
            [SOS_token] * batch_size))  # 每次为同一个batch的样本生成一个单词
        hidden_init = self.relu(self.projector(encoded_hidden)).unsqueeze(
            0)  # 对连接降维, 为什么要添加relu
        decoded_fine, decoded_coarse = [], []

        # 使用sketch RNN逐字生成输出
        for t in range(max_target_length):
            sketch_response = self.dropout_layer(
                self.C(decoder_input))  #[8] -> [1,8,128] .
            if len(sketch_response.size()
                   ) == 1:  # batch_size==1的时候会出现维度只有一位的情况
                sketch_response = sketch_response.unsqueeze(0)
            # 这里的seq_len为什么设置1?
            _, hidden = self.sketch_rnn(
                sketch_response.unsqueeze(0),
                hidden_init)  # [seq_len, batch_size, embedding_dim]
            query = hidden[
                0]  # [num_layers * num_directions, batch, embedding_dim]  我认为结果包含了各层的隐含态
            # p_vocab [batch_size, vocab_size]
            # 论文对p_vocab进行了softmax操作,但是实际代码注释了,因为会使得效果变得比较差
            '''
            C 的维度是[词汇表长度, embedding_dim],  从词向量矩阵中计算注意力得分(未归一化),
            因为embed_layer包含了词汇表所有的词汇的表示,而文本经过embed_layer得到的就是与文本长度有关的嵌入矩阵
            '''
            p_vocab = hidden.squeeze(0).matmul(self.C.weight.transpose(
                1, 0))  # 这里添加softmax层导致效果变差
            # p_vocab = self.attend_vocab(self.C.weight, hidden.squeeze(0))
            # p_vocab [vocab_size, embedding_dim]
            all_decoder_output_vocab[t] = p_vocab
            _, top_p_vocab = p_vocab.data.topk(1)  #

            # 使用sketch rnn的最后隐含态查询EK得到注意力分布,也就是local pointer
            local_ptr, prob_soft = ext_know(query,
                                            global_ptr)  # 针对整个文本计算注意力分布,然后从中抄词
            all_decoder_output_ptr[t] = local_ptr

            if use_teacher_forcing:  # 使用了标签数据进行初始化,算不算数据泄露?
                decoder_input = response_target[:, t]
            else:
                decoder_input = top_p_vocab.squeeze(
                )  # 使用这个来不断改变sketch_response,之前就是这里的问题

            if evaluating:
                search_len = min(5, min(story_length))
                prob_soft = prob_soft * record
                _, top_p_soft = prob_soft.data.topk(search_len)

                tmp_f, tmp_c = [], []
                for bi in range(batch_size):
                    token = top_p_vocab[bi].item()
                    tmp_c.append(self.word_map.index2word[token])

                    if '@' in self.word_map.index2word[
                            token]:  #'@R_cuisine','@R_location','@R_number','@R_price'
                        cw = 'UNK'  # 改为数值
                        for i in range(search_len):
                            if top_p_soft[bi][i] < story_length[
                                    bi] - 1:  # top_p_soft[i][bi] -> top_p_soft[:, i][bi]
                                cw = copy_list[bi][top_p_soft[bi][i].item()]
                                break
                        tmp_f.append(cw)  # 这个是放在循环外面
                        if args['record']:
                            record[bi][top_p_soft[bi]
                                       [i].item()] = 0  # copy_list中已经使用的部分清零
                    else:
                        tmp_f.append(self.word_map.index2word[token]
                                     )  # 如果不是那几个‘@’的话,则记录单词
                decoded_fine.append(tmp_f)
                decoded_coarse.append(tmp_c)
        return all_decoder_output_vocab, all_decoder_output_ptr, decoded_fine, decoded_coarse