Beispiel #1
0
    def test(self):
        states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device)
        score = CondRefineNetDilated(self.config).to(self.config.device)
        score = torch.nn.DataParallel(score)

        score.load_state_dict(states[0])

        if not os.path.exists(self.args.image_folder):
            os.makedirs(self.args.image_folder)

        sigmas = np.exp(np.linspace(np.log(self.config.model.sigma_begin), np.log(self.config.model.sigma_end),
                                    self.config.model.num_classes))

        score.eval()
        grid_size = 5

        imgs = []
        if self.config.data.dataset == 'MNIST':
            samples = torch.rand(grid_size ** 2, 1, 28, 28, device=self.config.device)
            all_samples = self.anneal_Langevin_dynamics(samples, score, sigmas, 100, 0.00002)

            for i, sample in enumerate(tqdm.tqdm(all_samples, total=len(all_samples), desc='saving images')):
                sample = sample.view(grid_size ** 2, self.config.data.channels, self.config.data.image_size,
                                     self.config.data.image_size)

                if self.config.data.logit_transform:
                    sample = torch.sigmoid(sample)

                image_grid = make_grid(sample, nrow=grid_size)
                if i % 10 == 0:
                    im = Image.fromarray(image_grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy())
                    imgs.append(im)

                save_image(image_grid, os.path.join(self.args.image_folder, 'image_{}.png'.format(i)))
                torch.save(sample, os.path.join(self.args.image_folder, 'image_raw_{}.pth'.format(i)))


        else:
            samples = torch.rand(grid_size ** 2, 3, 32, 32, device=self.config.device)

            all_samples = self.anneal_Langevin_dynamics(samples, score, sigmas, 100, 0.00002)

            for i, sample in enumerate(tqdm.tqdm(all_samples, total=len(all_samples), desc='saving images')):
                sample = sample.view(grid_size ** 2, self.config.data.channels, self.config.data.image_size,
                                     self.config.data.image_size)

                if self.config.data.logit_transform:
                    sample = torch.sigmoid(sample)

                image_grid = make_grid(sample, nrow=grid_size)
                if i % 10 == 0:
                    im = Image.fromarray(image_grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy())
                    imgs.append(im)

                save_image(image_grid, os.path.join(self.args.image_folder, 'image_{}.png'.format(i)), nrow=10)
                torch.save(sample, os.path.join(self.args.image_folder, 'image_raw_{}.pth'.format(i)))

        imgs[0].save(os.path.join(self.args.image_folder, "movie.gif"), save_all=True, append_images=imgs[1:], duration=1, loop=0)
    def save_sampled_images(self):
        num_of_step = 100
        bs = 100

        sigmas = np.exp(
            np.linspace(np.log(self.config.model.sigma_begin),
                        np.log(self.config.model.sigma_end),
                        self.config.model.num_classes))

        if not os.path.exists(self.args.image_folder):
            os.makedirs(self.args.image_folder)

        print('Load checkpoint from' + self.args.log)

        for epochs in [100000, 170000]:
            states = torch.load(os.path.join(
                self.args.log, 'checkpoint_' + str(epochs) + '.pth'),
                                map_location=self.config.device)
            score = CondRefineNetDilated(self.config).to(self.config.device)
            score = torch.nn.DataParallel(score)

            score.load_state_dict(states[0])
            score.eval()

            if not os.path.exists(
                    os.path.join(self.args.image_folder,
                                 'epochs' + str(epochs))):
                os.makedirs(
                    os.path.join(self.args.image_folder,
                                 'epochs' + str(epochs)))

            save_index = 0
            print("Begin epochs", epochs)
            for j in range(num_of_step):
                samples = torch.rand(bs, 3, 32, 32, device=self.config.device)
                all_samples = self.anneal_Langevin_dynamics_GenerateImages(
                    samples, score, sigmas, 100, 0.00002)
                images_new = all_samples.mul_(255).add_(0.5).clamp_(
                    0, 255).permute(0, 2, 3, 1).to('cpu', torch.uint8).numpy()
                for k in range(len(images_new)):
                    ims = Image.fromarray(images_new[k])
                    ims.save(
                        os.path.join(self.args.image_folder,
                                     'epochs' + str(epochs),
                                     'img_' + str(save_index) + '.png'))
                    print('Save images ', k)
                    save_index += 1
    def test_inpainting(self):
        states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device)
        score = CondRefineNetDilated(self.config).to(self.config.device)
        score = torch.nn.DataParallel(score)

        score.load_state_dict(states[0])

        if not os.path.exists(self.args.image_folder):
            os.makedirs(self.args.image_folder)

        sigmas = np.exp(np.linspace(np.log(self.config.model.sigma_begin), np.log(self.config.model.sigma_end),
                                    self.config.model.num_classes))
        score.eval()

        imgs = []
        if self.config.data.dataset == 'CELEBA':
            dataset = CelebA(root=os.path.join(self.args.run, 'datasets', 'celeba'), split='test',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(self.config.data.image_size),
                                 transforms.ToTensor(),
                             ]), download=True)

            dataloader = DataLoader(dataset, batch_size=20, shuffle=True,
                                    num_workers=0)  # changed num_workers from 4 to 0
            refer_image, _ = next(iter(dataloader))

            samples = torch.rand(20, 20, 3, self.config.data.image_size, self.config.data.image_size,
                                 device=self.config.device)

            all_samples = self.anneal_Langevin_dynamics_inpainting(samples, refer_image, score, sigmas, 100, 0.00002)
            torch.save(refer_image, os.path.join(self.args.image_folder, 'refer_image.pth'))

            for i, sample in enumerate(tqdm.tqdm(all_samples)):
                sample = sample.view(400, self.config.data.channels, self.config.data.image_size,
                                     self.config.data.image_size)

                if self.config.data.logit_transform:
                    sample = torch.sigmoid(sample)

                image_grid = make_grid(sample, nrow=20)
                if i % 10 == 0:
                    im = Image.fromarray(
                        image_grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy())
                    imgs.append(im)

                save_image(image_grid, os.path.join(self.args.image_folder, 'image_completion_{}.png'.format(i)))
                torch.save(sample, os.path.join(self.args.image_folder, 'image_completion_raw_{}.pth'.format(i)))

        elif self.config.data.dataset == 'NYUv2':
            # TODO implement inpainting and MSE calculate for NYUv2
            nyu_transform = transforms.Compose([transforms.CenterCrop((400, 400)),
                                                transforms.Resize(32),
                                                transforms.ToTensor()])
            dataset = NYUv2(os.path.join(self.args.run, 'datasets', 'nyuv2'), train=False, download=True,
                            rgb_transform=nyu_transform, depth_transform=nyu_transform)

            dataloader = DataLoader(dataset, batch_size=20, shuffle=True,
                                    num_workers=0)

            data_iter = iter(dataloader)
            rgb_image, depth = next(data_iter)
            rgb_image = rgb_image.to(self.config.device)
            depth = depth[0].to(self.config.device)

            # MSE loss evaluation
            mse = torch.nn.MSELoss()

            rgb_image = rgb_image / 256. * 255. + torch.rand_like(rgb_image) / 256.

            # torch.save(rgb_image, os.path.join(self.args.image_folder, 'rgb_image.pth'))
            samples = torch.rand(20, 20, self.config.data.channels, self.config.data.image_size,
                                 self.config.data.image_size).to(self.config.device)

            all_depth_samples, depth_pred = self.anneal_Langevin_dynamics_prediction(samples, rgb_image, score, sigmas, 100, 0.00002)

            print("MSE loss is %5.4f" % (mse(depth_pred, depth)))

            for i, sample in enumerate(tqdm.tqdm(all_depth_samples)):
                sample = sample.view(400, self.config.data.channels - 3, self.config.data.image_size,
                                     self.config.data.image_size)

                sample = torch.cat((depth.to('cpu'), sample), 0)

                depth_grid = make_grid(sample, nrow=20)
                if i % 10 == 0:
                    # dep = Image.fromarray(depth_grid.to('cpu').numpy().astype(np.float32), mode='F')
                    dep = F.to_pil_image(depth_grid[0, :, :])
                    imgs.append(dep)

                save_image(depth_grid, os.path.join(self.args.image_folder, 'depth_prediction_{}.png'.format(i)))
                # torch.save(sample, os.path.join(self.args.image_folder, 'depth_prediction_raw_{}.pth'.format(i)))



        else:
            transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])

            if self.config.data.dataset == 'CIFAR10':
                dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=True, download=True,
                                  transform=transform)
            elif self.config.data.dataset == 'SVHN':
                dataset = SVHN(os.path.join(self.args.run, 'datasets', 'svhn'), split='train', download=True,
                               transform=transform)

            dataloader = DataLoader(dataset, batch_size=20, shuffle=True,
                                    num_workers=0)  # changed num_workers from 4 to 0
            data_iter = iter(dataloader)
            refer_image, _ = next(data_iter)

            torch.save(refer_image, os.path.join(self.args.image_folder, 'refer_image.pth'))
            samples = torch.rand(20, 20, self.config.data.channels, self.config.data.image_size,
                                 self.config.data.image_size).to(self.config.device)

            all_samples = self.anneal_Langevin_dynamics_inpainting(samples, refer_image, score, sigmas, 100, 0.00002)

            for i, sample in enumerate(tqdm.tqdm(all_samples)):
                sample = sample.view(400, self.config.data.channels, self.config.data.image_size,
                                     self.config.data.image_size)

                if self.config.data.logit_transform:
                    sample = torch.sigmoid(sample)

                image_grid = make_grid(sample, nrow=20)
                if i % 10 == 0:
                    im = Image.fromarray(
                        image_grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy())
                    imgs.append(im)

                save_image(image_grid, os.path.join(self.args.image_folder, 'image_completion_{}.png'.format(i)))
                torch.save(sample, os.path.join(self.args.image_folder, 'image_completion_raw_{}.pth'.format(i)))

        imgs[0].save(os.path.join(self.args.image_folder, "movie.gif"), save_all=True, append_images=imgs[1:],
                     duration=1, loop=0)
