示例#1
0
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.config = config

        # 定义嵌入层
        self.embedding = Embedding(
            config.num_vocab,  # 词汇表大小
            config.embedding_size,  # 嵌入层维度
            config.pad_id,
            config.dropout)

        self.affect_embedding = Embedding(config.num_vocab,
                                          config.affect_embedding_size,
                                          config.pad_id, config.dropout)
        self.affect_embedding.embedding.weight.requires_grad = False

        # 编码器
        self.encoder = Encoder(
            config.encoder_decoder_cell_type,  # rnn类型
            config.embedding_size + config.affect_embedding_size,  # 输入维度
            config.encoder_decoder_output_size,  # 输出维度
            config.encoder_decoder_num_layers,  # rnn层数
            config.encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        # 解码器
        self.decoder = Decoder(
            config.encoder_decoder_cell_type,  # rnn类型
            config.embedding_size + config.affect_embedding_size +
            config.encoder_decoder_output_size,
            config.encoder_decoder_output_size,  # 输出维度
            config.encoder_decoder_num_layers,  # rnn层数
            config.dropout)  # dropout概率

        # 输出层
        self.projector = nn.Sequential(
            nn.Linear(config.encoder_decoder_output_size, config.num_vocab),
            nn.Softmax(-1))

    def forward(self, inputs, inference=False, max_len=60, gpu=True):
        if not inference:  # 训练
            id_posts = inputs['posts']  # [batch, seq]
            len_posts = inputs['len_posts']  # [batch]
            id_responses = inputs['responses']  # [batch, seq]
            len_decoder = id_responses.size(1) - 1

            embed_posts = torch.cat(
                [self.embedding(id_posts),
                 self.affect_embedding(id_posts)], 2)
            embed_responses = torch.cat([
                self.embedding(id_responses),
                self.affect_embedding(id_responses)
            ], 2)

            # encoder_output: [seq, batch, dim]
            # state: [layers, batch, dim]
            _, state_encoder = self.encoder(embed_posts.transpose(0, 1),
                                            len_posts)
            if isinstance(state_encoder, tuple):
                context = state_encoder[0][-1, :, :]
            else:
                context = state_encoder[-1, :, :]
            context = context.unsqueeze(0)  # [1, batch, dim]

            # 解码器的输入为回复去掉end_id
            decoder_inputs = embed_responses[:, :-1, :].transpose(
                0, 1)  # [seq-1, batch, embed_size]
            decoder_inputs = decoder_inputs.split([1] * len_decoder, 0)

            outputs = []
            for idx in range(len_decoder):
                if idx == 0:
                    state = state_encoder  # 解码器初始状态
                    decoder_input = torch.cat([decoder_inputs[idx], context],
                                              2)
                else:
                    decoder_input = torch.cat([decoder_inputs[idx], context],
                                              2)
                # output: [1, batch, dim_out]
                # state: [num_layer, batch, dim_out]
                output, state = self.decoder(decoder_input, state)
                outputs.append(output)

            outputs = torch.cat(outputs,
                                0).transpose(0, 1)  # [batch, seq-1, dim_out]
            output_vocab = self.projector(outputs)  # [batch, seq-1, num_vocab]

            return output_vocab
        else:  # 测试
            id_posts = inputs['posts']  # [batch, seq]
            len_posts = inputs['len_posts']  # [batch]
            batch_size = id_posts.size(0)

            embed_posts = torch.cat(
                [self.embedding(id_posts),
                 self.affect_embedding(id_posts)], 2)

            # state: [layers, batch, dim]
            _, state_encoder = self.encoder(embed_posts.transpose(0, 1),
                                            len_posts)
            if isinstance(state_encoder, tuple):
                context = state_encoder[0][-1, :, :]
            else:
                context = state_encoder[-1, :, :]
            context = context.unsqueeze(0)  # [1, batch, dim]

            done = torch.tensor([0] * batch_size).bool()
            first_input_id = (torch.ones(
                (1, batch_size)) * self.config.start_id).long()
            if gpu:
                done = done.cuda()
                first_input_id = first_input_id.cuda()

            outputs = []
            for idx in range(max_len):
                if idx == 0:  # 第一个时间步
                    state = state_encoder  # 解码器初始状态
                    decoder_input = torch.cat([
                        self.embedding(first_input_id),
                        self.affect_embedding(first_input_id), context
                    ], 2)
                else:
                    decoder_input = torch.cat([
                        self.embedding(next_input_id),
                        self.affect_embedding(next_input_id), context
                    ], 2)

                # output: [1, batch, dim_out]
                # state: [num_layers, batch, dim_out]
                output, state = self.decoder(decoder_input, state)
                outputs.append(output)

                vocab_prob = self.projector(output)  # [1, batch, num_vocab]
                next_input_id = torch.argmax(
                    vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]
                _done = next_input_id.squeeze(
                    0) == self.config.end_id  # 当前时间步完成解码的 [batch]
                done = done | _done  # 所有完成解码的
                if done.sum() == batch_size:  # 如果全部解码完成则提前停止
                    break

            outputs = torch.cat(outputs,
                                0).transpose(0, 1)  # [batch, seq, dim_out]
            output_vocab = self.projector(outputs)  # [batch, seq, num_vocab]

            return output_vocab

    # 统计参数
    def print_parameters(self):
        r""" 统计参数 """
        total_num = 0  # 参数总数
        for param in self.parameters():
            num = 1
            if param.requires_grad:
                size = param.size()
                for dim in size:
                    num *= dim
            total_num += num
        print(f"参数总数: {total_num}")

    def save_model(self, epoch, global_step, path):
        r""" 保存模型 """
        torch.save(
            {
                'embedding': self.embedding.state_dict(),
                'affect_embedding': self.affect_embedding.state_dict(),
                'encoder': self.encoder.state_dict(),
                'decoder': self.decoder.state_dict(),
                'projector': self.projector.state_dict(),
                'epoch': epoch,
                'global_step': global_step
            }, path)

    def load_model(self, path):
        r""" 载入模型 """
        checkpoint = torch.load(path)
        self.embedding.load_state_dict(checkpoint['embedding'])
        self.affect_embedding.load_state_dict(checkpoint['affect_embedding'])
        self.encoder.load_state_dict(checkpoint['encoder'])
        self.decoder.load_state_dict(checkpoint['decoder'])
        self.projector.load_state_dict(checkpoint['projector'])
        epoch = checkpoint['epoch']
        global_step = checkpoint['global_step']
        return epoch, global_step
decoder_buffer = io.BytesIO(decoder_bytes.read())

# init training network classes / architectures
encoder_eval = Encoder(input_size=ori_subset_transformed.shape[1],
                       hidden_size=[256, 64, 16, 4, 2])
decoder_eval = Decoder(output_size=ori_subset_transformed.shape[1],
                       hidden_size=[2, 4, 16, 64, 256])

# push to cuda if cudnn is available
if (torch.backends.cudnn.version() != None) and (USE_CUDA == True):
    encoder_eval = encoder_eval.cuda()
    decoder_eval = decoder_eval.cuda()

# load trained models
# since the model was trained on a gpu and will be restored in a cpu we need to provide: map_location = 'cpu'
encoder_eval.load_state_dict(torch.load(encoder_buffer, map_location='cpu'))
decoder_eval.load_state_dict(torch.load(decoder_buffer, map_location='cpu'))

## specify a dataloader that provides the ability to evaluate the journal entrie in an "unshuffled" batch-wise manner:
# convert pre-processed data to pytorch tensor
torch_dataset = torch.from_numpy(ori_subset_transformed.values).float()

# convert to pytorch tensor - none cuda enabled
dataloader_eval = DataLoader(torch_dataset,
                             batch_size=mini_batch_size,
                             shuffle=False,
                             num_workers=0)

# determine if CUDA is available at the compute node
if (torch.backends.cudnn.version() != None) and (USE_CUDA == True):
示例#3
0
class Mem2SeqRunner(ExperimentRunnerBase):
    def __init__(self, args):
        super(Mem2SeqRunner, self).__init__(args)

        # Model parameters
        self.gru_size = 128
        self.emb_size = 128
        #TODO: Try hops 4 with task 3
        self.hops = 3
        self.dropout = 0.2

        self.encoder = Encoder(self.hops, self.nwords, self.gru_size)
        self.decoder = Decoder(self.emb_size, self.hops, self.gru_size,
                               self.nwords)

        self.optim_enc = torch.optim.Adam(self.encoder.parameters(), lr=0.001)
        self.optim_dec = torch.optim.Adam(self.decoder.parameters(), lr=0.001)
        if self.loss_weighting:
            self.optim_loss_weights = torch.optim.Adam([self.loss_weights],
                                                       lr=0.0001)
        self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optim_dec,
                                                        mode='max',
                                                        factor=0.5,
                                                        patience=1,
                                                        min_lr=0.0001,
                                                        verbose=True)

        if self.use_cuda:
            self.cross_entropy = self.cross_entropy.cuda()
            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()
            if self.loss_weighting:
                self.loss_weights = self.loss_weights.cuda()

    def train_batch_wrapper(self, batch, new_epoch, clip_grads):
        context = batch[0].transpose(0, 1)
        responses = batch[1].transpose(0, 1)
        index = batch[2].transpose(0, 1)
        sentinel = batch[3].transpose(0, 1)
        context_lengths = batch[4]
        target_lengths = batch[5]
        return self.train_batch(context, responses, index, sentinel, new_epoch,
                                context_lengths, target_lengths, clip_grads)

    def train_batch(self, context, responses, index, sentinel, new_epoch,
                    context_lengths, target_lengths, clip_grads):

        # (TODO): remove transpose
        if new_epoch:  # (TODO): Change this part
            self.loss = 0
            self.ploss = 0
            self.vloss = 0
            self.n = 1

        context = context.type(self.TYPE)
        responses = responses.type(self.TYPE)
        index = index.type(self.TYPE)
        sentinel = sentinel.type(self.TYPE)

        self.optim_enc.zero_grad()
        self.optim_dec.zero_grad()
        if self.loss_weighting:
            self.optim_loss_weights.zero_grad()

        h = self.encoder(context.transpose(0, 1))
        self.decoder.load_memory(context.transpose(0, 1))
        y = torch.from_numpy(np.array([2] * context.size(1),
                                      dtype=int)).type(self.TYPE)
        y_len = 0

        h = h.unsqueeze(0)
        output_vocab = torch.zeros(max(target_lengths), context.size(1),
                                   self.nwords)
        output_ptr = torch.zeros(max(target_lengths), context.size(1),
                                 context.size(0))
        if self.use_cuda:
            output_vocab = output_vocab.cuda()
            output_ptr = output_ptr.cuda()
        while y_len < responses.size(0):  # TODO: Add EOS condition
            p_ptr, p_vocab, h = self.decoder(context, y, h)
            output_vocab[y_len] = p_vocab
            output_ptr[y_len] = p_ptr
            #TODO: Add teqacher forcing ratio
            y = responses[y_len].type(self.TYPE)
            y_len += 1

        # print(loss)
        mask_v = torch.ones(output_vocab.size())
        mask_p = torch.ones(output_ptr.size())
        if self.use_cuda:
            mask_p = mask_p.cuda()
            mask_v = mask_v.cuda()
        for i in range(responses.size(1)):
            mask_v[target_lengths[i]:, i, :] = 0
            mask_p[target_lengths[i]:, i, :] = 0

        loss_v = self.cross_entropy(
            output_vocab.contiguous().view(-1, self.nwords),
            responses.contiguous().view(-1))

        loss_ptr = self.cross_entropy(
            output_ptr.contiguous().view(-1, context.size(0)),
            index.contiguous().view(-1))
        if self.loss_weighting:
            loss = loss_ptr/(2*self.loss_weights[0]*self.loss_weights[0]) + loss_v/(2*self.loss_weights[1]*self.loss_weights[1]) + \
               torch.log(self.loss_weights[0] * self.loss_weights[1])
            loss_ptr = loss_ptr / (2 * self.loss_weights[0] *
                                   self.loss_weights[0])
            loss_v = loss_v / (2 * self.loss_weights[1] * self.loss_weights[1])
        else:
            loss = loss_ptr + loss_v
        loss.backward()
        ec = torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), 10.0)
        dc = torch.nn.utils.clip_grad_norm_(self.decoder.parameters(), 10.0)
        self.optim_enc.step()
        self.optim_dec.step()
        if self.loss_weighting:
            self.optim_loss_weights.step()

        self.loss += loss.item()
        self.vloss += loss_v.item()
        self.ploss += loss_ptr.item()

        return loss.item(), loss_v.item(), loss_ptr.item()

    def evaluate_batch(self,
                       batch_size,
                       input_batches,
                       input_lengths,
                       target_batches,
                       target_lengths,
                       target_index,
                       target_gate,
                       src_plain,
                       profile_memory=None):

        # Set to not-training mode to disable dropout
        self.encoder.train(False)
        self.decoder.train(False)
        # Run words through encoder
        decoder_hidden = self.encoder(input_batches.transpose(0,
                                                              1)).unsqueeze(0)
        self.decoder.load_memory(input_batches.transpose(0, 1))

        # Prepare input and output variables
        decoder_input = Variable(torch.LongTensor([2] * batch_size))

        decoded_words = []
        all_decoder_outputs_vocab = Variable(
            torch.zeros(max(target_lengths), batch_size, self.nwords))
        all_decoder_outputs_ptr = Variable(
            torch.zeros(max(target_lengths), batch_size,
                        input_batches.size(0)))
        # all_decoder_outputs_gate = Variable(torch.zeros(self.max_r, batch_size))
        # Move new Variables to CUDA

        if self.use_cuda:
            all_decoder_outputs_vocab = all_decoder_outputs_vocab.cuda()
            all_decoder_outputs_ptr = all_decoder_outputs_ptr.cuda()
            # all_decoder_outputs_gate = all_decoder_outputs_gate.cuda()
            decoder_input = decoder_input.cuda()

        p = []
        for elm in src_plain:
            elm_temp = [word_triple[0] for word_triple in elm]
            p.append(elm_temp)

        self.from_whichs = []
        acc_gate, acc_ptr, acc_vac = 0.0, 0.0, 0.0
        # Run through decoder one time step at a time
        for t in range(max(target_lengths)):
            decoder_ptr, decoder_vacab, decoder_hidden = self.decoder(
                input_batches, decoder_input, decoder_hidden)
            all_decoder_outputs_vocab[t] = decoder_vacab
            topv, topvi = decoder_vacab.data.topk(1)
            all_decoder_outputs_ptr[t] = decoder_ptr
            topp, toppi = decoder_ptr.data.topk(1)
            top_ptr_i = torch.gather(input_batches[:, :, 0], 0,
                                     Variable(toppi.view(1,
                                                         -1))).transpose(0, 1)
            next_in = [
                top_ptr_i[i].item() if
                (toppi[i].item() < input_lengths[i] - 1) else topvi[i].item()
                for i in range(batch_size)
            ]
            # if next_in in self.kb_entry.keys():
            #     ptr_distr.append([next_in, decoder_vacab.data])

            decoder_input = Variable(
                torch.LongTensor(next_in))  # Chosen word is next input
            if self.use_cuda: decoder_input = decoder_input.cuda()

            temp = []
            from_which = []
            for i in range(batch_size):
                if (toppi[i].item() < len(p[i]) - 1):
                    temp.append(p[i][toppi[i].item()])
                    from_which.append('p')
                else:
                    if target_index[t][i] != toppi[i].item():
                        self.incorrect_sentinel += 1
                    ind = topvi[i].item()
                    if ind == 3:
                        temp.append('<eos>')
                    else:
                        temp.append(self.i2w[ind])
                    from_which.append('v')
            decoded_words.append(temp)
            self.from_whichs.append(from_which)
        self.from_whichs = np.array(self.from_whichs)

        loss_v = self.cross_entropy(
            all_decoder_outputs_vocab.contiguous().view(-1, self.nwords),
            target_batches.contiguous().view(-1))
        loss_ptr = self.cross_entropy(
            all_decoder_outputs_ptr.contiguous().view(-1,
                                                      input_batches.size(0)),
            target_index.contiguous().view(-1))

        if self.loss_weighting:
            loss = loss_ptr/(2*self.loss_weights[0]*self.loss_weights[0]) + loss_v/(2*self.loss_weights[1]*self.loss_weights[1]) + \
               torch.log(self.loss_weights[0] * self.loss_weights[1])
        else:
            loss = loss_ptr + loss_v

        self.loss += loss.item()
        self.vloss += loss_v.item()
        self.ploss += loss_ptr.item()
        self.n += 1

        # Set back to training mode
        self.encoder.train(True)
        self.decoder.train(True)
        return decoded_words, self.from_whichs  # , acc_ptr, acc_vac

    def save_models(self, path):
        torch.save(self.encoder.state_dict(),
                   os.path.join(path, 'encoder.pth'))
        torch.save(self.decoder.state_dict(),
                   os.path.join(path, 'decoder.pth'))

    def load_models(self, path: str = '.'):
        self.encoder.load_state_dict(
            torch.load(os.path.join(path, 'encoder.pth')))
        self.decoder.load_state_dict(
            torch.load(os.path.join(path, 'decoder.pth')))
