def eval_full_wtvae128_iwtae512(epoch, full_model, optimizer, sample_loader, args, img_output_dir, model_dir, writer, save=True): # toggle model to train mode, IWT model in eval b/c frozen full_model.eval() full_model.wt_model.eval() full_model.iwt_model.eval() for batch_idx, data in enumerate(sample_loader): optimizer.zero_grad() X_128, X_512 = data Y_low_hat, mask_hat, X_hat, mu, logvar = full_model(X_128) Y_low_sample_hat, mask_sample_hat, X_sample_hat = full_model.sample( X_128.shape[0]) X_wt = full_model.wt_model.wt(X_512.to(full_model.wt_model.device)) Y_low = X_wt[:, :, :128, :128] Y_low_padded = zero_pad(Y_low, 512, device=Y_low.device) X_low = full_model.iwt_model.iwt( Y_low_padded.to(full_model.iwt_model.device)) # Save images save_image(Y_low_hat.cpu(), img_output_dir + '/y_wt_recon{}.png'.format(epoch)) save_image(mask_hat.cpu(), img_output_dir + '/mask_recon{}.png'.format(epoch)) save_image(X_hat.cpu(), img_output_dir + '/X_recon{}.png'.format(epoch)) save_image(Y_low_sample_hat.cpu(), img_output_dir + '/y_wt_sample{}.png'.format(epoch)) save_image(mask_sample_hat.cpu(), img_output_dir + '/mask_sample{}.png'.format(epoch)) save_image(X_sample_hat.cpu(), img_output_dir + '/X_sample{}.png'.format(epoch)) save_image(Y_low.cpu(), img_output_dir + '/y_wt{}.png'.format(epoch)) save_image(X_512.cpu(), img_output_dir + '/X{}.png'.format(epoch)) save_image(X_low.cpu(), img_output_dir + '/X_low{}.png'.format(epoch)) save_image(X_wt.cpu(), img_output_dir + '/X_wt{}.png'.format(epoch)) X_wt[:, :, :128, :128].fill_(0) mask = full_model.iwt_model.iwt(X_wt.to(full_model.iwt_model.device)) save_image(mask.cpu(), img_output_dir + '/mask{}.png'.format(epoch)) if save: torch.save( { 'epoch': epoch, 'model_state_dict': iwt_model.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, model_dir + '/fullvae_epoch{}.pth'.format(epoch))
def eval_wtvae_pair(epoch, model, sample_loader, args, img_output_dir, model_dir): with torch.no_grad(): model.eval() for data in sample_loader: data0 = data[0].to(model.device) data1 = data[1].to(model.device) # Run encoder: get z and sampled z z_sample1 = torch.randn(data1.shape[0], args.z_dim).to(model.device) z, mu_wt, logvar_wt = model.encode(data0) # Run decoder: get y and sampled y y = model.decode(z) y_sample = model.decode(z_sample1) # Create padded versions target_dim = np.power(2, args.num_wt) * y.shape[2] y_padded = zero_pad(y, target_dim=target_dim, device=model.device) y_sample_padded = zero_pad(y_sample, target_dim=target_dim, device=model.device) x_wt = wt(data1, model.filters, levels=args.num_wt) x_wt = x_wt[:, :, :y.shape[2], :y.shape[3]] save_image(y_padded.cpu(), img_output_dir + '/recon_y_padded{}.png'.format(epoch)) save_image(y.cpu(), img_output_dir + '/recon_y{}.png'.format(epoch)) save_image(y_sample.cpu(), img_output_dir + '/sample_y{}.png'.format(epoch)) save_image(x_wt.cpu(), img_output_dir + '/target{}.png'.format(epoch)) torch.save(model.state_dict(), model_dir + '/wtvae_epoch{}.pth'.format(epoch))
with torch.no_grad(): wt_model.eval() for data in sample_loader: z_sample1 = torch.randn(data.shape[0], args.z_dim).to(device) x = data.clone().detach().to(device) # z, mu_wt, logvar_wt, m1_idx, m2_idx = wt_model.encode(data.to(device)) # y = wt_model.decode(z, m1_idx, m2_idx) # y_sample = wt_model.decode(z_sample1, m1_idx, m2_idx) z, mu_wt, logvar_wt = wt_model.encode(data.to(device)) y = wt_model.decode(z) y_sample = wt_model.decode(z_sample1) y_padded = zero_pad(y, target_dim=512, device=device) y_sample_padded = zero_pad(y_sample, target_dim=512, device=device) x_wt = wt(x.reshape(x.shape[0] * x.shape[1], 1, x.shape[2], x.shape[3]), wt_model.filters, levels=2) x_wt = x_wt.reshape(x.shape) x_wt = x_wt[:, :, :128, :128] save_image(y_padded.cpu(), img_output_dir + '/sample_padded_y{}.png'.format(epoch)) save_image(y.cpu(), img_output_dir + '/sample_recon_y{}.png'.format(epoch)) save_image(y_sample.cpu(), img_output_dir + '/sample_y{}.png'.format(epoch)) save_image(x_wt.cpu(), img_output_dir + '/sample{}.png'.format(epoch)) torch.save(wt_model.state_dict(), model_dir + '/wtvae_epoch{}.pth'.format(epoch)) # Save train losses and plot np.save(model_dir+'/train_losses.npy', train_losses)
full_model.iwt_model.eval() for data in sample_loader: data128 = data[0] data512 = data[1] z, mu_wt, logvar_wt = full_model.wt_model.encode(data128.to(devices[0])) # Creating z sample for WT model by adding Gaussian noise ~ N(0,1) z_sample1 = torch.randn(z.shape).to(devices[0]) z_sample3 = z + torch.randn(z.shape).to(devices[0]) y = full_model.wt_model.decode(z) y_sample = full_model.wt_model.decode(z_sample1) y_sample_gaussian = full_model.wt_model.decode(z_sample3) y_padded = zero_pad(y, target_dim=512, device=devices[1]) y_sample_padded = zero_pad(y_sample, target_dim=512, device=devices[1]) y_sample_padded_gaussian = zero_pad(y_sample_gaussian, target_dim=512, device=devices[1]) y_padded_iwt = y_padded.clone().detach() for i in range(2): y_padded_iwt = full_model.iwt_model.iwt(y_padded_iwt) mu, var, m1_idx, m2_idx = full_model.iwt_model.encode(data512.to(devices[1]), y_padded) mu_iwt, var_iwt, m1_idx_iwt, m2_idx_iwt = full_model.iwt_model.encode(y_padded_iwt, y_padded) z_sample2 = torch.randn(mu.shape).to(devices[1]) x_hat = iwt_model.decode(y_padded, mu, m1_idx, m2_idx) x_sample = iwt_model.decode(y_padded, z_sample2, m1_idx, m2_idx) x_hat_iwt = iwt_model.decode(y_padded, mu_iwt, m1_idx_iwt, m2_idx_iwt)
iwt_model.eval() for data in sample_loader: data128 = data[0].to(device) data512 = data[1].to(device) z, mu_wt, logvar_wt = wt_model.encode(data128) # Creating z sample for WT model by adding Gaussian noise ~ N(0,1) z_sample1 = torch.randn(z.shape).to(device) z_sample2 = z + torch.randn(z.shape).to(device) y = wt_model.decode(z) y_sample1 = wt_model.decode(z_sample1) y_sample2 = wt_model.decode(z_sample2) y_padded = zero_pad(y, target_dim=512, device=device) y_sample_padded1 = zero_pad(y_sample1, target_dim=512, device=device) y_sample_padded2 = zero_pad(y_sample2, target_dim=512, device=device) data512_wt = wt_fn(data512) # Zero out first patch and apply IWT data512_mask = zero_mask(data512_wt, args.num_iwt, 1) data512_mask = iwt_fn(data512_mask) mask, mu, var = iwt_model(data512_mask) mask_wt = wt_fn(mask)