Beispiel #4
0
    def test_inpainting(self):
        states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'),
                            map_location=self.config.device)
        score = CondRefineNetDilated(self.config).to(self.config.device)
        score = torch.nn.DataParallel(score)

        score.load_state_dict(states[0])

        if not os.path.exists(self.args.image_folder):
            os.makedirs(self.args.image_folder)

        sigmas = np.exp(
            np.linspace(np.log(self.config.model.sigma_begin),
                        np.log(self.config.model.sigma_end),
                        self.config.model.num_classes))
        score.eval()

        imgs = []
        if self.config.data.dataset == 'CELEBA':
            dataset = CelebA(
                root=os.path.join(self.args.run, 'datasets', 'celeba'),
                split='test',
                transform=transforms.Compose([
                    transforms.CenterCrop(140),
                    transforms.Resize(self.config.data.image_size),
                    transforms.ToTensor(),
                ]),
                download=True)

            dataloader = DataLoader(dataset,
                                    batch_size=20,
                                    shuffle=True,
                                    num_workers=4)
            refer_image, _ = next(iter(dataloader))

            samples = torch.rand(20,
                                 20,
                                 3,
                                 self.config.data.image_size,
                                 self.config.data.image_size,
                                 device=self.config.device)

            all_samples = self.anneal_Langevin_dynamics_inpainting(
                samples, refer_image, score, sigmas, 100, 0.00002)
            torch.save(refer_image,
                       os.path.join(self.args.image_folder, 'refer_image.pth'))

            for i, sample in enumerate(tqdm.tqdm(all_samples)):
                sample = sample.view(400, self.config.data.channels,
                                     self.config.data.image_size,
                                     self.config.data.image_size)

                if self.config.data.logit_transform:
                    sample = torch.sigmoid(sample)

                image_grid = make_grid(sample, nrow=20)
                if i % 10 == 0:
                    im = Image.fromarray(
                        image_grid.mul_(255).add_(0.5).clamp_(0, 255).permute(
                            1, 2, 0).to('cpu', torch.uint8).numpy())
                    imgs.append(im)

                save_image(
                    image_grid,
                    os.path.join(self.args.image_folder,
                                 'image_completion_{}.png'.format(i)))
                torch.save(
                    sample,
                    os.path.join(self.args.image_folder,
                                 'image_completion_raw_{}.pth'.format(i)))

        else:
            transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])

            if self.config.data.dataset == 'CIFAR10':
                dataset = CIFAR10(os.path.join(self.args.run, 'datasets',
                                               'cifar10'),
                                  train=True,
                                  download=True,
                                  transform=transform)
            elif self.config.data.dataset == 'SVHN':
                dataset = SVHN(os.path.join(self.args.run, 'datasets', 'svhn'),
                               split='train',
                               download=True,
                               transform=transform)

            dataloader = DataLoader(dataset,
                                    batch_size=20,
                                    shuffle=True,
                                    num_workers=4)
            data_iter = iter(dataloader)
            refer_image, _ = next(data_iter)

            torch.save(refer_image,
                       os.path.join(self.args.image_folder, 'refer_image.pth'))
            samples = torch.rand(
                20, 20, self.config.data.channels, self.config.data.image_size,
                self.config.data.image_size).to(self.config.device)

            all_samples = self.anneal_Langevin_dynamics_inpainting(
                samples, refer_image, score, sigmas, 100, 0.00002)

            for i, sample in enumerate(tqdm.tqdm(all_samples)):
                sample = sample.view(400, self.config.data.channels,
                                     self.config.data.image_size,
                                     self.config.data.image_size)

                if self.config.data.logit_transform:
                    sample = torch.sigmoid(sample)

                image_grid = make_grid(sample, nrow=20)
                if i % 10 == 0:
                    im = Image.fromarray(
                        image_grid.mul_(255).add_(0.5).clamp_(0, 255).permute(
                            1, 2, 0).to('cpu', torch.uint8).numpy())
                    imgs.append(im)

                save_image(
                    image_grid,
                    os.path.join(self.args.image_folder,
                                 'image_completion_{}.png'.format(i)))
                torch.save(
                    sample,
                    os.path.join(self.args.image_folder,
                                 'image_completion_raw_{}.pth'.format(i)))

        imgs[0].save(os.path.join(self.args.image_folder, "movie.gif"),
                     save_all=True,
                     append_images=imgs[1:],
                     duration=1,
                     loop=0)
Beispiel #5
0
    def test(self):
        all_psnr = []  # All signal to noise ratios over all the batches
        all_percentages = []  # All percentage accuracies
        dummy_metrics = []  # Metrics for the averaging value

        # Load the score network
        states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device)
        scorenet = CondRefineNetDilated(self.config).to(self.config.device)
        scorenet = torch.nn.DataParallel(scorenet)
        scorenet.load_state_dict(states[0])
        scorenet.eval()

        # Grab the first two samples from MNIST
        trans = transforms.Compose([transforms.ToTensor()])
        dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=False, download=True)
        data = dataset.data.transpose(0, 3, 1, 2)


        for iteration in range(100):
            print("Iteration {}".format(iteration))
            curr_dir = os.path.join(SAVE_DIR, "{:07d}".format(iteration))
            if not os.path.exists(curr_dir):
                os.makedirs(curr_dir)
            # image1, image2 = get_images_split(first_digits, second_digits)
            gt_images = []
            for _ in range(N):
                gt_images.append(get_single_image(data))

            mixed = sum(gt_images).float()

            mixed_grid = make_grid(mixed.detach() / float(N), nrow=GRID_SIZE, pad_value=1., padding=1)
            save_image(mixed_grid, os.path.join(curr_dir, "mixed.png"))

            for i in range(N):
                gt_grid = make_grid(gt_images[i], nrow=GRID_SIZE, pad_value=1., padding=1)
                save_image(gt_grid, os.path.join(curr_dir, "gt{}.png".format(i)))

            mixed = torch.Tensor(mixed).cuda().view(BATCH_SIZE, 3, 32, 32)

            xs = []
            for _ in range(N):
                xs.append(nn.Parameter(torch.Tensor(BATCH_SIZE, 3, 32, 32).uniform_()).cuda())

            step_lr=0.00002

            # Noise amounts
            sigmas = np.array([1., 0.59948425, 0.35938137, 0.21544347, 0.12915497,
                              0.07742637, 0.04641589, 0.02782559, 0.01668101, 0.01])
            n_steps_each = 200

            for idx, sigma in enumerate(sigmas):
                lambda_recon = 1.8/(sigma**2)
                # Not completely sure what this part is for
                labels = torch.ones(1, device=xs[0].device) * idx
                labels = labels.long()
                step_size = step_lr * (sigma / sigmas[-1]) ** 2

                for step in range(n_steps_each):
                    noises = []
                    for _ in range(N):
                        noises.append(torch.randn_like(xs[0]) * np.sqrt(step_size * 2))

                    grads = []
                    for i in range(N):
                        grads.append(scorenet(xs[i].view(BATCH_SIZE, 3, 32, 32), labels).detach())

                    recon_loss = (torch.norm(torch.flatten(sum(xs) - mixed)) ** 2)
                    print(recon_loss)
                    recon_grads = torch.autograd.grad(recon_loss, xs)

                    for i in range(N):
                        xs[i] = xs[i] + (step_size * grads[i]) + (-step_size * lambda_recon * recon_grads[i].detach()) + noises[i]

            for i in range(N):
                xs[i] = torch.clamp(xs[i], 0, 1)

            x_to_write = []
            for i in range(N):
                x_to_write.append(torch.Tensor(xs[i].detach().cpu()))

            # PSNR Measure
            for idx in range(BATCH_SIZE):
                best_psnr = -10000
                best_permutation = None
                for permutation in permutations(range(N)):
                    curr_psnr = sum([psnr(xs[permutation[i]][idx], gt_images[i][idx].cuda()) for i in range(N)])
                    if curr_psnr > best_psnr:
                        best_psnr = curr_psnr
                        best_permutation = permutation

                all_psnr.append(best_psnr / float(N))
                for i in range(N):
                    x_to_write[i][idx] = xs[best_permutation[i]][idx] 

                    mixed_psnr = psnr(mixed.detach().cpu()[idx] / float(N), gt_images[i][idx])
                    dummy_metrics.append(mixed_psnr)
                
            for i in range(N):
                x_grid = make_grid(x_to_write[i], nrow=GRID_SIZE, pad_value=1., padding=1)
                save_image(x_grid, os.path.join(curr_dir, "x{}.png".format(i)))

            mixed_grid = make_grid(sum(xs)/float(N), nrow=GRID_SIZE, pad_value=1., padding=1)
            save_image(mixed_grid, os.path.join(curr_dir, "recon.png".format(i)))


            # average_grid = make_grid(mixed.detach()/2., nrow=GRID_SIZE)
            # save_image(average_grid, "results/average_cifar.png")
            
            print("Curr mean {}".format(np.array(all_psnr).mean()))
            print("Const mean {}".format(np.array(dummy_metrics).mean()))
    def calculate_fid(self):
        import fid
        import tensorflow as tf

        num_of_step = 500
        bs = 100

        sigmas = np.exp(
            np.linspace(np.log(self.config.model.sigma_begin),
                        np.log(self.config.model.sigma_end),
                        self.config.model.num_classes))
        stats_path = 'fid_stats_cifar10_train.npz'  # training set statistics
        inception_path = fid.check_or_download_inception(
            None)  # download inception network

        print('Load checkpoint from' + self.args.log)
        #for epochs in range(140000, 200001, 1000):
        for epochs in [149000]:
            states = torch.load(os.path.join(
                self.args.log, 'checkpoint_' + str(epochs) + '.pth'),
                                map_location=self.config.device)
            #states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device)
            score = CondRefineNetDilated(self.config).to(self.config.device)
            score = torch.nn.DataParallel(score)

            score.load_state_dict(states[0])

            score.eval()

            if self.config.data.dataset == 'MNIST':
                print("Begin epochs", epochs)
                samples = torch.rand(bs, 1, 28, 28, device=self.config.device)
                all_samples = self.anneal_Langevin_dynamics_GenerateImages(
                    samples, score, sigmas, 100, 0.00002)
                images = all_samples.mul_(255).add_(0.5).clamp_(
                    0, 255).permute(0, 2, 3, 1).to('cpu').numpy()
                for j in range(num_of_step - 1):
                    samples = torch.rand(bs,
                                         3,
                                         32,
                                         32,
                                         device=self.config.device)
                    all_samples = self.anneal_Langevin_dynamics_GenerateImages(
                        samples, score, sigmas, 100, 0.00002)
                    images_new = all_samples.mul_(255).add_(0.5).clamp_(
                        0, 255).permute(0, 2, 3, 1).to('cpu').numpy()
                    images = np.concatenate((images, images_new), axis=0)

            else:
                print("Begin epochs", epochs)
                samples = torch.rand(bs, 3, 32, 32, device=self.config.device)
                all_samples = self.anneal_Langevin_dynamics_GenerateImages(
                    samples, score, sigmas, 100, 0.00002)
                images = all_samples.mul_(255).add_(0.5).clamp_(
                    0, 255).permute(0, 2, 3, 1).to('cpu').numpy()
                for j in range(num_of_step - 1):
                    samples = torch.rand(bs,
                                         3,
                                         32,
                                         32,
                                         device=self.config.device)
                    all_samples = self.anneal_Langevin_dynamics_GenerateImages(
                        samples, score, sigmas, 100, 0.00002)
                    images_new = all_samples.mul_(255).add_(0.5).clamp_(
                        0, 255).permute(0, 2, 3, 1).to('cpu').numpy()
                    images = np.concatenate((images, images_new), axis=0)

            # load precalculated training set statistics
            f = np.load(stats_path)
            mu_real, sigma_real = f['mu'][:], f['sigma'][:]
            f.close()

            fid.create_inception_graph(
                inception_path)  # load the graph into the current TF graph
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                mu_gen, sigma_gen = fid.calculate_activation_statistics(
                    images, sess, batch_size=100)

            fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen,
                                                       mu_real, sigma_real)
            print("FID: %s" % fid_value)