示例#4
0
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.config = config

        # 定义嵌入层
        self.embedding = Embedding(
            config.num_vocab,  # 词汇表大小
            config.embedding_size,  # 嵌入层维度
            config.pad_id,  # pad_id
            config.dropout)

        # 情感嵌入层
        self.affect_embedding = Embedding(config.num_vocab,
                                          config.affect_embedding_size,
                                          config.pad_id, config.dropout)
        self.affect_embedding.embedding.weight.requires_grad = False

        # post编码器
        self.post_encoder = Encoder(
            config.post_encoder_cell_type,  # rnn类型
            config.embedding_size + config.affect_embedding_size,  # 输入维度
            config.post_encoder_output_size,  # 输出维度
            config.post_encoder_num_layers,  # rnn层数
            config.post_encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        # response编码器
        self.response_encoder = Encoder(
            config.response_encoder_cell_type,
            config.embedding_size + config.affect_embedding_size,  # 输入维度
            config.response_encoder_output_size,  # 输出维度
            config.response_encoder_num_layers,  # rnn层数
            config.response_encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        # 先验网络
        self.prior_net = PriorNet(
            config.post_encoder_output_size,  # post输入维度
            config.latent_size,  # 潜变量维度
            config.dims_prior)  # 隐藏层维度

        # 识别网络
        self.recognize_net = RecognizeNet(
            config.post_encoder_output_size,  # post输入维度
            config.response_encoder_output_size,  # response输入维度
            config.latent_size,  # 潜变量维度
            config.dims_recognize)  # 隐藏层维度

        # 初始化解码器状态
        self.prepare_state = PrepareState(
            config.post_encoder_output_size + config.latent_size,
            config.decoder_cell_type, config.decoder_output_size,
            config.decoder_num_layers)

        # 解码器
        self.decoder = Decoder(
            config.decoder_cell_type,  # rnn类型
            config.embedding_size + config.affect_embedding_size +
            config.post_encoder_output_size,
            config.decoder_output_size,  # 输出维度
            config.decoder_num_layers,  # rnn层数
            config.dropout)  # dropout概率

        # bow预测
        self.bow_predictor = nn.Sequential(
            nn.Linear(config.post_encoder_output_size + config.latent_size,
                      config.num_vocab), nn.Softmax(-1))

        # 输出层
        self.projector = nn.Sequential(
            nn.Linear(config.decoder_output_size, config.num_vocab),
            nn.Softmax(-1))

    def forward(self,
                inputs,
                inference=False,
                use_true=False,
                max_len=60,
                gpu=True):
        if not inference:  # 训练
            if use_true:  # 解码时使用真实值
                id_posts = inputs['posts']  # [batch, seq]
                len_posts = inputs['len_posts']  # [batch]
                id_responses = inputs['responses']  # [batch, seq]
                len_responses = inputs['len_responses']  # [batch, seq]
                sampled_latents = inputs[
                    'sampled_latents']  # [batch, latent_size]
                len_decoder = id_responses.size(1) - 1

                embed_posts = torch.cat([
                    self.embedding(id_posts),
                    self.affect_embedding(id_posts)
                ], 2)
                embed_responses = torch.cat([
                    self.embedding(id_responses),
                    self.affect_embedding(id_responses)
                ], 2)

                # state: [layers, batch, dim]
                _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                                   len_posts)
                _, state_responses = self.response_encoder(
                    embed_responses.transpose(0, 1), len_responses)
                if isinstance(state_posts, tuple):
                    state_posts = state_posts[0]
                if isinstance(state_responses, tuple):
                    state_responses = state_responses[0]
                x = state_posts[-1, :, :]  # [batch, dim]
                y = state_responses[-1, :, :]  # [batch, dim]

                _mu, _logvar = self.prior_net(x)  # [batch, latent]
                mu, logvar = self.recognize_net(x, y)  # [batch, latent]
                z = mu + (0.5 *
                          logvar).exp() * sampled_latents  # [batch, latent]

                bow_predict = self.bow_predictor(torch.cat(
                    [z, x], 1))  # [batch, num_vocab]

                first_state = self.prepare_state(torch.cat(
                    [z, x], 1))  # [num_layer, batch, dim_out]
                decoder_inputs = embed_responses[:, :-1, :].transpose(
                    0, 1)  # [seq-1, batch, embed_size]
                decoder_inputs = decoder_inputs.split(
                    [1] * len_decoder, 0)  # seq-1个[1, batch, embed_size]

                outputs = []
                for idx in range(len_decoder):
                    if idx == 0:
                        state = first_state  # 解码器初始状态
                    decoder_input = torch.cat(
                        [decoder_inputs[idx],
                         x.unsqueeze(0)], 2)
                    # output: [1, batch, dim_out]
                    # state: [num_layer, batch, dim_out]
                    output, state = self.decoder(decoder_input, state)
                    outputs.append(output)

                outputs = torch.cat(outputs,
                                    0).transpose(0,
                                                 1)  # [batch, seq-1, dim_out]
                output_vocab = self.projector(
                    outputs)  # [batch, seq-1, num_vocab]

                return output_vocab, bow_predict, _mu, _logvar, mu, logvar
            else:
                id_posts = inputs['posts']  # [batch, seq]
                len_posts = inputs['len_posts']  # [batch]
                id_responses = inputs['responses']  # [batch, seq]
                len_responses = inputs['len_responses']  # [batch]
                sampled_latents = inputs[
                    'sampled_latents']  # [batch, latent_size]
                len_decoder = id_responses.size(1) - 1
                batch_size = id_posts.size(0)

                embed_posts = torch.cat([
                    self.embedding(id_posts),
                    self.affect_embedding(id_posts)
                ], 2)
                embed_responses = torch.cat([
                    self.embedding(id_responses),
                    self.affect_embedding(id_responses)
                ], 2)

                # state: [layers, batch, dim]
                _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                                   len_posts)
                _, state_responses = self.response_encoder(
                    embed_responses.transpose(0, 1), len_responses)
                if isinstance(state_posts, tuple):
                    state_posts = state_posts[0]
                if isinstance(state_responses, tuple):
                    state_responses = state_responses[0]
                x = state_posts[-1, :, :]  # [batch, dim]
                y = state_responses[-1, :, :]  # [batch, dim]

                _mu, _logvar = self.prior_net(x)  # [batch, latent]
                mu, logvar = self.recognize_net(x, y)  # [batch, latent]
                z = mu + (0.5 *
                          logvar).exp() * sampled_latents  # [batch, latent]

                bow_predict = self.bow_predictor(torch.cat(
                    [z, x], 1))  # [batch, num_vocab]

                first_state = self.prepare_state(torch.cat(
                    [z, x], 1))  # [num_layer, batch, dim_out]
                first_input_id = (torch.ones(
                    (1, batch_size)) * self.config.start_id).long()
                if gpu:
                    first_input_id = first_input_id.cuda()

                outputs = []
                for idx in range(len_decoder):
                    if idx == 0:
                        state = first_state
                        decoder_input = torch.cat([
                            self.embedding(first_input_id),
                            self.affect_embedding(first_input_id),
                            x.unsqueeze(0)
                        ], 2)
                    else:
                        decoder_input = torch.cat([
                            self.embedding(next_input_id),
                            self.affect_embedding(next_input_id),
                            x.unsqueeze(0)
                        ], 2)
                    output, state = self.decoder(decoder_input, state)
                    outputs.append(output)

                    vocab_prob = self.projector(
                        output)  # [1, batch, num_vocab]
                    next_input_id = torch.argmax(
                        vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]

                outputs = torch.cat(outputs,
                                    0).transpose(0,
                                                 1)  # [batch, seq-1, dim_out]
                output_vocab = self.projector(
                    outputs)  # [batch, seq-1, num_vocab]
                return output_vocab, bow_predict, _mu, _logvar, mu, logvar
        else:  # 测试
            id_posts = inputs['posts']  # [batch, seq]
            len_posts = inputs['len_posts']  # [batch]
            sampled_latents = inputs['sampled_latents']  # [batch, latent_size]
            batch_size = id_posts.size(0)

            embed_posts = torch.cat(
                [self.embedding(id_posts),
                 self.affect_embedding(id_posts)], 2)

            # state: [layers, batch, dim]
            _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                               len_posts)
            if isinstance(state_posts, tuple):  # 如果是lstm则取h
                state_posts = state_posts[0]  # [layers, batch, dim]
            x = state_posts[-1, :, :]  # 取最后一层 [batch, dim]

            _mu, _logvar = self.prior_net(x)  # [batch, latent]
            z = _mu + (0.5 *
                       _logvar).exp() * sampled_latents  # [batch, latent]

            first_state = self.prepare_state(torch.cat(
                [z, x], 1))  # [num_layer, batch, dim_out]
            done = torch.tensor([0] * batch_size).bool()
            first_input_id = (torch.ones(
                (1, batch_size)) * self.config.start_id).long()
            if gpu:
                done = done.cuda()
                first_input_id = first_input_id.cuda()

            outputs = []
            for idx in range(max_len):
                if idx == 0:  # 第一个时间步
                    state = first_state  # 解码器初始状态
                    decoder_input = torch.cat([
                        self.embedding(first_input_id),
                        self.affect_embedding(first_input_id),
                        x.unsqueeze(0)
                    ], 2)
                else:
                    decoder_input = torch.cat([
                        self.embedding(next_input_id),
                        self.affect_embedding(next_input_id),
                        x.unsqueeze(0)
                    ], 2)
                # output: [1, batch, dim_out]
                # state: [num_layers, batch, dim_out]
                output, state = self.decoder(decoder_input, state)
                outputs.append(output)

                vocab_prob = self.projector(output)  # [1, batch, num_vocab]
                next_input_id = torch.argmax(
                    vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]
                _done = next_input_id.squeeze(
                    0) == self.config.end_id  # 当前时间步完成解码的 [batch]
                done = done | _done  # 所有完成解码的
                if done.sum() == batch_size:  # 如果全部解码完成则提前停止
                    break

            outputs = torch.cat(outputs,
                                0).transpose(0, 1)  # [batch, seq, dim_out]
            output_vocab = self.projector(outputs)  # [batch, seq, num_vocab]

            return output_vocab, _, _mu, _logvar, None, None

    def print_parameters(self):
        r""" 统计参数 """
        total_num = 0  # 参数总数
        for param in self.parameters():
            num = 1
            if param.requires_grad:
                size = param.size()
                for dim in size:
                    num *= dim
            total_num += num
        print(f"参数总数: {total_num}")

    def save_model(self, epoch, global_step, path):
        r""" 保存模型 """
        torch.save(
            {
                'affect_embedding': self.affect_embedding.state_dict(),
                'embedding': self.embedding.state_dict(),
                'post_encoder': self.post_encoder.state_dict(),
                'response_encoder': self.response_encoder.state_dict(),
                'prior_net': self.prior_net.state_dict(),
                'recognize_net': self.recognize_net.state_dict(),
                'prepare_state': self.prepare_state.state_dict(),
                'decoder': self.decoder.state_dict(),
                'projector': self.projector.state_dict(),
                'bow_predictor': self.bow_predictor.state_dict(),
                'epoch': epoch,
                'global_step': global_step
            }, path)

    def load_model(self, path):
        r""" 载入模型 """
        checkpoint = torch.load(path)
        self.affect_embedding.load_state_dict(checkpoint['affect_embedding'])
        self.embedding.load_state_dict(checkpoint['embedding'])
        self.post_encoder.load_state_dict(checkpoint['post_encoder'])
        self.response_encoder.load_state_dict(checkpoint['response_encoder'])
        self.prior_net.load_state_dict(checkpoint['prior_net'])
        self.recognize_net.load_state_dict(checkpoint['recognize_net'])
        self.prepare_state.load_state_dict(checkpoint['prepare_state'])
        self.decoder.load_state_dict(checkpoint['decoder'])
        self.projector.load_state_dict(checkpoint['projector'])
        self.bow_predictor.load_state_dict(checkpoint['bow_predictor'])
        epoch = checkpoint['epoch']
        global_step = checkpoint['global_step']
        return epoch, global_step
