def __init__(self, config): super(Model, self).__init__() self.config = config # 定义嵌入层 self.embedding = Embedding( config.num_vocab, # 词汇表大小 config.embedding_size, # 嵌入层维度 config.pad_id, # pad_id config.dropout) # post编码器 self.post_encoder = Encoder( config.post_encoder_cell_type, # rnn类型 config.embedding_size, # 输入维度 config.post_encoder_output_size, # 输出维度 config.post_encoder_num_layers, # rnn层数 config.post_encoder_bidirectional, # 是否双向 config.dropout) # dropout概率 # response编码器 self.response_encoder = Encoder( config.response_encoder_cell_type, config.embedding_size, # 输入维度 config.response_encoder_output_size, # 输出维度 config.response_encoder_num_layers, # rnn层数 config.response_encoder_bidirectional, # 是否双向 config.dropout) # dropout概率 # 先验网络 self.prior_net = PriorNet( config.post_encoder_output_size, # post输入维度 config.latent_size, # 潜变量维度 config.dims_prior) # 隐藏层维度 # 识别网络 self.recognize_net = RecognizeNet( config.post_encoder_output_size, # post输入维度 config.response_encoder_output_size, # response输入维度 config.latent_size, # 潜变量维度 config.dims_recognize) # 隐藏层维度 # 初始化解码器状态 self.prepare_state = PrepareState( config.post_encoder_output_size + config.latent_size, config.decoder_cell_type, config.decoder_output_size, config.decoder_num_layers) # 解码器 self.decoder = Decoder( config.decoder_cell_type, # rnn类型 config.embedding_size, # 输入维度 config.decoder_output_size, # 输出维度 config.decoder_num_layers, # rnn层数 config.dropout) # dropout概率 # 输出层 self.projector = nn.Sequential( nn.Linear(config.decoder_output_size, config.num_vocab), nn.Softmax(-1))
def __init__(self, config): super(Model, self).__init__() self.config = config self.embedding = Embedding(config.num_vocab, config.embedding_size, config.pad_id, config.dropout) self.affect_embedding = Embedding(config.num_vocab, config.affect_embedding_size, config.pad_id, config.dropout) self.affect_embedding.embedding.weight.requires_grad = False self.post_encoder = Encoder(config.encoder_cell_type, config.embedding_size + config.affect_embedding_size, config.encoder_output_size, config.encoder_num_layers, config.encoder_bidirectional, config.dropout) self.response_encoder = Encoder(config.encoder_cell_type, config.embedding_size + config.affect_embedding_size, config.encoder_output_size, config.encoder_num_layers, config.encoder_bidirectional, config.dropout) self.prior_net = PriorNet(config.encoder_output_size, config.latent_size, config.dims_prior) self.recognize_net = RecognizeNet(config.encoder_output_size, config.encoder_output_size, config.latent_size, config.dims_recognize) self.prepare_state = PrepareState(config.encoder_output_size + config.latent_size, config.decoder_cell_type, config.decoder_output_size, config.decoder_num_layers) self.decoder = Decoder(config.decoder_cell_type, config.embedding_size + config.affect_embedding_size + config.encoder_output_size, config.decoder_output_size, config.decoder_num_layers, config.dropout) self.projector = nn.Sequential(nn.Linear(config.decoder_output_size, config.num_vocab), nn.Softmax(-1))
class Model(nn.Module): def __init__(self, config): super(Model, self).__init__() self.config = config # 定义嵌入层 self.embedding = Embedding( config.num_vocab, # 词汇表大小 config.embedding_size, # 嵌入层维度 config.pad_id, # pad_id config.dropout) # 情感嵌入层 self.affect_embedding = Embedding(config.num_vocab, config.affect_embedding_size, config.pad_id, config.dropout) self.affect_embedding.embedding.weight.requires_grad = False # post编码器 self.post_encoder = Encoder( config.post_encoder_cell_type, # rnn类型 config.embedding_size + config.affect_embedding_size, # 输入维度 config.post_encoder_output_size, # 输出维度 config.post_encoder_num_layers, # rnn层数 config.post_encoder_bidirectional, # 是否双向 config.dropout) # dropout概率 # response编码器 self.response_encoder = Encoder( config.response_encoder_cell_type, config.embedding_size + config.affect_embedding_size, # 输入维度 config.response_encoder_output_size, # 输出维度 config.response_encoder_num_layers, # rnn层数 config.response_encoder_bidirectional, # 是否双向 config.dropout) # dropout概率 # 先验网络 self.prior_net = PriorNet( config.post_encoder_output_size, # post输入维度 config.latent_size, # 潜变量维度 config.dims_prior) # 隐藏层维度 # 识别网络 self.recognize_net = RecognizeNet( config.post_encoder_output_size, # post输入维度 config.response_encoder_output_size, # response输入维度 config.latent_size, # 潜变量维度 config.dims_recognize) # 隐藏层维度 # 初始化解码器状态 self.prepare_state = PrepareState( config.post_encoder_output_size + config.latent_size, config.decoder_cell_type, config.decoder_output_size, config.decoder_num_layers) # 解码器 self.decoder = Decoder( config.decoder_cell_type, # rnn类型 config.embedding_size + config.affect_embedding_size + config.post_encoder_output_size, config.decoder_output_size, # 输出维度 config.decoder_num_layers, # rnn层数 config.dropout) # dropout概率 # bow预测 self.bow_predictor = nn.Sequential( nn.Linear(config.post_encoder_output_size + config.latent_size, config.num_vocab), nn.Softmax(-1)) # 输出层 self.projector = nn.Sequential( nn.Linear(config.decoder_output_size, config.num_vocab), nn.Softmax(-1)) def forward(self, inputs, inference=False, use_true=False, max_len=60, gpu=True): if not inference: # 训练 if use_true: # 解码时使用真实值 id_posts = inputs['posts'] # [batch, seq] len_posts = inputs['len_posts'] # [batch] id_responses = inputs['responses'] # [batch, seq] len_responses = inputs['len_responses'] # [batch, seq] sampled_latents = inputs[ 'sampled_latents'] # [batch, latent_size] len_decoder = id_responses.size(1) - 1 embed_posts = torch.cat([ self.embedding(id_posts), self.affect_embedding(id_posts) ], 2) embed_responses = torch.cat([ self.embedding(id_responses), self.affect_embedding(id_responses) ], 2) # state: [layers, batch, dim] _, state_posts = self.post_encoder(embed_posts.transpose(0, 1), len_posts) _, state_responses = self.response_encoder( embed_responses.transpose(0, 1), len_responses) if isinstance(state_posts, tuple): state_posts = state_posts[0] if isinstance(state_responses, tuple): state_responses = state_responses[0] x = state_posts[-1, :, :] # [batch, dim] y = state_responses[-1, :, :] # [batch, dim] _mu, _logvar = self.prior_net(x) # [batch, latent] mu, logvar = self.recognize_net(x, y) # [batch, latent] z = mu + (0.5 * logvar).exp() * sampled_latents # [batch, latent] bow_predict = self.bow_predictor(torch.cat( [z, x], 1)) # [batch, num_vocab] first_state = self.prepare_state(torch.cat( [z, x], 1)) # [num_layer, batch, dim_out] decoder_inputs = embed_responses[:, :-1, :].transpose( 0, 1) # [seq-1, batch, embed_size] decoder_inputs = decoder_inputs.split( [1] * len_decoder, 0) # seq-1个[1, batch, embed_size] outputs = [] for idx in range(len_decoder): if idx == 0: state = first_state # 解码器初始状态 decoder_input = torch.cat( [decoder_inputs[idx], x.unsqueeze(0)], 2) # output: [1, batch, dim_out] # state: [num_layer, batch, dim_out] output, state = self.decoder(decoder_input, state) outputs.append(output) outputs = torch.cat(outputs, 0).transpose(0, 1) # [batch, seq-1, dim_out] output_vocab = self.projector( outputs) # [batch, seq-1, num_vocab] return output_vocab, bow_predict, _mu, _logvar, mu, logvar else: id_posts = inputs['posts'] # [batch, seq] len_posts = inputs['len_posts'] # [batch] id_responses = inputs['responses'] # [batch, seq] len_responses = inputs['len_responses'] # [batch] sampled_latents = inputs[ 'sampled_latents'] # [batch, latent_size] len_decoder = id_responses.size(1) - 1 batch_size = id_posts.size(0) embed_posts = torch.cat([ self.embedding(id_posts), self.affect_embedding(id_posts) ], 2) embed_responses = torch.cat([ self.embedding(id_responses), self.affect_embedding(id_responses) ], 2) # state: [layers, batch, dim] _, state_posts = self.post_encoder(embed_posts.transpose(0, 1), len_posts) _, state_responses = self.response_encoder( embed_responses.transpose(0, 1), len_responses) if isinstance(state_posts, tuple): state_posts = state_posts[0] if isinstance(state_responses, tuple): state_responses = state_responses[0] x = state_posts[-1, :, :] # [batch, dim] y = state_responses[-1, :, :] # [batch, dim] _mu, _logvar = self.prior_net(x) # [batch, latent] mu, logvar = self.recognize_net(x, y) # [batch, latent] z = mu + (0.5 * logvar).exp() * sampled_latents # [batch, latent] bow_predict = self.bow_predictor(torch.cat( [z, x], 1)) # [batch, num_vocab] first_state = self.prepare_state(torch.cat( [z, x], 1)) # [num_layer, batch, dim_out] first_input_id = (torch.ones( (1, batch_size)) * self.config.start_id).long() if gpu: first_input_id = first_input_id.cuda() outputs = [] for idx in range(len_decoder): if idx == 0: state = first_state decoder_input = torch.cat([ self.embedding(first_input_id), self.affect_embedding(first_input_id), x.unsqueeze(0) ], 2) else: decoder_input = torch.cat([ self.embedding(next_input_id), self.affect_embedding(next_input_id), x.unsqueeze(0) ], 2) output, state = self.decoder(decoder_input, state) outputs.append(output) vocab_prob = self.projector( output) # [1, batch, num_vocab] next_input_id = torch.argmax( vocab_prob, 2) # 选择概率最大的词作为下个时间步的输入 [1, batch] outputs = torch.cat(outputs, 0).transpose(0, 1) # [batch, seq-1, dim_out] output_vocab = self.projector( outputs) # [batch, seq-1, num_vocab] return output_vocab, bow_predict, _mu, _logvar, mu, logvar else: # 测试 id_posts = inputs['posts'] # [batch, seq] len_posts = inputs['len_posts'] # [batch] sampled_latents = inputs['sampled_latents'] # [batch, latent_size] batch_size = id_posts.size(0) embed_posts = torch.cat( [self.embedding(id_posts), self.affect_embedding(id_posts)], 2) # state: [layers, batch, dim] _, state_posts = self.post_encoder(embed_posts.transpose(0, 1), len_posts) if isinstance(state_posts, tuple): # 如果是lstm则取h state_posts = state_posts[0] # [layers, batch, dim] x = state_posts[-1, :, :] # 取最后一层 [batch, dim] _mu, _logvar = self.prior_net(x) # [batch, latent] z = _mu + (0.5 * _logvar).exp() * sampled_latents # [batch, latent] first_state = self.prepare_state(torch.cat( [z, x], 1)) # [num_layer, batch, dim_out] done = torch.tensor([0] * batch_size).bool() first_input_id = (torch.ones( (1, batch_size)) * self.config.start_id).long() if gpu: done = done.cuda() first_input_id = first_input_id.cuda() outputs = [] for idx in range(max_len): if idx == 0: # 第一个时间步 state = first_state # 解码器初始状态 decoder_input = torch.cat([ self.embedding(first_input_id), self.affect_embedding(first_input_id), x.unsqueeze(0) ], 2) else: decoder_input = torch.cat([ self.embedding(next_input_id), self.affect_embedding(next_input_id), x.unsqueeze(0) ], 2) # output: [1, batch, dim_out] # state: [num_layers, batch, dim_out] output, state = self.decoder(decoder_input, state) outputs.append(output) vocab_prob = self.projector(output) # [1, batch, num_vocab] next_input_id = torch.argmax( vocab_prob, 2) # 选择概率最大的词作为下个时间步的输入 [1, batch] _done = next_input_id.squeeze( 0) == self.config.end_id # 当前时间步完成解码的 [batch] done = done | _done # 所有完成解码的 if done.sum() == batch_size: # 如果全部解码完成则提前停止 break outputs = torch.cat(outputs, 0).transpose(0, 1) # [batch, seq, dim_out] output_vocab = self.projector(outputs) # [batch, seq, num_vocab] return output_vocab, _, _mu, _logvar, None, None def print_parameters(self): r""" 统计参数 """ total_num = 0 # 参数总数 for param in self.parameters(): num = 1 if param.requires_grad: size = param.size() for dim in size: num *= dim total_num += num print(f"参数总数: {total_num}") def save_model(self, epoch, global_step, path): r""" 保存模型 """ torch.save( { 'affect_embedding': self.affect_embedding.state_dict(), 'embedding': self.embedding.state_dict(), 'post_encoder': self.post_encoder.state_dict(), 'response_encoder': self.response_encoder.state_dict(), 'prior_net': self.prior_net.state_dict(), 'recognize_net': self.recognize_net.state_dict(), 'prepare_state': self.prepare_state.state_dict(), 'decoder': self.decoder.state_dict(), 'projector': self.projector.state_dict(), 'bow_predictor': self.bow_predictor.state_dict(), 'epoch': epoch, 'global_step': global_step }, path) def load_model(self, path): r""" 载入模型 """ checkpoint = torch.load(path) self.affect_embedding.load_state_dict(checkpoint['affect_embedding']) self.embedding.load_state_dict(checkpoint['embedding']) self.post_encoder.load_state_dict(checkpoint['post_encoder']) self.response_encoder.load_state_dict(checkpoint['response_encoder']) self.prior_net.load_state_dict(checkpoint['prior_net']) self.recognize_net.load_state_dict(checkpoint['recognize_net']) self.prepare_state.load_state_dict(checkpoint['prepare_state']) self.decoder.load_state_dict(checkpoint['decoder']) self.projector.load_state_dict(checkpoint['projector']) self.bow_predictor.load_state_dict(checkpoint['bow_predictor']) epoch = checkpoint['epoch'] global_step = checkpoint['global_step'] return epoch, global_step
class Model(nn.Module): def __init__(self, config): super(Model, self).__init__() self.config = config # 情感嵌入层 self.affect_embedding = AffectEmbedding(config.num_vocab, config.affect_embedding_size, config.pad_id) # 定义嵌入层 self.embedding = WordEmbedding( config.num_vocab, # 词汇表大小 config.embedding_size, # 嵌入层维度 config.pad_id) # pad_id # post编码器 self.post_encoder = Encoder( config.post_encoder_cell_type, # rnn类型 config.embedding_size, # 输入维度 config.post_encoder_output_size, # 输出维度 config.post_encoder_num_layers, # rnn层数 config.post_encoder_bidirectional, # 是否双向 config.dropout) # dropout概率 # response编码器 self.response_encoder = Encoder( config.response_encoder_cell_type, config.embedding_size, # 输入维度 config.response_encoder_output_size, # 输出维度 config.response_encoder_num_layers, # rnn层数 config.response_encoder_bidirectional, # 是否双向 config.dropout) # dropout概率 # 先验网络 self.prior_net = PriorNet( config.post_encoder_output_size, # post输入维度 config.latent_size, # 潜变量维度 config.dims_prior) # 隐藏层维度 # 识别网络 self.recognize_net = RecognizeNet( config.post_encoder_output_size, # post输入维度 config.response_encoder_output_size, # response输入维度 config.latent_size, # 潜变量维度 config.dims_recognize) # 隐藏层维度 # 初始化解码器状态 self.prepare_state = PrepareState( config.post_encoder_output_size + config.latent_size, config.decoder_cell_type, config.decoder_output_size, config.decoder_num_layers) # 解码器 self.decoder = Decoder( config.decoder_cell_type, # rnn类型 config.embedding_size, # 输入维度 config.decoder_output_size, # 输出维度 config.decoder_num_layers, # rnn层数 config.dropout) # dropout概率 # 输出层 self.projector = nn.Sequential( nn.Linear(config.decoder_output_size, config.num_vocab), nn.Softmax(-1)) def forward( self, input, inference=False, # 是否测试 use_true=False, max_len=60): # 解码的最大长度 if not inference: # 训练 if use_true: id_posts = input['posts'] # [batch, seq] len_posts = input['len_posts'] # [batch] id_responses = input['responses'] # [batch, seq] len_responses = input['len_responses'] # [batch, seq] sampled_latents = input[ 'sampled_latents'] # [batch, latent_size] embed_posts = self.embedding( id_posts) # [batch, seq, embed_size] embed_responses = self.embedding( id_responses) # [batch, seq, embed_size] # 解码器的输入为回复去掉end_id decoder_input = embed_responses[:, :-1, :].transpose( 0, 1) # [seq-1, batch, embed_size] len_decoder = decoder_input.size()[0] # 解码长度 seq-1 decoder_input = decoder_input.split( [1] * len_decoder, 0) # 解码器每一步的输入 seq-1个[1, batch, embed_size] # state = [layers, batch, dim] _, state_posts = self.post_encoder(embed_posts.transpose(0, 1), len_posts) _, state_responses = self.response_encoder( embed_responses.transpose(0, 1), len_responses) if isinstance(state_posts, tuple): state_posts = state_posts[0] if isinstance(state_responses, tuple): state_responses = state_responses[0] x = state_posts[-1, :, :] # [batch, dim] y = state_responses[-1, :, :] # [batch, dim] _mu, _logvar = self.prior_net(x) # [batch, latent] mu, logvar = self.recognize_net(x, y) # [batch, latent] z = mu + (0.5 * logvar).exp() * sampled_latents # [batch, latent] first_state = self.prepare_state(torch.cat( [z, x], 1)) # [num_layer, batch, dim_out] outputs = [] for idx in range(len_decoder): if idx == 0: state = first_state # 解码器初始状态 input = decoder_input[ idx] # 当前时间步输入 [1, batch, embed_size] # output: [1, batch, dim_out] # state: [num_layer, batch, dim_out] output, state = self.decoder(input, state) outputs.append(output) outputs = torch.cat(outputs, 0).transpose(0, 1) # [batch, seq-1, dim_out] output_vocab = self.projector( outputs) # [batch, seq-1, num_vocab] return output_vocab, _mu, _logvar, mu, logvar else: id_posts = input['posts'] # [batch, seq] len_posts = input['len_posts'] # [batch] id_responses = input['responses'] # [batch, seq] len_responses = input['len_responses'] # [batch] sampled_latents = input[ 'sampled_latents'] # [batch, latent_size] len_decoder = id_responses.size()[1] - 1 batch_size = id_posts.size()[0] device = id_posts.device.type embed_posts = self.embedding( id_posts) # [batch, seq, embed_size] embed_responses = self.embedding( id_responses) # [batch, seq, embed_size] # state = [layers, batch, dim] _, state_posts = self.post_encoder(embed_posts.transpose(0, 1), len_posts) _, state_responses = self.response_encoder( embed_responses.transpose(0, 1), len_responses) if isinstance(state_posts, tuple): state_posts = state_posts[0] if isinstance(state_responses, tuple): state_responses = state_responses[0] x = state_posts[-1, :, :] # [batch, dim] y = state_responses[-1, :, :] # [batch, dim] _mu, _logvar = self.prior_net(x) # [batch, latent] mu, logvar = self.recognize_net(x, y) # [batch, latent] z = mu + (0.5 * logvar).exp() * sampled_latents # [batch, latent] first_state = self.prepare_state(torch.cat( [z, x], 1)) # [num_layer, batch, dim_out] first_input_id = (torch.ones( (1, batch_size)) * self.config.start_id).long() if device == 'cuda': first_input_id = first_input_id.cuda() outputs = [] for idx in range(len_decoder): if idx == 0: state = first_state input = self.embedding(first_input_id) else: input = self.embedding( next_input_id) # 当前时间步输入 [1, batch, embed_size] output, state = self.decoder(input, state) outputs.append(output) vocab_prob = self.projector( output) # [1, batch, num_vocab] next_input_id = torch.argmax( vocab_prob, 2) # 选择概率最大的词作为下个时间步的输入 [1, batch] outputs = torch.cat(outputs, 0).transpose(0, 1) # [batch, seq-1, dim_out] output_vocab = self.projector( outputs) # [batch, seq-1, num_vocab] return output_vocab, _mu, _logvar, mu, logvar else: # 测试 id_posts = input['posts'] # [batch, seq] len_posts = input['len_posts'] # [batch] sampled_latents = input['sampled_latents'] # [batch, latent_size] batch_size = id_posts.size()[0] device = id_posts.device.type embed_posts = self.embedding(id_posts) # [batch, seq, embed_size] # state = [layers, batch, dim] _, state_posts = self.post_encoder(embed_posts.transpose(0, 1), len_posts) if isinstance(state_posts, tuple): # 如果是lstm则取h state_posts = state_posts[0] # [layers, batch, dim] x = state_posts[-1, :, :] # 取最后一层 [batch, dim] _mu, _logvar = self.prior_net(x) # [batch, latent] z = _mu + (0.5 * _logvar).exp() * sampled_latents # [batch, latent] first_state = self.prepare_state(torch.cat( [z, x], 1)) # [num_layer, batch, dim_out] outputs = [] done = torch.BoolTensor([0] * batch_size) first_input_id = (torch.ones( (1, batch_size)) * self.config.start_id).long() if device == 'cuda': done = done.cuda() first_input_id = first_input_id.cuda() for idx in range(max_len): if idx == 0: # 第一个时间步 state = first_state # 解码器初始状态 input = self.embedding( first_input_id) # 解码器初始输入 [1, batch, embed_size] # output: [1, batch, dim_out] # state: [num_layers, batch, dim_out] output, state = self.decoder(input, state) outputs.append(output) vocab_prob = self.projector(output) # [1, batch, num_vocab] next_input_id = torch.argmax( vocab_prob, 2) # 选择概率最大的词作为下个时间步的输入 [1, batch] _done = next_input_id.squeeze( 0) == self.config.end_id # 当前时间步完成解码的 [batch] done = done | _done # 所有完成解码的 if done.sum() == batch_size: # 如果全部解码完成则提前停止 break else: input = self.embedding( next_input_id) # [1, batch, embed_size] outputs = torch.cat(outputs, 0).transpose(0, 1) # [batch, seq, dim_out] output_vocab = self.projector(outputs) # [batch, seq, num_vocab] return output_vocab, _mu, _logvar, None, None # 统计参数 def print_parameters(self): def statistic_param(params): total_num = 0 # 参数总数 for param in params: num = 1 if param.requires_grad: size = param.size() for dim in size: num *= dim total_num += num return total_num print("嵌入层参数个数: %d" % statistic_param(self.embedding.parameters())) print("post编码器参数个数: %d" % statistic_param(self.post_encoder.parameters())) print("response编码器参数个数: %d" % statistic_param(self.response_encoder.parameters())) print("先验网络参数个数: %d" % statistic_param(self.prior_net.parameters())) print("识别网络参数个数: %d" % statistic_param(self.recognize_net.parameters())) print("解码器初始状态参数个数: %d" % statistic_param(self.prepare_state.parameters())) print("解码器参数个数: %d" % statistic_param(self.decoder.parameters())) print("输出层参数个数: %d" % statistic_param(self.projector.parameters())) print("参数总数: %d" % statistic_param(self.parameters())) # 保存模型 def save_model(self, epoch, global_step, path): torch.save( { 'affect_embedding': self.affect_embedding.state_dict(), 'embedding': self.embedding.state_dict(), 'post_encoder': self.post_encoder.state_dict(), 'response_encoder': self.response_encoder.state_dict(), 'prior_net': self.prior_net.state_dict(), 'recognize_net': self.recognize_net.state_dict(), 'prepare_state': self.prepare_state.state_dict(), 'decoder': self.decoder.state_dict(), 'projector': self.projector.state_dict(), 'epoch': epoch, 'global_step': global_step }, path) # 载入模型 def load_model(self, path): checkpoint = torch.load(path) self.affect_embedding.load_state_dict(checkpoint['affect_embedding']) self.embedding.load_state_dict(checkpoint['embedding']) self.post_encoder.load_state_dict(checkpoint['post_encoder']) self.response_encoder.load_state_dict(checkpoint['response_encoder']) self.prior_net.load_state_dict(checkpoint['prior_net']) self.recognize_net.load_state_dict(checkpoint['recognize_net']) self.prepare_state.load_state_dict(checkpoint['prepare_state']) self.decoder.load_state_dict(checkpoint['decoder']) self.projector.load_state_dict(checkpoint['projector']) epoch = checkpoint['epoch'] global_step = checkpoint['global_step'] return epoch, global_step