Beispiel #7
0
    def test(self):
        # For metrics
        all_psnr = []  # All signal to noise ratios over all the batches
        all_grayscale_psnr = [
        ]  # All signal to noise ratios over all the batches

        all_mixed_psnr = []  # All signal to noise ratios over all the batches
        all_mixed_grayscale_psnr = [
        ]  # All signal to noise ratios over all the batches

        all_ssim = []
        all_grayscale_ssim = []

        all_mixed_ssim = []  # All signal to noise ratios over all the batches
        all_mixed_grayscale_ssim = [
        ]  # All signal to noise ratios over all the batches

        strange_cases = {"gt1": [], "gt2": [], "mixed": [], "x": [], "y": []}

        # For inception score
        output_to_incept = []
        mixed_to_incept = []

        # For videos
        all_x = []
        all_y = []
        all_mixed = []

        # Load the score network
        states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'),
                            map_location=self.config.device)
        scorenet = CondRefineNetDilated(self.config).to(self.config.device)
        scorenet = torch.nn.DataParallel(scorenet)
        scorenet.load_state_dict(states[0])
        scorenet.eval()

        # Grab the first two samples from MNIST
        dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'),
                          train=False,
                          download=True)
        data = dataset.data.transpose(0, 3, 1, 2)
        all_animals_idx = np.isin(dataset.targets, ANIMAL_IDX)
        all_machines_idx = np.isin(dataset.targets, MACHINE_IDX)

        all_animals = dataset.data[all_animals_idx].transpose(0, 3, 1, 2)
        all_machines = dataset.data[all_machines_idx].transpose(0, 3, 1, 2)

        for iteration in range(105, 250):
            print("Iteration {}".format(iteration))
            curr_dir = os.path.join(SAVE_DIR, "{:07d}".format(iteration))
            if not os.path.exists(curr_dir):
                os.makedirs(curr_dir)

            # rand_idx_1 = np.random.randint(0, data.shape[0] - 1, BATCH_SIZE)
            # rand_idx_2 = np.random.randint(0, data.shape[0] - 1, BATCH_SIZE)

            # image1 = torch.tensor(data[rand_idx_1, :].astype(np.float) / 255.).float()
            # image2 = torch.tensor(data[rand_idx_2, :].astype(np.float) / 255.).float()

            image1, image2 = get_images_split(all_animals, all_machines)
            mixed = (image1 + image2).float()

            mixed_grid = make_grid(mixed.detach() / 2., nrow=GRID_SIZE)
            save_image(mixed_grid, os.path.join(curr_dir, "mixed.png"))

            gt1_grid = make_grid(image1, nrow=GRID_SIZE)
            save_image(gt1_grid, os.path.join(curr_dir, "gt1.png"))

            gt2_grid = make_grid(image2, nrow=GRID_SIZE)
            save_image(gt2_grid, os.path.join(curr_dir, "gt2.png"))

            mixed = torch.Tensor(mixed).cuda().view(BATCH_SIZE, 3, 32, 32)

            y = nn.Parameter(torch.Tensor(BATCH_SIZE, 3, 32,
                                          32).uniform_()).cuda()
            x = nn.Parameter(torch.Tensor(BATCH_SIZE, 3, 32,
                                          32).uniform_()).cuda()

            step_lr = 0.00002

            # Noise amounts
            sigmas = np.array([
                1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637,
                0.04641589, 0.02782559, 0.01668101, 0.01
            ])
            n_steps_each = 100

            #lambda_recon = 1.5
            for idx, sigma in enumerate(sigmas):
                lambda_recon = 1.0 / (sigma**2)

                # Not completely sure what this part is for
                labels = torch.ones(1, device=x.device) * idx
                labels = labels.long()
                step_size = step_lr * (sigma / sigmas[-1])**2

                for step in range(n_steps_each):
                    noise_x = torch.randn_like(x) * np.sqrt(step_size * 2)
                    noise_y = torch.randn_like(y) * np.sqrt(step_size * 2)

                    grad_x = scorenet(x.view(BATCH_SIZE, 3, 32, 32),
                                      labels).detach()
                    grad_y = scorenet(y.view(BATCH_SIZE, 3, 32, 32),
                                      labels).detach()

                    recon_loss = (torch.norm(torch.flatten(y + x - mixed))**2)
                    recon_grads = torch.autograd.grad(recon_loss, [x, y])

                    #x = x + (step_size * grad_x) + noise_x
                    #y = y + (step_size * grad_y) + noise_y
                    x = x + (step_size *
                             grad_x) + (-step_size * lambda_recon *
                                        recon_grads[0].detach()) + noise_x
                    y = y + (step_size *
                             grad_y) + (-step_size * lambda_recon *
                                        recon_grads[1].detach()) + noise_y

                    # Video
                    # if (step % 5) == 0:
                    #     all_x.append(x.detach().cpu().numpy())
                    #     all_y.append(y.detach().cpu().numpy())
                    #     all_mixed.append((x.detach().cpu().numpy() + y.detach().cpu().numpy()) / 2.)
                #lambda_recon *= 2.8

            # Inception
            # from model import get_inception_score
            # output_to_incept += [np.clip(x[idx].detach().cpu().numpy().transpose(1,2,0), 0, 1) * 255. for idx in range(x.shape[0])]
            # output_to_incept += [np.clip(y[idx].detach().cpu().numpy().transpose(1,2,0), 0, 1) * 255. for idx in range(y.shape[0])]
            # mixed_to_incept += [np.clip((mixed[idx] / 2.).detach().cpu().numpy().transpose(1,2,0), 0, 1) * 255. for idx in range(y.shape[0])]
            # mixed_to_incept += [np.clip((mixed[idx] / 2.).detach().cpu().numpy().transpose(1,2,0), 0, 1) * 255. for idx in range(y.shape[0])]

            x_to_write = torch.Tensor(x.detach().cpu())
            y_to_write = torch.Tensor(y.detach().cpu())

            # x_movie = np.array(np.stack(all_x, axis=0))
            # y_movie = np.array(np.stack(all_y, axis=0))
            # mixed_movie = np.array(np.stack(all_mixed, axis=0))

            for idx in range(BATCH_SIZE):
                # PSNR
                est1 = psnr(x[idx], image1[idx].cuda()) + psnr(
                    y[idx], image2[idx].cuda())
                est2 = psnr(x[idx], image2[idx].cuda()) + psnr(
                    y[idx], image1[idx].cuda())
                correct_psnr = max(est1, est2) / 2.
                all_psnr.append(correct_psnr)

                grayscale_est1 = psnr(
                    x[idx].mean(0), image1[idx].mean(0).cuda()) + psnr(
                        y[idx].mean(0), image2[idx].mean(0).cuda())
                grayscale_est2 = psnr(
                    x[idx].mean(0), image2[idx].mean(0).cuda()) + psnr(
                        y[idx].mean(0), image1[idx].mean(0).cuda())
                grayscale_psnr = max(grayscale_est1, grayscale_est2) / 2.
                all_grayscale_psnr.append(grayscale_psnr)

                # Mixed PSNR
                mixed_psnr = psnr((mixed[idx] / 2.), image1[idx].cuda())
                all_mixed_psnr.append(mixed_psnr)

                grayscale_mixed_psnr = psnr((mixed[idx] / 2.).mean(0),
                                            image1[idx].mean(0).cuda())
                all_mixed_grayscale_psnr.append(grayscale_mixed_psnr)

                if est2 > est1:
                    x_to_write[idx] = y[idx]
                    y_to_write[idx] = x[idx]

                    # tmp = x_movie[:, idx].copy()
                    # x_movie[:, idx] = y_movie[:, idx]
                    # y_movie[:, idx] = tmp

                # SSIM
                est1 = get_ssim(x[idx], image1[idx]) + get_ssim(
                    y[idx], image2[idx])
                est2 = get_ssim(x[idx], image2[idx]) + get_ssim(
                    y[idx], image1[idx])
                correct_ssim = max(est1, est2) / 2.
                all_ssim.append(correct_ssim)

                grayscale_est1 = get_ssim_grayscale(
                    x[idx], image1[idx]) + get_ssim(y[idx], image2[idx])
                grayscale_est2 = get_ssim_grayscale(
                    x[idx], image2[idx]) + get_ssim(y[idx], image1[idx])
                grayscale_ssim = max(grayscale_est1, grayscale_est2) / 2.
                all_grayscale_ssim.append(grayscale_ssim)

                # Mixed ssim
                mixed_ssim = get_ssim(
                    (mixed[idx] / 2.), image1[idx]) + get_ssim(
                        (mixed[idx] / 2.), image2[idx])
                all_mixed_ssim.append(mixed_ssim / 2.)

                grayscale_mixed_ssim = get_ssim_grayscale(
                    (mixed[idx] / 2.), image1[idx]) + get_ssim_grayscale(
                        (mixed[idx] / 2.), image2[idx])
                all_mixed_grayscale_ssim.append(grayscale_mixed_ssim / 2.)

                if correct_psnr < 19 and grayscale_psnr > 20.5:
                    strange_cases["gt1"].append(
                        image1[idx].detach().cpu().numpy())
                    strange_cases["gt2"].append(
                        image2[idx].detach().cpu().numpy())
                    strange_cases["mixed"].append(
                        mixed[idx].detach().cpu().numpy())
                    strange_cases["x"].append(
                        x_to_write[idx].detach().cpu().numpy())
                    strange_cases["y"].append(
                        y_to_write[idx].detach().cpu().numpy())
                    print("Added strange case")

            # # Write x and y
            x_grid = make_grid(x_to_write, nrow=GRID_SIZE)
            save_image(x_grid, os.path.join(curr_dir, "x.png"))

            y_grid = make_grid(y_to_write, nrow=GRID_SIZE)
            save_image(y_grid, os.path.join(curr_dir, "y.png"))

            print("PSNR {}".format(np.array(all_psnr).mean()))
            print("Mixed PSNR {}".format(np.array(all_mixed_psnr).mean()))

            print("PSNR Grayscale {}".format(
                np.array(all_grayscale_psnr).mean()))
            print("Mixed PSNR Grayscale {}".format(
                np.array(all_mixed_grayscale_psnr).mean()))

            print("SSIM {}".format(np.array(all_ssim).mean()))
            print("Mixed SSIM {}".format(np.array(all_mixed_ssim).mean()))

            print("SSIM Grayscale {}".format(
                np.array(all_grayscale_ssim).mean()))
            print("Mixed SSIM Grayscale {}".format(
                np.array(all_mixed_grayscale_ssim).mean()))

            # Write video frames
            # padding = 50
            # dim_w = 172 * 4 + padding * 5
            # dim_h = 172 * 2 + padding * 3
            # for frame_idx in range(x_movie.shape[0]):
            #     print(frame_idx)
            #     x_grid = make_grid(torch.Tensor(x_movie[frame_idx]), nrow=GRID_SIZE)
            #     # save_image(x_grid, "results/videos/x/x_{}.png".format(frame_idx))

            #     y_grid = make_grid(torch.Tensor(y_movie[frame_idx]), nrow=GRID_SIZE)
            #     # save_image(y_grid, "results/videos/y/y_{}.png".format(frame_idx))

            #     recon_grid = make_grid(torch.Tensor(mixed_movie[frame_idx]), nrow=GRID_SIZE)
            #     # save_image(recon_grid, "results/videos/mixed/mixed_{}.png".format(frame_idx))

            #     output_frame = torch.zeros(3, dim_h, dim_w)
            #     output_frame[:, 50:(50+172), 50:(50+172)] = gt1_grid
            #     output_frame[:, (100+172):(100+172*2), 50:(50+172)] = gt2_grid
            #     output_frame[:, (75 + 86):(75 + 86 + 172), (50 * 2 + 172):(50 * 2 + 172 * 2)] = mixed_grid
            #     output_frame[:, (75 + 86):(75 + 86 + 172), (50 * 3 + 172 * 2):(50 * 3 + 172 * 3)] = recon_grid
            #     output_frame[:,50:(50 + 172),(50 * 4 + 172 * 3):(50 * 4 + 172 * 4)] = x_grid
            #     output_frame[:,(50 * 2 + 172):(50 * 2 + 172 * 2),(50 * 4 + 172 * 3):(50 * 4 + 172 * 4)] = y_grid
            #     save_image(output_frame, "results/videos/combined/{:03d}.png".format(frame_idx))

        # Calculate inception scores
        # print("Output inception score {}".format(get_inception_score(output_to_incept)))
        # print("Mixed inception score {}".format(get_inception_score(mixed_to_incept)))

        # Write strange results
        y1 = np.stack(strange_cases["y"], axis=0)
        y_grid = make_grid(torch.Tensor(y1))
        save_image(y_grid, "results/y_strange.png")

        x1 = np.stack(strange_cases["x"], axis=0)
        x_grid = make_grid(torch.Tensor(x1))
        save_image(x_grid, "results/x_strange.png")

        gt1 = np.stack(strange_cases["gt1"], axis=0)
        gt1_grid = make_grid(torch.Tensor(gt1))
        save_image(gt1_grid, "results/gt1_strange.png")

        gt2 = np.stack(strange_cases["gt2"], axis=0)
        gt2_grid = make_grid(torch.Tensor(gt2))
        save_image(gt2_grid, "results/gt2_strange.png")

        mixed = np.stack(strange_cases["mixed"], axis=0) / 2.
        mixed_grid = make_grid(torch.Tensor(mixed))
        save_image(mixed_grid, "results/mixed_strange.png")
