print(
        '[LOG TRAIN {}] epoch: {:04}/{:04}, discriminator loss: {:.4f}'.format(
            now, epoch + 1, num_epochs, epoch_discriminator_loss))
    print('[LOG TRAIN {}] epoch: {:04}/{:04}, generator loss: {:.4f}'.format(
        now, epoch + 1, num_epochs, epoch_generator_loss))

    # =================== save model snapshots to disk ============================

    # save trained encoder model file to disk
    now = datetime.utcnow().strftime("%Y%m%d-%H_%M_%S")
    encoder_model_name = "{}_ep_{}_encoder_model.pth".format(now, (epoch + 1))
    torch.save(encoder_train.state_dict(),
               os.path.join("./models", encoder_model_name))

    # save trained decoder model file to disk
    decoder_model_name = "{}_ep_{}_decoder_model.pth".format(now, (epoch + 1))
    torch.save(decoder_train.state_dict(),
               os.path.join("./models", decoder_model_name))

    # save trained discriminator model file to disk
    decoder_model_name = "{}_ep_{}_discriminator_model.pth".format(
        now, (epoch + 1))
    torch.save(discriminator_train.state_dict(),
               os.path.join("./models", decoder_model_name))

totalElapsedTime = time.time() - start
#save execution summary
exec_summary = "{}_model_exec_summary.txt".format(now, (epoch + 1))
f = open(os.path.join("./models", exec_summary), "w+")
f.write("training elapsed time): %f " % totalElapsedTime)
f.close()
Example #2
0
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.config = config

        # 定义嵌入层
        self.embedding = Embedding(
            config.num_vocab,  # 词汇表大小
            config.embedding_size,  # 嵌入层维度
            config.pad_id,  # pad_id
            config.dropout)

        # 情感嵌入层
        self.affect_embedding = Embedding(config.num_vocab,
                                          config.affect_embedding_size,
                                          config.pad_id, config.dropout)
        self.affect_embedding.embedding.weight.requires_grad = False

        # post编码器
        self.post_encoder = Encoder(
            config.post_encoder_cell_type,  # rnn类型
            config.embedding_size + config.affect_embedding_size,  # 输入维度
            config.post_encoder_output_size,  # 输出维度
            config.post_encoder_num_layers,  # rnn层数
            config.post_encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        # response编码器
        self.response_encoder = Encoder(
            config.response_encoder_cell_type,
            config.embedding_size + config.affect_embedding_size,  # 输入维度
            config.response_encoder_output_size,  # 输出维度
            config.response_encoder_num_layers,  # rnn层数
            config.response_encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        # 先验网络
        self.prior_net = PriorNet(
            config.post_encoder_output_size,  # post输入维度
            config.latent_size,  # 潜变量维度
            config.dims_prior)  # 隐藏层维度

        # 识别网络
        self.recognize_net = RecognizeNet(
            config.post_encoder_output_size,  # post输入维度
            config.response_encoder_output_size,  # response输入维度
            config.latent_size,  # 潜变量维度
            config.dims_recognize)  # 隐藏层维度

        # 初始化解码器状态
        self.prepare_state = PrepareState(
            config.post_encoder_output_size + config.latent_size,
            config.decoder_cell_type, config.decoder_output_size,
            config.decoder_num_layers)

        # 解码器
        self.decoder = Decoder(
            config.decoder_cell_type,  # rnn类型
            config.embedding_size + config.affect_embedding_size +
            config.post_encoder_output_size,
            config.decoder_output_size,  # 输出维度
            config.decoder_num_layers,  # rnn层数
            config.dropout)  # dropout概率

        # bow预测
        self.bow_predictor = nn.Sequential(
            nn.Linear(config.post_encoder_output_size + config.latent_size,
                      config.num_vocab), nn.Softmax(-1))

        # 输出层
        self.projector = nn.Sequential(
            nn.Linear(config.decoder_output_size, config.num_vocab),
            nn.Softmax(-1))

    def forward(self,
                inputs,
                inference=False,
                use_true=False,
                max_len=60,
                gpu=True):
        if not inference:  # 训练
            if use_true:  # 解码时使用真实值
                id_posts = inputs['posts']  # [batch, seq]
                len_posts = inputs['len_posts']  # [batch]
                id_responses = inputs['responses']  # [batch, seq]
                len_responses = inputs['len_responses']  # [batch, seq]
                sampled_latents = inputs[
                    'sampled_latents']  # [batch, latent_size]
                len_decoder = id_responses.size(1) - 1

                embed_posts = torch.cat([
                    self.embedding(id_posts),
                    self.affect_embedding(id_posts)
                ], 2)
                embed_responses = torch.cat([
                    self.embedding(id_responses),
                    self.affect_embedding(id_responses)
                ], 2)

                # state: [layers, batch, dim]
                _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                                   len_posts)
                _, state_responses = self.response_encoder(
                    embed_responses.transpose(0, 1), len_responses)
                if isinstance(state_posts, tuple):
                    state_posts = state_posts[0]
                if isinstance(state_responses, tuple):
                    state_responses = state_responses[0]
                x = state_posts[-1, :, :]  # [batch, dim]
                y = state_responses[-1, :, :]  # [batch, dim]

                _mu, _logvar = self.prior_net(x)  # [batch, latent]
                mu, logvar = self.recognize_net(x, y)  # [batch, latent]
                z = mu + (0.5 *
                          logvar).exp() * sampled_latents  # [batch, latent]

                bow_predict = self.bow_predictor(torch.cat(
                    [z, x], 1))  # [batch, num_vocab]

                first_state = self.prepare_state(torch.cat(
                    [z, x], 1))  # [num_layer, batch, dim_out]
                decoder_inputs = embed_responses[:, :-1, :].transpose(
                    0, 1)  # [seq-1, batch, embed_size]
                decoder_inputs = decoder_inputs.split(
                    [1] * len_decoder, 0)  # seq-1个[1, batch, embed_size]

                outputs = []
                for idx in range(len_decoder):
                    if idx == 0:
                        state = first_state  # 解码器初始状态
                    decoder_input = torch.cat(
                        [decoder_inputs[idx],
                         x.unsqueeze(0)], 2)
                    # output: [1, batch, dim_out]
                    # state: [num_layer, batch, dim_out]
                    output, state = self.decoder(decoder_input, state)
                    outputs.append(output)

                outputs = torch.cat(outputs,
                                    0).transpose(0,
                                                 1)  # [batch, seq-1, dim_out]
                output_vocab = self.projector(
                    outputs)  # [batch, seq-1, num_vocab]

                return output_vocab, bow_predict, _mu, _logvar, mu, logvar
            else:
                id_posts = inputs['posts']  # [batch, seq]
                len_posts = inputs['len_posts']  # [batch]
                id_responses = inputs['responses']  # [batch, seq]
                len_responses = inputs['len_responses']  # [batch]
                sampled_latents = inputs[
                    'sampled_latents']  # [batch, latent_size]
                len_decoder = id_responses.size(1) - 1
                batch_size = id_posts.size(0)

                embed_posts = torch.cat([
                    self.embedding(id_posts),
                    self.affect_embedding(id_posts)
                ], 2)
                embed_responses = torch.cat([
                    self.embedding(id_responses),
                    self.affect_embedding(id_responses)
                ], 2)

                # state: [layers, batch, dim]
                _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                                   len_posts)
                _, state_responses = self.response_encoder(
                    embed_responses.transpose(0, 1), len_responses)
                if isinstance(state_posts, tuple):
                    state_posts = state_posts[0]
                if isinstance(state_responses, tuple):
                    state_responses = state_responses[0]
                x = state_posts[-1, :, :]  # [batch, dim]
                y = state_responses[-1, :, :]  # [batch, dim]

                _mu, _logvar = self.prior_net(x)  # [batch, latent]
                mu, logvar = self.recognize_net(x, y)  # [batch, latent]
                z = mu + (0.5 *
                          logvar).exp() * sampled_latents  # [batch, latent]

                bow_predict = self.bow_predictor(torch.cat(
                    [z, x], 1))  # [batch, num_vocab]

                first_state = self.prepare_state(torch.cat(
                    [z, x], 1))  # [num_layer, batch, dim_out]
                first_input_id = (torch.ones(
                    (1, batch_size)) * self.config.start_id).long()
                if gpu:
                    first_input_id = first_input_id.cuda()

                outputs = []
                for idx in range(len_decoder):
                    if idx == 0:
                        state = first_state
                        decoder_input = torch.cat([
                            self.embedding(first_input_id),
                            self.affect_embedding(first_input_id),
                            x.unsqueeze(0)
                        ], 2)
                    else:
                        decoder_input = torch.cat([
                            self.embedding(next_input_id),
                            self.affect_embedding(next_input_id),
                            x.unsqueeze(0)
                        ], 2)
                    output, state = self.decoder(decoder_input, state)
                    outputs.append(output)

                    vocab_prob = self.projector(
                        output)  # [1, batch, num_vocab]
                    next_input_id = torch.argmax(
                        vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]

                outputs = torch.cat(outputs,
                                    0).transpose(0,
                                                 1)  # [batch, seq-1, dim_out]
                output_vocab = self.projector(
                    outputs)  # [batch, seq-1, num_vocab]
                return output_vocab, bow_predict, _mu, _logvar, mu, logvar
        else:  # 测试
            id_posts = inputs['posts']  # [batch, seq]
            len_posts = inputs['len_posts']  # [batch]
            sampled_latents = inputs['sampled_latents']  # [batch, latent_size]
            batch_size = id_posts.size(0)

            embed_posts = torch.cat(
                [self.embedding(id_posts),
                 self.affect_embedding(id_posts)], 2)

            # state: [layers, batch, dim]
            _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                               len_posts)
            if isinstance(state_posts, tuple):  # 如果是lstm则取h
                state_posts = state_posts[0]  # [layers, batch, dim]
            x = state_posts[-1, :, :]  # 取最后一层 [batch, dim]

            _mu, _logvar = self.prior_net(x)  # [batch, latent]
            z = _mu + (0.5 *
                       _logvar).exp() * sampled_latents  # [batch, latent]

            first_state = self.prepare_state(torch.cat(
                [z, x], 1))  # [num_layer, batch, dim_out]
            done = torch.tensor([0] * batch_size).bool()
            first_input_id = (torch.ones(
                (1, batch_size)) * self.config.start_id).long()
            if gpu:
                done = done.cuda()
                first_input_id = first_input_id.cuda()

            outputs = []
            for idx in range(max_len):
                if idx == 0:  # 第一个时间步
                    state = first_state  # 解码器初始状态
                    decoder_input = torch.cat([
                        self.embedding(first_input_id),
                        self.affect_embedding(first_input_id),
                        x.unsqueeze(0)
                    ], 2)
                else:
                    decoder_input = torch.cat([
                        self.embedding(next_input_id),
                        self.affect_embedding(next_input_id),
                        x.unsqueeze(0)
                    ], 2)
                # output: [1, batch, dim_out]
                # state: [num_layers, batch, dim_out]
                output, state = self.decoder(decoder_input, state)
                outputs.append(output)

                vocab_prob = self.projector(output)  # [1, batch, num_vocab]
                next_input_id = torch.argmax(
                    vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]
                _done = next_input_id.squeeze(
                    0) == self.config.end_id  # 当前时间步完成解码的 [batch]
                done = done | _done  # 所有完成解码的
                if done.sum() == batch_size:  # 如果全部解码完成则提前停止
                    break

            outputs = torch.cat(outputs,
                                0).transpose(0, 1)  # [batch, seq, dim_out]
            output_vocab = self.projector(outputs)  # [batch, seq, num_vocab]

            return output_vocab, _, _mu, _logvar, None, None

    def print_parameters(self):
        r""" 统计参数 """
        total_num = 0  # 参数总数
        for param in self.parameters():
            num = 1
            if param.requires_grad:
                size = param.size()
                for dim in size:
                    num *= dim
            total_num += num
        print(f"参数总数: {total_num}")

    def save_model(self, epoch, global_step, path):
        r""" 保存模型 """
        torch.save(
            {
                'affect_embedding': self.affect_embedding.state_dict(),
                'embedding': self.embedding.state_dict(),
                'post_encoder': self.post_encoder.state_dict(),
                'response_encoder': self.response_encoder.state_dict(),
                'prior_net': self.prior_net.state_dict(),
                'recognize_net': self.recognize_net.state_dict(),
                'prepare_state': self.prepare_state.state_dict(),
                'decoder': self.decoder.state_dict(),
                'projector': self.projector.state_dict(),
                'bow_predictor': self.bow_predictor.state_dict(),
                'epoch': epoch,
                'global_step': global_step
            }, path)

    def load_model(self, path):
        r""" 载入模型 """
        checkpoint = torch.load(path)
        self.affect_embedding.load_state_dict(checkpoint['affect_embedding'])
        self.embedding.load_state_dict(checkpoint['embedding'])
        self.post_encoder.load_state_dict(checkpoint['post_encoder'])
        self.response_encoder.load_state_dict(checkpoint['response_encoder'])
        self.prior_net.load_state_dict(checkpoint['prior_net'])
        self.recognize_net.load_state_dict(checkpoint['recognize_net'])
        self.prepare_state.load_state_dict(checkpoint['prepare_state'])
        self.decoder.load_state_dict(checkpoint['decoder'])
        self.projector.load_state_dict(checkpoint['projector'])
        self.bow_predictor.load_state_dict(checkpoint['bow_predictor'])
        epoch = checkpoint['epoch']
        global_step = checkpoint['global_step']
        return epoch, global_step