示例#5
0
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()

        self.config = config

        # 情感嵌入层
        self.affect_embedding = AffectEmbedding(config.num_vocab,
                                                config.affect_embedding_size,
                                                config.pad_id)

        # 定义嵌入层
        self.embedding = WordEmbedding(
            config.num_vocab,  # 词汇表大小
            config.embedding_size,  # 嵌入层维度
            config.pad_id)  # pad_id

        # post编码器
        self.post_encoder = Encoder(
            config.post_encoder_cell_type,  # rnn类型
            config.embedding_size,  # 输入维度
            config.post_encoder_output_size,  # 输出维度
            config.post_encoder_num_layers,  # rnn层数
            config.post_encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        # response编码器
        self.response_encoder = Encoder(
            config.response_encoder_cell_type,
            config.embedding_size,  # 输入维度
            config.response_encoder_output_size,  # 输出维度
            config.response_encoder_num_layers,  # rnn层数
            config.response_encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        # 先验网络
        self.prior_net = PriorNet(
            config.post_encoder_output_size,  # post输入维度
            config.latent_size,  # 潜变量维度
            config.dims_prior)  # 隐藏层维度

        # 识别网络
        self.recognize_net = RecognizeNet(
            config.post_encoder_output_size,  # post输入维度
            config.response_encoder_output_size,  # response输入维度
            config.latent_size,  # 潜变量维度
            config.dims_recognize)  # 隐藏层维度

        # 初始化解码器状态
        self.prepare_state = PrepareState(
            config.post_encoder_output_size + config.latent_size,
            config.decoder_cell_type, config.decoder_output_size,
            config.decoder_num_layers)

        # 解码器
        self.decoder = Decoder(
            config.decoder_cell_type,  # rnn类型
            config.embedding_size,  # 输入维度
            config.decoder_output_size,  # 输出维度
            config.decoder_num_layers,  # rnn层数
            config.dropout)  # dropout概率

        # 输出层
        self.projector = nn.Sequential(
            nn.Linear(config.decoder_output_size, config.num_vocab),
            nn.Softmax(-1))

    def forward(
            self,
            input,
            inference=False,  # 是否测试
            use_true=False,
            max_len=60):  # 解码的最大长度

        if not inference:  # 训练
            if use_true:
                id_posts = input['posts']  # [batch, seq]
                len_posts = input['len_posts']  # [batch]
                id_responses = input['responses']  # [batch, seq]
                len_responses = input['len_responses']  # [batch, seq]
                sampled_latents = input[
                    'sampled_latents']  # [batch, latent_size]

                embed_posts = self.embedding(
                    id_posts)  # [batch, seq, embed_size]
                embed_responses = self.embedding(
                    id_responses)  # [batch, seq, embed_size]

                # 解码器的输入为回复去掉end_id
                decoder_input = embed_responses[:, :-1, :].transpose(
                    0, 1)  # [seq-1, batch, embed_size]
                len_decoder = decoder_input.size()[0]  # 解码长度 seq-1
                decoder_input = decoder_input.split(
                    [1] * len_decoder,
                    0)  # 解码器每一步的输入 seq-1个[1, batch, embed_size]

                # state = [layers, batch, dim]
                _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                                   len_posts)
                _, state_responses = self.response_encoder(
                    embed_responses.transpose(0, 1), len_responses)
                if isinstance(state_posts, tuple):
                    state_posts = state_posts[0]
                if isinstance(state_responses, tuple):
                    state_responses = state_responses[0]
                x = state_posts[-1, :, :]  # [batch, dim]
                y = state_responses[-1, :, :]  # [batch, dim]

                _mu, _logvar = self.prior_net(x)  # [batch, latent]
                mu, logvar = self.recognize_net(x, y)  # [batch, latent]
                z = mu + (0.5 *
                          logvar).exp() * sampled_latents  # [batch, latent]

                first_state = self.prepare_state(torch.cat(
                    [z, x], 1))  # [num_layer, batch, dim_out]
                outputs = []

                for idx in range(len_decoder):
                    if idx == 0:
                        state = first_state  # 解码器初始状态
                    input = decoder_input[
                        idx]  # 当前时间步输入 [1, batch, embed_size]
                    # output: [1, batch, dim_out]
                    # state: [num_layer, batch, dim_out]
                    output, state = self.decoder(input, state)
                    outputs.append(output)

                outputs = torch.cat(outputs,
                                    0).transpose(0,
                                                 1)  # [batch, seq-1, dim_out]
                output_vocab = self.projector(
                    outputs)  # [batch, seq-1, num_vocab]

                return output_vocab, _mu, _logvar, mu, logvar

            else:
                id_posts = input['posts']  # [batch, seq]
                len_posts = input['len_posts']  # [batch]
                id_responses = input['responses']  # [batch, seq]
                len_responses = input['len_responses']  # [batch]
                sampled_latents = input[
                    'sampled_latents']  # [batch, latent_size]
                len_decoder = id_responses.size()[1] - 1
                batch_size = id_posts.size()[0]
                device = id_posts.device.type

                embed_posts = self.embedding(
                    id_posts)  # [batch, seq, embed_size]
                embed_responses = self.embedding(
                    id_responses)  # [batch, seq, embed_size]

                # state = [layers, batch, dim]
                _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                                   len_posts)
                _, state_responses = self.response_encoder(
                    embed_responses.transpose(0, 1), len_responses)
                if isinstance(state_posts, tuple):
                    state_posts = state_posts[0]
                if isinstance(state_responses, tuple):
                    state_responses = state_responses[0]
                x = state_posts[-1, :, :]  # [batch, dim]
                y = state_responses[-1, :, :]  # [batch, dim]

                _mu, _logvar = self.prior_net(x)  # [batch, latent]
                mu, logvar = self.recognize_net(x, y)  # [batch, latent]
                z = mu + (0.5 *
                          logvar).exp() * sampled_latents  # [batch, latent]

                first_state = self.prepare_state(torch.cat(
                    [z, x], 1))  # [num_layer, batch, dim_out]
                first_input_id = (torch.ones(
                    (1, batch_size)) * self.config.start_id).long()
                if device == 'cuda':
                    first_input_id = first_input_id.cuda()
                outputs = []

                for idx in range(len_decoder):
                    if idx == 0:
                        state = first_state
                        input = self.embedding(first_input_id)
                    else:
                        input = self.embedding(
                            next_input_id)  # 当前时间步输入 [1, batch, embed_size]
                    output, state = self.decoder(input, state)
                    outputs.append(output)

                    vocab_prob = self.projector(
                        output)  # [1, batch, num_vocab]
                    next_input_id = torch.argmax(
                        vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]

                outputs = torch.cat(outputs,
                                    0).transpose(0,
                                                 1)  # [batch, seq-1, dim_out]
                output_vocab = self.projector(
                    outputs)  # [batch, seq-1, num_vocab]

                return output_vocab, _mu, _logvar, mu, logvar

        else:  # 测试

            id_posts = input['posts']  # [batch, seq]
            len_posts = input['len_posts']  # [batch]
            sampled_latents = input['sampled_latents']  # [batch, latent_size]
            batch_size = id_posts.size()[0]
            device = id_posts.device.type

            embed_posts = self.embedding(id_posts)  # [batch, seq, embed_size]

            # state = [layers, batch, dim]
            _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                               len_posts)
            if isinstance(state_posts, tuple):  # 如果是lstm则取h
                state_posts = state_posts[0]  # [layers, batch, dim]
            x = state_posts[-1, :, :]  # 取最后一层 [batch, dim]

            _mu, _logvar = self.prior_net(x)  # [batch, latent]
            z = _mu + (0.5 *
                       _logvar).exp() * sampled_latents  # [batch, latent]

            first_state = self.prepare_state(torch.cat(
                [z, x], 1))  # [num_layer, batch, dim_out]
            outputs = []

            done = torch.BoolTensor([0] * batch_size)
            first_input_id = (torch.ones(
                (1, batch_size)) * self.config.start_id).long()
            if device == 'cuda':
                done = done.cuda()
                first_input_id = first_input_id.cuda()

            for idx in range(max_len):
                if idx == 0:  # 第一个时间步
                    state = first_state  # 解码器初始状态
                    input = self.embedding(
                        first_input_id)  # 解码器初始输入 [1, batch, embed_size]

                # output: [1, batch, dim_out]
                # state: [num_layers, batch, dim_out]
                output, state = self.decoder(input, state)
                outputs.append(output)

                vocab_prob = self.projector(output)  # [1, batch, num_vocab]
                next_input_id = torch.argmax(
                    vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]
                _done = next_input_id.squeeze(
                    0) == self.config.end_id  # 当前时间步完成解码的 [batch]
                done = done | _done  # 所有完成解码的
                if done.sum() == batch_size:  # 如果全部解码完成则提前停止
                    break
                else:
                    input = self.embedding(
                        next_input_id)  # [1, batch, embed_size]

            outputs = torch.cat(outputs,
                                0).transpose(0, 1)  # [batch, seq, dim_out]
            output_vocab = self.projector(outputs)  # [batch, seq, num_vocab]

            return output_vocab, _mu, _logvar, None, None

    # 统计参数
    def print_parameters(self):
        def statistic_param(params):
            total_num = 0  # 参数总数
            for param in params:
                num = 1
                if param.requires_grad:
                    size = param.size()
                    for dim in size:
                        num *= dim
                total_num += num
            return total_num

        print("嵌入层参数个数: %d" % statistic_param(self.embedding.parameters()))
        print("post编码器参数个数: %d" %
              statistic_param(self.post_encoder.parameters()))
        print("response编码器参数个数: %d" %
              statistic_param(self.response_encoder.parameters()))
        print("先验网络参数个数: %d" % statistic_param(self.prior_net.parameters()))
        print("识别网络参数个数: %d" %
              statistic_param(self.recognize_net.parameters()))
        print("解码器初始状态参数个数: %d" %
              statistic_param(self.prepare_state.parameters()))
        print("解码器参数个数: %d" % statistic_param(self.decoder.parameters()))
        print("输出层参数个数: %d" % statistic_param(self.projector.parameters()))
        print("参数总数: %d" % statistic_param(self.parameters()))

    # 保存模型
    def save_model(self, epoch, global_step, path):

        torch.save(
            {
                'affect_embedding': self.affect_embedding.state_dict(),
                'embedding': self.embedding.state_dict(),
                'post_encoder': self.post_encoder.state_dict(),
                'response_encoder': self.response_encoder.state_dict(),
                'prior_net': self.prior_net.state_dict(),
                'recognize_net': self.recognize_net.state_dict(),
                'prepare_state': self.prepare_state.state_dict(),
                'decoder': self.decoder.state_dict(),
                'projector': self.projector.state_dict(),
                'epoch': epoch,
                'global_step': global_step
            }, path)

    # 载入模型
    def load_model(self, path):

        checkpoint = torch.load(path)
        self.affect_embedding.load_state_dict(checkpoint['affect_embedding'])
        self.embedding.load_state_dict(checkpoint['embedding'])
        self.post_encoder.load_state_dict(checkpoint['post_encoder'])
        self.response_encoder.load_state_dict(checkpoint['response_encoder'])
        self.prior_net.load_state_dict(checkpoint['prior_net'])
        self.recognize_net.load_state_dict(checkpoint['recognize_net'])
        self.prepare_state.load_state_dict(checkpoint['prepare_state'])
        self.decoder.load_state_dict(checkpoint['decoder'])
        self.projector.load_state_dict(checkpoint['projector'])
        epoch = checkpoint['epoch']
        global_step = checkpoint['global_step']

        return epoch, global_step