Beispiel #8
0
    def test(self):
        all_psnr = []  # All signal to noise ratios over all the batches
        all_percentages = []  # All percentage accuracies
        dummy_metrics = []  # Metrics for the averaging value

        bad_cases = {"gt1": [], "gt2": [], "mixed": [], "x": [], "y": []}

        # Load the score network
        states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'),
                            map_location=self.config.device)
        scorenet = CondRefineNetDilated(self.config).to(self.config.device)
        scorenet = torch.nn.DataParallel(scorenet)
        scorenet.load_state_dict(states[0])
        scorenet.eval()

        trans = transforms.Compose([transforms.ToTensor()])
        dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'),
                        train=False,
                        download=True)

        first_digits_idx = dataset.train_labels <= 4
        second_digits_idx = dataset.train_labels >= 5

        first_digits = dataset.train_data[first_digits_idx]
        second_digits = dataset.train_data[second_digits_idx]

        for iteration in range(100):
            print("Iteration {}".format(iteration))
            curr_dir = os.path.join(SAVE_DIR, "{:07d}".format(iteration))
            if not os.path.exists(curr_dir):
                os.makedirs(curr_dir)

            image1, image2 = get_images_split(first_digits, second_digits)
            #image1, image2 = get_images_no_split(dataset)

            mixed = (image1 + image2).float()
            # mixed = torch.clamp(image1 + image2, 0, 1).float()  # Capsule net

            mixed_grid = make_grid(mixed.detach() / 2.,
                                   nrow=GRID_SIZE,
                                   pad_value=1.,
                                   padding=1)
            save_image(mixed_grid, os.path.join(curr_dir, "mixed.png"))

            gt1_grid = make_grid(image1,
                                 nrow=GRID_SIZE,
                                 pad_value=1.,
                                 padding=1)
            save_image(gt1_grid, os.path.join(curr_dir, "gt1.png"))

            gt2_grid = make_grid(image2,
                                 nrow=GRID_SIZE,
                                 pad_value=1.,
                                 padding=1)
            save_image(gt2_grid, os.path.join(curr_dir, "gt2.png"))

            mixed = torch.Tensor(mixed).cuda().view(BATCH_SIZE, 1, 28, 28)

            y = nn.Parameter(torch.Tensor(BATCH_SIZE, 1, 28,
                                          28).uniform_()).cuda()
            x = nn.Parameter(torch.Tensor(BATCH_SIZE, 1, 28,
                                          28).uniform_()).cuda()

            step_lr = 0.00002

            # Noise amounts
            sigmas = np.array([
                1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637,
                0.04641589, 0.02782559, 0.01668101, 0.01
            ])
            n_steps_each = 100

            for idx, sigma in enumerate(sigmas):
                lambda_recon = 1.8 / (sigma**2)
                # Not completely sure what this part is for
                labels = torch.ones(1, device=x.device) * idx
                labels = labels.long()
                step_size = step_lr * (sigma / sigmas[-1])**2

                for step in range(n_steps_each):
                    noise_x = torch.randn_like(x) * np.sqrt(step_size * 2)
                    noise_y = torch.randn_like(y) * np.sqrt(step_size * 2)

                    grad_x = scorenet(x.view(BATCH_SIZE, 1, 28, 28),
                                      labels).detach()
                    grad_y = scorenet(y.view(BATCH_SIZE, 1, 28, 28),
                                      labels).detach()

                    recon_loss = (torch.norm(torch.flatten(y + x - mixed))**2)
                    #recon_loss = torch.norm(torch.flatten(torch.clamp(x + y, -10000000, 1) - mixed)) ** 2

                    #recon_loss = torch.norm(torch.flatten((1 / (1 + torch.exp(-5 * (y + x - 0.5))))  - mixed)) ** 2
                    #recon_loss = torch.norm(torch.flatten((y - mixed)) ** 2
                    # print(recon_loss)

                    recon_grads = torch.autograd.grad(recon_loss, [x, y])

                    #x = x + (step_size * grad_x) + noise_x
                    #y = y + (step_size * grad_y) + noise_y
                    x = x + (step_size *
                             grad_x) + (-step_size * lambda_recon *
                                        recon_grads[0].detach()) + noise_x
                    y = y + (step_size *
                             grad_y) + (-step_size * lambda_recon *
                                        recon_grads[1].detach()) + noise_y

                    # x = x + (-step_size * lambda_recon * recon_grads[0].detach()) + noise_x
                    # y = y + (-step_size * lambda_recon * recon_grads[1].detach()) + noise_y

            # Clamp for writing purposes
            x = torch.clamp(x, 0, 1)
            y = torch.clamp(y, 0, 1)

            x_to_write = torch.Tensor(x.detach().cpu())
            y_to_write = torch.Tensor(y.detach().cpu())

            # PSNR Measure
            for idx in range(BATCH_SIZE):
                est1 = psnr(x[idx], image1[idx].cuda()) + psnr(
                    y[idx], image2[idx].cuda())
                est2 = psnr(x[idx], image2[idx].cuda()) + psnr(
                    y[idx], image1[idx].cuda())
                correct_estimate = max(est1, est2) / 2.
                all_psnr.append(correct_estimate)

                if est2 > est1:
                    x_to_write[idx] = y[idx]
                    y_to_write[idx] = x[idx]

                mixed_psnr = psnr(mixed[idx] / 2., image1[idx].cuda()) + psnr(
                    mixed[idx] / 2., image2[idx].cuda())
                dummy_metrics.append(mixed_psnr / 2.)

                if correct_estimate < 12.:
                    bad_cases["gt1"].append(image1[idx].detach().cpu().numpy())
                    bad_cases["gt2"].append(image2[idx].detach().cpu().numpy())
                    bad_cases["mixed"].append(
                        mixed[idx].detach().cpu().numpy())
                    bad_cases["x"].append(
                        x_to_write[idx].detach().cpu().numpy())
                    bad_cases["y"].append(
                        y_to_write[idx].detach().cpu().numpy())
                    print("Added bad case")

            # Percentage Measure
            # x_thresh = (x > 0.01)
            # y_thresh = (y > 0.01)
            # image1_thresh = (image1 > 0.01)
            # image2_thresh = (image2 > 0.01)
            # avg_thresh = ((mixed.detach()[idx] / 2.) > 0.01)
            # for idx in range(BATCH_SIZE):
            #     est1 = np.count_nonzero((x_thresh[idx] == image1_thresh[idx].cuda()).detach().cpu()) + np.count_nonzero((y_thresh[idx] == image2_thresh[idx].cuda()).detach().cpu())
            #     est2 = np.count_nonzero((x_thresh[idx] == image2_thresh[idx].cuda()).detach().cpu()) + np.count_nonzero((y_thresh[idx] == image1_thresh[idx].cuda()).detach().cpu())
            #     correct_estimate = max(est1, est2)
            #     percentage = correct_estimate / (2 * x.shape[-1] * x.shape[-2])
            #     all_percentages.append(percentage)

            #     dummy_count = np.count_nonzero((avg_thresh == image1_thresh[idx].cuda()).detach().cpu()) + np.count_nonzero((avg_thresh == image2_thresh[idx].cuda()).detach().cpu())
            #     dummy_percentage = dummy_count / (2 * x.shape[-1] * x.shape[-2])
            #     dummy_metrics.append(dummy_percentage)

            # Recon Grid
            recon_grid = make_grid(torch.clamp(
                torch.clamp(x_to_write, 0, 1) + torch.clamp(y_to_write, 0, 1),
                0, 1),
                                   nrow=GRID_SIZE,
                                   pad_value=1.,
                                   padding=1)
            save_image(recon_grid, os.path.join(curr_dir, "recon.png"))
            # Write x and y
            x_grid = make_grid(x_to_write,
                               nrow=GRID_SIZE,
                               pad_value=1.,
                               padding=1)
            save_image(x_grid, os.path.join(curr_dir, "x.png"))

            y_grid = make_grid(y_to_write,
                               nrow=GRID_SIZE,
                               pad_value=1.,
                               padding=1)
            save_image(y_grid, os.path.join(curr_dir, "y.png"))

            # average_grid = make_grid(mixed.detach()/2., nrow=GRID_SIZE)
            # save_image(average_grid, "results/average_cifar.png")

            print("Curr mean {}".format(np.array(all_psnr).mean()))
            #print("Curr mean {}".format(np.array(all_percentages).mean()))
            print("Curr dummy mean {}".format(np.array(dummy_metrics).mean()))

        y1 = np.stack(bad_cases["y"], axis=0)
        y_grid = make_grid(torch.Tensor(y1), pad_value=1., padding=1)
        save_image(y_grid, os.path.join(SAVE_DIR, "y_strange.png"))

        x1 = np.stack(bad_cases["x"], axis=0)
        x_grid = make_grid(torch.Tensor(x1), pad_value=1., padding=1)
        save_image(x_grid, os.path.join(SAVE_DIR, "x_strange.png"))

        gt1 = np.stack(bad_cases["gt1"], axis=0)
        gt1_grid = make_grid(torch.Tensor(gt1), pad_value=1., padding=1)
        save_image(gt1_grid, os.path.join(SAVE_DIR, "gt1_strange.png"))

        gt2 = np.stack(bad_cases["gt2"], axis=0)
        gt2_grid = make_grid(torch.Tensor(gt2), pad_value=1., padding=1)
        save_image(gt2_grid, os.path.join(SAVE_DIR, "gt2_strange.png"))

        mixed = np.stack(bad_cases["mixed"], axis=0) / 2.
        mixed_grid = make_grid(torch.Tensor(mixed), pad_value=1., padding=1)
        save_image(mixed_grid, os.path.join(SAVE_DIR, "mixed_strange.png"))

        import pdb
        pdb.set_trace()
