def eval_wtvae_pair(epoch, model, sample_loader, args, img_output_dir, model_dir): with torch.no_grad(): model.eval() for data in sample_loader: data0 = data[0].to(model.device) data1 = data[1].to(model.device) # Run encoder: get z and sampled z z_sample1 = torch.randn(data1.shape[0], args.z_dim).to(model.device) z, mu_wt, logvar_wt = model.encode(data0) # Run decoder: get y and sampled y y = model.decode(z) y_sample = model.decode(z_sample1) # Create padded versions target_dim = np.power(2, args.num_wt) * y.shape[2] y_padded = zero_pad(y, target_dim=target_dim, device=model.device) y_sample_padded = zero_pad(y_sample, target_dim=target_dim, device=model.device) x_wt = wt(data1, model.filters, levels=args.num_wt) x_wt = x_wt[:, :, :y.shape[2], :y.shape[3]] save_image(y_padded.cpu(), img_output_dir + '/recon_y_padded{}.png'.format(epoch)) save_image(y.cpu(), img_output_dir + '/recon_y{}.png'.format(epoch)) save_image(y_sample.cpu(), img_output_dir + '/sample_y{}.png'.format(epoch)) save_image(x_wt.cpu(), img_output_dir + '/target{}.png'.format(epoch)) torch.save(model.state_dict(), model_dir + '/wtvae_epoch{}.pth'.format(epoch))
for data in sample_loader: z_sample1 = torch.randn(data.shape[0], args.z_dim).to(device) x = data.clone().detach().to(device) # z, mu_wt, logvar_wt, m1_idx, m2_idx = wt_model.encode(data.to(device)) # y = wt_model.decode(z, m1_idx, m2_idx) # y_sample = wt_model.decode(z_sample1, m1_idx, m2_idx) z, mu_wt, logvar_wt = wt_model.encode(data.to(device)) y = wt_model.decode(z) y_sample = wt_model.decode(z_sample1) y_padded = zero_pad(y, target_dim=512, device=device) y_sample_padded = zero_pad(y_sample, target_dim=512, device=device) x_wt = wt(x.reshape(x.shape[0] * x.shape[1], 1, x.shape[2], x.shape[3]), wt_model.filters, levels=2) x_wt = x_wt.reshape(x.shape) x_wt = x_wt[:, :, :128, :128] save_image(y_padded.cpu(), img_output_dir + '/sample_padded_y{}.png'.format(epoch)) save_image(y.cpu(), img_output_dir + '/sample_recon_y{}.png'.format(epoch)) save_image(y_sample.cpu(), img_output_dir + '/sample_y{}.png'.format(epoch)) save_image(x_wt.cpu(), img_output_dir + '/sample{}.png'.format(epoch)) torch.save(wt_model.state_dict(), model_dir + '/wtvae_epoch{}.pth'.format(epoch)) # Save train losses and plot np.save(model_dir+'/train_losses.npy', train_losses) save_plot(train_losses, img_output_dir + '/train_loss.png') writer.close()