예제 #1
0
    def __init__(self, config):
        super(Model, self).__init__()
        self.config = config

        # 定义嵌入层
        self.embedding = Embedding(
            config.num_vocab,  # 词汇表大小
            config.embedding_size,  # 嵌入层维度
            config.pad_id,  # pad_id
            config.dropout)

        # post编码器
        self.post_encoder = Encoder(
            config.post_encoder_cell_type,  # rnn类型
            config.embedding_size,  # 输入维度
            config.post_encoder_output_size,  # 输出维度
            config.post_encoder_num_layers,  # rnn层数
            config.post_encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        # response编码器
        self.response_encoder = Encoder(
            config.response_encoder_cell_type,
            config.embedding_size,  # 输入维度
            config.response_encoder_output_size,  # 输出维度
            config.response_encoder_num_layers,  # rnn层数
            config.response_encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        # 先验网络
        self.prior_net = PriorNet(
            config.post_encoder_output_size,  # post输入维度
            config.latent_size,  # 潜变量维度
            config.dims_prior)  # 隐藏层维度

        # 识别网络
        self.recognize_net = RecognizeNet(
            config.post_encoder_output_size,  # post输入维度
            config.response_encoder_output_size,  # response输入维度
            config.latent_size,  # 潜变量维度
            config.dims_recognize)  # 隐藏层维度

        # 初始化解码器状态
        self.prepare_state = PrepareState(
            config.post_encoder_output_size + config.latent_size,
            config.decoder_cell_type, config.decoder_output_size,
            config.decoder_num_layers)

        # 解码器
        self.decoder = Decoder(
            config.decoder_cell_type,  # rnn类型
            config.embedding_size,  # 输入维度
            config.decoder_output_size,  # 输出维度
            config.decoder_num_layers,  # rnn层数
            config.dropout)  # dropout概率

        # 输出层
        self.projector = nn.Sequential(
            nn.Linear(config.decoder_output_size, config.num_vocab),
            nn.Softmax(-1))
예제 #2
0
    def __init__(self, config):
        super(Model, self).__init__()
        self.config = config

        self.embedding = Embedding(config.num_vocab,
                                   config.embedding_size,
                                   config.pad_id,
                                   config.dropout)

        self.affect_embedding = Embedding(config.num_vocab,
                                          config.affect_embedding_size,
                                          config.pad_id,
                                          config.dropout)
        self.affect_embedding.embedding.weight.requires_grad = False

        self.post_encoder = Encoder(config.encoder_cell_type,
                                    config.embedding_size + config.affect_embedding_size,
                                    config.encoder_output_size,
                                    config.encoder_num_layers,
                                    config.encoder_bidirectional,
                                    config.dropout)

        self.response_encoder = Encoder(config.encoder_cell_type,
                                        config.embedding_size + config.affect_embedding_size,
                                        config.encoder_output_size,
                                        config.encoder_num_layers,
                                        config.encoder_bidirectional,
                                        config.dropout)

        self.prior_net = PriorNet(config.encoder_output_size,
                                  config.latent_size,
                                  config.dims_prior)

        self.recognize_net = RecognizeNet(config.encoder_output_size,
                                          config.encoder_output_size,
                                          config.latent_size,
                                          config.dims_recognize)

        self.prepare_state = PrepareState(config.encoder_output_size + config.latent_size,
                                          config.decoder_cell_type,
                                          config.decoder_output_size,
                                          config.decoder_num_layers)

        self.decoder = Decoder(config.decoder_cell_type,
                               config.embedding_size + config.affect_embedding_size + config.encoder_output_size,
                               config.decoder_output_size,
                               config.decoder_num_layers,
                               config.dropout)

        self.projector = nn.Sequential(nn.Linear(config.decoder_output_size, config.num_vocab), nn.Softmax(-1))
예제 #3
0
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.config = config

        # 定义嵌入层
        self.embedding = Embedding(
            config.num_vocab,  # 词汇表大小
            config.embedding_size,  # 嵌入层维度
            config.pad_id,  # pad_id
            config.dropout)

        # 情感嵌入层
        self.affect_embedding = Embedding(config.num_vocab,
                                          config.affect_embedding_size,
                                          config.pad_id, config.dropout)
        self.affect_embedding.embedding.weight.requires_grad = False

        # post编码器
        self.post_encoder = Encoder(
            config.post_encoder_cell_type,  # rnn类型
            config.embedding_size + config.affect_embedding_size,  # 输入维度
            config.post_encoder_output_size,  # 输出维度
            config.post_encoder_num_layers,  # rnn层数
            config.post_encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        # response编码器
        self.response_encoder = Encoder(
            config.response_encoder_cell_type,
            config.embedding_size + config.affect_embedding_size,  # 输入维度
            config.response_encoder_output_size,  # 输出维度
            config.response_encoder_num_layers,  # rnn层数
            config.response_encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        # 先验网络
        self.prior_net = PriorNet(
            config.post_encoder_output_size,  # post输入维度
            config.latent_size,  # 潜变量维度
            config.dims_prior)  # 隐藏层维度

        # 识别网络
        self.recognize_net = RecognizeNet(
            config.post_encoder_output_size,  # post输入维度
            config.response_encoder_output_size,  # response输入维度
            config.latent_size,  # 潜变量维度
            config.dims_recognize)  # 隐藏层维度

        # 初始化解码器状态
        self.prepare_state = PrepareState(
            config.post_encoder_output_size + config.latent_size,
            config.decoder_cell_type, config.decoder_output_size,
            config.decoder_num_layers)

        # 解码器
        self.decoder = Decoder(
            config.decoder_cell_type,  # rnn类型
            config.embedding_size + config.affect_embedding_size +
            config.post_encoder_output_size,
            config.decoder_output_size,  # 输出维度
            config.decoder_num_layers,  # rnn层数
            config.dropout)  # dropout概率

        # bow预测
        self.bow_predictor = nn.Sequential(
            nn.Linear(config.post_encoder_output_size + config.latent_size,
                      config.num_vocab), nn.Softmax(-1))

        # 输出层
        self.projector = nn.Sequential(
            nn.Linear(config.decoder_output_size, config.num_vocab),
            nn.Softmax(-1))

    def forward(self,
                inputs,
                inference=False,
                use_true=False,
                max_len=60,
                gpu=True):
        if not inference:  # 训练
            if use_true:  # 解码时使用真实值
                id_posts = inputs['posts']  # [batch, seq]
                len_posts = inputs['len_posts']  # [batch]
                id_responses = inputs['responses']  # [batch, seq]
                len_responses = inputs['len_responses']  # [batch, seq]
                sampled_latents = inputs[
                    'sampled_latents']  # [batch, latent_size]
                len_decoder = id_responses.size(1) - 1

                embed_posts = torch.cat([
                    self.embedding(id_posts),
                    self.affect_embedding(id_posts)
                ], 2)
                embed_responses = torch.cat([
                    self.embedding(id_responses),
                    self.affect_embedding(id_responses)
                ], 2)

                # state: [layers, batch, dim]
                _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                                   len_posts)
                _, state_responses = self.response_encoder(
                    embed_responses.transpose(0, 1), len_responses)
                if isinstance(state_posts, tuple):
                    state_posts = state_posts[0]
                if isinstance(state_responses, tuple):
                    state_responses = state_responses[0]
                x = state_posts[-1, :, :]  # [batch, dim]
                y = state_responses[-1, :, :]  # [batch, dim]

                _mu, _logvar = self.prior_net(x)  # [batch, latent]
                mu, logvar = self.recognize_net(x, y)  # [batch, latent]
                z = mu + (0.5 *
                          logvar).exp() * sampled_latents  # [batch, latent]

                bow_predict = self.bow_predictor(torch.cat(
                    [z, x], 1))  # [batch, num_vocab]

                first_state = self.prepare_state(torch.cat(
                    [z, x], 1))  # [num_layer, batch, dim_out]
                decoder_inputs = embed_responses[:, :-1, :].transpose(
                    0, 1)  # [seq-1, batch, embed_size]
                decoder_inputs = decoder_inputs.split(
                    [1] * len_decoder, 0)  # seq-1个[1, batch, embed_size]

                outputs = []
                for idx in range(len_decoder):
                    if idx == 0:
                        state = first_state  # 解码器初始状态
                    decoder_input = torch.cat(
                        [decoder_inputs[idx],
                         x.unsqueeze(0)], 2)
                    # output: [1, batch, dim_out]
                    # state: [num_layer, batch, dim_out]
                    output, state = self.decoder(decoder_input, state)
                    outputs.append(output)

                outputs = torch.cat(outputs,
                                    0).transpose(0,
                                                 1)  # [batch, seq-1, dim_out]
                output_vocab = self.projector(
                    outputs)  # [batch, seq-1, num_vocab]

                return output_vocab, bow_predict, _mu, _logvar, mu, logvar
            else:
                id_posts = inputs['posts']  # [batch, seq]
                len_posts = inputs['len_posts']  # [batch]
                id_responses = inputs['responses']  # [batch, seq]
                len_responses = inputs['len_responses']  # [batch]
                sampled_latents = inputs[
                    'sampled_latents']  # [batch, latent_size]
                len_decoder = id_responses.size(1) - 1
                batch_size = id_posts.size(0)

                embed_posts = torch.cat([
                    self.embedding(id_posts),
                    self.affect_embedding(id_posts)
                ], 2)
                embed_responses = torch.cat([
                    self.embedding(id_responses),
                    self.affect_embedding(id_responses)
                ], 2)

                # state: [layers, batch, dim]
                _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                                   len_posts)
                _, state_responses = self.response_encoder(
                    embed_responses.transpose(0, 1), len_responses)
                if isinstance(state_posts, tuple):
                    state_posts = state_posts[0]
                if isinstance(state_responses, tuple):
                    state_responses = state_responses[0]
                x = state_posts[-1, :, :]  # [batch, dim]
                y = state_responses[-1, :, :]  # [batch, dim]

                _mu, _logvar = self.prior_net(x)  # [batch, latent]
                mu, logvar = self.recognize_net(x, y)  # [batch, latent]
                z = mu + (0.5 *
                          logvar).exp() * sampled_latents  # [batch, latent]

                bow_predict = self.bow_predictor(torch.cat(
                    [z, x], 1))  # [batch, num_vocab]

                first_state = self.prepare_state(torch.cat(
                    [z, x], 1))  # [num_layer, batch, dim_out]
                first_input_id = (torch.ones(
                    (1, batch_size)) * self.config.start_id).long()
                if gpu:
                    first_input_id = first_input_id.cuda()

                outputs = []
                for idx in range(len_decoder):
                    if idx == 0:
                        state = first_state
                        decoder_input = torch.cat([
                            self.embedding(first_input_id),
                            self.affect_embedding(first_input_id),
                            x.unsqueeze(0)
                        ], 2)
                    else:
                        decoder_input = torch.cat([
                            self.embedding(next_input_id),
                            self.affect_embedding(next_input_id),
                            x.unsqueeze(0)
                        ], 2)
                    output, state = self.decoder(decoder_input, state)
                    outputs.append(output)

                    vocab_prob = self.projector(
                        output)  # [1, batch, num_vocab]
                    next_input_id = torch.argmax(
                        vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]

                outputs = torch.cat(outputs,
                                    0).transpose(0,
                                                 1)  # [batch, seq-1, dim_out]
                output_vocab = self.projector(
                    outputs)  # [batch, seq-1, num_vocab]
                return output_vocab, bow_predict, _mu, _logvar, mu, logvar
        else:  # 测试
            id_posts = inputs['posts']  # [batch, seq]
            len_posts = inputs['len_posts']  # [batch]
            sampled_latents = inputs['sampled_latents']  # [batch, latent_size]
            batch_size = id_posts.size(0)

            embed_posts = torch.cat(
                [self.embedding(id_posts),
                 self.affect_embedding(id_posts)], 2)

            # state: [layers, batch, dim]
            _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                               len_posts)
            if isinstance(state_posts, tuple):  # 如果是lstm则取h
                state_posts = state_posts[0]  # [layers, batch, dim]
            x = state_posts[-1, :, :]  # 取最后一层 [batch, dim]

            _mu, _logvar = self.prior_net(x)  # [batch, latent]
            z = _mu + (0.5 *
                       _logvar).exp() * sampled_latents  # [batch, latent]

            first_state = self.prepare_state(torch.cat(
                [z, x], 1))  # [num_layer, batch, dim_out]
            done = torch.tensor([0] * batch_size).bool()
            first_input_id = (torch.ones(
                (1, batch_size)) * self.config.start_id).long()
            if gpu:
                done = done.cuda()
                first_input_id = first_input_id.cuda()

            outputs = []
            for idx in range(max_len):
                if idx == 0:  # 第一个时间步
                    state = first_state  # 解码器初始状态
                    decoder_input = torch.cat([
                        self.embedding(first_input_id),
                        self.affect_embedding(first_input_id),
                        x.unsqueeze(0)
                    ], 2)
                else:
                    decoder_input = torch.cat([
                        self.embedding(next_input_id),
                        self.affect_embedding(next_input_id),
                        x.unsqueeze(0)
                    ], 2)
                # output: [1, batch, dim_out]
                # state: [num_layers, batch, dim_out]
                output, state = self.decoder(decoder_input, state)
                outputs.append(output)

                vocab_prob = self.projector(output)  # [1, batch, num_vocab]
                next_input_id = torch.argmax(
                    vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]
                _done = next_input_id.squeeze(
                    0) == self.config.end_id  # 当前时间步完成解码的 [batch]
                done = done | _done  # 所有完成解码的
                if done.sum() == batch_size:  # 如果全部解码完成则提前停止
                    break

            outputs = torch.cat(outputs,
                                0).transpose(0, 1)  # [batch, seq, dim_out]
            output_vocab = self.projector(outputs)  # [batch, seq, num_vocab]

            return output_vocab, _, _mu, _logvar, None, None

    def print_parameters(self):
        r""" 统计参数 """
        total_num = 0  # 参数总数
        for param in self.parameters():
            num = 1
            if param.requires_grad:
                size = param.size()
                for dim in size:
                    num *= dim
            total_num += num
        print(f"参数总数: {total_num}")

    def save_model(self, epoch, global_step, path):
        r""" 保存模型 """
        torch.save(
            {
                'affect_embedding': self.affect_embedding.state_dict(),
                'embedding': self.embedding.state_dict(),
                'post_encoder': self.post_encoder.state_dict(),
                'response_encoder': self.response_encoder.state_dict(),
                'prior_net': self.prior_net.state_dict(),
                'recognize_net': self.recognize_net.state_dict(),
                'prepare_state': self.prepare_state.state_dict(),
                'decoder': self.decoder.state_dict(),
                'projector': self.projector.state_dict(),
                'bow_predictor': self.bow_predictor.state_dict(),
                'epoch': epoch,
                'global_step': global_step
            }, path)

    def load_model(self, path):
        r""" 载入模型 """
        checkpoint = torch.load(path)
        self.affect_embedding.load_state_dict(checkpoint['affect_embedding'])
        self.embedding.load_state_dict(checkpoint['embedding'])
        self.post_encoder.load_state_dict(checkpoint['post_encoder'])
        self.response_encoder.load_state_dict(checkpoint['response_encoder'])
        self.prior_net.load_state_dict(checkpoint['prior_net'])
        self.recognize_net.load_state_dict(checkpoint['recognize_net'])
        self.prepare_state.load_state_dict(checkpoint['prepare_state'])
        self.decoder.load_state_dict(checkpoint['decoder'])
        self.projector.load_state_dict(checkpoint['projector'])
        self.bow_predictor.load_state_dict(checkpoint['bow_predictor'])
        epoch = checkpoint['epoch']
        global_step = checkpoint['global_step']
        return epoch, global_step
예제 #4
0
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()

        self.config = config

        # 情感嵌入层
        self.affect_embedding = AffectEmbedding(config.num_vocab,
                                                config.affect_embedding_size,
                                                config.pad_id)

        # 定义嵌入层
        self.embedding = WordEmbedding(
            config.num_vocab,  # 词汇表大小
            config.embedding_size,  # 嵌入层维度
            config.pad_id)  # pad_id

        # post编码器
        self.post_encoder = Encoder(
            config.post_encoder_cell_type,  # rnn类型
            config.embedding_size,  # 输入维度
            config.post_encoder_output_size,  # 输出维度
            config.post_encoder_num_layers,  # rnn层数
            config.post_encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        # response编码器
        self.response_encoder = Encoder(
            config.response_encoder_cell_type,
            config.embedding_size,  # 输入维度
            config.response_encoder_output_size,  # 输出维度
            config.response_encoder_num_layers,  # rnn层数
            config.response_encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        # 先验网络
        self.prior_net = PriorNet(
            config.post_encoder_output_size,  # post输入维度
            config.latent_size,  # 潜变量维度
            config.dims_prior)  # 隐藏层维度

        # 识别网络
        self.recognize_net = RecognizeNet(
            config.post_encoder_output_size,  # post输入维度
            config.response_encoder_output_size,  # response输入维度
            config.latent_size,  # 潜变量维度
            config.dims_recognize)  # 隐藏层维度

        # 初始化解码器状态
        self.prepare_state = PrepareState(
            config.post_encoder_output_size + config.latent_size,
            config.decoder_cell_type, config.decoder_output_size,
            config.decoder_num_layers)

        # 解码器
        self.decoder = Decoder(
            config.decoder_cell_type,  # rnn类型
            config.embedding_size,  # 输入维度
            config.decoder_output_size,  # 输出维度
            config.decoder_num_layers,  # rnn层数
            config.dropout)  # dropout概率

        # 输出层
        self.projector = nn.Sequential(
            nn.Linear(config.decoder_output_size, config.num_vocab),
            nn.Softmax(-1))

    def forward(
            self,
            input,
            inference=False,  # 是否测试
            use_true=False,
            max_len=60):  # 解码的最大长度

        if not inference:  # 训练
            if use_true:
                id_posts = input['posts']  # [batch, seq]
                len_posts = input['len_posts']  # [batch]
                id_responses = input['responses']  # [batch, seq]
                len_responses = input['len_responses']  # [batch, seq]
                sampled_latents = input[
                    'sampled_latents']  # [batch, latent_size]

                embed_posts = self.embedding(
                    id_posts)  # [batch, seq, embed_size]
                embed_responses = self.embedding(
                    id_responses)  # [batch, seq, embed_size]

                # 解码器的输入为回复去掉end_id
                decoder_input = embed_responses[:, :-1, :].transpose(
                    0, 1)  # [seq-1, batch, embed_size]
                len_decoder = decoder_input.size()[0]  # 解码长度 seq-1
                decoder_input = decoder_input.split(
                    [1] * len_decoder,
                    0)  # 解码器每一步的输入 seq-1个[1, batch, embed_size]

                # state = [layers, batch, dim]
                _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                                   len_posts)
                _, state_responses = self.response_encoder(
                    embed_responses.transpose(0, 1), len_responses)
                if isinstance(state_posts, tuple):
                    state_posts = state_posts[0]
                if isinstance(state_responses, tuple):
                    state_responses = state_responses[0]
                x = state_posts[-1, :, :]  # [batch, dim]
                y = state_responses[-1, :, :]  # [batch, dim]

                _mu, _logvar = self.prior_net(x)  # [batch, latent]
                mu, logvar = self.recognize_net(x, y)  # [batch, latent]
                z = mu + (0.5 *
                          logvar).exp() * sampled_latents  # [batch, latent]

                first_state = self.prepare_state(torch.cat(
                    [z, x], 1))  # [num_layer, batch, dim_out]
                outputs = []

                for idx in range(len_decoder):
                    if idx == 0:
                        state = first_state  # 解码器初始状态
                    input = decoder_input[
                        idx]  # 当前时间步输入 [1, batch, embed_size]
                    # output: [1, batch, dim_out]
                    # state: [num_layer, batch, dim_out]
                    output, state = self.decoder(input, state)
                    outputs.append(output)

                outputs = torch.cat(outputs,
                                    0).transpose(0,
                                                 1)  # [batch, seq-1, dim_out]
                output_vocab = self.projector(
                    outputs)  # [batch, seq-1, num_vocab]

                return output_vocab, _mu, _logvar, mu, logvar

            else:
                id_posts = input['posts']  # [batch, seq]
                len_posts = input['len_posts']  # [batch]
                id_responses = input['responses']  # [batch, seq]
                len_responses = input['len_responses']  # [batch]
                sampled_latents = input[
                    'sampled_latents']  # [batch, latent_size]
                len_decoder = id_responses.size()[1] - 1
                batch_size = id_posts.size()[0]
                device = id_posts.device.type

                embed_posts = self.embedding(
                    id_posts)  # [batch, seq, embed_size]
                embed_responses = self.embedding(
                    id_responses)  # [batch, seq, embed_size]

                # state = [layers, batch, dim]
                _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                                   len_posts)
                _, state_responses = self.response_encoder(
                    embed_responses.transpose(0, 1), len_responses)
                if isinstance(state_posts, tuple):
                    state_posts = state_posts[0]
                if isinstance(state_responses, tuple):
                    state_responses = state_responses[0]
                x = state_posts[-1, :, :]  # [batch, dim]
                y = state_responses[-1, :, :]  # [batch, dim]

                _mu, _logvar = self.prior_net(x)  # [batch, latent]
                mu, logvar = self.recognize_net(x, y)  # [batch, latent]
                z = mu + (0.5 *
                          logvar).exp() * sampled_latents  # [batch, latent]

                first_state = self.prepare_state(torch.cat(
                    [z, x], 1))  # [num_layer, batch, dim_out]
                first_input_id = (torch.ones(
                    (1, batch_size)) * self.config.start_id).long()
                if device == 'cuda':
                    first_input_id = first_input_id.cuda()
                outputs = []

                for idx in range(len_decoder):
                    if idx == 0:
                        state = first_state
                        input = self.embedding(first_input_id)
                    else:
                        input = self.embedding(
                            next_input_id)  # 当前时间步输入 [1, batch, embed_size]
                    output, state = self.decoder(input, state)
                    outputs.append(output)

                    vocab_prob = self.projector(
                        output)  # [1, batch, num_vocab]
                    next_input_id = torch.argmax(
                        vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]

                outputs = torch.cat(outputs,
                                    0).transpose(0,
                                                 1)  # [batch, seq-1, dim_out]
                output_vocab = self.projector(
                    outputs)  # [batch, seq-1, num_vocab]

                return output_vocab, _mu, _logvar, mu, logvar

        else:  # 测试

            id_posts = input['posts']  # [batch, seq]
            len_posts = input['len_posts']  # [batch]
            sampled_latents = input['sampled_latents']  # [batch, latent_size]
            batch_size = id_posts.size()[0]
            device = id_posts.device.type

            embed_posts = self.embedding(id_posts)  # [batch, seq, embed_size]

            # state = [layers, batch, dim]
            _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                               len_posts)
            if isinstance(state_posts, tuple):  # 如果是lstm则取h
                state_posts = state_posts[0]  # [layers, batch, dim]
            x = state_posts[-1, :, :]  # 取最后一层 [batch, dim]

            _mu, _logvar = self.prior_net(x)  # [batch, latent]
            z = _mu + (0.5 *
                       _logvar).exp() * sampled_latents  # [batch, latent]

            first_state = self.prepare_state(torch.cat(
                [z, x], 1))  # [num_layer, batch, dim_out]
            outputs = []

            done = torch.BoolTensor([0] * batch_size)
            first_input_id = (torch.ones(
                (1, batch_size)) * self.config.start_id).long()
            if device == 'cuda':
                done = done.cuda()
                first_input_id = first_input_id.cuda()

            for idx in range(max_len):
                if idx == 0:  # 第一个时间步
                    state = first_state  # 解码器初始状态
                    input = self.embedding(
                        first_input_id)  # 解码器初始输入 [1, batch, embed_size]

                # output: [1, batch, dim_out]
                # state: [num_layers, batch, dim_out]
                output, state = self.decoder(input, state)
                outputs.append(output)

                vocab_prob = self.projector(output)  # [1, batch, num_vocab]
                next_input_id = torch.argmax(
                    vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]
                _done = next_input_id.squeeze(
                    0) == self.config.end_id  # 当前时间步完成解码的 [batch]
                done = done | _done  # 所有完成解码的
                if done.sum() == batch_size:  # 如果全部解码完成则提前停止
                    break
                else:
                    input = self.embedding(
                        next_input_id)  # [1, batch, embed_size]

            outputs = torch.cat(outputs,
                                0).transpose(0, 1)  # [batch, seq, dim_out]
            output_vocab = self.projector(outputs)  # [batch, seq, num_vocab]

            return output_vocab, _mu, _logvar, None, None

    # 统计参数
    def print_parameters(self):
        def statistic_param(params):
            total_num = 0  # 参数总数
            for param in params:
                num = 1
                if param.requires_grad:
                    size = param.size()
                    for dim in size:
                        num *= dim
                total_num += num
            return total_num

        print("嵌入层参数个数: %d" % statistic_param(self.embedding.parameters()))
        print("post编码器参数个数: %d" %
              statistic_param(self.post_encoder.parameters()))
        print("response编码器参数个数: %d" %
              statistic_param(self.response_encoder.parameters()))
        print("先验网络参数个数: %d" % statistic_param(self.prior_net.parameters()))
        print("识别网络参数个数: %d" %
              statistic_param(self.recognize_net.parameters()))
        print("解码器初始状态参数个数: %d" %
              statistic_param(self.prepare_state.parameters()))
        print("解码器参数个数: %d" % statistic_param(self.decoder.parameters()))
        print("输出层参数个数: %d" % statistic_param(self.projector.parameters()))
        print("参数总数: %d" % statistic_param(self.parameters()))

    # 保存模型
    def save_model(self, epoch, global_step, path):

        torch.save(
            {
                'affect_embedding': self.affect_embedding.state_dict(),
                'embedding': self.embedding.state_dict(),
                'post_encoder': self.post_encoder.state_dict(),
                'response_encoder': self.response_encoder.state_dict(),
                'prior_net': self.prior_net.state_dict(),
                'recognize_net': self.recognize_net.state_dict(),
                'prepare_state': self.prepare_state.state_dict(),
                'decoder': self.decoder.state_dict(),
                'projector': self.projector.state_dict(),
                'epoch': epoch,
                'global_step': global_step
            }, path)

    # 载入模型
    def load_model(self, path):

        checkpoint = torch.load(path)
        self.affect_embedding.load_state_dict(checkpoint['affect_embedding'])
        self.embedding.load_state_dict(checkpoint['embedding'])
        self.post_encoder.load_state_dict(checkpoint['post_encoder'])
        self.response_encoder.load_state_dict(checkpoint['response_encoder'])
        self.prior_net.load_state_dict(checkpoint['prior_net'])
        self.recognize_net.load_state_dict(checkpoint['recognize_net'])
        self.prepare_state.load_state_dict(checkpoint['prepare_state'])
        self.decoder.load_state_dict(checkpoint['decoder'])
        self.projector.load_state_dict(checkpoint['projector'])
        epoch = checkpoint['epoch']
        global_step = checkpoint['global_step']

        return epoch, global_step