def eval_full_wtvae128_iwtae512(epoch,
                                full_model,
                                optimizer,
                                sample_loader,
                                args,
                                img_output_dir,
                                model_dir,
                                writer,
                                save=True):
    # toggle model to train mode, IWT model in eval b/c frozen
    full_model.eval()
    full_model.wt_model.eval()
    full_model.iwt_model.eval()

    for batch_idx, data in enumerate(sample_loader):
        optimizer.zero_grad()

        X_128, X_512 = data

        Y_low_hat, mask_hat, X_hat, mu, logvar = full_model(X_128)
        Y_low_sample_hat, mask_sample_hat, X_sample_hat = full_model.sample(
            X_128.shape[0])
        X_wt = full_model.wt_model.wt(X_512.to(full_model.wt_model.device))
        Y_low = X_wt[:, :, :128, :128]
        Y_low_padded = zero_pad(Y_low, 512, device=Y_low.device)
        X_low = full_model.iwt_model.iwt(
            Y_low_padded.to(full_model.iwt_model.device))

        # Save images
        save_image(Y_low_hat.cpu(),
                   img_output_dir + '/y_wt_recon{}.png'.format(epoch))
        save_image(mask_hat.cpu(),
                   img_output_dir + '/mask_recon{}.png'.format(epoch))
        save_image(X_hat.cpu(),
                   img_output_dir + '/X_recon{}.png'.format(epoch))

        save_image(Y_low_sample_hat.cpu(),
                   img_output_dir + '/y_wt_sample{}.png'.format(epoch))
        save_image(mask_sample_hat.cpu(),
                   img_output_dir + '/mask_sample{}.png'.format(epoch))
        save_image(X_sample_hat.cpu(),
                   img_output_dir + '/X_sample{}.png'.format(epoch))

        save_image(Y_low.cpu(), img_output_dir + '/y_wt{}.png'.format(epoch))
        save_image(X_512.cpu(), img_output_dir + '/X{}.png'.format(epoch))
        save_image(X_low.cpu(), img_output_dir + '/X_low{}.png'.format(epoch))
        save_image(X_wt.cpu(), img_output_dir + '/X_wt{}.png'.format(epoch))

        X_wt[:, :, :128, :128].fill_(0)
        mask = full_model.iwt_model.iwt(X_wt.to(full_model.iwt_model.device))
        save_image(mask.cpu(), img_output_dir + '/mask{}.png'.format(epoch))

    if save:
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': iwt_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, model_dir + '/fullvae_epoch{}.pth'.format(epoch))
def eval_wtvae_pair(epoch, model, sample_loader, args, img_output_dir,
                    model_dir):
    with torch.no_grad():
        model.eval()

        for data in sample_loader:
            data0 = data[0].to(model.device)
            data1 = data[1].to(model.device)

            # Run encoder: get z and sampled z
            z_sample1 = torch.randn(data1.shape[0],
                                    args.z_dim).to(model.device)
            z, mu_wt, logvar_wt = model.encode(data0)

            # Run decoder: get y and sampled y
            y = model.decode(z)
            y_sample = model.decode(z_sample1)

            # Create padded versions
            target_dim = np.power(2, args.num_wt) * y.shape[2]
            y_padded = zero_pad(y, target_dim=target_dim, device=model.device)
            y_sample_padded = zero_pad(y_sample,
                                       target_dim=target_dim,
                                       device=model.device)

            x_wt = wt(data1, model.filters, levels=args.num_wt)
            x_wt = x_wt[:, :, :y.shape[2], :y.shape[3]]

            save_image(y_padded.cpu(),
                       img_output_dir + '/recon_y_padded{}.png'.format(epoch))
            save_image(y.cpu(),
                       img_output_dir + '/recon_y{}.png'.format(epoch))
            save_image(y_sample.cpu(),
                       img_output_dir + '/sample_y{}.png'.format(epoch))
            save_image(x_wt.cpu(),
                       img_output_dir + '/target{}.png'.format(epoch))

    torch.save(model.state_dict(),
               model_dir + '/wtvae_epoch{}.pth'.format(epoch))