Beispiel #9
0
    def test(self):
        # Load the score network
        states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'),
                            map_location=self.config.device)
        scorenet = CondRefineNetDilated(self.config).to(self.config.device)
        scorenet = torch.nn.DataParallel(scorenet)
        scorenet.load_state_dict(states[0])
        scorenet.eval()

        # Grab the first two samples from MNIST
        dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'),
                          train=False,
                          download=True)
        dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
        image0 = np.array(dataset[130][0]).astype(np.float).transpose(2, 0, 1)
        image1 = np.array(dataset[131][0]).astype(np.float).transpose(2, 0, 1)

        mixed = (image0 + image1)
        mixed = mixed / 255.

        cv2.imwrite("mixed.png",
                    (mixed * 127.5).astype(np.uint8).transpose(1, 2,
                                                               0)[:, :, ::-1])
        cv2.imwrite("gt0.png",
                    (image0).astype(np.uint8).transpose(1, 2, 0)[:, :, ::-1])
        cv2.imwrite("gt1.png",
                    (image1).astype(np.uint8).transpose(1, 2, 0)[:, :, ::-1])
        mixed = torch.Tensor(mixed).cuda()

        y = nn.Parameter(torch.Tensor(3, 32, 32).uniform_()).cuda()
        x = nn.Parameter(torch.Tensor(3, 32, 32).uniform_()).cuda()

        step_lr = 0.00002

        # Noise amounts
        sigmas = np.array([
            1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637,
            0.04641589, 0.02782559, 0.01668101, 0.01
        ])
        n_steps_each = 100
        lambda_recon = 1.5  # Weight to put on reconstruction error vs p(x)

        for idx, sigma in enumerate(sigmas):
            # Not completely sure what this part is for
            labels = torch.ones(1, device=x.device) * idx
            labels = labels.long()
            step_size = step_lr * (sigma / sigmas[-1])**2

            for step in range(n_steps_each):
                noise_x = torch.randn_like(x) * np.sqrt(step_size * 2)
                noise_y = torch.randn_like(y) * np.sqrt(step_size * 2)

                grad_x = scorenet(x.view(1, 3, 32, 32), labels).detach()
                grad_y = scorenet(y.view(1, 3, 32, 32), labels).detach()

                recon_loss = (torch.norm(torch.flatten(y + x - mixed))**2)
                print(recon_loss)
                recon_grads = torch.autograd.grad(recon_loss, [x, y])

                #x = x + (step_size * grad_x) + noise_x
                #y = y + (step_size * grad_y) + noise_y
                x = x + (step_size *
                         grad_x) + (-step_size * lambda_recon *
                                    recon_grads[0].detach()) + noise_x
                y = y + (step_size *
                         grad_y) + (-step_size * lambda_recon *
                                    recon_grads[1].detach()) + noise_y

            lambda_recon *= 2.8

        # Write x and y
        x_np = x.detach().cpu().numpy()[0, :, :, :]
        x_np = np.clip(x_np, 0, 1)
        cv2.imwrite("x.png",
                    (x_np * 255).astype(np.uint8).transpose(1, 2,
                                                            0)[:, :, ::-1])

        y_np = y.detach().cpu().numpy()[0, :, :, :]
        y_np = np.clip(y_np, 0, 1)
        cv2.imwrite("y.png",
                    (y_np * 255).astype(np.uint8).transpose(1, 2,
                                                            0)[:, :, ::-1])

        cv2.imwrite(
            "out_mixed.png",
            (y_np * 127.5).astype(np.uint8).transpose(1, 2, 0)[:, :, ::-1] +
            (x_np * 127.5).astype(np.uint8).transpose(1, 2, 0)[:, :, ::-1])

        import pdb
        pdb.set_trace()
Beispiel #10
0
    def test(self):
        # Load the score network
        states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'),
                            map_location=self.config.device)
        scorenet = CondRefineNetDilated(self.config).to(self.config.device)
        scorenet = torch.nn.DataParallel(scorenet)
        scorenet.load_state_dict(states[0])
        scorenet.eval()

        # Grab the first two samples from MNIST
        dataset = CelebA(os.path.join(self.args.run, 'datasets', 'celeba'),
                         split='test',
                         download=True)
        dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
        input_image = np.array(dataset[0][0]).astype(np.float).transpose(
            2, 0, 1)

        # input_image = cv2.imread("/projects/grail/vjayaram/source_separation/ncsn/run/datasets/celeba/celeba/img_align_celeba/012690.jpg")

        # input_image = cv2.resize(input_image, (32, 32))[:,:,::-1].transpose(2, 0, 1)
        input_image = input_image / 255.
        noise = np.random.randn(*input_image.shape) / 10
        cv2.imwrite("input_image.png", (input_image * 255).astype(
            np.uint8).transpose(1, 2, 0)[:, :, ::-1])
        input_image += noise
        input_image = np.clip(input_image, 0, 1)

        cv2.imwrite("input_image_noisy.png", (input_image * 255).astype(
            np.uint8).transpose(1, 2, 0)[:, :, ::-1])

        input_image = torch.Tensor(input_image).cuda()
        x = nn.Parameter(torch.Tensor(3, 32, 32).uniform_()).cuda()

        step_lr = 0.00002

        # Noise amounts
        sigmas = np.array([
            1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637,
            0.04641589, 0.02782559, 0.01668101, 0.01
        ])
        n_steps_each = 100
        lambda_recon = 1.5  # Weight to put on reconstruction error vs p(x)

        for idx, sigma in enumerate(sigmas):
            # Not completely sure what this part is for
            labels = torch.ones(1, device=x.device) * idx
            labels = labels.long()
            step_size = step_lr * (sigma / sigmas[-1])**2

            for step in range(n_steps_each):
                noise_x = torch.randn_like(x) * np.sqrt(step_size * 2)

                grad_x = scorenet(x.view(1, 3, 32, 32), labels).detach()

                recon_loss = (torch.norm(torch.flatten(input_image - x))**2)
                print(recon_loss)
                recon_grads = torch.autograd.grad(recon_loss, [x])

                #x = x + (step_size * grad_x) + noise_x
                x = x + (step_size *
                         grad_x) + (-step_size * lambda_recon *
                                    recon_grads[0].detach()) + noise_x

            lambda_recon *= 1.6

        # # Write x and y
        x_np = x.detach().cpu().numpy()[0, :, :, :]
        x_np = np.clip(x_np, 0, 1)
        cv2.imwrite("x.png",
                    (x_np * 255).astype(np.uint8).transpose(1, 2,
                                                            0)[:, :, ::-1])

        # y_np = y.detach().cpu().numpy()[0,:,:,:]
        # y_np = np.clip(y_np, 0, 1)
        # cv2.imwrite("y.png", (y_np * 255).astype(np.uint8).transpose(1, 2, 0)[:,:,::-1])

        # cv2.imwrite("out_mixed.png", (y_np * 127.5).astype(np.uint8).transpose(1, 2, 0)[:,:,::-1] + (x_np * 127.5).astype(np.uint8).transpose(1, 2, 0)[:,:,::-1])

        import pdb
        pdb.set_trace()
