def eval_wtvae_pair(epoch, model, sample_loader, args, img_output_dir,
                    model_dir):
    with torch.no_grad():
        model.eval()

        for data in sample_loader:
            data0 = data[0].to(model.device)
            data1 = data[1].to(model.device)

            # Run encoder: get z and sampled z
            z_sample1 = torch.randn(data1.shape[0],
                                    args.z_dim).to(model.device)
            z, mu_wt, logvar_wt = model.encode(data0)

            # Run decoder: get y and sampled y
            y = model.decode(z)
            y_sample = model.decode(z_sample1)

            # Create padded versions
            target_dim = np.power(2, args.num_wt) * y.shape[2]
            y_padded = zero_pad(y, target_dim=target_dim, device=model.device)
            y_sample_padded = zero_pad(y_sample,
                                       target_dim=target_dim,
                                       device=model.device)

            x_wt = wt(data1, model.filters, levels=args.num_wt)
            x_wt = x_wt[:, :, :y.shape[2], :y.shape[3]]

            save_image(y_padded.cpu(),
                       img_output_dir + '/recon_y_padded{}.png'.format(epoch))
            save_image(y.cpu(),
                       img_output_dir + '/recon_y{}.png'.format(epoch))
            save_image(y_sample.cpu(),
                       img_output_dir + '/sample_y{}.png'.format(epoch))
            save_image(x_wt.cpu(),
                       img_output_dir + '/target{}.png'.format(epoch))

    torch.save(model.state_dict(),
               model_dir + '/wtvae_epoch{}.pth'.format(epoch))
Ejemplo n.º 2
0
            for data in sample_loader:
                z_sample1 = torch.randn(data.shape[0], args.z_dim).to(device)
                x = data.clone().detach().to(device)

                # z, mu_wt, logvar_wt, m1_idx, m2_idx = wt_model.encode(data.to(device))
                # y = wt_model.decode(z, m1_idx, m2_idx)
                # y_sample = wt_model.decode(z_sample1, m1_idx, m2_idx)
                
                z, mu_wt, logvar_wt = wt_model.encode(data.to(device))
                y = wt_model.decode(z)
                y_sample = wt_model.decode(z_sample1)

                y_padded = zero_pad(y, target_dim=512, device=device)
                y_sample_padded = zero_pad(y_sample, target_dim=512, device=device)

                x_wt = wt(x.reshape(x.shape[0] * x.shape[1], 1, x.shape[2], x.shape[3]), wt_model.filters, levels=2)
                x_wt = x_wt.reshape(x.shape)
                x_wt = x_wt[:, :, :128, :128]
                
                save_image(y_padded.cpu(), img_output_dir + '/sample_padded_y{}.png'.format(epoch))
                save_image(y.cpu(), img_output_dir + '/sample_recon_y{}.png'.format(epoch))
                save_image(y_sample.cpu(), img_output_dir + '/sample_y{}.png'.format(epoch))
                save_image(x_wt.cpu(), img_output_dir + '/sample{}.png'.format(epoch))
    
        torch.save(wt_model.state_dict(), model_dir + '/wtvae_epoch{}.pth'.format(epoch))
    
    # Save train losses and plot
    np.save(model_dir+'/train_losses.npy', train_losses)
    save_plot(train_losses, img_output_dir + '/train_loss.png')
    
    writer.close()