Exemplo n.º 3
0
        with torch.no_grad():
            wt_model.eval()
            
            for data in sample_loader:
                z_sample1 = torch.randn(data.shape[0], args.z_dim).to(device)
                x = data.clone().detach().to(device)

                # z, mu_wt, logvar_wt, m1_idx, m2_idx = wt_model.encode(data.to(device))
                # y = wt_model.decode(z, m1_idx, m2_idx)
                # y_sample = wt_model.decode(z_sample1, m1_idx, m2_idx)
                
                z, mu_wt, logvar_wt = wt_model.encode(data.to(device))
                y = wt_model.decode(z)
                y_sample = wt_model.decode(z_sample1)

                y_padded = zero_pad(y, target_dim=512, device=device)
                y_sample_padded = zero_pad(y_sample, target_dim=512, device=device)

                x_wt = wt(x.reshape(x.shape[0] * x.shape[1], 1, x.shape[2], x.shape[3]), wt_model.filters, levels=2)
                x_wt = x_wt.reshape(x.shape)
                x_wt = x_wt[:, :, :128, :128]
                
                save_image(y_padded.cpu(), img_output_dir + '/sample_padded_y{}.png'.format(epoch))
                save_image(y.cpu(), img_output_dir + '/sample_recon_y{}.png'.format(epoch))
                save_image(y_sample.cpu(), img_output_dir + '/sample_y{}.png'.format(epoch))
                save_image(x_wt.cpu(), img_output_dir + '/sample{}.png'.format(epoch))
    
        torch.save(wt_model.state_dict(), model_dir + '/wtvae_epoch{}.pth'.format(epoch))
    
    # Save train losses and plot
    np.save(model_dir+'/train_losses.npy', train_losses)
        full_model.iwt_model.eval()
        
        for data in sample_loader:
            data128 = data[0]
            data512 = data[1]
            z, mu_wt, logvar_wt = full_model.wt_model.encode(data128.to(devices[0]))

            # Creating z sample for WT model by adding Gaussian noise ~ N(0,1)
            z_sample1 = torch.randn(z.shape).to(devices[0])
            z_sample3 = z + torch.randn(z.shape).to(devices[0])

            y = full_model.wt_model.decode(z)
            y_sample = full_model.wt_model.decode(z_sample1)
            y_sample_gaussian = full_model.wt_model.decode(z_sample3)

            y_padded = zero_pad(y, target_dim=512, device=devices[1])
            y_sample_padded = zero_pad(y_sample, target_dim=512, device=devices[1])
            y_sample_padded_gaussian = zero_pad(y_sample_gaussian, target_dim=512, device=devices[1])

            y_padded_iwt = y_padded.clone().detach()
            for i in range(2):
                y_padded_iwt = full_model.iwt_model.iwt(y_padded_iwt)
            
            mu, var, m1_idx, m2_idx = full_model.iwt_model.encode(data512.to(devices[1]), y_padded)
            mu_iwt, var_iwt, m1_idx_iwt, m2_idx_iwt = full_model.iwt_model.encode(y_padded_iwt, y_padded)
            z_sample2 = torch.randn(mu.shape).to(devices[1])

            x_hat = iwt_model.decode(y_padded, mu, m1_idx, m2_idx)
            x_sample = iwt_model.decode(y_padded, z_sample2, m1_idx, m2_idx)

            x_hat_iwt = iwt_model.decode(y_padded, mu_iwt, m1_idx_iwt, m2_idx_iwt)
        iwt_model.eval()

        for data in sample_loader:
            data128 = data[0].to(device)
            data512 = data[1].to(device)
            z, mu_wt, logvar_wt = wt_model.encode(data128)

            # Creating z sample for WT model by adding Gaussian noise ~ N(0,1)
            z_sample1 = torch.randn(z.shape).to(device)
            z_sample2 = z + torch.randn(z.shape).to(device)

            y = wt_model.decode(z)
            y_sample1 = wt_model.decode(z_sample1)
            y_sample2 = wt_model.decode(z_sample2)

            y_padded = zero_pad(y, target_dim=512, device=device)
            y_sample_padded1 = zero_pad(y_sample1,
                                        target_dim=512,
                                        device=device)
            y_sample_padded2 = zero_pad(y_sample2,
                                        target_dim=512,
                                        device=device)

            data512_wt = wt_fn(data512)
            # Zero out first patch and apply IWT
            data512_mask = zero_mask(data512_wt, args.num_iwt, 1)
            data512_mask = iwt_fn(data512_mask)

            mask, mu, var = iwt_model(data512_mask)

            mask_wt = wt_fn(mask)