class StED(LAED): def __init__(self, corpus, config): super(StED, self).__init__(config) self.vocab = corpus.vocab self.rev_vocab = corpus.rev_vocab self.vocab_size = len(self.vocab) self.go_id = self.rev_vocab[BOS] self.eos_id = self.rev_vocab[EOS] if not hasattr(config, "freeze_step"): config.freeze_step = 6000 # build model here # word embeddings self.x_embedding = nn.Embedding(self.vocab_size, config.embed_size) # latent action learned self.x_encoder = EncoderRNN(config.embed_size, config.dec_cell_size, dropout_p=config.dropout, rnn_cell=config.rnn_cell, variable_lengths=False) self.q_y = nn.Linear(config.dec_cell_size, config.y_size * config.k) self.x_init_connector = nn_lib.LinearConnector(config.y_size * config.k, config.dec_cell_size, config.rnn_cell == 'lstm') # decoder self.prev_decoder = DecoderRNN(self.vocab_size, config.max_dec_len, config.embed_size, config.dec_cell_size, self.go_id, self.eos_id, n_layers=1, rnn_cell=config.rnn_cell, input_dropout_p=config.dropout, dropout_p=config.dropout, use_attention=False, use_gpu=config.use_gpu, embedding=self.x_embedding) self.next_decoder = DecoderRNN(self.vocab_size, config.max_dec_len, config.embed_size, config.dec_cell_size, self.go_id, self.eos_id, n_layers=1, rnn_cell=config.rnn_cell, input_dropout_p=config.dropout, dropout_p=config.dropout, use_attention=False, use_gpu=config.use_gpu, embedding=self.x_embedding) # Encoder-Decoder STARTS here self.embedding = nn.Embedding(self.vocab_size, config.embed_size, padding_idx=self.rev_vocab[PAD]) self.utt_encoder = RnnUttEncoder(config.utt_cell_size, config.dropout, bidirection=False, # bidirection=True in the original code use_attn=config.utt_type == 'attn_rnn', vocab_size=self.vocab_size, embedding=self.embedding) self.ctx_encoder = EncoderRNN(self.utt_encoder.output_size, config.ctx_cell_size, 0.0, config.dropout, config.num_layer, config.rnn_cell, variable_lengths=config.fix_batch) # FNN to get Y self.p_fc1 = nn.Linear(config.ctx_cell_size, config.ctx_cell_size) self.p_y = nn.Linear(config.ctx_cell_size, config.y_size * config.k) # connector self.c_init_connector = nn_lib.LinearConnector(config.y_size * config.k, config.dec_cell_size, config.rnn_cell == 'lstm') # decoder self.decoder = DecoderRNN(self.vocab_size, config.max_dec_len, config.embed_size, config.dec_cell_size, self.go_id, self.eos_id, n_layers=1, rnn_cell=config.rnn_cell, input_dropout_p=config.dropout, dropout_p=config.dropout, use_attention=config.use_attn, attn_size=config.dec_cell_size, attn_mode=config.attn_type, use_gpu=config.use_gpu, embedding=self.embedding) # force G(z,c) has z if config.use_attribute: self.attribute_loss = criterions.NLLEntropy(-100, config) self.cat_connector = nn_lib.GumbelConnector(config.use_gpu) self.greedy_cat_connector = nn_lib.GreedyConnector() self.nll_loss = criterions.NLLEntropy(self.rev_vocab[PAD], self.config) self.cat_kl_loss = criterions.CatKLLoss() self.log_uniform_y = Variable(torch.log(torch.ones(1) / config.k)) self.entropy_loss = criterions.Entropy() if self.use_gpu: self.log_uniform_y = self.log_uniform_y.cuda() self.kl_w = 0.0 def valid_loss(self, loss, batch_cnt=None): return loss.nll def valid_loss_(self, loss, batch_cnt=None): # for the VAE, there is vae_nll, reg_kl # for enc-deco, there is nll, pq_kl, maybe xz_nll vst_loss = loss.vst_prev_nll + loss.vst_next_nll if self.config.use_reg_kl: vst_loss += loss.reg_kl if self.config.greedy_q: enc_loss = loss.nll + loss.pi_nll else: enc_loss = loss.nll + loss.pi_kl if self.config.use_attribute: enc_loss += loss.attribute_nll if batch_cnt is not None and batch_cnt > self.config.freeze_step: total_loss = enc_loss if self.kl_w == 0.0: self.kl_w = 1.0 self.flush_valid = True for param in self.x_embedding.parameters(): param.requires_grad = False for param in self.x_encoder.parameters(): param.requires_grad = False for param in self.q_y.parameters(): param.requires_grad = False for param in self.x_init_connector.parameters(): param.requires_grad = False for param in self.prev_decoder.parameters(): param.requires_grad = False for param in self.next_decoder.parameters(): param.requires_grad = False else: total_loss = vst_loss return total_loss def pxz_forward(self, batch_size, results, prev_utts, next_utts, mode, gen_type): # map sample to initial state of decoder dec_init_state = self.x_init_connector(results.sample_y) prev_dec_inputs = prev_utts[:, 0:-1] next_dec_inputs = next_utts[:, 0:-1] # decode prev_dec_outs, prev_dec_last, prev_dec_ctx = self.prev_decoder( batch_size, prev_dec_inputs, dec_init_state, mode=mode, gen_type=gen_type, beam_size=self.config.beam_size) next_dec_outs, next_dec_last, next_dec_ctx = self.next_decoder( batch_size, next_dec_inputs, dec_init_state, mode=mode, gen_type=gen_type, beam_size=self.config.beam_size) results['prev_outs'] = prev_dec_outs results['prev_ctx'] = prev_dec_ctx results['next_outs'] = next_dec_outs results['next_ctx'] = next_dec_ctx return results def forward_(self, data_feed, mode, sample_n=1, gen_type='greedy', return_latent=False): ctx_lens = data_feed['context_lens'] batch_size = len(ctx_lens) ctx_utts = self.np2var(data_feed['contexts'], LONG) out_utts = self.np2var(data_feed['outputs'], LONG) prev_utts = self.np2var(data_feed['prevs'], LONG) next_utts = self.np2var(data_feed['nexts'], LONG) vst_resp = self.pxz_forward(batch_size, self.qzx_forward(out_utts[:,1:]), prev_utts, next_utts, mode, gen_type) # context encoder c_inputs = self.utt_encoder(ctx_utts) # c_outs, c_last = self.ctx_encoder(c_inputs, ctx_lens) # c_last = c_last.squeeze(0) c_last = c_inputs.squeeze(1) # prior network py_logits = self.p_y(F.tanh(self.p_fc1(c_last))).view(-1, self.config.k) log_py = F.log_softmax(py_logits, dim=1) if mode != GEN: sample_y, y_id = vst_resp.sample_y.detach(), vst_resp.y_ids.detach() y_id = y_id.view(-1, self.config.y_size) qy_id = y_id else: qy_id = vst_resp.y_ids.detach() qy_id = qy_id.view(-1, self.config.y_size) if sample_n > 1: if gen_type == 'greedy': temp = [] temp_ids = [] # sample the prior network N times for i in range(sample_n): temp_y, temp_id = self.cat_connector(py_logits, 1.0, hard=True, return_max_id=True) temp_ids.append(temp_id.view(-1, self.config.y_size)) temp.append(temp_y.view(-1, self.config.k * self.config.y_size)) sample_y = torch.cat(temp, dim=0) y_id = torch.cat(temp_ids, dim=0) batch_size *= sample_n c_last = c_last.repeat(sample_n, 1) elif gen_type == 'sample': sample_y, y_id = self.greedy_cat_connector(py_logits, self.use_gpu, return_max_id=True) sample_y = sample_y.view(-1, self.config.k*self.config.y_size).repeat(sample_n, 1) y_id = y_id.view(-1, self.config.y_size).repeat(sample_n, 1) c_last = c_last.repeat(sample_n, 1) batch_size *= sample_n else: raise ValueError else: sample_y, y_id = self.cat_connector(py_logits, 1.0, hard=True, return_max_id=True) # pack attention context if self.config.use_attn: dec_init_w = self.dec_init_connector.get_w() init_embed = dec_init_w.view(1, self.config.y_size, self.config.k, self.config.dec_cell_size) temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k, 1) attn_inputs = torch.sum(temp_sample_y * init_embed, dim=2) else: attn_inputs = None # map sample to initial state of decoder sample_y = sample_y.view(-1, self.config.k * self.config.y_size) dec_init_state = self.c_init_connector(sample_y) + c_last.unsqueeze(0) # decode dec_outs, dec_last, dec_ctx = self.decoder(batch_size, out_utts[:, 0:-1], dec_init_state, attn_context=attn_inputs, mode=mode, gen_type=gen_type, beam_size=self.config.beam_size) # get decoder inputs labels = out_utts[:, 1:].contiguous() prev_labels = prev_utts[:, 1:].contiguous() next_labels = next_utts[:, 1:].contiguous() dec_ctx[DecoderRNN.KEY_LATENT] = y_id dec_ctx[DecoderRNN.KEY_POLICY] = log_py dec_ctx[DecoderRNN.KEY_RECOG_LATENT] = qy_id # compute loss or return results if mode == GEN: return dec_ctx, labels else: # VAE-related Losses log_qy = F.log_softmax(vst_resp.qy_logits, dim=1) vst_prev_nll = self.nll_loss(vst_resp.prev_outs, prev_labels) vst_next_nll = self.nll_loss(vst_resp.next_outs, next_labels) avg_log_qy = torch.exp(log_qy.view(-1, self.config.y_size, self.config.k)) avg_log_qy = torch.log(torch.mean(avg_log_qy, dim=0) + 1e-15) reg_kl = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) # Encoder-decoder Losses enc_dec_nll = self.nll_loss(dec_outs, labels) pi_kl = self.cat_kl_loss(log_qy.detach(), log_py, batch_size, unit_average=True) pi_nll = F.cross_entropy(py_logits.view(-1, self.config.k), y_id.view(-1)) _, max_pi = torch.max(py_logits.view(-1, self.config.k), dim=1) pi_err = torch.mean((max_pi != y_id.view(-1)).float()) if self.config.use_attribute: pad_embeded = self.x_embedding.weight[self.rev_vocab[PAD]].view( 1, 1, self.config.embed_size) pad_embeded = pad_embeded.repeat(batch_size, dec_outs.size(1), 1) mask = torch.sign(labels).float().unsqueeze(2) dec_out_p = torch.exp(dec_outs.view(-1, self.vocab_size)) dec_out_embedded = torch.matmul(dec_out_p, self.x_embedding.weight) dec_out_embedded = dec_out_embedded.view(-1, dec_outs.size(1), self.config.embed_size) valid_out_embedded = mask * dec_out_embedded + (1.0 - mask) * pad_embeded x_outs, x_last = self.x_encoder(valid_out_embedded) x_last = x_last.transpose(0, 1).contiguous(). \ view(-1, self.config.dec_cell_size) qy_logits = self.q_y(x_last).view(-1, self.config.k) attribute_outs = F.log_softmax(qy_logits, dim=qy_logits.dim() - 1) attribute_outs = attribute_outs.view(-1, self.config.y_size, self.config.k) attribute_nll = self.attribute_loss(attribute_outs, y_id.detach()) _, max_qy = torch.max(qy_logits.view(-1, self.config.k), dim=1) adv_err = torch.mean((max_qy != y_id.view(-1)).float()) else: attribute_nll = None adv_err = None results = Pack(nll=enc_dec_nll, pi_kl=pi_kl, pi_nll=pi_nll, attribute_nll=attribute_nll, vst_prev_nll=vst_prev_nll, vst_next_nll=vst_next_nll, reg_kl=reg_kl, mi=mi, pi_err=pi_err, adv_err=adv_err) if return_latent: results['log_py'] = log_py results['log_qy'] = log_qy results['dec_init_state'] = dec_init_state results['y_ids'] = y_id return results def forward(self, data_feed, mode, sample_n=1, gen_type='greedy', return_latent=False): ctx_lens = data_feed['context_lens'] batch_size = len(ctx_lens) ctx_utts = self.np2var(data_feed['contexts'], LONG) out_utts = self.np2var(data_feed['outputs'], LONG) prev_utts = self.np2var(data_feed['prevs'], LONG) next_utts = self.np2var(data_feed['nexts'], LONG) vst_resp = self.pxz_forward(batch_size, self.qzx_forward(out_utts[:,1:]), prev_utts, next_utts, mode, gen_type) # context encoder c_inputs = self.utt_encoder(ctx_utts) c_outs, c_last = self.ctx_encoder(c_inputs, ctx_lens) c_last = c_last.squeeze(0) # prior network py_logits = self.p_y(F.tanh(self.p_fc1(c_last))).view(-1, self.config.k) log_py = F.log_softmax(py_logits, dim=1) if mode != GEN: sample_y, y_id = vst_resp.sample_y.detach(), vst_resp.y_ids.detach() y_id = y_id.view(-1, self.config.y_size) qy_id = y_id else: qy_id = vst_resp.y_ids.detach() qy_id = qy_id.view(-1, self.config.y_size) if sample_n > 1: if gen_type == 'greedy': temp = [] temp_ids = [] # sample the prior network N times for i in range(sample_n): temp_y, temp_id = self.cat_connector(py_logits, 1.0, hard=True, return_max_id=True) temp_ids.append(temp_id.view(-1, self.config.y_size)) temp.append(temp_y.view(-1, self.config.k * self.config.y_size)) sample_y = torch.cat(temp, dim=0) y_id = torch.cat(temp_ids, dim=0) batch_size *= sample_n c_last = c_last.repeat(sample_n, 1) elif gen_type == 'sample': sample_y, y_id = self.greedy_cat_connector(py_logits, self.use_gpu, return_max_id=True) sample_y = sample_y.view(-1, self.config.k*self.config.y_size).repeat(sample_n, 1) y_id = y_id.view(-1, self.config.y_size).repeat(sample_n, 1) c_last = c_last.repeat(sample_n, 1) batch_size *= sample_n else: raise ValueError else: sample_y, y_id = self.cat_connector(py_logits, 1.0, hard=True, return_max_id=True) # pack attention context if self.config.use_attn: dec_init_w = self.dec_init_connector.get_w() init_embed = dec_init_w.view(1, self.config.y_size, self.config.k, self.config.dec_cell_size) temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k, 1) attn_inputs = torch.sum(temp_sample_y * init_embed, dim=2) else: attn_inputs = None # map sample to initial state of decoder sample_y = sample_y.view(-1, self.config.k * self.config.y_size) dec_init_state = self.c_init_connector(sample_y) + c_last.unsqueeze(0) # decode dec_outs, dec_last, dec_ctx = self.decoder(batch_size, out_utts[:, 0:-1], dec_init_state, attn_context=attn_inputs, mode=mode, gen_type=gen_type, beam_size=self.config.beam_size) # get decoder inputs labels = out_utts[:, 1:].contiguous() prev_labels = prev_utts[:, 1:].contiguous() next_labels = next_utts[:, 1:].contiguous() dec_ctx[DecoderRNN.KEY_LATENT] = y_id dec_ctx[DecoderRNN.KEY_POLICY] = log_py dec_ctx[DecoderRNN.KEY_RECOG_LATENT] = qy_id # compute loss or return results if mode == GEN: return dec_ctx, labels else: # VAE-related Losses log_qy = F.log_softmax(vst_resp.qy_logits, dim=1) vst_prev_nll = self.nll_loss(vst_resp.prev_outs, prev_labels) vst_next_nll = self.nll_loss(vst_resp.next_outs, next_labels) avg_log_qy = torch.exp(log_qy.view(-1, self.config.y_size, self.config.k)) avg_log_qy = torch.log(torch.mean(avg_log_qy, dim=0) + 1e-15) reg_kl = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) # Encoder-decoder Losses enc_dec_nll = self.nll_loss(dec_outs, labels) pi_kl = self.cat_kl_loss(log_qy.detach(), log_py, batch_size, unit_average=True) pi_nll = F.cross_entropy(py_logits.view(-1, self.config.k), y_id.view(-1)) _, max_pi = torch.max(py_logits.view(-1, self.config.k), dim=1) pi_err = torch.mean((max_pi != y_id.view(-1)).float()) if self.config.use_attribute: pad_embeded = self.x_embedding.weight[self.rev_vocab[PAD]].view( 1, 1, self.config.embed_size) pad_embeded = pad_embeded.repeat(batch_size, dec_outs.size(1), 1) mask = torch.sign(labels).float().unsqueeze(2) dec_out_p = torch.exp(dec_outs.view(-1, self.vocab_size)) dec_out_embedded = torch.matmul(dec_out_p, self.x_embedding.weight) dec_out_embedded = dec_out_embedded.view(-1, dec_outs.size(1), self.config.embed_size) valid_out_embedded = mask * dec_out_embedded + (1.0 - mask) * pad_embeded x_outs, x_last = self.x_encoder(valid_out_embedded) x_last = x_last.transpose(0, 1).contiguous(). \ view(-1, self.config.dec_cell_size) qy_logits = self.q_y(x_last).view(-1, self.config.k) attribute_outs = F.log_softmax(qy_logits, dim=qy_logits.dim() - 1) attribute_outs = attribute_outs.view(-1, self.config.y_size, self.config.k) attribute_nll = self.attribute_loss(attribute_outs, y_id.detach()) _, max_qy = torch.max(qy_logits.view(-1, self.config.k), dim=1) adv_err = torch.mean((max_qy != y_id.view(-1)).float()) else: attribute_nll = None adv_err = None results = Pack(nll=enc_dec_nll, pi_kl=pi_kl, pi_nll=pi_nll, attribute_nll=attribute_nll, vst_prev_nll=vst_prev_nll, vst_next_nll=vst_next_nll, reg_kl=reg_kl, mi=mi, pi_err=pi_err, adv_err=adv_err) if return_latent: results['log_py'] = log_py results['log_qy'] = log_qy results['dec_init_state'] = dec_init_state results['y_ids'] = y_id return results def forward_debug(self, data_feed, mode, sample_n=1, gen_type='greedy', return_latent=False): ctx_lens = data_feed['context_lens'] batch_size = len(ctx_lens) ctx_utts = self.np2var(data_feed['contexts'], LONG) out_utts = self.np2var(data_feed['outputs'], LONG) prev_utts = self.np2var(data_feed['prevs'], LONG) next_utts = self.np2var(data_feed['nexts'], LONG) # context encoder c_inputs = self.utt_encoder(ctx_utts) # print(c_inputs.shape) # c_outs, c_last = self.ctx_encoder(c_inputs, ctx_lens) # c_last = c_last.squeeze(0) c_last = c_inputs.squeeze(1) # print(c_last.shape) # prior network dec_init_state = c_last.unsqueeze(0) # decode dec_outs, dec_last, dec_ctx = self.decoder(batch_size, out_utts[:, 0:-1], dec_init_state, attn_context=None, mode=mode, gen_type=gen_type, beam_size=self.config.beam_size) # get decoder inputs labels = out_utts[:, 1:].contiguous() # compute loss or return results if mode == GEN: return dec_ctx, labels else: # Encoder-decoder Losses enc_dec_nll = self.nll_loss(dec_outs, labels) results = Pack(nll=enc_dec_nll) if return_latent: results['dec_init_state'] = dec_init_state return results
class VAE(LAED): def __init__(self, corpus, config): super(VAE, self).__init__(config) self.vocab = corpus.vocab self.rev_vocab = corpus.rev_vocab self.vocab_size = len(self.vocab) self.go_id = self.rev_vocab[BOS] self.eos_id = self.rev_vocab[EOS] if not hasattr(config, "freeze_step"): config.freeze_step = 6000 # build model here # word embeddings self.x_embedding = nn.Embedding(self.vocab_size, config.embed_size) # latent action learned self.x_encoder = EncoderRNN(config.embed_size, config.dec_cell_size, dropout_p=config.dropout, rnn_cell=config.rnn_cell, variable_lengths=False) self.q_y = nn.Linear(config.dec_cell_size, config.y_size * config.k * 2) self.x_init_connector = nn_lib.LinearConnector( config.y_size * config.k, config.dec_cell_size, config.rnn_cell == 'lstm') # decoder self.x_decoder = DecoderRNN(self.vocab_size, config.max_dec_len, config.embed_size, config.dec_cell_size, self.go_id, self.eos_id, n_layers=1, rnn_cell=config.rnn_cell, input_dropout_p=config.dropout, dropout_p=config.dropout, use_attention=False, use_gpu=config.use_gpu, embedding=self.x_embedding) # Encoder-Decoder STARTS here self.embedding = nn.Embedding(self.vocab_size, config.embed_size, padding_idx=self.rev_vocab[PAD]) self.utt_encoder = RnnUttEncoder( config.utt_cell_size, config.dropout, use_attn=config.utt_type == 'attn_rnn', vocab_size=self.vocab_size, embedding=self.embedding) self.ctx_encoder = EncoderRNN(self.utt_encoder.output_size, config.ctx_cell_size, 0.0, config.dropout, config.num_layer, config.rnn_cell, variable_lengths=self.config.fix_batch) # FNN to get Y self.p_fc1 = nn.Linear(config.ctx_cell_size, config.ctx_cell_size) self.p_y = nn.Linear(config.ctx_cell_size, config.y_size * config.k) self.z_mu = self.np2var( np.zeros((self.config.batch_size, config.y_size * config.k)), FLOAT) self.z_logvar = self.np2var( np.zeros((self.config.batch_size, config.y_size * config.k)), FLOAT) # connector self.c_init_connector = nn_lib.LinearConnector( config.y_size * config.k, config.dec_cell_size, config.rnn_cell == 'lstm') # decoder self.decoder = DecoderRNN(self.vocab_size, config.max_dec_len, config.embed_size, config.dec_cell_size, self.go_id, self.eos_id, n_layers=1, rnn_cell=config.rnn_cell, input_dropout_p=config.dropout, dropout_p=config.dropout, use_attention=config.use_attn, attn_size=config.dec_cell_size, attn_mode=config.attn_type, use_gpu=config.use_gpu, embedding=self.embedding) # force G(z,c) has z if config.use_attribute: self.attribute_loss = criterions.NLLEntropy(-100, config) self.cat_connector = nn_lib.GumbelConnector() self.greedy_cat_connector = nn_lib.GreedyConnector() self.gaussian_connector = nn_lib.GaussianConnector() self.nll_loss = criterions.NLLEntropy(self.rev_vocab[PAD], self.config) self.kl_loss = criterions.NormKLLoss() self.log_uniform_y = Variable(torch.log(torch.ones(1) / config.k)) self.entropy_loss = criterions.Entropy() if self.use_gpu: self.log_uniform_y = self.log_uniform_y.cuda() self.kl_w = 0.0 def valid_loss(self, loss, batch_cnt=None): vae_loss = loss.vae_nll + loss.reg_kl enc_loss = loss.nll if self.config.use_attribute: enc_loss += 0.1 * loss.attribute_nll if batch_cnt is not None and batch_cnt > self.config.freeze_step: total_loss = enc_loss if self.kl_w == 0.0: self.kl_w = 1.0 self.flush_valid = True for param in self.x_embedding.parameters(): param.requires_grad = False for param in self.x_encoder.parameters(): param.requires_grad = False for param in self.q_y.parameters(): param.requires_grad = False for param in self.x_init_connector.parameters(): param.requires_grad = False for param in self.x_decoder.parameters(): param.requires_grad = False else: total_loss = vae_loss return total_loss def qzx_forward(self, out_utts): # output encoder output_embedding = self.x_embedding(out_utts) x_outs, x_last = self.x_encoder(output_embedding) x_last = x_last.transpose(0, 1).contiguous().view( -1, self.config.dec_cell_size) qy_logits = self.q_y(x_last) mu, logvar = torch.split(qy_logits, self.config.k, dim=-1) sample_y = self.gaussian_connector(mu, logvar, use_gpu=self.config.use_gpu) return Pack(qy_logits=qy_logits, mu=mu, logvar=logvar, sample_y=sample_y) def pxz_forward(self, batch_size, results, out_utts, mode, gen_type): # map sample to initial state of decoder dec_init_state = self.x_init_connector(results.sample_y) dec_outs, dec_last, dec_ctx = self.x_decoder( batch_size, out_utts[:, 0:-1], dec_init_state, mode=mode, gen_type=gen_type, beam_size=self.config.beam_size) results['dec_outs'] = dec_outs results['dec_ctx'] = dec_ctx return results def forward(self, data_feed, mode, sample_n=1, gen_type='greedy', return_latent=False): ctx_lens = data_feed['context_lens'] batch_size = len(ctx_lens) ctx_utts = self.np2var(data_feed['contexts'], LONG) out_utts = self.np2var(data_feed['outputs'], LONG) # First do VAE here vae_resp = self.pxz_forward(batch_size, self.qzx_forward(out_utts[:, 1:]), out_utts, mode, gen_type) qy_mu, qy_logvar = vae_resp.mu, vae_resp.logvar # context encoder c_inputs = self.utt_encoder(ctx_utts) c_outs, c_last = self.ctx_encoder(c_inputs, ctx_lens) c_last = c_last.squeeze(0) # prior network # py_logits = self.p_y(F.tanh(self.p_fc1(c_last))).view(-1, self.config.k * 2) # log_py = F.log_softmax(py_logits, dim=py_logits.dim()-1) if mode != GEN: sample_y = vae_resp.sample_y.detach() else: if sample_n > 1: if gen_type == 'greedy': temp = [] temp_ids = [] # sample the prior network N times for i in range(sample_n): temp_y = self.gaussian_connector( qy_mu, qy_logvar, self.use_gpu) temp.append( temp_y.view(-1, self.config.k * self.config.y_size)) sample_y = torch.cat(temp, dim=0) batch_size *= sample_n c_last = c_last.repeat(sample_n, 1) elif gen_type == 'sample': sample_y = self.gaussian_connector(qy_mu, qy_logvar, self.use_gpu) sample_y = sample_y.view( -1, self.config.k * self.config.y_size).repeat( sample_n, 1) c_last = c_last.repeat(sample_n, 1) batch_size *= sample_n else: raise ValueError else: sample_y = self.gaussian_connector(qy_mu, qy_logvar, self.use_gpu) # pack attention context if self.config.use_attn: attn_inputs = c_outs else: attn_inputs = None # map sample to initial state of decoder sample_y = sample_y.view(-1, self.config.k * self.config.y_size) dec_init_state = self.c_init_connector(sample_y) + c_last.unsqueeze(0) # decode dec_outs, dec_last, dec_ctx = self.decoder( batch_size, out_utts[:, 0:-1], dec_init_state, attn_context=attn_inputs, mode=mode, gen_type=gen_type, beam_size=self.config.beam_size) # get decoder inputs labels = out_utts[:, 1:].contiguous() dec_ctx[DecoderRNN.KEY_RECOG_LATENT] = None dec_ctx[DecoderRNN.KEY_LATENT] = None dec_ctx[DecoderRNN.KEY_POLICY] = None # compute loss or return results if mode == GEN: return dec_ctx, labels else: # VAE-related Losses log_qy = F.log_softmax(vae_resp.qy_logits, dim=1) vae_nll = self.nll_loss(vae_resp.dec_outs, labels) avg_log_qy = torch.exp( log_qy.view(-1, self.config.y_size, self.config.k)) avg_log_qy = torch.log(torch.mean(avg_log_qy, dim=0) + 1e-15) reg_kl = self.kl_loss(qy_mu, qy_logvar, self.z_mu, self.z_logvar) # Encoder-decoder Losses enc_dec_nll = self.nll_loss(dec_outs, labels) if self.config.use_attribute: pad_embeded = self.x_embedding.weight[ self.rev_vocab[PAD]].view(1, 1, self.config.embed_size) pad_embeded = pad_embeded.repeat(batch_size, dec_outs.size(1), 1) mask = torch.sign(labels).float().unsqueeze(2) dec_out_p = torch.exp(dec_outs.view(-1, self.vocab_size)) dec_out_embedded = torch.matmul(dec_out_p, self.x_embedding.weight) dec_out_embedded = dec_out_embedded.view( -1, dec_outs.size(1), self.config.embed_size) valid_out_embedded = mask * dec_out_embedded + ( 1.0 - mask) * pad_embeded x_outs, x_last = self.x_encoder(valid_out_embedded) x_last = x_last.transpose(0, 1).contiguous().view( -1, self.config.dec_cell_size) qy_logits = self.q_y(x_last).view(-1, self.config.k) attribute_nll = F.cross_entropy(qy_logits, y_id.view(-1).detach()) _, max_qy = torch.max(qy_logits.view(-1, self.config.k), dim=1) adv_err = torch.mean((max_qy != y_id.view(-1)).float()) else: attribute_nll = None adv_err = None results = Pack(nll=enc_dec_nll, attribute_nll=attribute_nll, vae_nll=vae_nll, reg_kl=reg_kl, adv_err=adv_err) if return_latent: results['sample_y'] = sample_y results['dec_init_state'] = dec_init_state return results