Example #3
0
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.config = config

        # 定义嵌入层
        self.embedding = Embedding(
            config.num_vocab,  # 词汇表大小
            config.embedding_size,  # 嵌入层维度
            config.pad_id,
            config.dropout)

        self.affect_embedding = Embedding(config.num_vocab,
                                          config.affect_embedding_size,
                                          config.pad_id, config.dropout)
        self.affect_embedding.embedding.weight.requires_grad = False

        # 编码器
        self.encoder = Encoder(
            config.encoder_decoder_cell_type,  # rnn类型
            config.embedding_size + config.affect_embedding_size,  # 输入维度
            config.encoder_decoder_output_size,  # 输出维度
            config.encoder_decoder_num_layers,  # rnn层数
            config.encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        # 解码器
        self.decoder = Decoder(
            config.encoder_decoder_cell_type,  # rnn类型
            config.embedding_size + config.affect_embedding_size +
            config.encoder_decoder_output_size,
            config.encoder_decoder_output_size,  # 输出维度
            config.encoder_decoder_num_layers,  # rnn层数
            config.dropout)  # dropout概率

        # 输出层
        self.projector = nn.Sequential(
            nn.Linear(config.encoder_decoder_output_size, config.num_vocab),
            nn.Softmax(-1))

    def forward(self, inputs, inference=False, max_len=60, gpu=True):
        if not inference:  # 训练
            id_posts = inputs['posts']  # [batch, seq]
            len_posts = inputs['len_posts']  # [batch]
            id_responses = inputs['responses']  # [batch, seq]
            len_decoder = id_responses.size(1) - 1

            embed_posts = torch.cat(
                [self.embedding(id_posts),
                 self.affect_embedding(id_posts)], 2)
            embed_responses = torch.cat([
                self.embedding(id_responses),
                self.affect_embedding(id_responses)
            ], 2)

            # encoder_output: [seq, batch, dim]
            # state: [layers, batch, dim]
            _, state_encoder = self.encoder(embed_posts.transpose(0, 1),
                                            len_posts)
            if isinstance(state_encoder, tuple):
                context = state_encoder[0][-1, :, :]
            else:
                context = state_encoder[-1, :, :]
            context = context.unsqueeze(0)  # [1, batch, dim]

            # 解码器的输入为回复去掉end_id
            decoder_inputs = embed_responses[:, :-1, :].transpose(
                0, 1)  # [seq-1, batch, embed_size]
            decoder_inputs = decoder_inputs.split([1] * len_decoder, 0)

            outputs = []
            for idx in range(len_decoder):
                if idx == 0:
                    state = state_encoder  # 解码器初始状态
                    decoder_input = torch.cat([decoder_inputs[idx], context],
                                              2)
                else:
                    decoder_input = torch.cat([decoder_inputs[idx], context],
                                              2)
                # output: [1, batch, dim_out]
                # state: [num_layer, batch, dim_out]
                output, state = self.decoder(decoder_input, state)
                outputs.append(output)

            outputs = torch.cat(outputs,
                                0).transpose(0, 1)  # [batch, seq-1, dim_out]
            output_vocab = self.projector(outputs)  # [batch, seq-1, num_vocab]

            return output_vocab
        else:  # 测试
            id_posts = inputs['posts']  # [batch, seq]
            len_posts = inputs['len_posts']  # [batch]
            batch_size = id_posts.size(0)

            embed_posts = torch.cat(
                [self.embedding(id_posts),
                 self.affect_embedding(id_posts)], 2)

            # state: [layers, batch, dim]
            _, state_encoder = self.encoder(embed_posts.transpose(0, 1),
                                            len_posts)
            if isinstance(state_encoder, tuple):
                context = state_encoder[0][-1, :, :]
            else:
                context = state_encoder[-1, :, :]
            context = context.unsqueeze(0)  # [1, batch, dim]

            done = torch.tensor([0] * batch_size).bool()
            first_input_id = (torch.ones(
                (1, batch_size)) * self.config.start_id).long()
            if gpu:
                done = done.cuda()
                first_input_id = first_input_id.cuda()

            outputs = []
            for idx in range(max_len):
                if idx == 0:  # 第一个时间步
                    state = state_encoder  # 解码器初始状态
                    decoder_input = torch.cat([
                        self.embedding(first_input_id),
                        self.affect_embedding(first_input_id), context
                    ], 2)
                else:
                    decoder_input = torch.cat([
                        self.embedding(next_input_id),
                        self.affect_embedding(next_input_id), context
                    ], 2)

                # output: [1, batch, dim_out]
                # state: [num_layers, batch, dim_out]
                output, state = self.decoder(decoder_input, state)
                outputs.append(output)

                vocab_prob = self.projector(output)  # [1, batch, num_vocab]
                next_input_id = torch.argmax(
                    vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]
                _done = next_input_id.squeeze(
                    0) == self.config.end_id  # 当前时间步完成解码的 [batch]
                done = done | _done  # 所有完成解码的
                if done.sum() == batch_size:  # 如果全部解码完成则提前停止
                    break

            outputs = torch.cat(outputs,
                                0).transpose(0, 1)  # [batch, seq, dim_out]
            output_vocab = self.projector(outputs)  # [batch, seq, num_vocab]

            return output_vocab

    # 统计参数
    def print_parameters(self):
        r""" 统计参数 """
        total_num = 0  # 参数总数
        for param in self.parameters():
            num = 1
            if param.requires_grad:
                size = param.size()
                for dim in size:
                    num *= dim
            total_num += num
        print(f"参数总数: {total_num}")

    def save_model(self, epoch, global_step, path):
        r""" 保存模型 """
        torch.save(
            {
                'embedding': self.embedding.state_dict(),
                'affect_embedding': self.affect_embedding.state_dict(),
                'encoder': self.encoder.state_dict(),
                'decoder': self.decoder.state_dict(),
                'projector': self.projector.state_dict(),
                'epoch': epoch,
                'global_step': global_step
            }, path)

    def load_model(self, path):
        r""" 载入模型 """
        checkpoint = torch.load(path)
        self.embedding.load_state_dict(checkpoint['embedding'])
        self.affect_embedding.load_state_dict(checkpoint['affect_embedding'])
        self.encoder.load_state_dict(checkpoint['encoder'])
        self.decoder.load_state_dict(checkpoint['decoder'])
        self.projector.load_state_dict(checkpoint['projector'])
        epoch = checkpoint['epoch']
        global_step = checkpoint['global_step']
        return epoch, global_step
Example #4
0
class Mem2SeqRunner(ExperimentRunnerBase):
    def __init__(self, args):
        super(Mem2SeqRunner, self).__init__(args)

        # Model parameters
        self.gru_size = 128
        self.emb_size = 128
        #TODO: Try hops 4 with task 3
        self.hops = 3
        self.dropout = 0.2

        self.encoder = Encoder(self.hops, self.nwords, self.gru_size)
        self.decoder = Decoder(self.emb_size, self.hops, self.gru_size,
                               self.nwords)

        self.optim_enc = torch.optim.Adam(self.encoder.parameters(), lr=0.001)
        self.optim_dec = torch.optim.Adam(self.decoder.parameters(), lr=0.001)
        if self.loss_weighting:
            self.optim_loss_weights = torch.optim.Adam([self.loss_weights],
                                                       lr=0.0001)
        self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optim_dec,
                                                        mode='max',
                                                        factor=0.5,
                                                        patience=1,
                                                        min_lr=0.0001,
                                                        verbose=True)

        if self.use_cuda:
            self.cross_entropy = self.cross_entropy.cuda()
            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()
            if self.loss_weighting:
                self.loss_weights = self.loss_weights.cuda()

    def train_batch_wrapper(self, batch, new_epoch, clip_grads):
        context = batch[0].transpose(0, 1)
        responses = batch[1].transpose(0, 1)
        index = batch[2].transpose(0, 1)
        sentinel = batch[3].transpose(0, 1)
        context_lengths = batch[4]
        target_lengths = batch[5]
        return self.train_batch(context, responses, index, sentinel, new_epoch,
                                context_lengths, target_lengths, clip_grads)

    def train_batch(self, context, responses, index, sentinel, new_epoch,
                    context_lengths, target_lengths, clip_grads):

        # (TODO): remove transpose
        if new_epoch:  # (TODO): Change this part
            self.loss = 0
            self.ploss = 0
            self.vloss = 0
            self.n = 1

        context = context.type(self.TYPE)
        responses = responses.type(self.TYPE)
        index = index.type(self.TYPE)
        sentinel = sentinel.type(self.TYPE)

        self.optim_enc.zero_grad()
        self.optim_dec.zero_grad()
        if self.loss_weighting:
            self.optim_loss_weights.zero_grad()

        h = self.encoder(context.transpose(0, 1))
        self.decoder.load_memory(context.transpose(0, 1))
        y = torch.from_numpy(np.array([2] * context.size(1),
                                      dtype=int)).type(self.TYPE)
        y_len = 0

        h = h.unsqueeze(0)
        output_vocab = torch.zeros(max(target_lengths), context.size(1),
                                   self.nwords)
        output_ptr = torch.zeros(max(target_lengths), context.size(1),
                                 context.size(0))
        if self.use_cuda:
            output_vocab = output_vocab.cuda()
            output_ptr = output_ptr.cuda()
        while y_len < responses.size(0):  # TODO: Add EOS condition
            p_ptr, p_vocab, h = self.decoder(context, y, h)
            output_vocab[y_len] = p_vocab
            output_ptr[y_len] = p_ptr
            #TODO: Add teqacher forcing ratio
            y = responses[y_len].type(self.TYPE)
            y_len += 1

        # print(loss)
        mask_v = torch.ones(output_vocab.size())
        mask_p = torch.ones(output_ptr.size())
        if self.use_cuda:
            mask_p = mask_p.cuda()
            mask_v = mask_v.cuda()
        for i in range(responses.size(1)):
            mask_v[target_lengths[i]:, i, :] = 0
            mask_p[target_lengths[i]:, i, :] = 0

        loss_v = self.cross_entropy(
            output_vocab.contiguous().view(-1, self.nwords),
            responses.contiguous().view(-1))

        loss_ptr = self.cross_entropy(
            output_ptr.contiguous().view(-1, context.size(0)),
            index.contiguous().view(-1))
        if self.loss_weighting:
            loss = loss_ptr/(2*self.loss_weights[0]*self.loss_weights[0]) + loss_v/(2*self.loss_weights[1]*self.loss_weights[1]) + \
               torch.log(self.loss_weights[0] * self.loss_weights[1])
            loss_ptr = loss_ptr / (2 * self.loss_weights[0] *
                                   self.loss_weights[0])
            loss_v = loss_v / (2 * self.loss_weights[1] * self.loss_weights[1])
        else:
            loss = loss_ptr + loss_v
        loss.backward()
        ec = torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), 10.0)
        dc = torch.nn.utils.clip_grad_norm_(self.decoder.parameters(), 10.0)
        self.optim_enc.step()
        self.optim_dec.step()
        if self.loss_weighting:
            self.optim_loss_weights.step()

        self.loss += loss.item()
        self.vloss += loss_v.item()
        self.ploss += loss_ptr.item()

        return loss.item(), loss_v.item(), loss_ptr.item()

    def evaluate_batch(self,
                       batch_size,
                       input_batches,
                       input_lengths,
                       target_batches,
                       target_lengths,
                       target_index,
                       target_gate,
                       src_plain,
                       profile_memory=None):

        # Set to not-training mode to disable dropout
        self.encoder.train(False)
        self.decoder.train(False)
        # Run words through encoder
        decoder_hidden = self.encoder(input_batches.transpose(0,
                                                              1)).unsqueeze(0)
        self.decoder.load_memory(input_batches.transpose(0, 1))

        # Prepare input and output variables
        decoder_input = Variable(torch.LongTensor([2] * batch_size))

        decoded_words = []
        all_decoder_outputs_vocab = Variable(
            torch.zeros(max(target_lengths), batch_size, self.nwords))
        all_decoder_outputs_ptr = Variable(
            torch.zeros(max(target_lengths), batch_size,
                        input_batches.size(0)))
        # all_decoder_outputs_gate = Variable(torch.zeros(self.max_r, batch_size))
        # Move new Variables to CUDA

        if self.use_cuda:
            all_decoder_outputs_vocab = all_decoder_outputs_vocab.cuda()
            all_decoder_outputs_ptr = all_decoder_outputs_ptr.cuda()
            # all_decoder_outputs_gate = all_decoder_outputs_gate.cuda()
            decoder_input = decoder_input.cuda()

        p = []
        for elm in src_plain:
            elm_temp = [word_triple[0] for word_triple in elm]
            p.append(elm_temp)

        self.from_whichs = []
        acc_gate, acc_ptr, acc_vac = 0.0, 0.0, 0.0
        # Run through decoder one time step at a time
        for t in range(max(target_lengths)):
            decoder_ptr, decoder_vacab, decoder_hidden = self.decoder(
                input_batches, decoder_input, decoder_hidden)
            all_decoder_outputs_vocab[t] = decoder_vacab
            topv, topvi = decoder_vacab.data.topk(1)
            all_decoder_outputs_ptr[t] = decoder_ptr
            topp, toppi = decoder_ptr.data.topk(1)
            top_ptr_i = torch.gather(input_batches[:, :, 0], 0,
                                     Variable(toppi.view(1,
                                                         -1))).transpose(0, 1)
            next_in = [
                top_ptr_i[i].item() if
                (toppi[i].item() < input_lengths[i] - 1) else topvi[i].item()
                for i in range(batch_size)
            ]
            # if next_in in self.kb_entry.keys():
            #     ptr_distr.append([next_in, decoder_vacab.data])

            decoder_input = Variable(
                torch.LongTensor(next_in))  # Chosen word is next input
            if self.use_cuda: decoder_input = decoder_input.cuda()

            temp = []
            from_which = []
            for i in range(batch_size):
                if (toppi[i].item() < len(p[i]) - 1):
                    temp.append(p[i][toppi[i].item()])
                    from_which.append('p')
                else:
                    if target_index[t][i] != toppi[i].item():
                        self.incorrect_sentinel += 1
                    ind = topvi[i].item()
                    if ind == 3:
                        temp.append('<eos>')
                    else:
                        temp.append(self.i2w[ind])
                    from_which.append('v')
            decoded_words.append(temp)
            self.from_whichs.append(from_which)
        self.from_whichs = np.array(self.from_whichs)

        loss_v = self.cross_entropy(
            all_decoder_outputs_vocab.contiguous().view(-1, self.nwords),
            target_batches.contiguous().view(-1))
        loss_ptr = self.cross_entropy(
            all_decoder_outputs_ptr.contiguous().view(-1,
                                                      input_batches.size(0)),
            target_index.contiguous().view(-1))

        if self.loss_weighting:
            loss = loss_ptr/(2*self.loss_weights[0]*self.loss_weights[0]) + loss_v/(2*self.loss_weights[1]*self.loss_weights[1]) + \
               torch.log(self.loss_weights[0] * self.loss_weights[1])
        else:
            loss = loss_ptr + loss_v

        self.loss += loss.item()
        self.vloss += loss_v.item()
        self.ploss += loss_ptr.item()
        self.n += 1

        # Set back to training mode
        self.encoder.train(True)
        self.decoder.train(True)
        return decoded_words, self.from_whichs  # , acc_ptr, acc_vac

    def save_models(self, path):
        torch.save(self.encoder.state_dict(),
                   os.path.join(path, 'encoder.pth'))
        torch.save(self.decoder.state_dict(),
                   os.path.join(path, 'decoder.pth'))

    def load_models(self, path: str = '.'):
        self.encoder.load_state_dict(
            torch.load(os.path.join(path, 'encoder.pth')))
        self.decoder.load_state_dict(
            torch.load(os.path.join(path, 'decoder.pth')))
