def get_init_posterior(data): fz = lambda z, mu, logvar: loss_function(decoder(z), data, mu, logvar) if cmd_args.unroll_test: best_z, mu, logvar = optimize_gaussian(data, encoder, fz, inner_opt_class, nsteps = cmd_args.unroll_steps, training = True) else: best_z, mu, logvar = encoder(data) return best_z, mu, logvar
def test(epoch): encoder.eval() inner_opt.set_freeze_flag(True) test_loss = 0 for i, (data, _) in tqdm(enumerate(test_loader)): data = convert_data(data) fz = lambda z, mu, logvar: loss_function(decoder(z), data, mu, logvar) if cmd_args.unroll_test: inner_opt.zero_grad() bak_dict = encoder.diff_var_dict() best_z, mu, logvar = optimize_func(data, encoder, fz, inner_opt, nsteps = cmd_args.unroll_steps) encoder.load_diff_var_dict(bak_dict) else: best_z, mu, logvar = encoder(data) loss = fz(best_z, mu, logvar) recon_batch = decoder(best_z) test_loss += loss.item() * data.shape[0] if i == 0: n = min(data.size(0), 8) comparison = torch.cat([data[:n], recon_batch.view(-1, 1, cmd_args.img_size, cmd_args.img_size)[:n]]) save_image(comparison.data.cpu(), '%s/parametric_op_reconstruction_' % cmd_args.save_dir + str(epoch) + '.png', nrow=n) inner_opt.set_freeze_flag(False) test_loss /= len(test_loader.dataset) msg = 'test epoch %d, average loss %.4f' % (epoch, test_loss) print(msg) return test_loss
def train(epoch): encoder.train() train_loss = 0 pbar = tqdm(celeb_loader) num_mini_batches = 0 for (data, _) in pbar: data = Variable(data) if cmd_args.ctx == 'gpu': data = data.cuda() optimizer.zero_grad() inner_opt.zero_grad() fz = lambda z, mu, logvar: loss_function(decoder(z).view(data.shape[0], -1), data.view(data.shape[0], -1), mu, logvar) best_z, mu, logvar = optimize_func(data, encoder, fz, inner_opt, nsteps = cmd_args.unroll_steps) loss = fz(best_z, mu, logvar) loss.backward() train_loss += loss.item() optimizer.step() pbar.set_description('minibatch loss: %.4f' % loss.item()) num_mini_batches += 1 print('Epoch %d, average loss %.4f' % (epoch, train_loss / num_mini_batches)) return train_loss / num_mini_batches
def test(epoch): encoder.eval() inner_opt.set_freeze_flag(True) test_loss = 0 for i, (data, _) in enumerate(test_loader): if cmd_args.ctx == 'gpu': data = data.cuda() data = Variable(data) fz = lambda z, mu, logvar: loss_function(decoder(z).view(data.shape[0], -1), data.view(data.shape[0], -1), mu, logvar) if cmd_args.unroll_test: inner_opt.zero_grad() bak_dict = encoder.diff_var_dict() best_z, mu, logvar = optimize_func(data, encoder, fz, inner_opt, nsteps = cmd_args.unroll_steps) encoder.load_diff_var_dict(bak_dict) else: best_z, mu, logvar = encoder(data) loss = fz(best_z, mu, logvar) recon_batch = decoder(best_z) test_loss += loss.item() * data.shape[0] if i == 0: n = min(data.size(0), 8) comparison = torch.cat([data[:n], recon_batch.view(cmd_args.batch_size, 3, 64, 64)[:n]]) save_image(comparison.data.cpu(), '%s/vae_reconstruction_' % cmd_args.save_dir + str(epoch) + '.png', nrow=n) break inner_opt.set_freeze_flag(False)
def get_init_posterior(data): fz = lambda z, mu, logvar: loss_function(decoder(z).view(data.shape[0], -1), data.view(data.shape[0], -1), mu, logvar) if cmd_args.unroll_test: inner_opt.zero_grad() bak_dict = encoder.diff_var_dict() best_z, mu, logvar = optimize_func(data, encoder, fz, inner_opt, nsteps = cmd_args.unroll_steps) encoder.load_diff_var_dict(bak_dict) else: best_z, mu, logvar = encoder(data) return best_z, mu, logvar
def train(epoch): train_loss = 0 encoder.train() pbar = tqdm(train_loader) num_mini_batches = 0 for (data, _) in pbar: data = convert_data(data) optimizer.zero_grad() fz = lambda z, mu, logvar: loss_function(decoder(z), data, mu, logvar) best_z, mu, logvar = optimize_gaussian(data, encoder, fz, inner_opt_class, nsteps = cmd_args.unroll_steps, training=True) loss = fz(best_z, mu, logvar) loss.backward() train_loss += loss.item() optimizer.step() recon_loss = binary_cross_entropy(decoder(mu), data) pbar.set_description('minibatch loss: %.4f, recon: %.4f' % (loss.item(), recon_loss.item())) num_mini_batches += 1 msg = 'train epoch %d, average loss %.4f' % (epoch, train_loss / num_mini_batches) print(msg)
def train(epoch): train_loss = 0 encoder.train() pbar = tqdm(train_loader) num_mini_batches = 0 for (data, _) in pbar: data = convert_data(data) optimizer.zero_grad() inner_opt.zero_grad() fz = lambda z, mu, logvar: loss_function(decoder(z), data, mu, logvar) best_z, mu, logvar = optimize_func(data, encoder, fz, inner_opt, nsteps = cmd_args.unroll_steps) loss = fz(best_z, mu, logvar) loss.backward() train_loss += loss.item() optimizer.step() pbar.set_description('minibatch loss: %.4f' % loss.item()) num_mini_batches += 1 msg = 'train epoch %d, average loss %.4f' % (epoch, train_loss / num_mini_batches) print(msg)
def test(epoch): encoder.eval() test_loss = 0 for i, (data, _) in tqdm(enumerate(test_loader)): data = convert_data(data) fz = lambda z, mu, logvar: loss_function(decoder(z), data, mu, logvar) if cmd_args.unroll_test: best_z, mu, logvar = optimize_gaussian(data, encoder, fz, inner_opt_class, nsteps = cmd_args.unroll_steps, training = True) else: best_z, mu, logvar = encoder(data) loss = fz(best_z, mu, logvar) recon_batch = decoder(best_z) test_loss += loss.item() * data.shape[0] if i == 0: n = min(data.size(0), 8) comparison = torch.cat([data[:n], recon_batch.view(cmd_args.batch_size, 1, cmd_args.img_size, cmd_args.img_size)[:n]]) save_image(comparison.data.cpu(), '%s/unroll_gauss_reconstruction_' % cmd_args.save_dir + str(epoch) + '.png', nrow=n) test_loss /= len(test_loader.dataset) msg = 'test epoch %d, average loss %.4f' % (epoch, test_loss) print(msg) return test_loss