def train(epoch): model_encoder.eval() model_decoder.eval() generator.train() critic.train() c_train_loss = 0. g_train_loss = 0. g_batches = 0 for i, x in enumerate(train_loader): x = x[0] if args.cuda: x = x.cuda() # Generate noise B = args.per_gpu_train_batch_size c_optimizer.zero_grad() noise = torch.from_numpy(np.random.normal( 0, 1, (B, args.latent_size))).float() if args.cuda: noise = noise.cuda() # Get original text latent embeddings with torch.no_grad(): pooled_hidden_fea = model_encoder( x, attention_mask=(x > 0).float())[1] mean, logvar = model_encoder.linear(pooled_hidden_fea).chunk(2, -1) z_real = mean.squeeze(1) # train critic z_fake = generator(noise) real_score = critic(z_real) fake_score = critic(z_fake) grad_penalty = compute_grad_penalty(critic, z_real.data, z_fake.data) c_loss = -torch.mean(real_score) + torch.mean(fake_score) + \ args.gp_lambda*grad_penalty c_train_loss += c_loss.item() c_loss.backward() c_optimizer.step() # train generator if i % args.n_critic == 0: g_batches += 1 g_optimizer.zero_grad() fake_score = critic(generator(noise)) g_loss = -torch.mean(fake_score) g_train_loss += g_loss.item() g_loss.backward() g_optimizer.step() if args.interval > 0 and i % args.interval == 0: logger.info( 'Epoch: {} | Batch: {}/{} ({:.0f}%) | G Loss: {:.6f} | C Loss: {:.6f}' .format( epoch, args.batch_size * i, len(train_loader.dataset), 100. * (args.batch_size * i) / len(train_loader.dataset), g_loss.item(), c_loss.item())) test_noise = torch.Tensor( np.random.normal(0, 1, (1, args.latent_size))).to(args.device) test_new_z = generator(test_noise).data # create new sent test_z = rollout_test(model_decoder, test_new_z, tokenizer_decoder, args.max_seq_length, 1, 0, 1) logger.info("Text: {}".format(test_z)) g_train_loss /= g_batches c_train_loss /= len(train_loader) logger.info('* (Train) Epoch: {} | G Loss: {:.4f} | C Loss: {:.4f}'.format( epoch, g_train_loss, c_train_loss)) return (g_train_loss, c_train_loss)
def train(epoch): model_encoder.eval() model_decoder.eval() generator.train() critic.train() classifier.train() cl_train_loss = 0. c_train_loss = 0. g_train_loss = 0. g_batches = 0 c_batches = 0 c_loss_0 = 1 g_loss_0 = 1 for i, x in enumerate(train_loader): label = x[3] x = x[0] if args.cuda: x = x.cuda() # Generate noise and labels gen_labels = (torch.rand(args.per_gpu_train_batch_size, 1) * args.n_classes).type(torch.LongTensor) B = args.per_gpu_train_batch_size noise = torch.from_numpy(np.random.normal(0, 1, (B, args.latent_size))).float() if args.cuda: noise = noise.cuda() label = label.cuda() gen_labels = gen_labels.cuda() # Get original text latent embeddings with torch.no_grad(): pooled_hidden_fea = model_encoder(x, attention_mask=(x > 0).float())[1] mean, logvar = model_encoder.linear(pooled_hidden_fea).chunk(2, -1) z_real = mean.squeeze(1) # Evaluate and get losses z_fake = generator(noise, gen_labels) real_score = critic(z_real, label) fake_score = critic(z_fake, gen_labels) grad_penalty = compute_grad_penalty(critic, z_real.data, z_fake.data, label.data) pred_class = classifier(z_real) cl_lab = label.clone().squeeze_() # Classifier loss cl_optimizer.zero_grad() cl_loss = nn.CrossEntropyLoss()(pred_class.to(args.device), cl_lab) cl_train_loss += cl_loss.item() cl_loss.backward() cl_optimizer.step() # Train critic or generator c_loss = -torch.mean(real_score) + torch.mean(fake_score) + \ args.gp_lambda*grad_penalty fake_score = critic(generator(noise, gen_labels), gen_labels) pred_gen_class = classifier(generator(noise, gen_labels)).to(args.device) cl_gen_lab = gen_labels.clone().squeeze_() g_cl_loss = nn.CrossEntropyLoss()(pred_gen_class, cl_gen_lab) g_loss = -torch.mean(fake_score) + g_cl_loss * 10 r_g = abs(((g_loss.item() - g_loss_0) / (g_loss_0 + 0.001))) r_c = abs(((c_loss.item() - c_loss_0) / (c_loss_0 + 0.001))) if ((2 + epoch) / epoch) * r_c > r_g: c_optimizer.zero_grad() c_batches += 1 c_train_loss += c_loss.item() c_loss.backward() c_optimizer.step() else: g_optimizer.zero_grad() g_batches += 1 g_train_loss += g_loss.item() g_loss.backward() g_optimizer.step() c_loss_0 = c_loss.item() g_loss_0 = g_loss.item() if args.interval > 0 and i % args.interval == 0: logger.info('Epoch: {} | Batch: {}/{} ({:.0f}%) | G Loss: {:.6f} | C Loss: {:.6f} | Cl Loss: {:.6f}'.format( epoch, args.batch_size*i, len(train_loader.dataset), 100.*(args.batch_size*i)/len(train_loader.dataset), g_loss.item(), c_loss.item(), cl_loss.item() )) test_lab = (torch.rand(1, 1) * args.n_classes).type(torch.LongTensor).to(args.device) test_noise = torch.Tensor(np.random.normal(0, 1, (1, args.latent_size))).to(args.device) test_new_z = generator(test_noise, test_lab).data # create new sent test_z = rollout_test(model_decoder, test_new_z, tokenizer_decoder, args.max_seq_length, 1, 0, 1) logger.info("Label: {} | Text: {}".format(test_lab.item(), test_z)) c_train_loss /= c_batches + 1 g_train_loss /= g_batches + 1 logger.info('* (Train) Epoch: {} | G Loss: {:.4f} | C Loss: {:.4f} | Updates G: {} | Updates C: {}'.format( epoch, g_train_loss, c_train_loss, g_batches, c_batches )) return (g_train_loss, c_train_loss)
best_bleu = 0 reference = list() with (open(args.valid_data_file, "r")) as valid: for sents in valid: reference.append(sents.replace("\n", "")) for epoch in range(1, args.epochs + 1): g_loss, c_loss = train(epoch) data_test = list() for i in range(2): test_noise = torch.Tensor( np.random.normal(0, 1, (250, args.latent_size))).to(args.device) test_z = generator(test_noise).data new_sent = rollout_test(model_decoder, test_z, tokenizer_decoder, args.max_seq_length, 250, 0, 1) data_test.extend(new_sent) p_reference = random.sample(reference, 500) bleu = calc_blue_parallel_func(p_reference, data_test, 2, 500) b_bleu = calc_blue_parallel_func(data_test, p_reference, 2, 500) logger.info("Bleu-2:{:0.3f} | B-Bleu-2:{:0.3f}".format(bleu, b_bleu)) if (bleu + b_bleu) > best_bleu: best_bleu = bleu + b_bleu logger.info( '* Saving. Best Score:{:0.3f} | Bleu-2:{:0.3f} | B-Bleu-2:{:0.3f}' .format(best_bleu, bleu, b_bleu)) torch.save( generator.state_dict(), args.output_dir + '/generator_' + str(args.gloabl_step_eval) + '.th')
label = list() label.extend(args.new_sent * [int(args.generate_label)]) label = torch.LongTensor(label).to(args.device) # Get number of generation iterations for i in range(0, int(args.new_sent / args.batch_size)): # sample noise noise = torch.Tensor( np.random.normal(0, 1, (args.batch_size, args.latent_size))).to( args.device) new_z = generator(noise, label[i * args.batch_size:args.batch_size * (i + 1)]).data # create new sent sents = rollout_test(model_decoder, new_z, tokenizer_decoder, args.max_seq_length, args.batch_size, args.top_k, args.top_p) sents = [ str(lab) + " " + str(sen) for lab, sen in zip( label.tolist()[i * args.batch_size:args.batch_size * (i + 1)], sents) ] if args.save: with open(args.output_dir + "/{}.txt".format(args.output_name), 'a') as file: for i in sents: file.write(i + "\n") else: for i in sents: logger.info(i)