Example #5
0
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()

        self.config = config

        # 情感嵌入层
        self.affect_embedding = AffectEmbedding(config.num_vocab,
                                                config.affect_embedding_size,
                                                config.pad_id)

        # 定义嵌入层
        self.embedding = WordEmbedding(
            config.num_vocab,  # 词汇表大小
            config.embedding_size,  # 嵌入层维度
            config.pad_id)  # pad_id

        # post编码器
        self.post_encoder = Encoder(
            config.post_encoder_cell_type,  # rnn类型
            config.embedding_size,  # 输入维度
            config.post_encoder_output_size,  # 输出维度
            config.post_encoder_num_layers,  # rnn层数
            config.post_encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        # response编码器
        self.response_encoder = Encoder(
            config.response_encoder_cell_type,
            config.embedding_size,  # 输入维度
            config.response_encoder_output_size,  # 输出维度
            config.response_encoder_num_layers,  # rnn层数
            config.response_encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        # 先验网络
        self.prior_net = PriorNet(
            config.post_encoder_output_size,  # post输入维度
            config.latent_size,  # 潜变量维度
            config.dims_prior)  # 隐藏层维度

        # 识别网络
        self.recognize_net = RecognizeNet(
            config.post_encoder_output_size,  # post输入维度
            config.response_encoder_output_size,  # response输入维度
            config.latent_size,  # 潜变量维度
            config.dims_recognize)  # 隐藏层维度

        # 初始化解码器状态
        self.prepare_state = PrepareState(
            config.post_encoder_output_size + config.latent_size,
            config.decoder_cell_type, config.decoder_output_size,
            config.decoder_num_layers)

        # 解码器
        self.decoder = Decoder(
            config.decoder_cell_type,  # rnn类型
            config.embedding_size,  # 输入维度
            config.decoder_output_size,  # 输出维度
            config.decoder_num_layers,  # rnn层数
            config.dropout)  # dropout概率

        # 输出层
        self.projector = nn.Sequential(
            nn.Linear(config.decoder_output_size, config.num_vocab),
            nn.Softmax(-1))

    def forward(
            self,
            input,
            inference=False,  # 是否测试
            use_true=False,
            max_len=60):  # 解码的最大长度

        if not inference:  # 训练
            if use_true:
                id_posts = input['posts']  # [batch, seq]
                len_posts = input['len_posts']  # [batch]
                id_responses = input['responses']  # [batch, seq]
                len_responses = input['len_responses']  # [batch, seq]
                sampled_latents = input[
                    'sampled_latents']  # [batch, latent_size]

                embed_posts = self.embedding(
                    id_posts)  # [batch, seq, embed_size]
                embed_responses = self.embedding(
                    id_responses)  # [batch, seq, embed_size]

                # 解码器的输入为回复去掉end_id
                decoder_input = embed_responses[:, :-1, :].transpose(
                    0, 1)  # [seq-1, batch, embed_size]
                len_decoder = decoder_input.size()[0]  # 解码长度 seq-1
                decoder_input = decoder_input.split(
                    [1] * len_decoder,
                    0)  # 解码器每一步的输入 seq-1个[1, batch, embed_size]

                # state = [layers, batch, dim]
                _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                                   len_posts)
                _, state_responses = self.response_encoder(
                    embed_responses.transpose(0, 1), len_responses)
                if isinstance(state_posts, tuple):
                    state_posts = state_posts[0]
                if isinstance(state_responses, tuple):
                    state_responses = state_responses[0]
                x = state_posts[-1, :, :]  # [batch, dim]
                y = state_responses[-1, :, :]  # [batch, dim]

                _mu, _logvar = self.prior_net(x)  # [batch, latent]
                mu, logvar = self.recognize_net(x, y)  # [batch, latent]
                z = mu + (0.5 *
                          logvar).exp() * sampled_latents  # [batch, latent]

                first_state = self.prepare_state(torch.cat(
                    [z, x], 1))  # [num_layer, batch, dim_out]
                outputs = []

                for idx in range(len_decoder):
                    if idx == 0:
                        state = first_state  # 解码器初始状态
                    input = decoder_input[
                        idx]  # 当前时间步输入 [1, batch, embed_size]
                    # output: [1, batch, dim_out]
                    # state: [num_layer, batch, dim_out]
                    output, state = self.decoder(input, state)
                    outputs.append(output)

                outputs = torch.cat(outputs,
                                    0).transpose(0,
                                                 1)  # [batch, seq-1, dim_out]
                output_vocab = self.projector(
                    outputs)  # [batch, seq-1, num_vocab]

                return output_vocab, _mu, _logvar, mu, logvar

            else:
                id_posts = input['posts']  # [batch, seq]
                len_posts = input['len_posts']  # [batch]
                id_responses = input['responses']  # [batch, seq]
                len_responses = input['len_responses']  # [batch]
                sampled_latents = input[
                    'sampled_latents']  # [batch, latent_size]
                len_decoder = id_responses.size()[1] - 1
                batch_size = id_posts.size()[0]
                device = id_posts.device.type

                embed_posts = self.embedding(
                    id_posts)  # [batch, seq, embed_size]
                embed_responses = self.embedding(
                    id_responses)  # [batch, seq, embed_size]

                # state = [layers, batch, dim]
                _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                                   len_posts)
                _, state_responses = self.response_encoder(
                    embed_responses.transpose(0, 1), len_responses)
                if isinstance(state_posts, tuple):
                    state_posts = state_posts[0]
                if isinstance(state_responses, tuple):
                    state_responses = state_responses[0]
                x = state_posts[-1, :, :]  # [batch, dim]
                y = state_responses[-1, :, :]  # [batch, dim]

                _mu, _logvar = self.prior_net(x)  # [batch, latent]
                mu, logvar = self.recognize_net(x, y)  # [batch, latent]
                z = mu + (0.5 *
                          logvar).exp() * sampled_latents  # [batch, latent]

                first_state = self.prepare_state(torch.cat(
                    [z, x], 1))  # [num_layer, batch, dim_out]
                first_input_id = (torch.ones(
                    (1, batch_size)) * self.config.start_id).long()
                if device == 'cuda':
                    first_input_id = first_input_id.cuda()
                outputs = []

                for idx in range(len_decoder):
                    if idx == 0:
                        state = first_state
                        input = self.embedding(first_input_id)
                    else:
                        input = self.embedding(
                            next_input_id)  # 当前时间步输入 [1, batch, embed_size]
                    output, state = self.decoder(input, state)
                    outputs.append(output)

                    vocab_prob = self.projector(
                        output)  # [1, batch, num_vocab]
                    next_input_id = torch.argmax(
                        vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]

                outputs = torch.cat(outputs,
                                    0).transpose(0,
                                                 1)  # [batch, seq-1, dim_out]
                output_vocab = self.projector(
                    outputs)  # [batch, seq-1, num_vocab]

                return output_vocab, _mu, _logvar, mu, logvar

        else:  # 测试

            id_posts = input['posts']  # [batch, seq]
            len_posts = input['len_posts']  # [batch]
            sampled_latents = input['sampled_latents']  # [batch, latent_size]
            batch_size = id_posts.size()[0]
            device = id_posts.device.type

            embed_posts = self.embedding(id_posts)  # [batch, seq, embed_size]

            # state = [layers, batch, dim]
            _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                               len_posts)
            if isinstance(state_posts, tuple):  # 如果是lstm则取h
                state_posts = state_posts[0]  # [layers, batch, dim]
            x = state_posts[-1, :, :]  # 取最后一层 [batch, dim]

            _mu, _logvar = self.prior_net(x)  # [batch, latent]
            z = _mu + (0.5 *
                       _logvar).exp() * sampled_latents  # [batch, latent]

            first_state = self.prepare_state(torch.cat(
                [z, x], 1))  # [num_layer, batch, dim_out]
            outputs = []

            done = torch.BoolTensor([0] * batch_size)
            first_input_id = (torch.ones(
                (1, batch_size)) * self.config.start_id).long()
            if device == 'cuda':
                done = done.cuda()
                first_input_id = first_input_id.cuda()

            for idx in range(max_len):
                if idx == 0:  # 第一个时间步
                    state = first_state  # 解码器初始状态
                    input = self.embedding(
                        first_input_id)  # 解码器初始输入 [1, batch, embed_size]

                # output: [1, batch, dim_out]
                # state: [num_layers, batch, dim_out]
                output, state = self.decoder(input, state)
                outputs.append(output)

                vocab_prob = self.projector(output)  # [1, batch, num_vocab]
                next_input_id = torch.argmax(
                    vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]
                _done = next_input_id.squeeze(
                    0) == self.config.end_id  # 当前时间步完成解码的 [batch]
                done = done | _done  # 所有完成解码的
                if done.sum() == batch_size:  # 如果全部解码完成则提前停止
                    break
                else:
                    input = self.embedding(
                        next_input_id)  # [1, batch, embed_size]

            outputs = torch.cat(outputs,
                                0).transpose(0, 1)  # [batch, seq, dim_out]
            output_vocab = self.projector(outputs)  # [batch, seq, num_vocab]

            return output_vocab, _mu, _logvar, None, None

    # 统计参数
    def print_parameters(self):
        def statistic_param(params):
            total_num = 0  # 参数总数
            for param in params:
                num = 1
                if param.requires_grad:
                    size = param.size()
                    for dim in size:
                        num *= dim
                total_num += num
            return total_num

        print("嵌入层参数个数: %d" % statistic_param(self.embedding.parameters()))
        print("post编码器参数个数: %d" %
              statistic_param(self.post_encoder.parameters()))
        print("response编码器参数个数: %d" %
              statistic_param(self.response_encoder.parameters()))
        print("先验网络参数个数: %d" % statistic_param(self.prior_net.parameters()))
        print("识别网络参数个数: %d" %
              statistic_param(self.recognize_net.parameters()))
        print("解码器初始状态参数个数: %d" %
              statistic_param(self.prepare_state.parameters()))
        print("解码器参数个数: %d" % statistic_param(self.decoder.parameters()))
        print("输出层参数个数: %d" % statistic_param(self.projector.parameters()))
        print("参数总数: %d" % statistic_param(self.parameters()))

    # 保存模型
    def save_model(self, epoch, global_step, path):

        torch.save(
            {
                'affect_embedding': self.affect_embedding.state_dict(),
                'embedding': self.embedding.state_dict(),
                'post_encoder': self.post_encoder.state_dict(),
                'response_encoder': self.response_encoder.state_dict(),
                'prior_net': self.prior_net.state_dict(),
                'recognize_net': self.recognize_net.state_dict(),
                'prepare_state': self.prepare_state.state_dict(),
                'decoder': self.decoder.state_dict(),
                'projector': self.projector.state_dict(),
                'epoch': epoch,
                'global_step': global_step
            }, path)

    # 载入模型
    def load_model(self, path):

        checkpoint = torch.load(path)
        self.affect_embedding.load_state_dict(checkpoint['affect_embedding'])
        self.embedding.load_state_dict(checkpoint['embedding'])
        self.post_encoder.load_state_dict(checkpoint['post_encoder'])
        self.response_encoder.load_state_dict(checkpoint['response_encoder'])
        self.prior_net.load_state_dict(checkpoint['prior_net'])
        self.recognize_net.load_state_dict(checkpoint['recognize_net'])
        self.prepare_state.load_state_dict(checkpoint['prepare_state'])
        self.decoder.load_state_dict(checkpoint['decoder'])
        self.projector.load_state_dict(checkpoint['projector'])
        epoch = checkpoint['epoch']
        global_step = checkpoint['global_step']

        return epoch, global_step