示例#6
0
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()

        self.config = config

        # 情感嵌入层
        self.affect_embedding = AffectEmbedding(config.num_vocab,
                                                config.affect_embedding_size,
                                                config.pad_id)

        # 定义嵌入层
        self.embedding = WordEmbedding(
            config.num_vocab,  # 词汇表大小
            config.embedding_size,  # 嵌入层维度
            config.pad_id)  # pad_id

        # 情感编码器
        self.affect_encoder = Encoder(
            config.encoder_decoder_cell_type,  # rnn类型
            config.affect_embedding_size,  # 输入维度
            config.affect_encoder_output_size,  # 输出维度
            config.encoder_decoder_num_layers,  # 层数
            config.encoder_bidirectional,  # 是否双向
            config.dropout)

        # 编码器
        self.encoder = Encoder(
            config.encoder_decoder_cell_type,  # rnn类型
            config.embedding_size,  # 输入维度
            config.encoder_decoder_output_size,  # 输出维度
            config.encoder_decoder_num_layers,  # rnn层数
            config.encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        self.attention = Attention(config.encoder_decoder_output_size,
                                   config.affect_encoder_output_size,
                                   config.attention_type,
                                   config.attention_size)

        self.prepare_state = PrepareState(
            config.encoder_decoder_cell_type,
            config.encoder_decoder_output_size +
            config.affect_encoder_output_size,
            config.encoder_decoder_output_size)

        self.linear_prepare_input = nn.Linear(
            config.embedding_size + config.affect_encoder_output_size +
            config.attention_size, config.decoder_input_size)

        # 解码器
        self.decoder = Decoder(
            config.encoder_decoder_cell_type,  # rnn类型
            config.decoder_input_size,  # 输入维度
            config.encoder_decoder_output_size,  # 输出维度
            config.encoder_decoder_num_layers,  # rnn层数
            config.dropout)  # dropout概率

        # 输出层
        self.projector = nn.Sequential(
            nn.Linear(
                config.encoder_decoder_output_size + config.attention_size,
                config.num_vocab), nn.Softmax(-1))

    def forward(
            self,
            input,
            inference=False,  # 是否测试
            max_len=60):  # 解码的最大长度

        if not inference:  # 训练

            id_posts = input['posts']  # [batch, seq]
            len_posts = input['len_posts']  # [batch]
            id_responses = input['responses']  # [batch, seq]
            batch_size = id_posts.size()[0]
            device = id_posts.device.type

            embed_posts = self.embedding(id_posts)  # [batch, seq, embed_size]
            embed_responses = self.embedding(
                id_responses)  # [batch, seq, embed_size]
            affect_posts = self.affect_embedding(id_posts)

            # 解码器的输入为回复去掉end_id
            decoder_input = embed_responses[:, :-1, :].transpose(
                0, 1)  # [seq-1, batch, embed_size]
            len_decoder = decoder_input.size()[0]  # 解码长度 seq-1
            decoder_input = decoder_input.split(
                [1] * len_decoder, 0)  # 解码器每一步的输入 seq-1个[1, batch, embed_size]

            # state: [layers, batch, dim]
            _, state_encoder = self.encoder(embed_posts.transpose(0, 1),
                                            len_posts)
            # output_affect: [seq, batch, dim]
            output_affect, state_affect_encoder = self.affect_encoder(
                affect_posts.transpose(0, 1), len_posts)
            context_affect = output_affect[-1, :, :].unsqueeze(
                0)  # [1, batch, dim]

            outputs = []
            weights = []

            init_attn = torch.zeros(
                [1, batch_size, self.config.attention_size])
            if device == 'cuda':
                init_attn = init_attn.cuda()

            for idx in range(len_decoder):
                if idx == 0:
                    state = self.prepare_state(state_encoder,
                                               state_affect_encoder)  # 解码器初始状态
                    input = torch.cat(
                        [decoder_input[idx], context_affect, init_attn], 2)  #
                else:
                    input = torch.cat([
                        decoder_input[idx], context_affect,
                        attn.transpose(0, 1)
                    ], 2)  #

                input = self.linear_prepare_input(input)

                # output: [1, batch, dim_out]
                # state: [num_layer, batch, dim_out]
                output, state = self.decoder(input, state)
                # attn: [batch, 1, attention_size]
                # weights: [batch, 1, encoder_len]
                attn, weight = self.attention(output.transpose(0, 1),
                                              output_affect.transpose(0, 1))
                outputs.append(torch.cat([output, attn.transpose(0, 1)], 2))
                weights.append(weight)

            outputs = torch.cat(outputs,
                                0).transpose(0, 1)  # [batch, seq-1, dim_out]
            weights = torch.cat(weights, 1)  # [batch, seq-1, dim_out]

            output_vocab = self.projector(outputs)  # [batch, seq-1, num_vocab]

            return output_vocab, weights

        else:  # 测试

            id_posts = input['posts']  # [batch, seq]
            len_posts = input['len_posts']  # [batch]
            batch_size = id_posts.size()[0]
            device = id_posts.device.type

            embed_posts = self.embedding(id_posts)  # [batch, seq, embed_size]
            affect_posts = self.affect_embedding(id_posts)

            # state: [layers, batch, dim]
            _, state_encoder = self.encoder(embed_posts.transpose(0, 1),
                                            len_posts)
            output_affect, state_affect_encoder = self.affect_encoder(
                affect_posts.transpose(0, 1), len_posts)
            context_affect = output_affect[-1, :, :].unsqueeze(
                0)  # [1, batch, dim]

            outputs = []
            weights = []

            done = torch.BoolTensor([0] * batch_size)
            first_input_id = (torch.ones(
                (1, batch_size)) * self.config.start_id).long()
            init_attn = torch.zeros(
                [1, batch_size, self.config.attention_size])
            if device == 'cuda':
                done = done.cuda()
                first_input_id = first_input_id.cuda()
                init_attn = init_attn.cuda()

            for idx in range(max_len):

                if idx == 0:  # 第一个时间步
                    state = self.prepare_state(state_encoder,
                                               state_affect_encoder)  # 解码器初始状态
                    input = torch.cat([
                        self.embedding(first_input_id), context_affect,
                        init_attn
                    ], 2)  # 解码器初始输入 [1, batch, embed_size]
                else:
                    input = torch.cat(
                        [input, context_affect,
                         attn.transpose(0, 1)], 2)

                input = self.linear_prepare_input(input)

                # output: [1, batch, dim_out]
                # state: [num_layers, batch, dim_out]
                output, state = self.decoder(input, state)
                # attn: [batch, 1, attention_size]
                # weights: [batch, 1, encoder_len]
                attn, weight = self.attention(output.transpose(0, 1),
                                              output_affect.transpose(0, 1))
                outputs.append(torch.cat([output, attn.transpose(0, 1)], 2))
                weights.append(weight)

                vocab_prob = self.projector(
                    torch.cat([output, attn.transpose(0, 1)],
                              2))  # [1, batch, num_vocab]
                next_input_id = torch.argmax(
                    vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]

                _done = next_input_id.squeeze(
                    0) == self.config.end_id  # 当前时间步完成解码的 [batch]
                done = done | _done  # 所有完成解码的

                if done.sum() == batch_size:  # 如果全部解码完成则提前停止
                    break
                else:
                    input = self.embedding(
                        next_input_id)  # [1, batch, embed_size]

            outputs = torch.cat(outputs,
                                0).transpose(0, 1)  # [batch, seq, dim_out]
            weights = torch.cat(weights, 1)
            output_vocab = self.projector(outputs)  # [batch, seq, num_vocab]

            return output_vocab, weights

    # 统计参数
    def print_parameters(self):
        def statistic_param(params):
            total_num = 0  # 参数总数
            for param in params:
                num = 1
                if param.requires_grad:
                    size = param.size()
                    for dim in size:
                        num *= dim
                total_num += num
            return total_num

        print("嵌入层参数个数: %d" % statistic_param(self.embedding.parameters()))
        print("编码器参数个数: %d" % statistic_param(self.encoder.parameters()))
        print("情感编码器参数个数: %d" %
              statistic_param(self.affect_encoder.parameters()))
        print("准备状态参数个数: %d" %
              statistic_param(self.prepare_state.parameters()))
        print("准备输入参数个数: %d" %
              statistic_param(self.linear_prepare_input.parameters()))
        print("注意力参数个数: %d" % statistic_param(self.attention.parameters()))
        print("解码器参数个数: %d" % statistic_param(self.decoder.parameters()))
        print("输出层参数个数: %d" % statistic_param(self.projector.parameters()))
        print("参数总数: %d" % statistic_param(self.parameters()))

    # 保存模型
    def save_model(self, epoch, global_step, path):

        torch.save(
            {
                'affect_embedding': self.affect_embedding.state_dict(),
                'embedding': self.embedding.state_dict(),
                'encoder': self.encoder.state_dict(),
                'affect_encoder': self.affect_encoder.state_dict(),
                'prepare_state': self.prepare_state.state_dict(),
                'linear_prepare_input': self.linear_prepare_input.state_dict(),
                'attention': self.attention.state_dict(),
                'projector': self.projector.state_dict(),
                'decoder': self.decoder.state_dict(),
                'epoch': epoch,
                'global_step': global_step
            }, path)

    # 载入模型
    def load_model(self, path):

        checkpoint = torch.load(path)
        self.affect_embedding.load_state_dict(checkpoint['embedding'])
        self.embedding.load_state_dict(checkpoint['embedding'])
        self.encoder.load_state_dict(checkpoint['encoder'])
        self.affect_encoder.load_state_dict(checkpoint['affect_encoder'])
        self.prepare_state.load_state_dict(checkpoint['prepare_state'])
        self.linear_prepare_input.load_state_dict(
            checkpoint['linear_prepare_input'])
        self.attention.load_state_dict(checkpoint['attention'])
        self.decoder.load_state_dict(checkpoint['decoder'])
        self.projector.load_state_dict(checkpoint['projector'])
        epoch = checkpoint['epoch']
        global_step = checkpoint['global_step']

        return epoch, global_step
示例#7
0
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.config = config

        # 定义嵌入层
        self.embedding = Embedding(
            config.num_vocab,  # 词汇表大小
            config.embedding_size,  # 嵌入层维度
            config.pad_id,
            config.dropout)

        self.affect_embedding = Embedding(config.num_vocab,
                                          config.affect_embedding_size,
                                          config.pad_id, config.dropout)
        self.affect_embedding.embedding.weight.requires_grad = False

        # 编码器
        self.encoder = Encoder(
            config.encoder_cell_type,  # rnn类型
            config.embedding_size + config.affect_embedding_size,  # 输入维度
            config.encoder_output_size,  # 输出维度
            config.encoder_num_layers,  # rnn层数
            config.encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        # 输出层
        self.classifier = nn.Sequential(
            nn.Linear(config.encoder_output_size,
                      config.encoder_output_size // 2),
            nn.Linear(config.encoder_output_size // 2,
                      config.num_classifications), nn.Softmax(-1))

    def forward(self, inputs):
        x = inputs['x']  # [batch, len]
        len_x = inputs['len_x']  # [batch]

        embed_x = torch.cat(
            [self.embedding(x), self.affect_embedding(x)],
            2)  # [batch, len, embed]
        # state: [layers, batch, dim]
        _, state = self.encoder(embed_x.transpose(0, 1), len_x)
        if isinstance(state, tuple):
            state = state[0]
        context = state[-1, :, :]  # [batch, dim]

        output = self.classifier(context)  # [batch, 7]
        return output

    # 统计参数
    def print_parameters(self):
        r""" 统计参数 """
        total_num = 0  # 参数总数
        for param in self.parameters():
            num = 1
            if param.requires_grad:
                size = param.size()
                for dim in size:
                    num *= dim
            total_num += num
        print(f"参数总数: {total_num}")

    def save_model(self, epoch, global_step, path):
        r""" 保存模型 """
        torch.save(
            {
                'embedding': self.embedding.state_dict(),
                'affect_embedding': self.affect_embedding.state_dict(),
                'encoder': self.encoder.state_dict(),
                'classifier': self.classifier.state_dict(),
                'epoch': epoch,
                'global_step': global_step
            }, path)

    def load_model(self, path):
        r""" 载入模型 """
        checkpoint = torch.load(path)
        self.embedding.load_state_dict(checkpoint['embedding'])
        self.affect_embedding.load_state_dict(checkpoint['affect_embedding'])
        self.encoder.load_state_dict(checkpoint['encoder'])
        self.classifier.load_state_dict(checkpoint['classifier'])
        epoch = checkpoint['epoch']
        global_step = checkpoint['global_step']
        return epoch, global_step
示例#8
0
文件: train.py 项目: sztudy/poetry
n_iter = int(cfg.get('params', 'n_iter'))
wv_size = int(cfg.get('params', 'wv_size'))
e_hidden_size = int(cfg.get('params', 'e_hidden_size'))
d_hidden_size = int(cfg.get('params', 'd_hidden_size'))
a_hidden_size = int(cfg.get('params', 'a_hidden_size'))
d_linear_size = int(cfg.get('params', 'd_linear_size'))
encoder_lr = float(cfg.get('params', 'encoder_lr'))
decoder_lr = float(cfg.get('params', 'decoder_lr'))
momentum_rate = float(cfg.get('params', 'momentum_rate'))

encoder = Encoder(wv_size, e_hidden_size)
decoder = Decoder(2 * e_hidden_size, wv_size, d_hidden_size, len(ch_index),
                  d_linear_size)

try:
    encoder.load_state_dict(torch.load('model/encoder.params.pkl'))
    decoder.load_state_dict(torch.load('model/decoder.params.pkl'))
    print('load')
except:
    pass

if use_cuda:
    encoder = encoder.cuda()
    decoder = decoder.cuda()

criterion = nn.NLLLoss()
encoder_optimizer = optim.SGD(encoder.parameters(),
                              lr=encoder_lr,
                              momentum=momentum_rate)
decoder_optimizer = optim.SGD(decoder.parameters(),
                              lr=decoder_lr,
示例#9
0
def main(args):
    print(device)

    # Path to save the trained models
    saving_model_path = args.saving_model_path

    # If path is not empty, set check_out = True
    check_point = True if args.encoder_model_path and args.decoder_model_path else False

    curr_epoch = int(args.encoder_model_path.split('/')[-1].split('.')[0]
                     [8:]) if check_point else 0

    # Load all vocabulary in the data set
    vocab = vocab_loader("vocab.txt")

    # Build the models
    encoder = Encoder(base_model=args.cnn_model,
                      embed_size=embedding_size,
                      init=not check_point,
                      train_cnn=args.train_cnn).to(device)
    decoder = Decoder(vocab_size=vocal_size,
                      input_size=embedding_size,
                      hidden_size=args.hidden_size).to(device)

    # Transform image size to 224 or 299
    size_of_image = 299 if args.cnn_model == "inception" else 224
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size_of_image),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    val_transform = transforms.Compose([
        transforms.Resize((size_of_image, size_of_image)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load training data
    train_loader = get_train_loader(vocab, train_image_file,
                                    train_captions_json, train_transform,
                                    args.batch_size, True)

    # Load validation data
    val_loader = get_val_loader(val_image_file, val_captions_json,
                                val_transform)

    # load model from a check point
    if check_point:
        encoder.load_state_dict(
            torch.load(args.encoder_model_path, map_location=device))
        decoder.load_state_dict(
            torch.load(args.decoder_model_path, map_location=device))

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(filter(
        lambda p: p.requires_grad,
        list(encoder.parameters()) + list(decoder.parameters())),
                                 lr=args.learning_rate)

    while curr_epoch < args.num_epochs:
        curr_epoch += 1
        train(epoch=curr_epoch,
              num_epochs=args.num_epochs,
              vocab=vocab,
              train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              optimizer=optimizer,
              criterion=criterion,
              saving_model_path=saving_model_path)
        validation(vocab=vocab,
                   val_loader=val_loader,
                   encoder=encoder,
                   decoder=decoder,
                   beam_width=args.beam_width)