def train_D(self, context, context_lens, utt_lens, floors, response, res_lens): self.context_encoder.eval() self.discriminator.train() self.optimizer_D.zero_grad() batch_size = context.size(0) c = self.context_encoder(context, context_lens, utt_lens, floors) x, _ = self.utt_encoder(response[:, 1:], res_lens - 1) post_z = self.sample_code_post(x, c) errD_post = torch.mean( self.discriminator(torch.cat((post_z.detach(), c.detach()), 1))) errD_post.backward(self.one) prior_z = self.sample_code_prior(c) errD_prior = torch.mean( self.discriminator(torch.cat((prior_z.detach(), c.detach()), 1))) errD_prior.backward(self.minus_one) alpha = gData(torch.rand(batch_size, 1)) alpha = alpha.expand(prior_z.size()) interpolates = alpha * prior_z.data + ((1 - alpha) * post_z.data) interpolates = Variable(interpolates, requires_grad=True) d_input = torch.cat((interpolates, c.detach()), 1) disc_interpolates = torch.mean(self.discriminator(d_input)) gradients = torch.autograd.grad( outputs=disc_interpolates, inputs=interpolates, grad_outputs=gData(torch.ones(disc_interpolates.size())), create_graph=True, retain_graph=True, only_inputs=True)[0] gradient_penalty = ( (gradients.contiguous().view(gradients.size(0), -1).norm(2, dim=1) - 1)**2).mean() * self.lambda_gp gradient_penalty.backward() self.optimizer_D.step() costD = -(errD_prior - errD_post) + gradient_penalty return [('train_loss_D', costD.item())]
def evaluate(model, metrics, test_loader, ivocab, vocab, repeat, PAD_token=0): recall_bleus, prec_bleus, bows_extrema, bows_avg, bows_greedy, intra_dist1s, intra_dist2s, \ avg_lens, inter_dist1s, inter_dist2s = [], [], [], [], [], [], [], [], [], [] local_t = 0 model.eval() pbar = tqdm(range(test_loader.num_batch)) for bat in pbar: batch = test_loader.next_batch() if bat == test_loader.num_batch: break # end of epoch local_t += 1 context, context_lens, utt_lens, floors, _, _, _, response, res_lens, _ = batch # remove the sos token in the context and reduce the context length context, utt_lens = context[:, :, 1:], utt_lens - 1 if local_t % 2000 == 0: logging.info("Batch %d \n" % (local_t)) # print the context start = np.maximum(0, context_lens[0] - 5) for t_id in range(start, context.shape[1], 1): context_str = indexes2sent(context[0, t_id], ivocab, ivocab["</s>"], PAD_token) if local_t % 2000 == 0: logging.info("Context %d-%d: %s\n" % (t_id, floors[0, t_id], context_str)) # print the true outputs ref_str, _ = indexes2sent(response[0], ivocab, ivocab["</s>"], ivocab["<s>"]) ref_tokens = ref_str.split(' ') if local_t % 2000 == 0: logging.info("Target >> %s\n" % (ref_str.replace(" ' ", "'"))) context, context_lens, utt_lens, floors = gVar(context), gVar( context_lens), gVar(utt_lens), gData(floors) sample_words, sample_lens = model.sample(context, context_lens, utt_lens, floors, repeat, ivocab["<s>"], ivocab["</s>"]) # nparray: [repeat x seq_len] pred_sents, _ = indexes2sent(sample_words, ivocab, ivocab["</s>"], PAD_token) pred_tokens = [sent.split(' ') for sent in pred_sents] for r_id, pred_sent in enumerate(pred_sents): if local_t % 2000 == 0: logging.info("Sample %d >> %s\n" % (r_id, pred_sent.replace(" ' ", "'"))) max_bleu, avg_bleu = metrics.sim_bleu(pred_tokens, ref_tokens) recall_bleus.append(max_bleu) prec_bleus.append(avg_bleu) bow_extrema, bow_avg, bow_greedy = metrics.sim_bow( sample_words, sample_lens, response[:, 1:], res_lens - 2) bows_extrema.append(bow_extrema) bows_avg.append(bow_avg) bows_greedy.append(bow_greedy) intra_dist1, intra_dist2, inter_dist1, inter_dist2 = metrics.div_distinct( sample_words, sample_lens) intra_dist1s.append(intra_dist1) intra_dist2s.append(intra_dist2) avg_lens.append(np.mean(sample_lens)) inter_dist1s.append(inter_dist1) inter_dist2s.append(inter_dist2) recall_bleu = float(np.mean(recall_bleus)) prec_bleu = float(np.mean(prec_bleus)) f1 = 2 * (prec_bleu * recall_bleu) / (prec_bleu + recall_bleu + 10e-12) bow_extrema = float(np.mean(bows_extrema)) bow_avg = float(np.mean(bows_avg)) bow_greedy = float(np.mean(bows_greedy)) intra_dist1 = float(np.mean(intra_dist1s)) intra_dist2 = float(np.mean(intra_dist2s)) avg_len = float(np.mean(avg_lens)) inter_dist1 = float(np.mean(inter_dist1s)) inter_dist2 = float(np.mean(inter_dist2s)) report = "Avg recall BLEU %f, avg precision BLEU %f, F1 %f, \nbow_avg %f, bow_extrema %f, bow_greedy %f, \n" \ "intra_dist1 %f, intra_dist2 %f, inter_dist1 %f, inter_dist2 %f, \navg_len %f" \ % (recall_bleu, prec_bleu, f1, bow_avg, bow_extrema, bow_greedy, intra_dist1, intra_dist2, inter_dist1, inter_dist2, avg_len) print(report) logging.info(report + "\n") print("Done testing") model.train() return recall_bleu, prec_bleu, bow_extrema, bow_avg, bow_greedy, intra_dist1, intra_dist2, avg_len, inter_dist1, inter_dist2
def __init__(self, config, vocab_size, PAD_token=0): super(DialogWAE, self).__init__() self.vocab_size = vocab_size self.maxlen = config['maxlen'] self.clip = config['clip'] self.lambda_gp = config['lambda_gp'] self.temp = config['temp'] self.embedder = nn.Embedding(vocab_size, config['emb_size'], padding_idx=PAD_token) self.utt_encoder = Encoder(self.embedder, config['emb_size'], config['n_hidden'], True, config['n_layers'], config['noise_radius']) self.context_encoder = ContextEncoder(self.utt_encoder, config['n_hidden'] * 2 + 2, config['n_hidden'], 1, config['noise_radius']) self.prior_net = Variation(config['n_hidden'], config['z_size']) # p(e|c) self.post_net = Variation(config['n_hidden'] * 3, config['z_size']) # q(e|c,x) self.post_generator = nn.Sequential( nn.Linear(config['z_size'], config['z_size']), nn.BatchNorm1d(config['z_size'], eps=1e-05, momentum=0.1), nn.ReLU(), nn.Linear(config['z_size'], config['z_size']), nn.BatchNorm1d(config['z_size'], eps=1e-05, momentum=0.1), nn.ReLU(), nn.Linear(config['z_size'], config['z_size'])) self.post_generator.apply(self.init_weights) self.prior_generator = nn.Sequential( nn.Linear(config['z_size'], config['z_size']), nn.BatchNorm1d(config['z_size'], eps=1e-05, momentum=0.1), nn.ReLU(), nn.Linear(config['z_size'], config['z_size']), nn.BatchNorm1d(config['z_size'], eps=1e-05, momentum=0.1), nn.ReLU(), nn.Linear(config['z_size'], config['z_size'])) self.prior_generator.apply(self.init_weights) self.decoder = Decoder(self.embedder, config['emb_size'], config['n_hidden'] + config['z_size'], vocab_size, n_layers=1) self.discriminator = nn.Sequential( nn.Linear(config['n_hidden'] + config['z_size'], config['n_hidden'] * 2), nn.BatchNorm1d(config['n_hidden'] * 2, eps=1e-05, momentum=0.1), nn.LeakyReLU(0.2), nn.Linear(config['n_hidden'] * 2, config['n_hidden'] * 2), nn.BatchNorm1d(config['n_hidden'] * 2, eps=1e-05, momentum=0.1), nn.LeakyReLU(0.2), nn.Linear(config['n_hidden'] * 2, 1), ) self.discriminator.apply(self.init_weights) self.optimizer_AE = optim.SGD(list(self.context_encoder.parameters()) + list(self.post_net.parameters()) + list(self.post_generator.parameters()) + list(self.decoder.parameters()), lr=config['lr_ae']) self.optimizer_G = optim.RMSprop( list(self.post_net.parameters()) + list(self.post_generator.parameters()) + list(self.prior_net.parameters()) + list(self.prior_generator.parameters()), lr=config['lr_gan_g']) self.optimizer_D = optim.RMSprop(self.discriminator.parameters(), lr=config['lr_gan_d']) self.lr_scheduler_AE = optim.lr_scheduler.StepLR(self.optimizer_AE, step_size=10, gamma=0.6) self.criterion_ce = nn.CrossEntropyLoss() self.one = gData(torch.FloatTensor([1])) self.minus_one = self.one * -1
n_iters = train_loader.num_batch / max(1, config['n_iters_d']) itr = 1 pbar = tqdm(range(train_loader.num_batch)) for bat in pbar: model.train() loss_records = [] batch = train_loader.next_batch() if bat == train_loader.num_batch: break # end of epoch context, context_lens, utt_lens, floors, _, _, _, response, res_lens, _ = batch # remove the sos token in the context and reduce the context length context, utt_lens = context[:, :, 1:], utt_lens - 1 context, context_lens, utt_lens, floors, response, res_lens \ = gVar(context), gVar(context_lens), gVar(utt_lens), gData(floors), gVar(response), gVar(res_lens) loss_AE = model.train_AE(context, context_lens, utt_lens, floors, response, res_lens) loss_records.extend(loss_AE) loss_G = model.train_G(context, context_lens, utt_lens, floors, response, res_lens) loss_records.extend(loss_G) for i in range(config['n_iters_d']): # train discriminator/critic loss_D = model.train_D(context, context_lens, utt_lens, floors, response, res_lens) if i == 0: loss_records.extend(loss_D) if i == config['n_iters_d'] - 1: