def inference(args): tf.reset_default_graph() with tf.Session() as sess: net = VAE(sess, latent_dim=args.latent_dim, lambda_kl=args.lambda_kl) net.build_model() net.inference(test_path=args.testing_images, gen_from=args.gen_from, out_path=args.output_images, bsize=args.bsize)
def main(args): ### VAE on MNIST n_transform = transforms.Compose([transforms.ToTensor()]) dataset = MNIST('data', transform=n_transform, download=True) ### CVAE on MNIST # n_transform = transforms.Compose([transforms.ToTensor()]) # dataset = MNIST('data', transform=n_transform) ### CVAE on facescrub-5 # n_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor()]) # dataset = ImageFolder('facescrub-5', transform=n_transform) data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=2) vae = VAE(args.latent_size).cuda() ### CVAE # vae = CVAE(args.latent_size, num_labels=args.num_labels).cuda() optimizer = torch.optim.Adam(vae.parameters(), lr=args.learning_rate) # Decide if you want to use fixed noise or not fix_noise = Variable(torch.randn((50, args.latent_size)).cuda()) num_iter = 0 for epoch in range(args.epochs * 10): for _, batch in enumerate(data_loader, 0): img = Variable(batch[0].cuda()) label = Variable(batch[1].cuda()) recon_img, mean, log_var, z = vae(img) ### CVAE # recon_img, mean, log_var, z = vae(img, label) loss = loss_fn(recon_img, img, mean, log_var) optimizer.zero_grad() loss.backward() optimizer.step() num_iter += 1 if num_iter % args.print_every == 0: print("Batch %04d/%i, Loss %9.4f" % (num_iter, len(data_loader) - 1, loss.data.item())) if num_iter % args.save_test_sample == 0: x = vae.inference(fix_noise) save_img(args, x.detach(), num_iter) if num_iter % args.save_recon_img == 0: save_img(args, recon_img.detach(), num_iter, recon=True)
def main(ARGS, device): """ Prepares the datasets for training, and optional, validation and testing. Then, initializes the VAE model and runs the training (/validation) process for a given number of epochs. """ data_splits = ['train', 'val'] datasets = { split: IMDB(ARGS.data_dir, split, ARGS.max_sequence_length, ARGS.min_word_occ, ARGS.create_data) for split in data_splits } pretrained_embeddings = datasets['train'].get_pretrained_embeddings( ARGS.embed_dim).to(device) model = VAE( datasets['train'].vocab_size, ARGS.batch_size, device, pretrained_embeddings=pretrained_embeddings, trainset=datasets['train'], max_sequence_length=ARGS.max_sequence_length, lstm_dim=ARGS.lstm_dim, z_dim=ARGS.z_dim, embed_dim=ARGS.embed_dim, n_lstm_layers=ARGS.n_lstm_layers, kl_anneal_type=ARGS.kl_anneal_type, kl_anneal_x0=ARGS.kl_anneal_x0, kl_anneal_k=ARGS.kl_anneal_k, kl_fbits_lambda=ARGS.kl_fbits_lambda, word_keep_rate=ARGS.word_keep_rate, ) model.to(device) optimizer = torch.optim.Adam(model.parameters()) print('Starting training process...') amount_of_files = len(os.listdir("trained_models")) for epoch in range(ARGS.epochs): elbos = run_epoch(model, datasets, device, optimizer) train_elbo, val_elbo = elbos print( f"[Epoch {epoch} train elbo: {train_elbo}, val_elbo: {val_elbo}]") # Perform inference on the trained model with torch.no_grad(): model.eval() samples = model.inference() print(*idx2word(samples, i2w=datasets['train'].i2w, pad_idx=datasets['train'].pad_idx), sep='\n') model.save(f"trained_models/{amount_of_files + 1}.model")
def train(cond): torch.manual_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed(0) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ts = time.time() dataset = MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True) data_loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True) def loss_fn(recon_x, x, mean, log_var): BCE = torch.nn.functional.binary_cross_entropy(recon_x.view( -1, 28 * 28), x.view(-1, 28 * 28), reduction='sum') KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp()) return (BCE + KLD) / x.size(0) vae = VAE(encoder_layer_sizes=[784, 256], latent_size=2, decoder_layer_sizes=[256, 784], conditional=cond, num_labels=10 if cond else 0).to(device) optimizer = torch.optim.Adam(vae.parameters(), lr=0.001) logs = defaultdict(list) for epoch in range(10): tracker_epoch = defaultdict(lambda: defaultdict(dict)) for iteration, (x, y) in enumerate(data_loader): x, y = x.to(device), y.to(device) if cond: recon_x, mean, log_var, z = vae(x, y) else: recon_x, mean, log_var, z = vae(x) for i, yi in enumerate(y): id = len(tracker_epoch) tracker_epoch[id]['x'] = z[i, 0].item() tracker_epoch[id]['y'] = z[i, 1].item() tracker_epoch[id]['label'] = yi.item() loss = loss_fn(recon_x, x, mean, log_var) optimizer.zero_grad() loss.backward() optimizer.step() logs['loss'].append(loss.item()) if iteration % 100 == 0 or iteration == len(data_loader) - 1: print("Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Loss {:9.4f}". format(epoch, 10, iteration, len(data_loader) - 1, loss.item())) if cond: c = torch.arange(0, 10).long().unsqueeze(1) x = vae.inference(n=c.size(0), c=c) else: x = vae.inference(n=10) plt.figure() plt.figure(figsize=(5, 10)) for p in range(10): plt.subplot(5, 2, p + 1) if cond: plt.text(0, 0, "c={:d}".format(c[p].item()), color='black', backgroundcolor='white', fontsize=8) plt.imshow(x[p].view(28, 28).data.numpy()) plt.axis('off') if not os.path.exists(os.path.join('figs', str(ts))): if not (os.path.exists(os.path.join('figs'))): os.mkdir(os.path.join('figs')) os.mkdir(os.path.join('figs', str(ts))) plt.savefig(os.path.join( 'figs', str(ts), "E{:d}I{:d}.png".format(epoch, iteration)), dpi=300) plt.clf() plt.close('all') df = pd.DataFrame.from_dict(tracker_epoch, orient='index') g = sns.lmplot(x='x', y='y', hue='label', data=df.groupby('label').head(100), fit_reg=False, legend=True) g.savefig(os.path.join('figs', str(ts), "E{:d}-Dist.png".format(epoch)), dpi=300) torch.save(vae, os.path.join('figs', 'vae.pth'))