Beispiel #11
0
    def test(self):
        # Load the score network
        states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'),
                            map_location=self.config.device)
        scorenet = CondRefineNetDilated(self.config).to(self.config.device)
        scorenet = torch.nn.DataParallel(scorenet)
        scorenet.load_state_dict(states[0])
        scorenet.eval()

        batch_size = 1
        samples = 2
        files_list = glob.glob('./ground turth/*.png')
        files_list = natsorted(files_list)
        length = len(files_list)
        result_all = np.zeros([101, 2])
        for z, file_path in enumerate(files_list):
            img = cv2.imread(file_path)
            img2 = cv2.imread('./iGM-3C/img_{}_Rec_x_end_rgb.png'.format(z))
            img = cv2.resize(img, (128, 128))

            YCbCrimg2 = cv2.cvtColor(img2, cv2.COLOR_BGR2YCrCb)
            x0 = img.copy()
            x1 = np.concatenate((img2, YCbCrimg2), 2)

            original_image = img.copy()
            x0 = torch.tensor(x0.transpose(2, 0, 1),
                              dtype=torch.float).unsqueeze(0) / 255.0
            x1 = torch.tensor(x1.transpose(2, 0, 1),
                              dtype=torch.float).unsqueeze(0) / 255.0
            x_stack = torch.zeros(
                [x0.shape[0] * samples, x0.shape[1], x0.shape[2], x0.shape[3]],
                dtype=torch.float32)

            for i in range(samples):
                x_stack[i * batch_size:(i + 1) * batch_size, ...] = x0
            x0 = x_stack

            gray = (x0[:, 0, ...] + x0[:, 1, ...] + x0[:, 2, ...]).cuda() / 3.0
            gray1 = (x1[:, 0, ...] + x1[:, 1, ...] + x1[:, 2, ...] +
                     x1[:, 3, ...] + x1[:, 4, ...] +
                     x1[:, 5, ...]).cuda() / 6.0

            gray_mixed = torch.stack([gray, gray, gray], dim=1)
            gray_mixed_1 = torch.stack(
                [gray1, gray1, gray1, gray1, gray1, gray1], dim=1)

            x0 = nn.Parameter(
                torch.Tensor(samples * batch_size, 6, x0.shape[2],
                             x0.shape[3]).uniform_(-1, 1)).cuda()
            x01 = x0.clone()

            step_lr = 0.0003 * 0.04  #bedroom 0.04   church 0.02

            sigmas = np.array([
                1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637,
                0.04641589, 0.02782559, 0.01668101, 0.01
            ])
            n_steps_each = 100
            max_psnr = 0
            max_ssim = 0
            for idx, sigma in enumerate(sigmas):
                lambda_recon = 1. / sigma**2
                labels = torch.ones(1, device=x0.device) * idx
                labels = labels.long()

                step_size = step_lr * (sigma / sigmas[-1])**2

                print('sigma = {}'.format(sigma))
                for step in range(n_steps_each):
                    print('current step %03d iter' % step)
                    x0_mix = (x01[:, 0, ...] + x01[:, 1, ...] +
                              x01[:, 2, ...]) / 3.0
                    x1_mix = (x01[:, 0, ...] + x01[:, 1, ...] +
                              x01[:, 2, ...] + x01[:, 3, ...] +
                              x01[:, 4, ...] + x01[:, 5, ...]) / 6.0

                    error = torch.stack([x0_mix, x0_mix, x0_mix],
                                        dim=1) - gray_mixed
                    error1 = torch.stack(
                        [x1_mix, x1_mix, x1_mix, x1_mix, x1_mix, x1_mix],
                        dim=1) - gray_mixed_1

                    noise_x = torch.randn_like(x01) * np.sqrt(step_size * 2)

                    grad_x0 = scorenet(x01, labels).detach()

                    x0 = x01 + step_size * (grad_x0)
                    x0 = x0 - 0.1 * step_size * lambda_recon * error1  #bedroom 0.1  church 1.5
                    x0[:, 0:3,
                       ...] = x0[:, 0:3,
                                 ...] - step_size * lambda_recon * (error)

                    x0 = torch.mean(x0, dim=0)
                    x0 = torch.stack([x0, x0], dim=0)
                    x01 = x0.clone() + noise_x

                    x_rec = x0.clone().detach().cpu().numpy().transpose(
                        0, 2, 3, 1)

                    for j in range(x_rec.shape[0]):
                        x_rec_ = np.squeeze(x_rec[j, ...])
                        x_rec_ycbcr2rgb = cv2.cvtColor(x_rec_[..., 3:],
                                                       cv2.COLOR_YCrCb2BGR)
                        x_rec_ycbcr2rgb = np.clip(x_rec_ycbcr2rgb, 0, 1)

                    x_rec_ycbcr2rgb = x_rec_ycbcr2rgb[np.newaxis, ...]

                    x_rec = (x_rec[..., :3] + x_rec_ycbcr2rgb) / 2
                    original_image = np.array(original_image, dtype=np.float32)

                    for i in range(x_rec.shape[0]):
                        psnr = compare_psnr(x_rec[i, ...] * 255.0,
                                            original_image,
                                            data_range=255)
                        ssim = compare_ssim(x_rec[i, ...],
                                            original_image / 255.0,
                                            data_range=1,
                                            multichannel=True)
                        print("current {} step".format(step), 'PSNR :', psnr,
                              'SSIM :', ssim)
                    if max_psnr < psnr:
                        result_all[z, 0] = psnr
                        max_psnr = psnr
                        cv2.imwrite(
                            os.path.join(
                                self.args.image_folder,
                                'img_{}_Rec_6ch_finally.png'.format(z)),
                            (x_rec[i, ...] * 256.0).clip(0,
                                                         255).astype(np.uint8))
                        result_all[length,
                                   0] = sum(result_all[:length, 0]) / length

                    if max_ssim < ssim:
                        result_all[z, 1] = ssim
                        max_ssim = ssim
                        result_all[length,
                                   1] = sum(result_all[:length, 1]) / length

                    write_Data(result_all, z)

            x_save = x0.clone().detach().cpu().numpy().transpose(0, 2, 3, 1)

            for j in range(x_save.shape[0]):
                x_save_ = np.squeeze(x_save[j, ...])
                print(np.max(x_save_), np.min(x_save_))
                x_save_ycbcr2rgb = cv2.cvtColor(x_save_[..., 3:],
                                                cv2.COLOR_YCrCb2BGR)
                print(np.max(x_save_ycbcr2rgb), np.min(x_save_ycbcr2rgb))

            x_save_ycbcr2rgb = torch.tensor(x_save_ycbcr2rgb)
            x_save_ycbcr2rgb = torch.unsqueeze(x_save_ycbcr2rgb, 0)
            x_save_ycbcr2rgb = np.array(x_save_ycbcr2rgb)

            x_save = (x_save[..., :3] + x_save_ycbcr2rgb) / 2
            x_save = np.array(x_save).transpose(0, 3, 1, 2)
            x_save_R = x_save[:, 2:3, :, :]
            x_save_G = x_save[:, 1:2, :, :]
            x_save_B = x_save[:, 0:1, :, :]
            x_save = np.concatenate((x_save_R, x_save_G, x_save_B), 1)

            self.write_images(
                torch.tensor(x_save).detach().cpu(), 'x_end.png', 1, z)
Beispiel #12
0
    def test(self):
        # Load the score network
        states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'),
                            map_location=self.config.device)
        scorenet = CondRefineNetDilated(self.config).to(self.config.device)
        scorenet = torch.nn.DataParallel(scorenet)
        scorenet.load_state_dict(states[0])
        scorenet.eval()

        batch_size = 1
        samples = 1
        image_count = 2

        files_list = glob.glob('./ground turth/*.png')
        files_list = natsorted(files_list)
        length = len(files_list)
        result_all = np.zeros([101, 2])
        for z, file_path in enumerate(files_list):
            x0 = cv2.imread(file_path)
            x0 = cv2.resize(x0, (128, 128))
            original_image = x0.copy()
            x0 = torch.tensor(x0.transpose(2, 0, 1),
                              dtype=torch.float).unsqueeze(0) / 255.0
            x_stack = torch.zeros(
                [x0.shape[0] * samples, x0.shape[1], x0.shape[2], x0.shape[3]],
                dtype=torch.float32)

            for i in range(samples):
                x_stack[i * batch_size:(i + 1) * batch_size, ...] = x0
            x0 = x_stack

            gray = (
                (x0[:, 0, ...] + x0[:, 1, ...] + x0[:, 2, ...])).cuda() / 3.0
            print(gray.shape)
            gray_mixed = torch.stack([gray, gray, gray], dim=1)
            print(gray_mixed.shape)

            x0 = nn.Parameter(
                torch.Tensor(samples * batch_size, 3, x0.shape[2],
                             x0.shape[3]).uniform_(-1, 1)).cuda()
            x01 = x0.clone()

            x0_mix = ((x01[:, 0, ...] + x01[:, 1, ...] + x01[:, 2, ...])) / 3.0

            recon = (torch.stack([x0_mix, x0_mix, x0_mix], dim=1) -
                     gray_mixed)**2

            step_lr = 0.00003 * 0.1  # bedroom 0.1  church 0.2
            sigmas = np.array([
                1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637,
                0.04641589, 0.02782559, 0.01668101, 0.01
            ])
            n_steps_each = 60
            max_psnr = 0
            max_ssim = 0
            for idx, sigma in enumerate(sigmas):
                lambda_recon = 1. / sigma**2
                labels = torch.ones(1, device=x0.device) * idx
                labels = labels.long()

                step_size = step_lr * (sigma / sigmas[-1])**2

                print('sigma = {}'.format(sigma))
                for step in range(n_steps_each):
                    print('current step %03d iter' % step)
                    x0_mix = ((x01[:, 0, ...] + x01[:, 1, ...] +
                               x01[:, 2, ...])) / 3.0
                    error = torch.stack([x0_mix, x0_mix, x0_mix],
                                        dim=1) - gray_mixed
                    noise_x = torch.randn_like(x01) * np.sqrt(step_size * 2)
                    grad_x0 = scorenet(x01, labels).detach()
                    x0 = x01 + step_size * (grad_x0 - lambda_recon *
                                            (error))  # + noise_x
                    x01 = x0.clone() + noise_x

                    x_rec = x0.clone().detach().cpu().numpy().transpose(
                        0, 2, 3, 1)
                    max_result, post = 0, 0
                    for i in range(x_rec.shape[0]):
                        psnr = compare_psnr(x_rec[i, ...] * 255.0,
                                            original_image,
                                            data_range=255)
                        ssim = compare_ssim(x_rec[i, ...],
                                            original_image / 255.0,
                                            data_range=1,
                                            multichannel=True)
                        if max_result < psnr:
                            max_result = psnr
                            post = i
                        print("current {} step".format(step), 'PSNR :', psnr,
                              'SSIM :', ssim)

                    savemat('./Rec_Best', {'img': x_rec[post, ...]})
                    if max_psnr < psnr:
                        result_all[z, 0] = psnr
                        max_psnr = psnr
                        result_all[length,
                                   0] = sum(result_all[:length, 0]) / length

                    if max_ssim < ssim:
                        result_all[z, 1] = ssim
                        max_ssim = ssim
                        result_all[length,
                                   1] = sum(result_all[:length, 1]) / length

                    write_Data(result_all, z)

            x_save = x0.clone().detach().cpu()
            x_save = np.array(x_save)
            x_save_R = x_save[:, 2:3, :, :]
            x_save_G = x_save[:, 1:2, :, :]
            x_save_B = x_save[:, 0:1, :, :]
            x_save = np.concatenate((x_save_R, x_save_G, x_save_B), 1)
            self.write_images(
                torch.tensor(x_save).detach().cpu(), 'x_end_rgb.png', samples,
                z)