Example #6
0
        print('Epoch: %s' % i)
        gan.train()
        discriminator.train()
        start = time.time()
        for data in train_loader:
            t += 1
            discriminator_optimizer.zero_grad()
            loss = get_disc_loss(data, gan, discriminator, data[0].shape[0],
                                 args.z_size, args.use_penalty)
            discriminator_optimizer.step()
            if args.use_weight_clip:
                discriminator.apply(weight_clip)
            if t == args.n_dis:
                t = 0
                generator_optimizer.zero_grad()
                loss = get_gen_loss(gan, discriminator, data[0].shape[0],
                                    args.z_size)
                generator_optimizer.step()

        end = time.time()
        print('Epoch time: %s' % (end - start))
        gan.eval()
        discriminator.eval()

        data, labels = get_data(test_loader)
        vizualize(data, gan, 0, args.z_size, viz, args.save_path,
                  args.batch_size)
        if i % 10 == 0:
            torch.save(gan.state_dict(),
                       os.path.join(args.model_path, 'model[%s].ph' % i))
Example #7
0
            real_center = torch.FloatTensor(batch_size, 1, opt.crop_point_num,
                                            3)
            input_cropped1 = torch.FloatTensor(batch_size, opt.pnum, 3)
            input_cropped1 = input_cropped1.data.copy_(real_point)

            real_point = torch.unsqueeze(real_point, 1)
            input_cropped1 = torch.unsqueeze(input_cropped1, 1)

            input_cropped1 = torch.squeeze(input_cropped1, 1)

            input_cropped2 = input_cropped2.to(device)
            fake_center1, fake_fine = gen_net(input_cropped1)
            CD_loss = criterion_PointLoss(torch.squeeze(fake_fine, 1),
                                          torch.squeeze(real_center, 1))
            print('test CD loss: %.4f' % (CD_loss))
            f.write('\n' + 'test result:  %.4f' % (CD_loss))
            break
        f.close()
        schedulerD.step()
        schedulerG.step()
        if epoch % 5 == 0:
            torch.save({
                'epoch': epoch + 1,
                'state_dict': gen_net.state_dict()
            }, 'Trained_Model/gen_net' + str(epoch) + '.pth')
            torch.save({
                'epoch': epoch + 1,
                'state_dict': dis_net.state_dict()
            }, 'Trained_Model/dis_net' + str(epoch) + '.pth')

