예제 #1
0
    def visualize(self, epoch, ims_np, ims_dir, n_vis=64):
        idx = torch.from_numpy(np.arange(n_vis)).cuda()
        zi = self.netZ(Variable(idx)).view(n_vis, -1, 1, 1)
        I_implicit = self.netG(zi)
        I_est = self.netT(I_implicit)
        I_target = Variable(torch.from_numpy(ims_np[:n_vis]).float()).cuda()

        ims = utils.format_im(I_implicit, self.mu, self.sd)
        vutils.save_image(ims,
                          '%s/implicit_epoch_%03d.png' % (ims_dir, epoch),
                          normalize=False)
        ims = utils.format_im(I_est, self.mu, self.sd)
        vutils.save_image(ims,
                          '%s/est_epoch_%03d.png' % (ims_dir, epoch),
                          normalize=False)
        ims = utils.format_im(I_target, self.mu, self.sd)
        vutils.save_image(ims, '%s/target.png' % ims_dir, normalize=False)
예제 #2
0
    def visualize(self, epoch, ims_np, ims_dir, n_vis=64):
        n, nc, sz, _ = ims_np.shape

        images = torch.FloatTensor(n, nc, sz, sz)
        images = Variable(images.cuda())
        images.data.resize_(ims_np.shape).copy_(torch.from_numpy(ims_np))

        zi = self.netZ(images[:n_vis])
        I_implicit = self.netG(zi.view(n_vis, -1, 1, 1))
        I_target_est = self.netT(I_implicit)


        I_target = Variable(torch.from_numpy(ims_np[:n_vis]).float()).cuda()

        ims = utils.format_im(I_implicit, self.mu, self.sd)
        vutils.save_image(ims,
                          '%s/implicit_epoch_%03d.png' % (ims_dir, epoch),
                          normalize=False)
        ims = utils.format_im(I_target_est, self.mu, self.sd)
        vutils.save_image(ims,
                          '%s/est_epoch_%03d.png' % (ims_dir, epoch),
                          normalize=False)
        ims = utils.format_im(I_target, self.mu, self.sd)
        vutils.save_image(ims, '%s/target.png' % ims_dir, normalize=False)
예제 #3
0
    def eval_target_images(self, netZ, netT, ims_np, opt_params, vis_epochs=10):
        n, nc, sz, sz_y = ims_np.shape
        assert (sz == sz_y), "Input must be square!"
        self.netZ = netZ
        self.netT = netT

        ids = np.arange(8)
        images = torch.FloatTensor(8, nc, sz, sz)
        images = Variable(images.cuda())
        images.copy_(torch.from_numpy(ims_np[ids]))
        output_implicit = torch.FloatTensor(48,3,sz,sz).cuda()
        output_implicit[:8] = images
        for i in range(5):
            z_mu, z_var = self.netZ(images)
            zi = self.netZ.sample(z_mu, z_var)
            I_implicit = self.netG(zi.view(8, -1, 1, 1))
            I_target_est = self.netT(I_implicit)
            output_implicit[8*(i+1):8*(i+2)] = I_implicit
        ims = utils.format_im(output_implicit, self.mu, self.sd)
        vutils.save_image(ims, 'nam_eval_ims/eval_out.png', normalize=False)
예제 #4
0
if torch.cuda.is_available():
    encoderX, decoderX = encoderX.cuda(), decoderX.cuda()
    encoderY, decoderY = encoderY.cuda(), decoderY.cuda()


one = torch.Tensor([1])
mone = one * -1

if torch.cuda.is_available():
    one = one.cuda()
    mone = mone.cuda()


nc = 3
sz = 64
ids = np.arange(8)
images = torch.FloatTensor(8, nc, sz, sz)
images = Variable(images.cuda())
images.copy_(torch.from_numpy(Y_images[ids]))
output_implicit = torch.FloatTensor(48,nc,sz,sz).cuda()

output_implicit[:8] = images
for i in range(5):
    z_y = encoderY(images)
    I_implicit = decoderX(z_y)
    output_implicit[8*(i+1):8*(i+2)] = I_implicit

ims = format_im(output_implicit)
save_image(ims, 'images/eval_out.png', normalize=False)
        step += 1

    print(
        "Epoch: [%d/%d],  Reconstruction Loss: %.4f %.4f MMD Loss : %.4f Cross Term : %.4f %.4f Total Loss : %.4f"
        % (epoch + 1, args.epochs, recon_avg_x, recon_avg_y, mmd_avg,
           cross_x_avg, cross_y_avg, loss_avg))

    if epoch % 10 == 0:
        images_y = torch.FloatTensor(64, nc_y, sz, sz)
        images_y = Variable(images_y.cuda())
        images_y.data.copy_(torch.from_numpy(Y_images[:64]))

        with torch.no_grad():
            z_y = encoderY(images_y)
            I_implicit = decoderX(z_y)
            I_target_est = decoderY(z_y)
            I_target = Variable(torch.from_numpy(Y_images[:64]).float()).cuda()

            I_target_est = format_im(I_target_est)
            I_implicit = format_im(I_implicit)
            I_target = format_im(I_target)

            save_image(I_implicit, 'images/implicit_epoch_%03d.png' % epoch)
            save_image(I_target_est, 'images/target_est_%03d.png' % epoch)
            save_image(I_target, 'images/target.png')

        torch.save(encoderX.state_dict(), 'nets/E_X.pth')
        torch.save(decoderX.state_dict(), 'nets/D_X.pth')
        torch.save(encoderY.state_dict(), 'nets/E_Y.pth')
        torch.save(decoderY.state_dict(), 'nets/D_Y.pth')