Beispiel #13
0
    def test(self):
        # Load the score network
        states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'),
                            map_location=self.config.device)
        scorenet = CondRefineNetDilated(self.config).to(self.config.device)
        scorenet = torch.nn.DataParallel(scorenet)
        scorenet.load_state_dict(states[0])
        scorenet.eval()

        image1, image2 = get_images_manual()
        mixed = (image1 + image2).float()
        curr_dir = SAVE_DIR

        if not os.path.exists(curr_dir):
            os.makedirs(curr_dir)

        mixed_grid = make_grid(mixed.detach() / 2., nrow=GRID_SIZE)
        save_image(mixed_grid, os.path.join(curr_dir, "mixed.png"))

        gt1_grid = make_grid(image1, nrow=GRID_SIZE)
        save_image(gt1_grid, os.path.join(curr_dir, "gt1.png"))

        gt2_grid = make_grid(image2, nrow=GRID_SIZE)
        save_image(gt2_grid, os.path.join(curr_dir, "gt2.png"))

        mixed = torch.Tensor(mixed).cuda().view(BATCH_SIZE, 3, 32, 32)

        y = nn.Parameter(torch.Tensor(BATCH_SIZE, 3, 32, 32).uniform_()).cuda()
        x = nn.Parameter(torch.Tensor(BATCH_SIZE, 3, 32, 32).uniform_()).cuda()

        step_lr = 0.00002

        # Noise amounts
        sigmas = np.array([
            1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637,
            0.04641589, 0.02782559, 0.01668101, 0.01
        ])
        n_steps_each = 100

        #lambda_recon = 1.5
        for idx, sigma in enumerate(sigmas):
            lambda_recon = 1.8 / (sigma**2)

            # Not completely sure what this part is for
            labels = torch.ones(1, device=x.device) * idx
            labels = labels.long()
            step_size = step_lr * (sigma / sigmas[-1])**2

            for step in range(n_steps_each):
                noise_x = torch.randn_like(x) * np.sqrt(step_size * 2)
                noise_y = torch.randn_like(y) * np.sqrt(step_size * 2)

                grad_x = scorenet(x.view(BATCH_SIZE, 3, 32, 32),
                                  labels).detach()
                grad_y = scorenet(y.view(BATCH_SIZE, 3, 32, 32),
                                  labels).detach()

                recon_loss = (torch.norm(torch.flatten(y + x - mixed))**2)
                print(recon_loss)
                recon_grads = torch.autograd.grad(recon_loss, [x, y])

                x = x + (step_size *
                         grad_x) + (-step_size * lambda_recon *
                                    recon_grads[0].detach()) + noise_x
                y = y + (step_size *
                         grad_y) + (-step_size * lambda_recon *
                                    recon_grads[1].detach()) + noise_y

        x_to_write = torch.Tensor(x.detach().cpu())
        y_to_write = torch.Tensor(y.detach().cpu())

        for idx in range(BATCH_SIZE):
            # PSNR
            est1 = psnr(x[idx], image1[idx].cuda()) + psnr(
                y[idx], image2[idx].cuda())
            est2 = psnr(x[idx], image2[idx].cuda()) + psnr(
                y[idx], image1[idx].cuda())

            if est2 > est1:
                x_to_write[idx] = y[idx]
                y_to_write[idx] = x[idx]

        # # Write x and y
        x_grid = make_grid(x_to_write, nrow=GRID_SIZE)
        save_image(x_grid, os.path.join(curr_dir, "x.png"))

        y_grid = make_grid(y_to_write, nrow=GRID_SIZE)
        save_image(y_grid, os.path.join(curr_dir, "y.png"))
    def train(self):
        if self.config.data.random_flip is False:
            tran_transform = test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])
        else:
            tran_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor()
            ])
            test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])

        if self.config.data.dataset == 'CIFAR10':
            dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=True, download=True,
                              transform=tran_transform)
            test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10_test'), train=False, download=True,
                                   transform=test_transform)

        elif self.config.data.dataset == 'MNIST':
            dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=True, download=True,
                            transform=tran_transform)
            test_dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist_test'), train=False, download=True,
                                 transform=test_transform)

        elif self.config.data.dataset == 'CELEBA':
            if self.config.data.random_flip:
                dataset = CelebA(root=os.path.join(self.args.run, 'datasets', 'celeba'), split='train',
                                 transform=transforms.Compose([
                                     transforms.CenterCrop(140),
                                     transforms.Resize(self.config.data.image_size),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                 ]), download=True)
            else:
                dataset = CelebA(root=os.path.join(self.args.run, 'datasets', 'celeba'), split='train',
                                 transform=transforms.Compose([
                                     transforms.CenterCrop(140),
                                     transforms.Resize(self.config.data.image_size),
                                     transforms.ToTensor(),
                                 ]), download=True)

            test_dataset = CelebA(root=os.path.join(self.args.run, 'datasets', 'celeba_test'), split='test',
                                  transform=transforms.Compose([
                                      transforms.CenterCrop(140),
                                      transforms.Resize(self.config.data.image_size),
                                      transforms.ToTensor(),
                                  ]), download=True)

        elif self.config.data.dataset == 'SVHN':
            dataset = SVHN(os.path.join(self.args.run, 'datasets', 'svhn'), split='train', download=True,
                           transform=tran_transform)
            test_dataset = SVHN(os.path.join(self.args.run, 'datasets', 'svhn_test'), split='test', download=True,
                                transform=test_transform)

        elif self.config.data.dataset == 'NYUv2':
            if self.config.data.random_flip is False:
                nyu_train_transform = nyu_test_transform = transforms.Compose([
                    transforms.CenterCrop((400, 400)),
                    transforms.Resize(32),
                    transforms.ToTensor()
                ])
            else:
                nyu_train_transform = transforms.Compose([
                    transforms.CenterCrop((400, 400)),
                    transforms.Resize(32),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.ToTensor()
                ])
                nyu_test_transform = transforms.Compose([
                    transforms.CenterCrop((400, 400)),
                    transforms.Resize(32),
                    transforms.ToTensor()
                ])

            dataset = NYUv2(os.path.join(self.args.run, 'datasets', 'nyuv2'), train=True, download=True,
                            rgb_transform=nyu_train_transform, depth_transform=nyu_train_transform)
            test_dataset = NYUv2(os.path.join(self.args.run, 'datasets', 'nyuv2'), train=False, download=True,
                                 rgb_transform=nyu_test_transform, depth_transform=nyu_test_transform)

        dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                num_workers=0)  # changed num_workers from 4 to 0
        test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                 num_workers=0, drop_last=True)  # changed num_workers from 4 to 0

        test_iter = iter(test_loader)
        self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels

        tb_path = os.path.join(self.args.run, 'tensorboard', self.args.doc)
        if os.path.exists(tb_path):
            shutil.rmtree(tb_path)

        tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path)
        score = CondRefineNetDilated(self.config).to(self.config.device)

        score = torch.nn.DataParallel(score)

        optimizer = self.get_optimizer(score.parameters())

        if self.args.resume_training:
            states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'))
            score.load_state_dict(states[0])
            optimizer.load_state_dict(states[1])

        step = 0

        sigmas = torch.tensor(
            np.exp(np.linspace(np.log(self.config.model.sigma_begin), np.log(self.config.model.sigma_end),
                               self.config.model.num_classes))).float().to(self.config.device)

        for epoch in range(self.config.training.n_epochs):
            for i, (X, y) in enumerate(dataloader):
                step += 1
                score.train()
                X = X.to(self.config.device)
                X = X / 256. * 255. + torch.rand_like(X) / 256.

                if self.config.data.logit_transform:
                    X = self.logit_transform(X)

                if self.config.data.dataset == 'NYUv2':
                    # concatenate depth map with image
                    y = y[0]
                    # code to see resized depth map
                    # input_gt_depth_image = y[0][0].data.cpu().numpy().astype(np.float32)
                    # plot.imsave('gt_depth_map_{}.png'.format(i), input_gt_depth_image,
                    #             cmap="viridis")
                    y = y.to(self.config.device)
                    X = torch.cat((X, y), 1)

                labels = torch.randint(0, len(sigmas), (X.shape[0],), device=X.device)

                if self.config.training.algo == 'dsm':
                    loss = anneal_dsm_score_estimation(score, X, labels, sigmas, self.config.training.anneal_power)
                elif self.config.training.algo == 'ssm':
                    loss = anneal_sliced_score_estimation_vr(score, X, labels, sigmas,
                                                             n_particles=self.config.training.n_particles)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                tb_logger.add_scalar('loss', loss, global_step=step)
                logging.info("step: {}, loss: {}".format(step, loss.item()))

                if step >= self.config.training.n_iters:
                    return 0

                if step % 100 == 0:
                    score.eval()
                    try:
                        test_X, test_y = next(test_iter)
                    except StopIteration:
                        test_iter = iter(test_loader)
                        test_X, test_y = next(test_iter)

                    test_X = test_X.to(self.config.device)

                    test_X = test_X / 256. * 255. + torch.rand_like(test_X) / 256.
                    if self.config.data.logit_transform:
                        test_X = self.logit_transform(test_X)

                    if self.config.data.dataset == 'NYUv2':
                        test_y = test_y[0]
                        test_y = test_y.to(self.config.device)
                        test_X = torch.cat((test_X, test_y), 1)

                    test_labels = torch.randint(0, len(sigmas), (test_X.shape[0],), device=test_X.device)

                    with torch.no_grad():
                        test_dsm_loss = anneal_dsm_score_estimation(score, test_X, test_labels, sigmas,
                                                                    self.config.training.anneal_power)

                    tb_logger.add_scalar('test_dsm_loss', test_dsm_loss, global_step=step)

                if step % self.config.training.snapshot_freq == 0:
                    states = [
                        score.state_dict(),
                        optimizer.state_dict(),
                    ]
                    torch.save(states, os.path.join(self.args.log, 'checkpoint_{}.pth'.format(step)))
                    torch.save(states, os.path.join(self.args.log, 'checkpoint.pth'))
    def train(self):
        if self.config.data.random_flip is False:
            tran_transform = test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])
        else:
            tran_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor()
            ])
            test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])

        if self.config.data.dataset == 'CIFAR10':
            dataset = CIFAR10(os.path.join(self.args.run, 'datasets',
                                           'cifar10'),
                              train=True,
                              download=True,
                              transform=tran_transform)
            test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets',
                                                'cifar10_test'),
                                   train=False,
                                   download=True,
                                   transform=test_transform)
        elif self.config.data.dataset == 'MNIST':
            dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'),
                            train=True,
                            download=True,
                            transform=tran_transform)
            test_dataset = MNIST(os.path.join(self.args.run, 'datasets',
                                              'mnist_test'),
                                 train=False,
                                 download=True,
                                 transform=test_transform)

        elif self.config.data.dataset == 'CELEBA':
            if self.config.data.random_flip:
                dataset = CelebA(
                    root=os.path.join(self.args.run, 'datasets', 'celeba'),
                    split='train',
                    transform=transforms.Compose([
                        transforms.CenterCrop(140),
                        transforms.Resize(self.config.data.image_size),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                    ]),
                    download=False)
            else:
                dataset = CelebA(
                    root=os.path.join(self.args.run, 'datasets', 'celeba'),
                    split='train',
                    transform=transforms.Compose([
                        transforms.CenterCrop(140),
                        transforms.Resize(self.config.data.image_size),
                        transforms.ToTensor(),
                    ]),
                    download=False)

            test_dataset = CelebA(
                root=os.path.join(self.args.run, 'datasets', 'celeba_test'),
                split='test',
                transform=transforms.Compose([
                    transforms.CenterCrop(140),
                    transforms.Resize(self.config.data.image_size),
                    transforms.ToTensor(),
                ]),
                download=False)

        elif self.config.data.dataset == 'SVHN':
            dataset = SVHN(os.path.join(self.args.run, 'datasets', 'svhn'),
                           split='train',
                           download=True,
                           transform=tran_transform)
            test_dataset = SVHN(os.path.join(self.args.run, 'datasets',
                                             'svhn_test'),
                                split='test',
                                download=True,
                                transform=test_transform)

        dataloader = DataLoader(dataset,
                                batch_size=self.config.training.batch_size,
                                shuffle=True,
                                num_workers=4)
        test_loader = DataLoader(test_dataset,
                                 batch_size=self.config.training.batch_size,
                                 shuffle=True,
                                 num_workers=4,
                                 drop_last=True)

        test_iter = iter(test_loader)
        self.config.input_dim = self.config.data.image_size**2 * self.config.data.channels

        tb_path = os.path.join(self.args.run, 'tensorboard', self.args.doc)
        if os.path.exists(tb_path):
            shutil.rmtree(tb_path)

        tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path)
        score = CondRefineNetDilated(self.config).to(self.config.device)

        score = torch.nn.DataParallel(score)

        optimizer = self.get_optimizer(score.parameters())

        if self.args.resume_training:
            states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'))
            score.load_state_dict(states[0])
            optimizer.load_state_dict(states[1])

        step = 0

        sigmas = torch.tensor(
            np.exp(
                np.linspace(np.log(self.config.model.sigma_begin),
                            np.log(self.config.model.sigma_end),
                            self.config.model.num_classes))).float().to(
                                self.config.device)

        time_record = []
        for epoch in range(self.config.training.n_epochs):
            for i, (X, y) in enumerate(dataloader):
                step += 1
                score.train()
                X = X.to(self.config.device)
                X = X / 256. * 255. + torch.rand_like(X) / 256.

                if self.config.data.logit_transform:
                    X = self.logit_transform(X)

                labels = torch.randint(0,
                                       len(sigmas), (X.shape[0], ),
                                       device=X.device)
                if self.config.training.algo == 'dsm':
                    t = time.time()
                    loss = anneal_dsm_score_estimation(
                        score, X, labels, sigmas,
                        self.config.training.anneal_power)
                elif self.config.training.algo == 'dsm_tracetrick':
                    t = time.time()
                    loss = anneal_dsm_score_estimation_TraceTrick(
                        score, X, labels, sigmas,
                        self.config.training.anneal_power)
                elif self.config.training.algo == 'ssm':
                    t = time.time()
                    loss = anneal_sliced_score_estimation_vr(
                        score,
                        X,
                        labels,
                        sigmas,
                        n_particles=self.config.training.n_particles)
                elif self.config.training.algo == 'esm_scorenet':
                    t = time.time()
                    loss = anneal_ESM_scorenet(
                        score,
                        X,
                        labels,
                        sigmas,
                        n_particles=self.config.training.n_particles)
                elif self.config.training.algo == 'esm_scorenet_VR':
                    t = time.time()
                    loss = anneal_ESM_scorenet_VR(
                        score,
                        X,
                        labels,
                        sigmas,
                        n_particles=self.config.training.n_particles)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                t = time.time() - t
                time_record.append(t)

                if step >= self.config.training.n_iters:
                    return 0

                if step % 100 == 0:
                    tb_logger.add_scalar('loss', loss, global_step=step)
                    logging.info(
                        "step: {}, loss: {}, time per step: {:.3f} +- {:.3f} ms"
                        .format(step, loss.item(),
                                np.mean(time_record) * 1e3,
                                np.std(time_record) * 1e3))

                    # if step % 2000 == 0:
                    #     score.eval()
                    #     try:
                    #         test_X, test_y = next(test_iter)
                    #     except StopIteration:
                    #         test_iter = iter(test_loader)
                    #         test_X, test_y = next(test_iter)

                    #     test_X = test_X.to(self.config.device)
                    #     test_X = test_X / 256. * 255. + torch.rand_like(test_X) / 256.

                    #     if self.config.data.logit_transform:
                    #         test_X = self.logit_transform(test_X)

                    #     test_labels = torch.randint(0, len(sigmas), (test_X.shape[0],), device=test_X.device)

                    #     #if self.config.training.algo == 'dsm':
                    #     with torch.no_grad():
                    #         test_dsm_loss = anneal_dsm_score_estimation(score, test_X, test_labels, sigmas,
                    #                                                         self.config.training.anneal_power)

                    #     tb_logger.add_scalar('test_dsm_loss', test_dsm_loss, global_step=step)
                    #     logging.info("step: {}, test dsm loss: {}".format(step, test_dsm_loss.item()))

                    # elif self.config.training.algo == 'ssm':
                    #     test_ssm_loss = anneal_sliced_score_estimation_vr(score, test_X, test_labels, sigmas,
                    #                                          n_particles=self.config.training.n_particles)

                    #     tb_logger.add_scalar('test_ssm_loss', test_ssm_loss, global_step=step)
                    #     logging.info("step: {}, test ssm loss: {}".format(step, test_ssm_loss.item()))

                if step >= 140000 and step % self.config.training.snapshot_freq == 0:
                    states = [
                        score.state_dict(),
                        optimizer.state_dict(),
                    ]
                    torch.save(
                        states,
                        os.path.join(self.args.log,
                                     'checkpoint_{}.pth'.format(step)))
                    torch.save(states,
                               os.path.join(self.args.log, 'checkpoint.pth'))