print('done')
Example #8
0
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()

        self.config = config

        # 情感嵌入层
        self.affect_embedding = AffectEmbedding(config.num_vocab,
                                                config.affect_embedding_size,
                                                config.pad_id)

        # 定义嵌入层
        self.embedding = WordEmbedding(
            config.num_vocab,  # 词汇表大小
            config.embedding_size,  # 嵌入层维度
            config.pad_id)  # pad_id

        # 情感编码器
        self.affect_encoder = Encoder(
            config.encoder_decoder_cell_type,  # rnn类型
            config.affect_embedding_size,  # 输入维度
            config.affect_encoder_output_size,  # 输出维度
            config.encoder_decoder_num_layers,  # 层数
            config.encoder_bidirectional,  # 是否双向
            config.dropout)

        # 编码器
        self.encoder = Encoder(
            config.encoder_decoder_cell_type,  # rnn类型
            config.embedding_size,  # 输入维度
            config.encoder_decoder_output_size,  # 输出维度
            config.encoder_decoder_num_layers,  # rnn层数
            config.encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        self.attention = Attention(config.encoder_decoder_output_size,
                                   config.affect_encoder_output_size,
                                   config.attention_type,
                                   config.attention_size)

        self.prepare_state = PrepareState(
            config.encoder_decoder_cell_type,
            config.encoder_decoder_output_size +
            config.affect_encoder_output_size,
            config.encoder_decoder_output_size)

        self.linear_prepare_input = nn.Linear(
            config.embedding_size + config.affect_encoder_output_size +
            config.attention_size, config.decoder_input_size)

        # 解码器
        self.decoder = Decoder(
            config.encoder_decoder_cell_type,  # rnn类型
            config.decoder_input_size,  # 输入维度
            config.encoder_decoder_output_size,  # 输出维度
            config.encoder_decoder_num_layers,  # rnn层数
            config.dropout)  # dropout概率

        # 输出层
        self.projector = nn.Sequential(
            nn.Linear(
                config.encoder_decoder_output_size + config.attention_size,
                config.num_vocab), nn.Softmax(-1))

    def forward(
            self,
            input,
            inference=False,  # 是否测试
            max_len=60):  # 解码的最大长度

        if not inference:  # 训练

            id_posts = input['posts']  # [batch, seq]
            len_posts = input['len_posts']  # [batch]
            id_responses = input['responses']  # [batch, seq]
            batch_size = id_posts.size()[0]
            device = id_posts.device.type

            embed_posts = self.embedding(id_posts)  # [batch, seq, embed_size]
            embed_responses = self.embedding(
                id_responses)  # [batch, seq, embed_size]
            affect_posts = self.affect_embedding(id_posts)

            # 解码器的输入为回复去掉end_id
            decoder_input = embed_responses[:, :-1, :].transpose(
                0, 1)  # [seq-1, batch, embed_size]
            len_decoder = decoder_input.size()[0]  # 解码长度 seq-1
            decoder_input = decoder_input.split(
                [1] * len_decoder, 0)  # 解码器每一步的输入 seq-1个[1, batch, embed_size]

            # state: [layers, batch, dim]
            _, state_encoder = self.encoder(embed_posts.transpose(0, 1),
                                            len_posts)
            # output_affect: [seq, batch, dim]
            output_affect, state_affect_encoder = self.affect_encoder(
                affect_posts.transpose(0, 1), len_posts)
            context_affect = output_affect[-1, :, :].unsqueeze(
                0)  # [1, batch, dim]

            outputs = []
            weights = []

            init_attn = torch.zeros(
                [1, batch_size, self.config.attention_size])
            if device == 'cuda':
                init_attn = init_attn.cuda()

            for idx in range(len_decoder):
                if idx == 0:
                    state = self.prepare_state(state_encoder,
                                               state_affect_encoder)  # 解码器初始状态
                    input = torch.cat(
                        [decoder_input[idx], context_affect, init_attn], 2)  #
                else:
                    input = torch.cat([
                        decoder_input[idx], context_affect,
                        attn.transpose(0, 1)
                    ], 2)  #

                input = self.linear_prepare_input(input)

                # output: [1, batch, dim_out]
                # state: [num_layer, batch, dim_out]
                output, state = self.decoder(input, state)
                # attn: [batch, 1, attention_size]
                # weights: [batch, 1, encoder_len]
                attn, weight = self.attention(output.transpose(0, 1),
                                              output_affect.transpose(0, 1))
                outputs.append(torch.cat([output, attn.transpose(0, 1)], 2))
                weights.append(weight)

            outputs = torch.cat(outputs,
                                0).transpose(0, 1)  # [batch, seq-1, dim_out]
            weights = torch.cat(weights, 1)  # [batch, seq-1, dim_out]

            output_vocab = self.projector(outputs)  # [batch, seq-1, num_vocab]

            return output_vocab, weights

        else:  # 测试

            id_posts = input['posts']  # [batch, seq]
            len_posts = input['len_posts']  # [batch]
            batch_size = id_posts.size()[0]
            device = id_posts.device.type

            embed_posts = self.embedding(id_posts)  # [batch, seq, embed_size]
            affect_posts = self.affect_embedding(id_posts)

            # state: [layers, batch, dim]
            _, state_encoder = self.encoder(embed_posts.transpose(0, 1),
                                            len_posts)
            output_affect, state_affect_encoder = self.affect_encoder(
                affect_posts.transpose(0, 1), len_posts)
            context_affect = output_affect[-1, :, :].unsqueeze(
                0)  # [1, batch, dim]

            outputs = []
            weights = []

            done = torch.BoolTensor([0] * batch_size)
            first_input_id = (torch.ones(
                (1, batch_size)) * self.config.start_id).long()
            init_attn = torch.zeros(
                [1, batch_size, self.config.attention_size])
            if device == 'cuda':
                done = done.cuda()
                first_input_id = first_input_id.cuda()
                init_attn = init_attn.cuda()

            for idx in range(max_len):

                if idx == 0:  # 第一个时间步
                    state = self.prepare_state(state_encoder,
                                               state_affect_encoder)  # 解码器初始状态
                    input = torch.cat([
                        self.embedding(first_input_id), context_affect,
                        init_attn
                    ], 2)  # 解码器初始输入 [1, batch, embed_size]
                else:
                    input = torch.cat(
                        [input, context_affect,
                         attn.transpose(0, 1)], 2)

                input = self.linear_prepare_input(input)

                # output: [1, batch, dim_out]
                # state: [num_layers, batch, dim_out]
                output, state = self.decoder(input, state)
                # attn: [batch, 1, attention_size]
                # weights: [batch, 1, encoder_len]
                attn, weight = self.attention(output.transpose(0, 1),
                                              output_affect.transpose(0, 1))
                outputs.append(torch.cat([output, attn.transpose(0, 1)], 2))
                weights.append(weight)

                vocab_prob = self.projector(
                    torch.cat([output, attn.transpose(0, 1)],
                              2))  # [1, batch, num_vocab]
                next_input_id = torch.argmax(
                    vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]

                _done = next_input_id.squeeze(
                    0) == self.config.end_id  # 当前时间步完成解码的 [batch]
                done = done | _done  # 所有完成解码的

                if done.sum() == batch_size:  # 如果全部解码完成则提前停止
                    break
                else:
                    input = self.embedding(
                        next_input_id)  # [1, batch, embed_size]

            outputs = torch.cat(outputs,
                                0).transpose(0, 1)  # [batch, seq, dim_out]
            weights = torch.cat(weights, 1)
            output_vocab = self.projector(outputs)  # [batch, seq, num_vocab]

            return output_vocab, weights

    # 统计参数
    def print_parameters(self):
        def statistic_param(params):
            total_num = 0  # 参数总数
            for param in params:
                num = 1
                if param.requires_grad:
                    size = param.size()
                    for dim in size:
                        num *= dim
                total_num += num
            return total_num

        print("嵌入层参数个数: %d" % statistic_param(self.embedding.parameters()))
        print("编码器参数个数: %d" % statistic_param(self.encoder.parameters()))
        print("情感编码器参数个数: %d" %
              statistic_param(self.affect_encoder.parameters()))
        print("准备状态参数个数: %d" %
              statistic_param(self.prepare_state.parameters()))
        print("准备输入参数个数: %d" %
              statistic_param(self.linear_prepare_input.parameters()))
        print("注意力参数个数: %d" % statistic_param(self.attention.parameters()))
        print("解码器参数个数: %d" % statistic_param(self.decoder.parameters()))
        print("输出层参数个数: %d" % statistic_param(self.projector.parameters()))
        print("参数总数: %d" % statistic_param(self.parameters()))

    # 保存模型
    def save_model(self, epoch, global_step, path):

        torch.save(
            {
                'affect_embedding': self.affect_embedding.state_dict(),
                'embedding': self.embedding.state_dict(),
                'encoder': self.encoder.state_dict(),
                'affect_encoder': self.affect_encoder.state_dict(),
                'prepare_state': self.prepare_state.state_dict(),
                'linear_prepare_input': self.linear_prepare_input.state_dict(),
                'attention': self.attention.state_dict(),
                'projector': self.projector.state_dict(),
                'decoder': self.decoder.state_dict(),
                'epoch': epoch,
                'global_step': global_step
            }, path)

    # 载入模型
    def load_model(self, path):

        checkpoint = torch.load(path)
        self.affect_embedding.load_state_dict(checkpoint['embedding'])
        self.embedding.load_state_dict(checkpoint['embedding'])
        self.encoder.load_state_dict(checkpoint['encoder'])
        self.affect_encoder.load_state_dict(checkpoint['affect_encoder'])
        self.prepare_state.load_state_dict(checkpoint['prepare_state'])
        self.linear_prepare_input.load_state_dict(
            checkpoint['linear_prepare_input'])
        self.attention.load_state_dict(checkpoint['attention'])
        self.decoder.load_state_dict(checkpoint['decoder'])
        self.projector.load_state_dict(checkpoint['projector'])
        epoch = checkpoint['epoch']
        global_step = checkpoint['global_step']

        return epoch, global_step