else: model.load_state_dict(torch.load('models/{}.bin' .format(args.model + utils.getModelName(args)), map_location=lambda storage, loc: storage) ) # Samples latent and conditional codes randomly from prior z = model.sample_z_prior(1) c = model.sample_c_prior(1) # Generate positive sample given z c[0, 0], c[0, 1] = 1, 0 _, c_idx = torch.max(c, dim=1) sample_idxs = model.sample_sentence(z, c, temp=0.1) print('\nSentiment: {}'.format(dataset.idx2label(int(c_idx)))) print('Generated: {}'.format(dataset.idxs2sentence(sample_idxs))) # Generate negative sample from the same z c[0, 0], c[0, 1] = 0, 1 _, c_idx = torch.max(c, dim=1) sample_idxs = model.sample_sentence(z, c, temp=0.8) print('\nSentiment: {}'.format(dataset.idx2label(int(c_idx)))) print('Generated: {}'.format(dataset.idxs2sentence(sample_idxs))) print()
loss.backward() grad_norm = torch.nn.utils.clip_grad_norm(model.vae_params, 5) trainer.step() trainer.zero_grad() # anneal kl_weight if ep > kld_start_inc and kld_weight < kld_max: kld_weight += kld_inc # print current state if ep % LOG_INTERVAL == 0: z = model.sample_z_prior(1) c = model.sample_c_prior(1) _, c_idx = torch.max(c, dim=1) sample_idxs = model.sample_sentence(z, c) sample_sent = dataset.idxs2sentence(sample_idxs) print( 'epoch-{}; Loss: {:.4f}; Recon: {:.4f}; KL: {:.4f}; Grad_norm: {:.4f}; Code: {}' .format(ep, loss.data[0], recon_loss.data[0], kl_loss.data[0], grad_norm, 'Positive' if c_idx.data[0] == 0 else 'Negative')) print('Sample: "{}"'.format(sample_sent)) print() # save current model if ep % SAVE_INTERVAL == 0: save_base_vae_iter(ep) save_base_vae()