Beispiel #16
0
    def train(self):
        if self.config.data.random_flip is False:
            tran_transform = test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])
        else:
            tran_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor()
            ])
            test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])

        dataset = LoadDataset('./ground turth',tran_transform)
        dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=4)


        self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels

        tb_path = os.path.join(self.args.run, 'tensorboard', self.args.doc)
        if os.path.exists(tb_path):
            shutil.rmtree(tb_path)

        tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path)
        score = CondRefineNetDilated(self.config).to(self.config.device)

        score = torch.nn.DataParallel(score)

        optimizer = self.get_optimizer(score.parameters())

        if self.args.resume_training:
            states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'))
            score.load_state_dict(states[0])
            optimizer.load_state_dict(states[1])

        step = 0

        sigmas = torch.tensor(
            np.exp(np.linspace(np.log(self.config.model.sigma_begin), np.log(self.config.model.sigma_end),
                               self.config.model.num_classes))).float().to(self.config.device)

        for epoch in range(self.config.training.n_epochs):
            for i, X in enumerate(dataloader):
                X = torch.tensor(X,dtype=torch.float32)
                step += 1
                score.train()
                X = X.to(self.config.device)
                X = X / 256. * 255. + torch.rand_like(X) / 256.
                if self.config.data.logit_transform:
                    X = self.logit_transform(X)

                labels = torch.randint(0, len(sigmas), (X.shape[0],), device=X.device)
                if self.config.training.algo == 'dsm':
                    loss = anneal_dsm_score_estimation(score, X, labels, sigmas, self.config.training.anneal_power)
                elif self.config.training.algo == 'ssm':
                    loss = anneal_sliced_score_estimation_vr(score, X, labels, sigmas,
                                                             n_particles=self.config.training.n_particles)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                tb_logger.add_scalar('loss', loss, global_step=step)
                logging.info("step: {}, loss: {}".format(step, loss.item()))

                if step >= self.config.training.n_iters:
                    return 0

                
                if step % self.config.training.snapshot_freq == 0:
                    states = [
                        score.state_dict(),
                        optimizer.state_dict(),
                    ]
                    torch.save(states, os.path.join(self.args.log, 'checkpoint_{}.pth'.format(step)))

                    torch.save(states, os.path.join(self.args.log, 'checkpoint.pth'))
Beispiel #17
0
    def test(self):
        # Load the score network
        states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'),
                            map_location=self.config.device)
        scorenet = CondRefineNetDilated(self.config).to(self.config.device)
        scorenet = torch.nn.DataParallel(scorenet)
        scorenet.load_state_dict(states[0])
        scorenet.eval()

        # Grab the first two samples from MNIST
        dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'),
                          train=False,
                          download=True)
        data = dataset.data.transpose(0, 3, 1, 2)

        for iteration in range(500):
            print("Iteration {}".format(iteration))

            curr_dir = os.path.join(SAVE_DIR, "{:07d}".format(iteration))
            if not os.path.exists(curr_dir):
                os.makedirs(curr_dir)

            rand_idx = np.random.randint(0, data.shape[0] - 1, BATCH_SIZE)
            image = torch.tensor(data[rand_idx, :].astype(np.float) /
                                 255.).float()

            # GT color images
            image_grid = make_grid(image, nrow=GRID_SIZE)
            save_image(image_grid, os.path.join(curr_dir, "gt.png"))

            # Grayscale image
            image_gray = image.mean(1).view(BATCH_SIZE, 1, 32, 32)
            image_grid = make_grid(image_gray, nrow=GRID_SIZE)
            save_image(image_grid, os.path.join(curr_dir, "grayscale.png"))

            image_gray = image_gray.cuda()
            x = nn.Parameter(torch.Tensor(BATCH_SIZE, 3, 32,
                                          32).uniform_()).cuda()

            step_lr = 0.00002

            # Noise amounts
            sigmas = np.array([
                1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637,
                0.04641589, 0.02782559, 0.01668101, 0.01
            ])
            n_steps_each = 100
            # Weight to put on reconstruction error vs p(x)

            for idx, sigma in enumerate(sigmas):
                lambda_recon = 1.5 / (sigma**2)
                # Not completely sure what this part is for
                labels = torch.ones(1, device=x.device) * idx
                labels = labels.long()
                step_size = step_lr * (sigma / sigmas[-1])**2

                for step in range(n_steps_each):
                    noise_x = torch.randn_like(x) * np.sqrt(step_size * 2)
                    grad_x = scorenet(x, labels).detach()

                    recon_loss = (torch.norm(
                        torch.flatten(
                            x.mean(1).view(BATCH_SIZE, 1, 32, 32) -
                            image_gray))**2)
                    print(recon_loss)
                    recon_grads = torch.autograd.grad(recon_loss, [x])

                    x = x + (step_size *
                             grad_x) + (-step_size * lambda_recon *
                                        recon_grads[0].detach()) + noise_x

            # Write x
            image_grid = make_grid(x, nrow=GRID_SIZE)
            save_image(image_grid, os.path.join(curr_dir, "x.png"))

        import pdb
        pdb.set_trace()