def __init__(self, config, vocab_size, PAD_token=0): super(DFVAE, self).__init__() self.vocab_size = vocab_size self.maxlen = config['maxlen'] self.clip = config['clip'] self.lambda_gp = config['lambda_gp'] self.temp = config['temp'] self.embedder = nn.Embedding(vocab_size, config['emb_size'], padding_idx=PAD_token) self.utt_encoder = Encoder(self.embedder, config['emb_size'], config['n_hidden'], True, config['n_layers'], config['noise_radius']) self.context_encoder = ContextEncoder(self.utt_encoder, config['n_hidden'] * 2 + 2, config['n_hidden'], 1, config['noise_radius']) self.prior_net = Variation(config['n_hidden'], config['z_size']) # p(e|c) self.post_net = Variation(config['n_hidden'] * 3, config['z_size']) # q(e|c,x) #self.prior_highway = nn.Linear(config['n_hidden'], config['n_hidden']) #self.post_highway = nn.Linear(config['n_hidden'] * 3, config['n_hidden']) self.postflow1 = flow.myIAF(config['z_size'], config['z_size'] * 2, config['n_hidden'], 3) self.postflow2 = flow.myIAF(config['z_size'], config['z_size'] * 2, config['n_hidden'], 3) self.postflow3 = flow.myIAF(config['z_size'], config['z_size'] * 2, config['n_hidden'], 3) self.priorflow1 = flow.IAF(config['z_size'], config['z_size'] * 2, config['n_hidden'], 3) self.priorflow2 = flow.IAF(config['z_size'], config['z_size'] * 2, config['n_hidden'], 3) self.priorflow3 = flow.IAF(config['z_size'], config['z_size'] * 2, config['n_hidden'], 3) self.post_generator = nn_.SequentialFlow(self.postflow1, self.postflow2, self.postflow3) self.prior_generator = nn_.SequentialFlow(self.priorflow1, self.priorflow2, self.priorflow3) self.decoder = Decoder(self.embedder, config['emb_size'], config['n_hidden'] + config['z_size'], vocab_size, n_layers=1) self.optimizer_AE = optim.SGD( list(self.context_encoder.parameters()) + list(self.post_net.parameters()) + list(self.post_generator.parameters()) + list(self.decoder.parameters()) + list(self.prior_net.parameters()) + list(self.prior_generator.parameters()) #+list(self.prior_highway.parameters()) #+list(self.post_highway.parameters()) , lr=config['lr_ae']) self.optimizer_G = optim.RMSprop( list(self.post_net.parameters()) + list(self.post_generator.parameters()) + list(self.prior_net.parameters()) + list(self.prior_generator.parameters()) #+list(self.prior_highway.parameters()) #+list(self.post_highway.parameters()) , lr=config['lr_gan_g']) #self.optimizer_D = optim.RMSprop(self.discriminator.parameters(), lr=config['lr_gan_d']) self.lr_scheduler_AE = optim.lr_scheduler.StepLR(self.optimizer_AE, step_size=10, gamma=0.6) self.criterion_ce = nn.CrossEntropyLoss()
def __init__(self, config, vocab_size, PAD_token=0): super(DialogWAE, self).__init__() self.vocab_size = vocab_size self.maxlen = config['maxlen'] self.clip = config['clip'] self.lambda_gp = config['lambda_gp'] self.temp = config['temp'] self.embedder = nn.Embedding(vocab_size, config['emb_size'], padding_idx=PAD_token) self.utt_encoder = Encoder(self.embedder, config['emb_size'], config['n_hidden'], True, config['n_layers'], config['noise_radius']) self.context_encoder = ContextEncoder(self.utt_encoder, config['n_hidden'] * 2 + 2, config['n_hidden'], 1, config['noise_radius']) self.prior_net = Variation(config['n_hidden'], config['z_size']) # p(e|c) self.post_net = Variation(config['n_hidden'] * 3, config['z_size']) # q(e|c,x) self.post_generator = nn.Sequential( nn.Linear(config['z_size'], config['z_size']), nn.BatchNorm1d(config['z_size'], eps=1e-05, momentum=0.1), nn.ReLU(), nn.Linear(config['z_size'], config['z_size']), nn.BatchNorm1d(config['z_size'], eps=1e-05, momentum=0.1), nn.ReLU(), nn.Linear(config['z_size'], config['z_size'])) self.post_generator.apply(self.init_weights) self.prior_generator = nn.Sequential( nn.Linear(config['z_size'], config['z_size']), nn.BatchNorm1d(config['z_size'], eps=1e-05, momentum=0.1), nn.ReLU(), nn.Linear(config['z_size'], config['z_size']), nn.BatchNorm1d(config['z_size'], eps=1e-05, momentum=0.1), nn.ReLU(), nn.Linear(config['z_size'], config['z_size'])) self.prior_generator.apply(self.init_weights) self.decoder = Decoder(self.embedder, config['emb_size'], config['n_hidden'] + config['z_size'], vocab_size, n_layers=1) self.discriminator = nn.Sequential( nn.Linear(config['n_hidden'] + config['z_size'], config['n_hidden'] * 2), nn.BatchNorm1d(config['n_hidden'] * 2, eps=1e-05, momentum=0.1), nn.LeakyReLU(0.2), nn.Linear(config['n_hidden'] * 2, config['n_hidden'] * 2), nn.BatchNorm1d(config['n_hidden'] * 2, eps=1e-05, momentum=0.1), nn.LeakyReLU(0.2), nn.Linear(config['n_hidden'] * 2, 1), ) self.discriminator.apply(self.init_weights) self.optimizer_AE = optim.SGD(list(self.context_encoder.parameters()) + list(self.post_net.parameters()) + list(self.post_generator.parameters()) + list(self.decoder.parameters()), lr=config['lr_ae']) self.optimizer_G = optim.RMSprop( list(self.post_net.parameters()) + list(self.post_generator.parameters()) + list(self.prior_net.parameters()) + list(self.prior_generator.parameters()), lr=config['lr_gan_g']) self.optimizer_D = optim.RMSprop(self.discriminator.parameters(), lr=config['lr_gan_d']) self.lr_scheduler_AE = optim.lr_scheduler.StepLR(self.optimizer_AE, step_size=10, gamma=0.6) self.criterion_ce = nn.CrossEntropyLoss()
class DFVAE(nn.Module): def __init__(self, config, vocab_size, PAD_token=0): super(DFVAE, self).__init__() self.vocab_size = vocab_size self.maxlen = config['maxlen'] self.clip = config['clip'] self.lambda_gp = config['lambda_gp'] self.temp = config['temp'] self.embedder = nn.Embedding(vocab_size, config['emb_size'], padding_idx=PAD_token) self.utt_encoder = Encoder(self.embedder, config['emb_size'], config['n_hidden'], True, config['n_layers'], config['noise_radius']) self.context_encoder = ContextEncoder(self.utt_encoder, config['n_hidden'] * 2 + 2, config['n_hidden'], 1, config['noise_radius']) self.prior_net = Variation(config['n_hidden'], config['z_size']) # p(e|c) self.post_net = Variation(config['n_hidden'] * 3, config['z_size']) # q(e|c,x) #self.prior_highway = nn.Linear(config['n_hidden'], config['n_hidden']) #self.post_highway = nn.Linear(config['n_hidden'] * 3, config['n_hidden']) self.postflow1 = flow.myIAF(config['z_size'], config['z_size'] * 2, config['n_hidden'], 3) self.postflow2 = flow.myIAF(config['z_size'], config['z_size'] * 2, config['n_hidden'], 3) self.postflow3 = flow.myIAF(config['z_size'], config['z_size'] * 2, config['n_hidden'], 3) self.priorflow1 = flow.IAF(config['z_size'], config['z_size'] * 2, config['n_hidden'], 3) self.priorflow2 = flow.IAF(config['z_size'], config['z_size'] * 2, config['n_hidden'], 3) self.priorflow3 = flow.IAF(config['z_size'], config['z_size'] * 2, config['n_hidden'], 3) self.post_generator = nn_.SequentialFlow(self.postflow1, self.postflow2, self.postflow3) self.prior_generator = nn_.SequentialFlow(self.priorflow1, self.priorflow2, self.priorflow3) self.decoder = Decoder(self.embedder, config['emb_size'], config['n_hidden'] + config['z_size'], vocab_size, n_layers=1) self.optimizer_AE = optim.SGD( list(self.context_encoder.parameters()) + list(self.post_net.parameters()) + list(self.post_generator.parameters()) + list(self.decoder.parameters()) + list(self.prior_net.parameters()) + list(self.prior_generator.parameters()) #+list(self.prior_highway.parameters()) #+list(self.post_highway.parameters()) , lr=config['lr_ae']) self.optimizer_G = optim.RMSprop( list(self.post_net.parameters()) + list(self.post_generator.parameters()) + list(self.prior_net.parameters()) + list(self.prior_generator.parameters()) #+list(self.prior_highway.parameters()) #+list(self.post_highway.parameters()) , lr=config['lr_gan_g']) #self.optimizer_D = optim.RMSprop(self.discriminator.parameters(), lr=config['lr_gan_d']) self.lr_scheduler_AE = optim.lr_scheduler.StepLR(self.optimizer_AE, step_size=10, gamma=0.6) self.criterion_ce = nn.CrossEntropyLoss() def init_weights(self, m): if isinstance(m, nn.Linear): m.weight.data.uniform_(-0.02, 0.02) m.bias.data.fill_(0) def sample_post(self, x, c): xc = torch.cat((x, c), 1) e, mu, log_s = self.post_net(xc) #h_post = self.post_highway(xc) z, det_f, _, _ = self.post_generator((e, torch.eye(e.shape[1]), c, mu)) #h_prior = self.prior_highway(c) tilde_z, det_g, _ = self.prior_generator((z, det_f, c)) return tilde_z, z, mu, log_s, det_f, det_g def sample_code_post(self, x, c): xc = torch.cat((x, c), 1) e, mu, log_s = self.post_net(xc) #h_post = self.post_highway(xc) z, det_f, _, _ = self.post_generator((e, torch.eye(e.shape[1]), c, mu)) #h_prior = self.prior_highway(c) tilde_z, det_g, _ = self.prior_generator((z, det_f, c)) return tilde_z, mu, log_s, det_f, det_g def sample_post2(self, x, c): xc = torch.cat((x, c), 1) e, mu, log_s = self.post_net(xc) #h_post = self.post_highway(xc) z, det_f, _, _ = self.post_generator((e, torch.eye(e.shape[1]), c, mu)) return e, mu, log_s, z, det_f def sample_code_prior(self, c): e, mu, log_s = self.prior_net(c) #z = self.prior_generator(e) #h_prior = self.prior_highway(c) #tilde_z, det_g, _ = self.prior_generator((e, 0, h_prior)) return e, mu, log_s #, det_g def sample_prior(self, c): e, mu, log_s = self.prior_net(c) #h_prior = self.prior_highway(c) z, det_prior, _ = self.prior_generator((e, 0, c)) return z, det_prior def train_AE(self, context, context_lens, utt_lens, floors, response, res_lens): self.context_encoder.train() self.decoder.train() c = self.context_encoder(context, context_lens, utt_lens, floors) x, _ = self.utt_encoder(response[:, 1:], res_lens - 1) z, _, _, _, _ = self.sample_code_post(x, c) z_post, mu_post, log_s_post, det_f, det_g = self.sample_code_post(x, c) #prior_z, mu_prior, log_s_prior = self.sample_code_prior(c) #KL_loss = torch.sum(log_s_prior - log_s_post + (torch.exp(log_s_post) + (mu_post - mu_prior)**2)/torch.exp(log_s_prior),1) / 2 - 100 #kloss = KL_loss - det_f #+ det_g #KL_loss = log_Normal_diag(z_post, mu_post, log_s_post) - log_Normal_diag(prior_z, mu_prior, log_s_prior) output = self.decoder(torch.cat((z_post, c), 1), None, response[:, :-1], (res_lens - 1)) flattened_output = output.view(-1, self.vocab_size) dec_target = response[:, 1:].contiguous().view(-1) mask = dec_target.gt(0) # [(batch_sz*seq_len)] masked_target = dec_target.masked_select(mask) # output_mask = mask.unsqueeze(1).expand( mask.size(0), self.vocab_size) # [(batch_sz*seq_len) x n_tokens] masked_output = flattened_output.masked_select(output_mask).view( -1, self.vocab_size) #print(KL_loss.mean()) #print(det_f.mean()) self.optimizer_AE.zero_grad() AE_term = self.criterion_ce(masked_output / self.temp, masked_target) loss = AE_term #+ KL_loss.mean() loss.backward() #torch.nn.utils.clip_grad_norm_(list(self.context_encoder.parameters())+list(self.decoder.parameters()), self.clip) torch.nn.utils.clip_grad_norm_( list(self.context_encoder.parameters()) + list(self.decoder.parameters()) + list(self.post_generator.parameters()) + list(self.prior_generator.parameters()) + list(self.post_net.parameters()), self.clip) self.optimizer_AE.step() return [ ('train_loss_AE', AE_term.item()) ] #,('KL_loss', KL_loss.mean().item())]#,('det_f', det_f.mean().item()),('det_g', det_g.mean().item())] def train_G(self, context, context_lens, utt_lens, floors, response, res_lens): self.context_encoder.eval() self.optimizer_G.zero_grad() c = self.context_encoder(context, context_lens, utt_lens, floors) # -----------------posterior samples --------------------------- x, _ = self.utt_encoder(response[:, 1:], res_lens - 1) z_0, mu_post, log_s_post, z_post, weight = self.sample_post2( x.detach(), c.detach()) # ----------------- prior samples --------------------------- prior_z, mu_prior, log_s_prior = self.sample_code_prior(c.detach()) KL_loss = torch.sum( log_s_prior - log_s_post + torch.exp(log_s_post) / torch.exp(log_s_prior) * torch.sum(weight**2, dim=2) + (mu_post)**2 / torch.exp(log_s_prior), 1) / 2 - 100 #KL_loss = abs(log_Normal_diag(z_0, mu_post, log_s_post) - log_Normal_diag(z_post, mu_prior, log_s_prior)) #KL_loss2 = torch.sum((prior_z - mu_post.detach())**2 / (2 * torch.exp(log_s_post.detach())),1) #print(mu_post.shape, prior_z.shape) loss = KL_loss #print(-det_f , KL_loss ) #loss = abs(loss) loss.mean().backward() torch.nn.utils.clip_grad_norm_( list(self.post_generator.parameters()) + list(self.prior_generator.parameters()) + list(self.post_net.parameters()) + list(self.prior_generator.parameters()), self.clip) self.optimizer_G.step() #costG = errG_prior - errG_post return [ ('KL_loss', KL_loss.mean().item()) ] #,('det_f', det_f.mean().item()),('det_g', det_g.sum().item())] def valid(self, context, context_lens, utt_lens, floors, response, res_lens): self.context_encoder.eval() #self.discriminator.eval() self.decoder.eval() c = self.context_encoder(context, context_lens, utt_lens, floors) x, _ = self.utt_encoder(response[:, 1:], res_lens - 1) post_z, mu_post, log_s_post, det_f, det_g = self.sample_code_post(x, c) prior_z, mu_prior, log_s_prior = self.sample_code_prior(c) #errD_post = torch.mean(self.discriminator(torch.cat((post_z, c),1))) #errD_prior = torch.mean(self.discriminator(torch.cat((prior_z, c),1))) KL_loss = torch.sum( log_s_prior - log_s_post + (torch.exp(log_s_post) + (mu_post)**2) / torch.exp(log_s_prior), 1) / 2 #KL_loss = log_Normal_diag(post_z, mu_post, log_s_post) - log_Normal_diag(prior_z, mu_prior, log_s_prior) #KL_loss2 = torch.sum((prior_z - mu_post)**2 / (2 * torch.exp(log_s_post)),1) loss = KL_loss # -det_f costG = loss.sum() dec_target = response[:, 1:].contiguous().view(-1) mask = dec_target.gt(0) # [(batch_sz*seq_len)] masked_target = dec_target.masked_select(mask) output_mask = mask.unsqueeze(1).expand(mask.size(0), self.vocab_size) output = self.decoder(torch.cat((post_z, c), 1), None, response[:, :-1], (res_lens - 1)) flattened_output = output.view(-1, self.vocab_size) masked_output = flattened_output.masked_select(output_mask).view( -1, self.vocab_size) lossAE = self.criterion_ce(masked_output / self.temp, masked_target) return [('valid_loss_AE', lossAE.item()), ('valid_loss_G', costG.item())] def sample(self, context, context_lens, utt_lens, floors, repeat, SOS_tok, EOS_tok): self.context_encoder.eval() self.decoder.eval() c = self.context_encoder(context, context_lens, utt_lens, floors) c_repeated = c.expand(repeat, -1) prior_z, _ = self.sample_prior(c_repeated) sample_words, sample_lens = self.decoder.sampling( torch.cat((prior_z, c_repeated), 1), None, self.maxlen, SOS_tok, EOS_tok, "greedy") return sample_words, sample_lens def gen(self, context, prior_z, context_lens, utt_lens, floors, repeat, SOS_tok, EOS_tok): self.context_encoder.eval() self.decoder.eval() c = self.context_encoder(context, context_lens, utt_lens, floors) c_repeated = c.expand(repeat, -1) sample_words, sample_lens = self.decoder.sampling( torch.cat((prior_z, c_repeated), 1), None, self.maxlen, SOS_tok, EOS_tok, "greedy") return sample_words, sample_lens def sample_latent(self, context, context_lens, utt_lens, floors, repeat, SOS_tok, EOS_tok): self.context_encoder.eval() #self.decoder.eval() c = self.context_encoder(context, context_lens, utt_lens, floors) c_repeated = c.expand(repeat, -1) e, _, _ = self.sample_code_prior(c_repeated) prior_z, _, _ = self.prior_generator((e, 0, c_repeated)) return prior_z, e def sample_latent_post(self, context, context_lens, utt_lens, floors, response, res_lens, repeat): self.context_encoder.eval() c = self.context_encoder(context, context_lens, utt_lens, floors) x, _ = self.utt_encoder(response[:, 1:], res_lens - 1) c_repeated = c.expand(repeat, -1) x_repeated = x.expand(repeat, -1) z_post, z, mu_post, log_s_post, det_f, det_g = self.sample_post( x_repeated, c_repeated) return z_post, z def adjust_lr(self): self.lr_scheduler_AE.step()
class DialogWAE(nn.Module): def __init__(self, config, vocab_size, PAD_token=0): super(DialogWAE, self).__init__() self.vocab_size = vocab_size self.maxlen = config['maxlen'] self.clip = config['clip'] self.lambda_gp = config['lambda_gp'] self.temp = config['temp'] self.embedder = nn.Embedding(vocab_size, config['emb_size'], padding_idx=PAD_token) self.utt_encoder = Encoder(self.embedder, config['emb_size'], config['n_hidden'], True, config['n_layers'], config['noise_radius']) self.context_encoder = ContextEncoder(self.utt_encoder, config['n_hidden'] * 2 + 2, config['n_hidden'], 1, config['noise_radius']) self.prior_net = Variation(config['n_hidden'], config['z_size']) # p(e|c) self.post_net = Variation(config['n_hidden'] * 3, config['z_size']) # q(e|c,x) self.post_generator = nn.Sequential( nn.Linear(config['z_size'], config['z_size']), nn.BatchNorm1d(config['z_size'], eps=1e-05, momentum=0.1), nn.ReLU(), nn.Linear(config['z_size'], config['z_size']), nn.BatchNorm1d(config['z_size'], eps=1e-05, momentum=0.1), nn.ReLU(), nn.Linear(config['z_size'], config['z_size'])) self.post_generator.apply(self.init_weights) self.prior_generator = nn.Sequential( nn.Linear(config['z_size'], config['z_size']), nn.BatchNorm1d(config['z_size'], eps=1e-05, momentum=0.1), nn.ReLU(), nn.Linear(config['z_size'], config['z_size']), nn.BatchNorm1d(config['z_size'], eps=1e-05, momentum=0.1), nn.ReLU(), nn.Linear(config['z_size'], config['z_size'])) self.prior_generator.apply(self.init_weights) self.decoder = Decoder(self.embedder, config['emb_size'], config['n_hidden'] + config['z_size'], vocab_size, n_layers=1) self.discriminator = nn.Sequential( nn.Linear(config['n_hidden'] + config['z_size'], config['n_hidden'] * 2), nn.BatchNorm1d(config['n_hidden'] * 2, eps=1e-05, momentum=0.1), nn.LeakyReLU(0.2), nn.Linear(config['n_hidden'] * 2, config['n_hidden'] * 2), nn.BatchNorm1d(config['n_hidden'] * 2, eps=1e-05, momentum=0.1), nn.LeakyReLU(0.2), nn.Linear(config['n_hidden'] * 2, 1), ) self.discriminator.apply(self.init_weights) self.optimizer_AE = optim.SGD(list(self.context_encoder.parameters()) + list(self.post_net.parameters()) + list(self.post_generator.parameters()) + list(self.decoder.parameters()), lr=config['lr_ae']) self.optimizer_G = optim.RMSprop( list(self.post_net.parameters()) + list(self.post_generator.parameters()) + list(self.prior_net.parameters()) + list(self.prior_generator.parameters()), lr=config['lr_gan_g']) self.optimizer_D = optim.RMSprop(self.discriminator.parameters(), lr=config['lr_gan_d']) self.lr_scheduler_AE = optim.lr_scheduler.StepLR(self.optimizer_AE, step_size=10, gamma=0.6) self.criterion_ce = nn.CrossEntropyLoss() def init_weights(self, m): if isinstance(m, nn.Linear): m.weight.data.uniform_(-0.02, 0.02) m.bias.data.fill_(0) def sample_code_post(self, x, c): e, _, _ = self.post_net(torch.cat((x, c), 1)) z = self.post_generator(e) return z def sample_code_prior(self, c): e, _, _ = self.prior_net(c) z = self.prior_generator(e) return z def train_AE(self, context, context_lens, utt_lens, floors, response, res_lens): self.context_encoder.train() self.decoder.train() c = self.context_encoder(context, context_lens, utt_lens, floors) x, _ = self.utt_encoder(response[:, 1:], res_lens - 1) z = self.sample_code_post(x, c) output = self.decoder(torch.cat((z, c), 1), None, response[:, :-1], (res_lens - 1)) flattened_output = output.view(-1, self.vocab_size) dec_target = response[:, 1:].contiguous().view(-1) mask = dec_target.gt(0) # [(batch_sz*seq_len)] masked_target = dec_target.masked_select(mask) # output_mask = mask.unsqueeze(1).expand( mask.size(0), self.vocab_size) # [(batch_sz*seq_len) x n_tokens] masked_output = flattened_output.masked_select(output_mask).view( -1, self.vocab_size) self.optimizer_AE.zero_grad() loss = self.criterion_ce(masked_output / self.temp, masked_target) loss.backward() torch.nn.utils.clip_grad_norm_( list(self.context_encoder.parameters()) + list(self.decoder.parameters()), self.clip) self.optimizer_AE.step() return [('train_loss_AE', loss.item())] def train_G(self, context, context_lens, utt_lens, floors, response, res_lens): self.context_encoder.eval() self.optimizer_G.zero_grad() for p in self.discriminator.parameters(): p.requires_grad = False c = self.context_encoder(context, context_lens, utt_lens, floors) # -----------------posterior samples --------------------------- x, _ = self.utt_encoder(response[:, 1:], res_lens - 1) z_post = self.sample_code_post(x.detach(), c.detach()) errG_post = torch.mean( self.discriminator(torch.cat((z_post, c.detach()), 1))) errG_post.backward(minus_one) # ----------------- prior samples --------------------------- prior_z = self.sample_code_prior(c.detach()) errG_prior = torch.mean( self.discriminator(torch.cat((prior_z, c.detach()), 1))) errG_prior.backward(one) self.optimizer_G.step() for p in self.discriminator.parameters(): p.requires_grad = True costG = errG_prior - errG_post return [('train_loss_G', costG.item())] def train_D(self, context, context_lens, utt_lens, floors, response, res_lens): self.context_encoder.eval() self.discriminator.train() self.optimizer_D.zero_grad() batch_size = context.size(0) c = self.context_encoder(context, context_lens, utt_lens, floors) x, _ = self.utt_encoder(response[:, 1:], res_lens - 1) post_z = self.sample_code_post(x, c) errD_post = torch.mean( self.discriminator(torch.cat((post_z.detach(), c.detach()), 1))) errD_post.backward(one) prior_z = self.sample_code_prior(c) errD_prior = torch.mean( self.discriminator(torch.cat((prior_z.detach(), c.detach()), 1))) errD_prior.backward(minus_one) alpha = gData(torch.rand(batch_size, 1)) alpha = alpha.expand(prior_z.size()) interpolates = alpha * prior_z.data + ((1 - alpha) * post_z.data) interpolates = Variable(interpolates, requires_grad=True) d_input = torch.cat((interpolates, c.detach()), 1) disc_interpolates = torch.mean(self.discriminator(d_input)) gradients = torch.autograd.grad( outputs=disc_interpolates, inputs=interpolates, grad_outputs=gData(torch.ones(disc_interpolates.size())), create_graph=True, retain_graph=True, only_inputs=True)[0] gradient_penalty = ( (gradients.contiguous().view(gradients.size(0), -1).norm(2, dim=1) - 1)**2).mean() * self.lambda_gp gradient_penalty.backward() self.optimizer_D.step() costD = -(errD_prior - errD_post) + gradient_penalty return [('train_loss_D', costD.item())] def valid(self, context, context_lens, utt_lens, floors, response, res_lens): self.context_encoder.eval() self.discriminator.eval() self.decoder.eval() c = self.context_encoder(context, context_lens, utt_lens, floors) x, _ = self.utt_encoder(response[:, 1:], res_lens - 1) post_z = self.sample_code_post(x, c) prior_z = self.sample_code_prior(c) errD_post = torch.mean(self.discriminator(torch.cat((post_z, c), 1))) errD_prior = torch.mean(self.discriminator(torch.cat((prior_z, c), 1))) costD = -(errD_prior - errD_post) costG = -costD dec_target = response[:, 1:].contiguous().view(-1) mask = dec_target.gt(0) # [(batch_sz*seq_len)] masked_target = dec_target.masked_select(mask) output_mask = mask.unsqueeze(1).expand(mask.size(0), self.vocab_size) output = self.decoder(torch.cat((post_z, c), 1), None, response[:, :-1], (res_lens - 1)) flattened_output = output.view(-1, self.vocab_size) masked_output = flattened_output.masked_select(output_mask).view( -1, self.vocab_size) lossAE = self.criterion_ce(masked_output / self.temp, masked_target) return [('valid_loss_AE', lossAE.item()), ('valid_loss_G', costG.item()), ('valid_loss_D', costD.item())] def sample(self, context, context_lens, utt_lens, floors, repeat, SOS_tok, EOS_tok): self.context_encoder.eval() self.decoder.eval() c = self.context_encoder(context, context_lens, utt_lens, floors) # encode context into embedding c_repeated = c.expand(repeat, -1) prior_z = self.sample_code_prior(c_repeated) # print(prior_z.shape) # print(prior_z) sample_words, sample_lens = self.decoder.sampling( torch.cat((prior_z, c_repeated), 1), None, self.maxlen, SOS_tok, EOS_tok, "greedy") return sample_words, sample_lens def adjust_lr(self): self.lr_scheduler_AE.step()