def test(self): states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device) score = CondRefineNetDilated(self.config).to(self.config.device) score = torch.nn.DataParallel(score) score.load_state_dict(states[0]) if not os.path.exists(self.args.image_folder): os.makedirs(self.args.image_folder) sigmas = np.exp(np.linspace(np.log(self.config.model.sigma_begin), np.log(self.config.model.sigma_end), self.config.model.num_classes)) score.eval() grid_size = 5 imgs = [] if self.config.data.dataset == 'MNIST': samples = torch.rand(grid_size ** 2, 1, 28, 28, device=self.config.device) all_samples = self.anneal_Langevin_dynamics(samples, score, sigmas, 100, 0.00002) for i, sample in enumerate(tqdm.tqdm(all_samples, total=len(all_samples), desc='saving images')): sample = sample.view(grid_size ** 2, self.config.data.channels, self.config.data.image_size, self.config.data.image_size) if self.config.data.logit_transform: sample = torch.sigmoid(sample) image_grid = make_grid(sample, nrow=grid_size) if i % 10 == 0: im = Image.fromarray(image_grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()) imgs.append(im) save_image(image_grid, os.path.join(self.args.image_folder, 'image_{}.png'.format(i))) torch.save(sample, os.path.join(self.args.image_folder, 'image_raw_{}.pth'.format(i))) else: samples = torch.rand(grid_size ** 2, 3, 32, 32, device=self.config.device) all_samples = self.anneal_Langevin_dynamics(samples, score, sigmas, 100, 0.00002) for i, sample in enumerate(tqdm.tqdm(all_samples, total=len(all_samples), desc='saving images')): sample = sample.view(grid_size ** 2, self.config.data.channels, self.config.data.image_size, self.config.data.image_size) if self.config.data.logit_transform: sample = torch.sigmoid(sample) image_grid = make_grid(sample, nrow=grid_size) if i % 10 == 0: im = Image.fromarray(image_grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()) imgs.append(im) save_image(image_grid, os.path.join(self.args.image_folder, 'image_{}.png'.format(i)), nrow=10) torch.save(sample, os.path.join(self.args.image_folder, 'image_raw_{}.pth'.format(i))) imgs[0].save(os.path.join(self.args.image_folder, "movie.gif"), save_all=True, append_images=imgs[1:], duration=1, loop=0)
def save_sampled_images(self): num_of_step = 100 bs = 100 sigmas = np.exp( np.linspace(np.log(self.config.model.sigma_begin), np.log(self.config.model.sigma_end), self.config.model.num_classes)) if not os.path.exists(self.args.image_folder): os.makedirs(self.args.image_folder) print('Load checkpoint from' + self.args.log) for epochs in [100000, 170000]: states = torch.load(os.path.join( self.args.log, 'checkpoint_' + str(epochs) + '.pth'), map_location=self.config.device) score = CondRefineNetDilated(self.config).to(self.config.device) score = torch.nn.DataParallel(score) score.load_state_dict(states[0]) score.eval() if not os.path.exists( os.path.join(self.args.image_folder, 'epochs' + str(epochs))): os.makedirs( os.path.join(self.args.image_folder, 'epochs' + str(epochs))) save_index = 0 print("Begin epochs", epochs) for j in range(num_of_step): samples = torch.rand(bs, 3, 32, 32, device=self.config.device) all_samples = self.anneal_Langevin_dynamics_GenerateImages( samples, score, sigmas, 100, 0.00002) images_new = all_samples.mul_(255).add_(0.5).clamp_( 0, 255).permute(0, 2, 3, 1).to('cpu', torch.uint8).numpy() for k in range(len(images_new)): ims = Image.fromarray(images_new[k]) ims.save( os.path.join(self.args.image_folder, 'epochs' + str(epochs), 'img_' + str(save_index) + '.png')) print('Save images ', k) save_index += 1
def test_inpainting(self): states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device) score = CondRefineNetDilated(self.config).to(self.config.device) score = torch.nn.DataParallel(score) score.load_state_dict(states[0]) if not os.path.exists(self.args.image_folder): os.makedirs(self.args.image_folder) sigmas = np.exp(np.linspace(np.log(self.config.model.sigma_begin), np.log(self.config.model.sigma_end), self.config.model.num_classes)) score.eval() imgs = [] if self.config.data.dataset == 'CELEBA': dataset = CelebA(root=os.path.join(self.args.run, 'datasets', 'celeba'), split='test', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(self.config.data.image_size), transforms.ToTensor(), ]), download=True) dataloader = DataLoader(dataset, batch_size=20, shuffle=True, num_workers=0) # changed num_workers from 4 to 0 refer_image, _ = next(iter(dataloader)) samples = torch.rand(20, 20, 3, self.config.data.image_size, self.config.data.image_size, device=self.config.device) all_samples = self.anneal_Langevin_dynamics_inpainting(samples, refer_image, score, sigmas, 100, 0.00002) torch.save(refer_image, os.path.join(self.args.image_folder, 'refer_image.pth')) for i, sample in enumerate(tqdm.tqdm(all_samples)): sample = sample.view(400, self.config.data.channels, self.config.data.image_size, self.config.data.image_size) if self.config.data.logit_transform: sample = torch.sigmoid(sample) image_grid = make_grid(sample, nrow=20) if i % 10 == 0: im = Image.fromarray( image_grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()) imgs.append(im) save_image(image_grid, os.path.join(self.args.image_folder, 'image_completion_{}.png'.format(i))) torch.save(sample, os.path.join(self.args.image_folder, 'image_completion_raw_{}.pth'.format(i))) elif self.config.data.dataset == 'NYUv2': # TODO implement inpainting and MSE calculate for NYUv2 nyu_transform = transforms.Compose([transforms.CenterCrop((400, 400)), transforms.Resize(32), transforms.ToTensor()]) dataset = NYUv2(os.path.join(self.args.run, 'datasets', 'nyuv2'), train=False, download=True, rgb_transform=nyu_transform, depth_transform=nyu_transform) dataloader = DataLoader(dataset, batch_size=20, shuffle=True, num_workers=0) data_iter = iter(dataloader) rgb_image, depth = next(data_iter) rgb_image = rgb_image.to(self.config.device) depth = depth[0].to(self.config.device) # MSE loss evaluation mse = torch.nn.MSELoss() rgb_image = rgb_image / 256. * 255. + torch.rand_like(rgb_image) / 256. # torch.save(rgb_image, os.path.join(self.args.image_folder, 'rgb_image.pth')) samples = torch.rand(20, 20, self.config.data.channels, self.config.data.image_size, self.config.data.image_size).to(self.config.device) all_depth_samples, depth_pred = self.anneal_Langevin_dynamics_prediction(samples, rgb_image, score, sigmas, 100, 0.00002) print("MSE loss is %5.4f" % (mse(depth_pred, depth))) for i, sample in enumerate(tqdm.tqdm(all_depth_samples)): sample = sample.view(400, self.config.data.channels - 3, self.config.data.image_size, self.config.data.image_size) sample = torch.cat((depth.to('cpu'), sample), 0) depth_grid = make_grid(sample, nrow=20) if i % 10 == 0: # dep = Image.fromarray(depth_grid.to('cpu').numpy().astype(np.float32), mode='F') dep = F.to_pil_image(depth_grid[0, :, :]) imgs.append(dep) save_image(depth_grid, os.path.join(self.args.image_folder, 'depth_prediction_{}.png'.format(i))) # torch.save(sample, os.path.join(self.args.image_folder, 'depth_prediction_raw_{}.pth'.format(i))) else: transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.ToTensor() ]) if self.config.data.dataset == 'CIFAR10': dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=True, download=True, transform=transform) elif self.config.data.dataset == 'SVHN': dataset = SVHN(os.path.join(self.args.run, 'datasets', 'svhn'), split='train', download=True, transform=transform) dataloader = DataLoader(dataset, batch_size=20, shuffle=True, num_workers=0) # changed num_workers from 4 to 0 data_iter = iter(dataloader) refer_image, _ = next(data_iter) torch.save(refer_image, os.path.join(self.args.image_folder, 'refer_image.pth')) samples = torch.rand(20, 20, self.config.data.channels, self.config.data.image_size, self.config.data.image_size).to(self.config.device) all_samples = self.anneal_Langevin_dynamics_inpainting(samples, refer_image, score, sigmas, 100, 0.00002) for i, sample in enumerate(tqdm.tqdm(all_samples)): sample = sample.view(400, self.config.data.channels, self.config.data.image_size, self.config.data.image_size) if self.config.data.logit_transform: sample = torch.sigmoid(sample) image_grid = make_grid(sample, nrow=20) if i % 10 == 0: im = Image.fromarray( image_grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()) imgs.append(im) save_image(image_grid, os.path.join(self.args.image_folder, 'image_completion_{}.png'.format(i))) torch.save(sample, os.path.join(self.args.image_folder, 'image_completion_raw_{}.pth'.format(i))) imgs[0].save(os.path.join(self.args.image_folder, "movie.gif"), save_all=True, append_images=imgs[1:], duration=1, loop=0)
def test_inpainting(self): states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device) score = CondRefineNetDilated(self.config).to(self.config.device) score = torch.nn.DataParallel(score) score.load_state_dict(states[0]) if not os.path.exists(self.args.image_folder): os.makedirs(self.args.image_folder) sigmas = np.exp( np.linspace(np.log(self.config.model.sigma_begin), np.log(self.config.model.sigma_end), self.config.model.num_classes)) score.eval() imgs = [] if self.config.data.dataset == 'CELEBA': dataset = CelebA( root=os.path.join(self.args.run, 'datasets', 'celeba'), split='test', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(self.config.data.image_size), transforms.ToTensor(), ]), download=True) dataloader = DataLoader(dataset, batch_size=20, shuffle=True, num_workers=4) refer_image, _ = next(iter(dataloader)) samples = torch.rand(20, 20, 3, self.config.data.image_size, self.config.data.image_size, device=self.config.device) all_samples = self.anneal_Langevin_dynamics_inpainting( samples, refer_image, score, sigmas, 100, 0.00002) torch.save(refer_image, os.path.join(self.args.image_folder, 'refer_image.pth')) for i, sample in enumerate(tqdm.tqdm(all_samples)): sample = sample.view(400, self.config.data.channels, self.config.data.image_size, self.config.data.image_size) if self.config.data.logit_transform: sample = torch.sigmoid(sample) image_grid = make_grid(sample, nrow=20) if i % 10 == 0: im = Image.fromarray( image_grid.mul_(255).add_(0.5).clamp_(0, 255).permute( 1, 2, 0).to('cpu', torch.uint8).numpy()) imgs.append(im) save_image( image_grid, os.path.join(self.args.image_folder, 'image_completion_{}.png'.format(i))) torch.save( sample, os.path.join(self.args.image_folder, 'image_completion_raw_{}.pth'.format(i))) else: transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.ToTensor() ]) if self.config.data.dataset == 'CIFAR10': dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=True, download=True, transform=transform) elif self.config.data.dataset == 'SVHN': dataset = SVHN(os.path.join(self.args.run, 'datasets', 'svhn'), split='train', download=True, transform=transform) dataloader = DataLoader(dataset, batch_size=20, shuffle=True, num_workers=4) data_iter = iter(dataloader) refer_image, _ = next(data_iter) torch.save(refer_image, os.path.join(self.args.image_folder, 'refer_image.pth')) samples = torch.rand( 20, 20, self.config.data.channels, self.config.data.image_size, self.config.data.image_size).to(self.config.device) all_samples = self.anneal_Langevin_dynamics_inpainting( samples, refer_image, score, sigmas, 100, 0.00002) for i, sample in enumerate(tqdm.tqdm(all_samples)): sample = sample.view(400, self.config.data.channels, self.config.data.image_size, self.config.data.image_size) if self.config.data.logit_transform: sample = torch.sigmoid(sample) image_grid = make_grid(sample, nrow=20) if i % 10 == 0: im = Image.fromarray( image_grid.mul_(255).add_(0.5).clamp_(0, 255).permute( 1, 2, 0).to('cpu', torch.uint8).numpy()) imgs.append(im) save_image( image_grid, os.path.join(self.args.image_folder, 'image_completion_{}.png'.format(i))) torch.save( sample, os.path.join(self.args.image_folder, 'image_completion_raw_{}.pth'.format(i))) imgs[0].save(os.path.join(self.args.image_folder, "movie.gif"), save_all=True, append_images=imgs[1:], duration=1, loop=0)
def test(self): all_psnr = [] # All signal to noise ratios over all the batches all_percentages = [] # All percentage accuracies dummy_metrics = [] # Metrics for the averaging value # Load the score network states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device) scorenet = CondRefineNetDilated(self.config).to(self.config.device) scorenet = torch.nn.DataParallel(scorenet) scorenet.load_state_dict(states[0]) scorenet.eval() # Grab the first two samples from MNIST trans = transforms.Compose([transforms.ToTensor()]) dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=False, download=True) data = dataset.data.transpose(0, 3, 1, 2) for iteration in range(100): print("Iteration {}".format(iteration)) curr_dir = os.path.join(SAVE_DIR, "{:07d}".format(iteration)) if not os.path.exists(curr_dir): os.makedirs(curr_dir) # image1, image2 = get_images_split(first_digits, second_digits) gt_images = [] for _ in range(N): gt_images.append(get_single_image(data)) mixed = sum(gt_images).float() mixed_grid = make_grid(mixed.detach() / float(N), nrow=GRID_SIZE, pad_value=1., padding=1) save_image(mixed_grid, os.path.join(curr_dir, "mixed.png")) for i in range(N): gt_grid = make_grid(gt_images[i], nrow=GRID_SIZE, pad_value=1., padding=1) save_image(gt_grid, os.path.join(curr_dir, "gt{}.png".format(i))) mixed = torch.Tensor(mixed).cuda().view(BATCH_SIZE, 3, 32, 32) xs = [] for _ in range(N): xs.append(nn.Parameter(torch.Tensor(BATCH_SIZE, 3, 32, 32).uniform_()).cuda()) step_lr=0.00002 # Noise amounts sigmas = np.array([1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637, 0.04641589, 0.02782559, 0.01668101, 0.01]) n_steps_each = 200 for idx, sigma in enumerate(sigmas): lambda_recon = 1.8/(sigma**2) # Not completely sure what this part is for labels = torch.ones(1, device=xs[0].device) * idx labels = labels.long() step_size = step_lr * (sigma / sigmas[-1]) ** 2 for step in range(n_steps_each): noises = [] for _ in range(N): noises.append(torch.randn_like(xs[0]) * np.sqrt(step_size * 2)) grads = [] for i in range(N): grads.append(scorenet(xs[i].view(BATCH_SIZE, 3, 32, 32), labels).detach()) recon_loss = (torch.norm(torch.flatten(sum(xs) - mixed)) ** 2) print(recon_loss) recon_grads = torch.autograd.grad(recon_loss, xs) for i in range(N): xs[i] = xs[i] + (step_size * grads[i]) + (-step_size * lambda_recon * recon_grads[i].detach()) + noises[i] for i in range(N): xs[i] = torch.clamp(xs[i], 0, 1) x_to_write = [] for i in range(N): x_to_write.append(torch.Tensor(xs[i].detach().cpu())) # PSNR Measure for idx in range(BATCH_SIZE): best_psnr = -10000 best_permutation = None for permutation in permutations(range(N)): curr_psnr = sum([psnr(xs[permutation[i]][idx], gt_images[i][idx].cuda()) for i in range(N)]) if curr_psnr > best_psnr: best_psnr = curr_psnr best_permutation = permutation all_psnr.append(best_psnr / float(N)) for i in range(N): x_to_write[i][idx] = xs[best_permutation[i]][idx] mixed_psnr = psnr(mixed.detach().cpu()[idx] / float(N), gt_images[i][idx]) dummy_metrics.append(mixed_psnr) for i in range(N): x_grid = make_grid(x_to_write[i], nrow=GRID_SIZE, pad_value=1., padding=1) save_image(x_grid, os.path.join(curr_dir, "x{}.png".format(i))) mixed_grid = make_grid(sum(xs)/float(N), nrow=GRID_SIZE, pad_value=1., padding=1) save_image(mixed_grid, os.path.join(curr_dir, "recon.png".format(i))) # average_grid = make_grid(mixed.detach()/2., nrow=GRID_SIZE) # save_image(average_grid, "results/average_cifar.png") print("Curr mean {}".format(np.array(all_psnr).mean())) print("Const mean {}".format(np.array(dummy_metrics).mean()))
def calculate_fid(self): import fid import tensorflow as tf num_of_step = 500 bs = 100 sigmas = np.exp( np.linspace(np.log(self.config.model.sigma_begin), np.log(self.config.model.sigma_end), self.config.model.num_classes)) stats_path = 'fid_stats_cifar10_train.npz' # training set statistics inception_path = fid.check_or_download_inception( None) # download inception network print('Load checkpoint from' + self.args.log) #for epochs in range(140000, 200001, 1000): for epochs in [149000]: states = torch.load(os.path.join( self.args.log, 'checkpoint_' + str(epochs) + '.pth'), map_location=self.config.device) #states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device) score = CondRefineNetDilated(self.config).to(self.config.device) score = torch.nn.DataParallel(score) score.load_state_dict(states[0]) score.eval() if self.config.data.dataset == 'MNIST': print("Begin epochs", epochs) samples = torch.rand(bs, 1, 28, 28, device=self.config.device) all_samples = self.anneal_Langevin_dynamics_GenerateImages( samples, score, sigmas, 100, 0.00002) images = all_samples.mul_(255).add_(0.5).clamp_( 0, 255).permute(0, 2, 3, 1).to('cpu').numpy() for j in range(num_of_step - 1): samples = torch.rand(bs, 3, 32, 32, device=self.config.device) all_samples = self.anneal_Langevin_dynamics_GenerateImages( samples, score, sigmas, 100, 0.00002) images_new = all_samples.mul_(255).add_(0.5).clamp_( 0, 255).permute(0, 2, 3, 1).to('cpu').numpy() images = np.concatenate((images, images_new), axis=0) else: print("Begin epochs", epochs) samples = torch.rand(bs, 3, 32, 32, device=self.config.device) all_samples = self.anneal_Langevin_dynamics_GenerateImages( samples, score, sigmas, 100, 0.00002) images = all_samples.mul_(255).add_(0.5).clamp_( 0, 255).permute(0, 2, 3, 1).to('cpu').numpy() for j in range(num_of_step - 1): samples = torch.rand(bs, 3, 32, 32, device=self.config.device) all_samples = self.anneal_Langevin_dynamics_GenerateImages( samples, score, sigmas, 100, 0.00002) images_new = all_samples.mul_(255).add_(0.5).clamp_( 0, 255).permute(0, 2, 3, 1).to('cpu').numpy() images = np.concatenate((images, images_new), axis=0) # load precalculated training set statistics f = np.load(stats_path) mu_real, sigma_real = f['mu'][:], f['sigma'][:] f.close() fid.create_inception_graph( inception_path) # load the graph into the current TF graph with tf.Session() as sess: sess.run(tf.global_variables_initializer()) mu_gen, sigma_gen = fid.calculate_activation_statistics( images, sess, batch_size=100) fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real) print("FID: %s" % fid_value)
def test(self): # For metrics all_psnr = [] # All signal to noise ratios over all the batches all_grayscale_psnr = [ ] # All signal to noise ratios over all the batches all_mixed_psnr = [] # All signal to noise ratios over all the batches all_mixed_grayscale_psnr = [ ] # All signal to noise ratios over all the batches all_ssim = [] all_grayscale_ssim = [] all_mixed_ssim = [] # All signal to noise ratios over all the batches all_mixed_grayscale_ssim = [ ] # All signal to noise ratios over all the batches strange_cases = {"gt1": [], "gt2": [], "mixed": [], "x": [], "y": []} # For inception score output_to_incept = [] mixed_to_incept = [] # For videos all_x = [] all_y = [] all_mixed = [] # Load the score network states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device) scorenet = CondRefineNetDilated(self.config).to(self.config.device) scorenet = torch.nn.DataParallel(scorenet) scorenet.load_state_dict(states[0]) scorenet.eval() # Grab the first two samples from MNIST dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=False, download=True) data = dataset.data.transpose(0, 3, 1, 2) all_animals_idx = np.isin(dataset.targets, ANIMAL_IDX) all_machines_idx = np.isin(dataset.targets, MACHINE_IDX) all_animals = dataset.data[all_animals_idx].transpose(0, 3, 1, 2) all_machines = dataset.data[all_machines_idx].transpose(0, 3, 1, 2) for iteration in range(105, 250): print("Iteration {}".format(iteration)) curr_dir = os.path.join(SAVE_DIR, "{:07d}".format(iteration)) if not os.path.exists(curr_dir): os.makedirs(curr_dir) # rand_idx_1 = np.random.randint(0, data.shape[0] - 1, BATCH_SIZE) # rand_idx_2 = np.random.randint(0, data.shape[0] - 1, BATCH_SIZE) # image1 = torch.tensor(data[rand_idx_1, :].astype(np.float) / 255.).float() # image2 = torch.tensor(data[rand_idx_2, :].astype(np.float) / 255.).float() image1, image2 = get_images_split(all_animals, all_machines) mixed = (image1 + image2).float() mixed_grid = make_grid(mixed.detach() / 2., nrow=GRID_SIZE) save_image(mixed_grid, os.path.join(curr_dir, "mixed.png")) gt1_grid = make_grid(image1, nrow=GRID_SIZE) save_image(gt1_grid, os.path.join(curr_dir, "gt1.png")) gt2_grid = make_grid(image2, nrow=GRID_SIZE) save_image(gt2_grid, os.path.join(curr_dir, "gt2.png")) mixed = torch.Tensor(mixed).cuda().view(BATCH_SIZE, 3, 32, 32) y = nn.Parameter(torch.Tensor(BATCH_SIZE, 3, 32, 32).uniform_()).cuda() x = nn.Parameter(torch.Tensor(BATCH_SIZE, 3, 32, 32).uniform_()).cuda() step_lr = 0.00002 # Noise amounts sigmas = np.array([ 1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637, 0.04641589, 0.02782559, 0.01668101, 0.01 ]) n_steps_each = 100 #lambda_recon = 1.5 for idx, sigma in enumerate(sigmas): lambda_recon = 1.0 / (sigma**2) # Not completely sure what this part is for labels = torch.ones(1, device=x.device) * idx labels = labels.long() step_size = step_lr * (sigma / sigmas[-1])**2 for step in range(n_steps_each): noise_x = torch.randn_like(x) * np.sqrt(step_size * 2) noise_y = torch.randn_like(y) * np.sqrt(step_size * 2) grad_x = scorenet(x.view(BATCH_SIZE, 3, 32, 32), labels).detach() grad_y = scorenet(y.view(BATCH_SIZE, 3, 32, 32), labels).detach() recon_loss = (torch.norm(torch.flatten(y + x - mixed))**2) recon_grads = torch.autograd.grad(recon_loss, [x, y]) #x = x + (step_size * grad_x) + noise_x #y = y + (step_size * grad_y) + noise_y x = x + (step_size * grad_x) + (-step_size * lambda_recon * recon_grads[0].detach()) + noise_x y = y + (step_size * grad_y) + (-step_size * lambda_recon * recon_grads[1].detach()) + noise_y # Video # if (step % 5) == 0: # all_x.append(x.detach().cpu().numpy()) # all_y.append(y.detach().cpu().numpy()) # all_mixed.append((x.detach().cpu().numpy() + y.detach().cpu().numpy()) / 2.) #lambda_recon *= 2.8 # Inception # from model import get_inception_score # output_to_incept += [np.clip(x[idx].detach().cpu().numpy().transpose(1,2,0), 0, 1) * 255. for idx in range(x.shape[0])] # output_to_incept += [np.clip(y[idx].detach().cpu().numpy().transpose(1,2,0), 0, 1) * 255. for idx in range(y.shape[0])] # mixed_to_incept += [np.clip((mixed[idx] / 2.).detach().cpu().numpy().transpose(1,2,0), 0, 1) * 255. for idx in range(y.shape[0])] # mixed_to_incept += [np.clip((mixed[idx] / 2.).detach().cpu().numpy().transpose(1,2,0), 0, 1) * 255. for idx in range(y.shape[0])] x_to_write = torch.Tensor(x.detach().cpu()) y_to_write = torch.Tensor(y.detach().cpu()) # x_movie = np.array(np.stack(all_x, axis=0)) # y_movie = np.array(np.stack(all_y, axis=0)) # mixed_movie = np.array(np.stack(all_mixed, axis=0)) for idx in range(BATCH_SIZE): # PSNR est1 = psnr(x[idx], image1[idx].cuda()) + psnr( y[idx], image2[idx].cuda()) est2 = psnr(x[idx], image2[idx].cuda()) + psnr( y[idx], image1[idx].cuda()) correct_psnr = max(est1, est2) / 2. all_psnr.append(correct_psnr) grayscale_est1 = psnr( x[idx].mean(0), image1[idx].mean(0).cuda()) + psnr( y[idx].mean(0), image2[idx].mean(0).cuda()) grayscale_est2 = psnr( x[idx].mean(0), image2[idx].mean(0).cuda()) + psnr( y[idx].mean(0), image1[idx].mean(0).cuda()) grayscale_psnr = max(grayscale_est1, grayscale_est2) / 2. all_grayscale_psnr.append(grayscale_psnr) # Mixed PSNR mixed_psnr = psnr((mixed[idx] / 2.), image1[idx].cuda()) all_mixed_psnr.append(mixed_psnr) grayscale_mixed_psnr = psnr((mixed[idx] / 2.).mean(0), image1[idx].mean(0).cuda()) all_mixed_grayscale_psnr.append(grayscale_mixed_psnr) if est2 > est1: x_to_write[idx] = y[idx] y_to_write[idx] = x[idx] # tmp = x_movie[:, idx].copy() # x_movie[:, idx] = y_movie[:, idx] # y_movie[:, idx] = tmp # SSIM est1 = get_ssim(x[idx], image1[idx]) + get_ssim( y[idx], image2[idx]) est2 = get_ssim(x[idx], image2[idx]) + get_ssim( y[idx], image1[idx]) correct_ssim = max(est1, est2) / 2. all_ssim.append(correct_ssim) grayscale_est1 = get_ssim_grayscale( x[idx], image1[idx]) + get_ssim(y[idx], image2[idx]) grayscale_est2 = get_ssim_grayscale( x[idx], image2[idx]) + get_ssim(y[idx], image1[idx]) grayscale_ssim = max(grayscale_est1, grayscale_est2) / 2. all_grayscale_ssim.append(grayscale_ssim) # Mixed ssim mixed_ssim = get_ssim( (mixed[idx] / 2.), image1[idx]) + get_ssim( (mixed[idx] / 2.), image2[idx]) all_mixed_ssim.append(mixed_ssim / 2.) grayscale_mixed_ssim = get_ssim_grayscale( (mixed[idx] / 2.), image1[idx]) + get_ssim_grayscale( (mixed[idx] / 2.), image2[idx]) all_mixed_grayscale_ssim.append(grayscale_mixed_ssim / 2.) if correct_psnr < 19 and grayscale_psnr > 20.5: strange_cases["gt1"].append( image1[idx].detach().cpu().numpy()) strange_cases["gt2"].append( image2[idx].detach().cpu().numpy()) strange_cases["mixed"].append( mixed[idx].detach().cpu().numpy()) strange_cases["x"].append( x_to_write[idx].detach().cpu().numpy()) strange_cases["y"].append( y_to_write[idx].detach().cpu().numpy()) print("Added strange case") # # Write x and y x_grid = make_grid(x_to_write, nrow=GRID_SIZE) save_image(x_grid, os.path.join(curr_dir, "x.png")) y_grid = make_grid(y_to_write, nrow=GRID_SIZE) save_image(y_grid, os.path.join(curr_dir, "y.png")) print("PSNR {}".format(np.array(all_psnr).mean())) print("Mixed PSNR {}".format(np.array(all_mixed_psnr).mean())) print("PSNR Grayscale {}".format( np.array(all_grayscale_psnr).mean())) print("Mixed PSNR Grayscale {}".format( np.array(all_mixed_grayscale_psnr).mean())) print("SSIM {}".format(np.array(all_ssim).mean())) print("Mixed SSIM {}".format(np.array(all_mixed_ssim).mean())) print("SSIM Grayscale {}".format( np.array(all_grayscale_ssim).mean())) print("Mixed SSIM Grayscale {}".format( np.array(all_mixed_grayscale_ssim).mean())) # Write video frames # padding = 50 # dim_w = 172 * 4 + padding * 5 # dim_h = 172 * 2 + padding * 3 # for frame_idx in range(x_movie.shape[0]): # print(frame_idx) # x_grid = make_grid(torch.Tensor(x_movie[frame_idx]), nrow=GRID_SIZE) # # save_image(x_grid, "results/videos/x/x_{}.png".format(frame_idx)) # y_grid = make_grid(torch.Tensor(y_movie[frame_idx]), nrow=GRID_SIZE) # # save_image(y_grid, "results/videos/y/y_{}.png".format(frame_idx)) # recon_grid = make_grid(torch.Tensor(mixed_movie[frame_idx]), nrow=GRID_SIZE) # # save_image(recon_grid, "results/videos/mixed/mixed_{}.png".format(frame_idx)) # output_frame = torch.zeros(3, dim_h, dim_w) # output_frame[:, 50:(50+172), 50:(50+172)] = gt1_grid # output_frame[:, (100+172):(100+172*2), 50:(50+172)] = gt2_grid # output_frame[:, (75 + 86):(75 + 86 + 172), (50 * 2 + 172):(50 * 2 + 172 * 2)] = mixed_grid # output_frame[:, (75 + 86):(75 + 86 + 172), (50 * 3 + 172 * 2):(50 * 3 + 172 * 3)] = recon_grid # output_frame[:,50:(50 + 172),(50 * 4 + 172 * 3):(50 * 4 + 172 * 4)] = x_grid # output_frame[:,(50 * 2 + 172):(50 * 2 + 172 * 2),(50 * 4 + 172 * 3):(50 * 4 + 172 * 4)] = y_grid # save_image(output_frame, "results/videos/combined/{:03d}.png".format(frame_idx)) # Calculate inception scores # print("Output inception score {}".format(get_inception_score(output_to_incept))) # print("Mixed inception score {}".format(get_inception_score(mixed_to_incept))) # Write strange results y1 = np.stack(strange_cases["y"], axis=0) y_grid = make_grid(torch.Tensor(y1)) save_image(y_grid, "results/y_strange.png") x1 = np.stack(strange_cases["x"], axis=0) x_grid = make_grid(torch.Tensor(x1)) save_image(x_grid, "results/x_strange.png") gt1 = np.stack(strange_cases["gt1"], axis=0) gt1_grid = make_grid(torch.Tensor(gt1)) save_image(gt1_grid, "results/gt1_strange.png") gt2 = np.stack(strange_cases["gt2"], axis=0) gt2_grid = make_grid(torch.Tensor(gt2)) save_image(gt2_grid, "results/gt2_strange.png") mixed = np.stack(strange_cases["mixed"], axis=0) / 2. mixed_grid = make_grid(torch.Tensor(mixed)) save_image(mixed_grid, "results/mixed_strange.png")
def test(self): all_psnr = [] # All signal to noise ratios over all the batches all_percentages = [] # All percentage accuracies dummy_metrics = [] # Metrics for the averaging value bad_cases = {"gt1": [], "gt2": [], "mixed": [], "x": [], "y": []} # Load the score network states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device) scorenet = CondRefineNetDilated(self.config).to(self.config.device) scorenet = torch.nn.DataParallel(scorenet) scorenet.load_state_dict(states[0]) scorenet.eval() trans = transforms.Compose([transforms.ToTensor()]) dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=False, download=True) first_digits_idx = dataset.train_labels <= 4 second_digits_idx = dataset.train_labels >= 5 first_digits = dataset.train_data[first_digits_idx] second_digits = dataset.train_data[second_digits_idx] for iteration in range(100): print("Iteration {}".format(iteration)) curr_dir = os.path.join(SAVE_DIR, "{:07d}".format(iteration)) if not os.path.exists(curr_dir): os.makedirs(curr_dir) image1, image2 = get_images_split(first_digits, second_digits) #image1, image2 = get_images_no_split(dataset) mixed = (image1 + image2).float() # mixed = torch.clamp(image1 + image2, 0, 1).float() # Capsule net mixed_grid = make_grid(mixed.detach() / 2., nrow=GRID_SIZE, pad_value=1., padding=1) save_image(mixed_grid, os.path.join(curr_dir, "mixed.png")) gt1_grid = make_grid(image1, nrow=GRID_SIZE, pad_value=1., padding=1) save_image(gt1_grid, os.path.join(curr_dir, "gt1.png")) gt2_grid = make_grid(image2, nrow=GRID_SIZE, pad_value=1., padding=1) save_image(gt2_grid, os.path.join(curr_dir, "gt2.png")) mixed = torch.Tensor(mixed).cuda().view(BATCH_SIZE, 1, 28, 28) y = nn.Parameter(torch.Tensor(BATCH_SIZE, 1, 28, 28).uniform_()).cuda() x = nn.Parameter(torch.Tensor(BATCH_SIZE, 1, 28, 28).uniform_()).cuda() step_lr = 0.00002 # Noise amounts sigmas = np.array([ 1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637, 0.04641589, 0.02782559, 0.01668101, 0.01 ]) n_steps_each = 100 for idx, sigma in enumerate(sigmas): lambda_recon = 1.8 / (sigma**2) # Not completely sure what this part is for labels = torch.ones(1, device=x.device) * idx labels = labels.long() step_size = step_lr * (sigma / sigmas[-1])**2 for step in range(n_steps_each): noise_x = torch.randn_like(x) * np.sqrt(step_size * 2) noise_y = torch.randn_like(y) * np.sqrt(step_size * 2) grad_x = scorenet(x.view(BATCH_SIZE, 1, 28, 28), labels).detach() grad_y = scorenet(y.view(BATCH_SIZE, 1, 28, 28), labels).detach() recon_loss = (torch.norm(torch.flatten(y + x - mixed))**2) #recon_loss = torch.norm(torch.flatten(torch.clamp(x + y, -10000000, 1) - mixed)) ** 2 #recon_loss = torch.norm(torch.flatten((1 / (1 + torch.exp(-5 * (y + x - 0.5)))) - mixed)) ** 2 #recon_loss = torch.norm(torch.flatten((y - mixed)) ** 2 # print(recon_loss) recon_grads = torch.autograd.grad(recon_loss, [x, y]) #x = x + (step_size * grad_x) + noise_x #y = y + (step_size * grad_y) + noise_y x = x + (step_size * grad_x) + (-step_size * lambda_recon * recon_grads[0].detach()) + noise_x y = y + (step_size * grad_y) + (-step_size * lambda_recon * recon_grads[1].detach()) + noise_y # x = x + (-step_size * lambda_recon * recon_grads[0].detach()) + noise_x # y = y + (-step_size * lambda_recon * recon_grads[1].detach()) + noise_y # Clamp for writing purposes x = torch.clamp(x, 0, 1) y = torch.clamp(y, 0, 1) x_to_write = torch.Tensor(x.detach().cpu()) y_to_write = torch.Tensor(y.detach().cpu()) # PSNR Measure for idx in range(BATCH_SIZE): est1 = psnr(x[idx], image1[idx].cuda()) + psnr( y[idx], image2[idx].cuda()) est2 = psnr(x[idx], image2[idx].cuda()) + psnr( y[idx], image1[idx].cuda()) correct_estimate = max(est1, est2) / 2. all_psnr.append(correct_estimate) if est2 > est1: x_to_write[idx] = y[idx] y_to_write[idx] = x[idx] mixed_psnr = psnr(mixed[idx] / 2., image1[idx].cuda()) + psnr( mixed[idx] / 2., image2[idx].cuda()) dummy_metrics.append(mixed_psnr / 2.) if correct_estimate < 12.: bad_cases["gt1"].append(image1[idx].detach().cpu().numpy()) bad_cases["gt2"].append(image2[idx].detach().cpu().numpy()) bad_cases["mixed"].append( mixed[idx].detach().cpu().numpy()) bad_cases["x"].append( x_to_write[idx].detach().cpu().numpy()) bad_cases["y"].append( y_to_write[idx].detach().cpu().numpy()) print("Added bad case") # Percentage Measure # x_thresh = (x > 0.01) # y_thresh = (y > 0.01) # image1_thresh = (image1 > 0.01) # image2_thresh = (image2 > 0.01) # avg_thresh = ((mixed.detach()[idx] / 2.) > 0.01) # for idx in range(BATCH_SIZE): # est1 = np.count_nonzero((x_thresh[idx] == image1_thresh[idx].cuda()).detach().cpu()) + np.count_nonzero((y_thresh[idx] == image2_thresh[idx].cuda()).detach().cpu()) # est2 = np.count_nonzero((x_thresh[idx] == image2_thresh[idx].cuda()).detach().cpu()) + np.count_nonzero((y_thresh[idx] == image1_thresh[idx].cuda()).detach().cpu()) # correct_estimate = max(est1, est2) # percentage = correct_estimate / (2 * x.shape[-1] * x.shape[-2]) # all_percentages.append(percentage) # dummy_count = np.count_nonzero((avg_thresh == image1_thresh[idx].cuda()).detach().cpu()) + np.count_nonzero((avg_thresh == image2_thresh[idx].cuda()).detach().cpu()) # dummy_percentage = dummy_count / (2 * x.shape[-1] * x.shape[-2]) # dummy_metrics.append(dummy_percentage) # Recon Grid recon_grid = make_grid(torch.clamp( torch.clamp(x_to_write, 0, 1) + torch.clamp(y_to_write, 0, 1), 0, 1), nrow=GRID_SIZE, pad_value=1., padding=1) save_image(recon_grid, os.path.join(curr_dir, "recon.png")) # Write x and y x_grid = make_grid(x_to_write, nrow=GRID_SIZE, pad_value=1., padding=1) save_image(x_grid, os.path.join(curr_dir, "x.png")) y_grid = make_grid(y_to_write, nrow=GRID_SIZE, pad_value=1., padding=1) save_image(y_grid, os.path.join(curr_dir, "y.png")) # average_grid = make_grid(mixed.detach()/2., nrow=GRID_SIZE) # save_image(average_grid, "results/average_cifar.png") print("Curr mean {}".format(np.array(all_psnr).mean())) #print("Curr mean {}".format(np.array(all_percentages).mean())) print("Curr dummy mean {}".format(np.array(dummy_metrics).mean())) y1 = np.stack(bad_cases["y"], axis=0) y_grid = make_grid(torch.Tensor(y1), pad_value=1., padding=1) save_image(y_grid, os.path.join(SAVE_DIR, "y_strange.png")) x1 = np.stack(bad_cases["x"], axis=0) x_grid = make_grid(torch.Tensor(x1), pad_value=1., padding=1) save_image(x_grid, os.path.join(SAVE_DIR, "x_strange.png")) gt1 = np.stack(bad_cases["gt1"], axis=0) gt1_grid = make_grid(torch.Tensor(gt1), pad_value=1., padding=1) save_image(gt1_grid, os.path.join(SAVE_DIR, "gt1_strange.png")) gt2 = np.stack(bad_cases["gt2"], axis=0) gt2_grid = make_grid(torch.Tensor(gt2), pad_value=1., padding=1) save_image(gt2_grid, os.path.join(SAVE_DIR, "gt2_strange.png")) mixed = np.stack(bad_cases["mixed"], axis=0) / 2. mixed_grid = make_grid(torch.Tensor(mixed), pad_value=1., padding=1) save_image(mixed_grid, os.path.join(SAVE_DIR, "mixed_strange.png")) import pdb pdb.set_trace()
def test(self): # Load the score network states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device) scorenet = CondRefineNetDilated(self.config).to(self.config.device) scorenet = torch.nn.DataParallel(scorenet) scorenet.load_state_dict(states[0]) scorenet.eval() # Grab the first two samples from MNIST dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=False, download=True) dataloader = DataLoader(dataset, batch_size=1, shuffle=True) image0 = np.array(dataset[130][0]).astype(np.float).transpose(2, 0, 1) image1 = np.array(dataset[131][0]).astype(np.float).transpose(2, 0, 1) mixed = (image0 + image1) mixed = mixed / 255. cv2.imwrite("mixed.png", (mixed * 127.5).astype(np.uint8).transpose(1, 2, 0)[:, :, ::-1]) cv2.imwrite("gt0.png", (image0).astype(np.uint8).transpose(1, 2, 0)[:, :, ::-1]) cv2.imwrite("gt1.png", (image1).astype(np.uint8).transpose(1, 2, 0)[:, :, ::-1]) mixed = torch.Tensor(mixed).cuda() y = nn.Parameter(torch.Tensor(3, 32, 32).uniform_()).cuda() x = nn.Parameter(torch.Tensor(3, 32, 32).uniform_()).cuda() step_lr = 0.00002 # Noise amounts sigmas = np.array([ 1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637, 0.04641589, 0.02782559, 0.01668101, 0.01 ]) n_steps_each = 100 lambda_recon = 1.5 # Weight to put on reconstruction error vs p(x) for idx, sigma in enumerate(sigmas): # Not completely sure what this part is for labels = torch.ones(1, device=x.device) * idx labels = labels.long() step_size = step_lr * (sigma / sigmas[-1])**2 for step in range(n_steps_each): noise_x = torch.randn_like(x) * np.sqrt(step_size * 2) noise_y = torch.randn_like(y) * np.sqrt(step_size * 2) grad_x = scorenet(x.view(1, 3, 32, 32), labels).detach() grad_y = scorenet(y.view(1, 3, 32, 32), labels).detach() recon_loss = (torch.norm(torch.flatten(y + x - mixed))**2) print(recon_loss) recon_grads = torch.autograd.grad(recon_loss, [x, y]) #x = x + (step_size * grad_x) + noise_x #y = y + (step_size * grad_y) + noise_y x = x + (step_size * grad_x) + (-step_size * lambda_recon * recon_grads[0].detach()) + noise_x y = y + (step_size * grad_y) + (-step_size * lambda_recon * recon_grads[1].detach()) + noise_y lambda_recon *= 2.8 # Write x and y x_np = x.detach().cpu().numpy()[0, :, :, :] x_np = np.clip(x_np, 0, 1) cv2.imwrite("x.png", (x_np * 255).astype(np.uint8).transpose(1, 2, 0)[:, :, ::-1]) y_np = y.detach().cpu().numpy()[0, :, :, :] y_np = np.clip(y_np, 0, 1) cv2.imwrite("y.png", (y_np * 255).astype(np.uint8).transpose(1, 2, 0)[:, :, ::-1]) cv2.imwrite( "out_mixed.png", (y_np * 127.5).astype(np.uint8).transpose(1, 2, 0)[:, :, ::-1] + (x_np * 127.5).astype(np.uint8).transpose(1, 2, 0)[:, :, ::-1]) import pdb pdb.set_trace()
def test(self): # Load the score network states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device) scorenet = CondRefineNetDilated(self.config).to(self.config.device) scorenet = torch.nn.DataParallel(scorenet) scorenet.load_state_dict(states[0]) scorenet.eval() # Grab the first two samples from MNIST dataset = CelebA(os.path.join(self.args.run, 'datasets', 'celeba'), split='test', download=True) dataloader = DataLoader(dataset, batch_size=1, shuffle=True) input_image = np.array(dataset[0][0]).astype(np.float).transpose( 2, 0, 1) # input_image = cv2.imread("/projects/grail/vjayaram/source_separation/ncsn/run/datasets/celeba/celeba/img_align_celeba/012690.jpg") # input_image = cv2.resize(input_image, (32, 32))[:,:,::-1].transpose(2, 0, 1) input_image = input_image / 255. noise = np.random.randn(*input_image.shape) / 10 cv2.imwrite("input_image.png", (input_image * 255).astype( np.uint8).transpose(1, 2, 0)[:, :, ::-1]) input_image += noise input_image = np.clip(input_image, 0, 1) cv2.imwrite("input_image_noisy.png", (input_image * 255).astype( np.uint8).transpose(1, 2, 0)[:, :, ::-1]) input_image = torch.Tensor(input_image).cuda() x = nn.Parameter(torch.Tensor(3, 32, 32).uniform_()).cuda() step_lr = 0.00002 # Noise amounts sigmas = np.array([ 1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637, 0.04641589, 0.02782559, 0.01668101, 0.01 ]) n_steps_each = 100 lambda_recon = 1.5 # Weight to put on reconstruction error vs p(x) for idx, sigma in enumerate(sigmas): # Not completely sure what this part is for labels = torch.ones(1, device=x.device) * idx labels = labels.long() step_size = step_lr * (sigma / sigmas[-1])**2 for step in range(n_steps_each): noise_x = torch.randn_like(x) * np.sqrt(step_size * 2) grad_x = scorenet(x.view(1, 3, 32, 32), labels).detach() recon_loss = (torch.norm(torch.flatten(input_image - x))**2) print(recon_loss) recon_grads = torch.autograd.grad(recon_loss, [x]) #x = x + (step_size * grad_x) + noise_x x = x + (step_size * grad_x) + (-step_size * lambda_recon * recon_grads[0].detach()) + noise_x lambda_recon *= 1.6 # # Write x and y x_np = x.detach().cpu().numpy()[0, :, :, :] x_np = np.clip(x_np, 0, 1) cv2.imwrite("x.png", (x_np * 255).astype(np.uint8).transpose(1, 2, 0)[:, :, ::-1]) # y_np = y.detach().cpu().numpy()[0,:,:,:] # y_np = np.clip(y_np, 0, 1) # cv2.imwrite("y.png", (y_np * 255).astype(np.uint8).transpose(1, 2, 0)[:,:,::-1]) # cv2.imwrite("out_mixed.png", (y_np * 127.5).astype(np.uint8).transpose(1, 2, 0)[:,:,::-1] + (x_np * 127.5).astype(np.uint8).transpose(1, 2, 0)[:,:,::-1]) import pdb pdb.set_trace()
def test(self): # Load the score network states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device) scorenet = CondRefineNetDilated(self.config).to(self.config.device) scorenet = torch.nn.DataParallel(scorenet) scorenet.load_state_dict(states[0]) scorenet.eval() batch_size = 1 samples = 2 files_list = glob.glob('./ground turth/*.png') files_list = natsorted(files_list) length = len(files_list) result_all = np.zeros([101, 2]) for z, file_path in enumerate(files_list): img = cv2.imread(file_path) img2 = cv2.imread('./iGM-3C/img_{}_Rec_x_end_rgb.png'.format(z)) img = cv2.resize(img, (128, 128)) YCbCrimg2 = cv2.cvtColor(img2, cv2.COLOR_BGR2YCrCb) x0 = img.copy() x1 = np.concatenate((img2, YCbCrimg2), 2) original_image = img.copy() x0 = torch.tensor(x0.transpose(2, 0, 1), dtype=torch.float).unsqueeze(0) / 255.0 x1 = torch.tensor(x1.transpose(2, 0, 1), dtype=torch.float).unsqueeze(0) / 255.0 x_stack = torch.zeros( [x0.shape[0] * samples, x0.shape[1], x0.shape[2], x0.shape[3]], dtype=torch.float32) for i in range(samples): x_stack[i * batch_size:(i + 1) * batch_size, ...] = x0 x0 = x_stack gray = (x0[:, 0, ...] + x0[:, 1, ...] + x0[:, 2, ...]).cuda() / 3.0 gray1 = (x1[:, 0, ...] + x1[:, 1, ...] + x1[:, 2, ...] + x1[:, 3, ...] + x1[:, 4, ...] + x1[:, 5, ...]).cuda() / 6.0 gray_mixed = torch.stack([gray, gray, gray], dim=1) gray_mixed_1 = torch.stack( [gray1, gray1, gray1, gray1, gray1, gray1], dim=1) x0 = nn.Parameter( torch.Tensor(samples * batch_size, 6, x0.shape[2], x0.shape[3]).uniform_(-1, 1)).cuda() x01 = x0.clone() step_lr = 0.0003 * 0.04 #bedroom 0.04 church 0.02 sigmas = np.array([ 1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637, 0.04641589, 0.02782559, 0.01668101, 0.01 ]) n_steps_each = 100 max_psnr = 0 max_ssim = 0 for idx, sigma in enumerate(sigmas): lambda_recon = 1. / sigma**2 labels = torch.ones(1, device=x0.device) * idx labels = labels.long() step_size = step_lr * (sigma / sigmas[-1])**2 print('sigma = {}'.format(sigma)) for step in range(n_steps_each): print('current step %03d iter' % step) x0_mix = (x01[:, 0, ...] + x01[:, 1, ...] + x01[:, 2, ...]) / 3.0 x1_mix = (x01[:, 0, ...] + x01[:, 1, ...] + x01[:, 2, ...] + x01[:, 3, ...] + x01[:, 4, ...] + x01[:, 5, ...]) / 6.0 error = torch.stack([x0_mix, x0_mix, x0_mix], dim=1) - gray_mixed error1 = torch.stack( [x1_mix, x1_mix, x1_mix, x1_mix, x1_mix, x1_mix], dim=1) - gray_mixed_1 noise_x = torch.randn_like(x01) * np.sqrt(step_size * 2) grad_x0 = scorenet(x01, labels).detach() x0 = x01 + step_size * (grad_x0) x0 = x0 - 0.1 * step_size * lambda_recon * error1 #bedroom 0.1 church 1.5 x0[:, 0:3, ...] = x0[:, 0:3, ...] - step_size * lambda_recon * (error) x0 = torch.mean(x0, dim=0) x0 = torch.stack([x0, x0], dim=0) x01 = x0.clone() + noise_x x_rec = x0.clone().detach().cpu().numpy().transpose( 0, 2, 3, 1) for j in range(x_rec.shape[0]): x_rec_ = np.squeeze(x_rec[j, ...]) x_rec_ycbcr2rgb = cv2.cvtColor(x_rec_[..., 3:], cv2.COLOR_YCrCb2BGR) x_rec_ycbcr2rgb = np.clip(x_rec_ycbcr2rgb, 0, 1) x_rec_ycbcr2rgb = x_rec_ycbcr2rgb[np.newaxis, ...] x_rec = (x_rec[..., :3] + x_rec_ycbcr2rgb) / 2 original_image = np.array(original_image, dtype=np.float32) for i in range(x_rec.shape[0]): psnr = compare_psnr(x_rec[i, ...] * 255.0, original_image, data_range=255) ssim = compare_ssim(x_rec[i, ...], original_image / 255.0, data_range=1, multichannel=True) print("current {} step".format(step), 'PSNR :', psnr, 'SSIM :', ssim) if max_psnr < psnr: result_all[z, 0] = psnr max_psnr = psnr cv2.imwrite( os.path.join( self.args.image_folder, 'img_{}_Rec_6ch_finally.png'.format(z)), (x_rec[i, ...] * 256.0).clip(0, 255).astype(np.uint8)) result_all[length, 0] = sum(result_all[:length, 0]) / length if max_ssim < ssim: result_all[z, 1] = ssim max_ssim = ssim result_all[length, 1] = sum(result_all[:length, 1]) / length write_Data(result_all, z) x_save = x0.clone().detach().cpu().numpy().transpose(0, 2, 3, 1) for j in range(x_save.shape[0]): x_save_ = np.squeeze(x_save[j, ...]) print(np.max(x_save_), np.min(x_save_)) x_save_ycbcr2rgb = cv2.cvtColor(x_save_[..., 3:], cv2.COLOR_YCrCb2BGR) print(np.max(x_save_ycbcr2rgb), np.min(x_save_ycbcr2rgb)) x_save_ycbcr2rgb = torch.tensor(x_save_ycbcr2rgb) x_save_ycbcr2rgb = torch.unsqueeze(x_save_ycbcr2rgb, 0) x_save_ycbcr2rgb = np.array(x_save_ycbcr2rgb) x_save = (x_save[..., :3] + x_save_ycbcr2rgb) / 2 x_save = np.array(x_save).transpose(0, 3, 1, 2) x_save_R = x_save[:, 2:3, :, :] x_save_G = x_save[:, 1:2, :, :] x_save_B = x_save[:, 0:1, :, :] x_save = np.concatenate((x_save_R, x_save_G, x_save_B), 1) self.write_images( torch.tensor(x_save).detach().cpu(), 'x_end.png', 1, z)
def test(self): # Load the score network states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device) scorenet = CondRefineNetDilated(self.config).to(self.config.device) scorenet = torch.nn.DataParallel(scorenet) scorenet.load_state_dict(states[0]) scorenet.eval() batch_size = 1 samples = 1 image_count = 2 files_list = glob.glob('./ground turth/*.png') files_list = natsorted(files_list) length = len(files_list) result_all = np.zeros([101, 2]) for z, file_path in enumerate(files_list): x0 = cv2.imread(file_path) x0 = cv2.resize(x0, (128, 128)) original_image = x0.copy() x0 = torch.tensor(x0.transpose(2, 0, 1), dtype=torch.float).unsqueeze(0) / 255.0 x_stack = torch.zeros( [x0.shape[0] * samples, x0.shape[1], x0.shape[2], x0.shape[3]], dtype=torch.float32) for i in range(samples): x_stack[i * batch_size:(i + 1) * batch_size, ...] = x0 x0 = x_stack gray = ( (x0[:, 0, ...] + x0[:, 1, ...] + x0[:, 2, ...])).cuda() / 3.0 print(gray.shape) gray_mixed = torch.stack([gray, gray, gray], dim=1) print(gray_mixed.shape) x0 = nn.Parameter( torch.Tensor(samples * batch_size, 3, x0.shape[2], x0.shape[3]).uniform_(-1, 1)).cuda() x01 = x0.clone() x0_mix = ((x01[:, 0, ...] + x01[:, 1, ...] + x01[:, 2, ...])) / 3.0 recon = (torch.stack([x0_mix, x0_mix, x0_mix], dim=1) - gray_mixed)**2 step_lr = 0.00003 * 0.1 # bedroom 0.1 church 0.2 sigmas = np.array([ 1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637, 0.04641589, 0.02782559, 0.01668101, 0.01 ]) n_steps_each = 60 max_psnr = 0 max_ssim = 0 for idx, sigma in enumerate(sigmas): lambda_recon = 1. / sigma**2 labels = torch.ones(1, device=x0.device) * idx labels = labels.long() step_size = step_lr * (sigma / sigmas[-1])**2 print('sigma = {}'.format(sigma)) for step in range(n_steps_each): print('current step %03d iter' % step) x0_mix = ((x01[:, 0, ...] + x01[:, 1, ...] + x01[:, 2, ...])) / 3.0 error = torch.stack([x0_mix, x0_mix, x0_mix], dim=1) - gray_mixed noise_x = torch.randn_like(x01) * np.sqrt(step_size * 2) grad_x0 = scorenet(x01, labels).detach() x0 = x01 + step_size * (grad_x0 - lambda_recon * (error)) # + noise_x x01 = x0.clone() + noise_x x_rec = x0.clone().detach().cpu().numpy().transpose( 0, 2, 3, 1) max_result, post = 0, 0 for i in range(x_rec.shape[0]): psnr = compare_psnr(x_rec[i, ...] * 255.0, original_image, data_range=255) ssim = compare_ssim(x_rec[i, ...], original_image / 255.0, data_range=1, multichannel=True) if max_result < psnr: max_result = psnr post = i print("current {} step".format(step), 'PSNR :', psnr, 'SSIM :', ssim) savemat('./Rec_Best', {'img': x_rec[post, ...]}) if max_psnr < psnr: result_all[z, 0] = psnr max_psnr = psnr result_all[length, 0] = sum(result_all[:length, 0]) / length if max_ssim < ssim: result_all[z, 1] = ssim max_ssim = ssim result_all[length, 1] = sum(result_all[:length, 1]) / length write_Data(result_all, z) x_save = x0.clone().detach().cpu() x_save = np.array(x_save) x_save_R = x_save[:, 2:3, :, :] x_save_G = x_save[:, 1:2, :, :] x_save_B = x_save[:, 0:1, :, :] x_save = np.concatenate((x_save_R, x_save_G, x_save_B), 1) self.write_images( torch.tensor(x_save).detach().cpu(), 'x_end_rgb.png', samples, z)
def test(self): # Load the score network states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device) scorenet = CondRefineNetDilated(self.config).to(self.config.device) scorenet = torch.nn.DataParallel(scorenet) scorenet.load_state_dict(states[0]) scorenet.eval() image1, image2 = get_images_manual() mixed = (image1 + image2).float() curr_dir = SAVE_DIR if not os.path.exists(curr_dir): os.makedirs(curr_dir) mixed_grid = make_grid(mixed.detach() / 2., nrow=GRID_SIZE) save_image(mixed_grid, os.path.join(curr_dir, "mixed.png")) gt1_grid = make_grid(image1, nrow=GRID_SIZE) save_image(gt1_grid, os.path.join(curr_dir, "gt1.png")) gt2_grid = make_grid(image2, nrow=GRID_SIZE) save_image(gt2_grid, os.path.join(curr_dir, "gt2.png")) mixed = torch.Tensor(mixed).cuda().view(BATCH_SIZE, 3, 32, 32) y = nn.Parameter(torch.Tensor(BATCH_SIZE, 3, 32, 32).uniform_()).cuda() x = nn.Parameter(torch.Tensor(BATCH_SIZE, 3, 32, 32).uniform_()).cuda() step_lr = 0.00002 # Noise amounts sigmas = np.array([ 1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637, 0.04641589, 0.02782559, 0.01668101, 0.01 ]) n_steps_each = 100 #lambda_recon = 1.5 for idx, sigma in enumerate(sigmas): lambda_recon = 1.8 / (sigma**2) # Not completely sure what this part is for labels = torch.ones(1, device=x.device) * idx labels = labels.long() step_size = step_lr * (sigma / sigmas[-1])**2 for step in range(n_steps_each): noise_x = torch.randn_like(x) * np.sqrt(step_size * 2) noise_y = torch.randn_like(y) * np.sqrt(step_size * 2) grad_x = scorenet(x.view(BATCH_SIZE, 3, 32, 32), labels).detach() grad_y = scorenet(y.view(BATCH_SIZE, 3, 32, 32), labels).detach() recon_loss = (torch.norm(torch.flatten(y + x - mixed))**2) print(recon_loss) recon_grads = torch.autograd.grad(recon_loss, [x, y]) x = x + (step_size * grad_x) + (-step_size * lambda_recon * recon_grads[0].detach()) + noise_x y = y + (step_size * grad_y) + (-step_size * lambda_recon * recon_grads[1].detach()) + noise_y x_to_write = torch.Tensor(x.detach().cpu()) y_to_write = torch.Tensor(y.detach().cpu()) for idx in range(BATCH_SIZE): # PSNR est1 = psnr(x[idx], image1[idx].cuda()) + psnr( y[idx], image2[idx].cuda()) est2 = psnr(x[idx], image2[idx].cuda()) + psnr( y[idx], image1[idx].cuda()) if est2 > est1: x_to_write[idx] = y[idx] y_to_write[idx] = x[idx] # # Write x and y x_grid = make_grid(x_to_write, nrow=GRID_SIZE) save_image(x_grid, os.path.join(curr_dir, "x.png")) y_grid = make_grid(y_to_write, nrow=GRID_SIZE) save_image(y_grid, os.path.join(curr_dir, "y.png"))
def train(self): if self.config.data.random_flip is False: tran_transform = test_transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.ToTensor() ]) else: tran_transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor() ]) test_transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.ToTensor() ]) if self.config.data.dataset == 'CIFAR10': dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=True, download=True, transform=tran_transform) test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10_test'), train=False, download=True, transform=test_transform) elif self.config.data.dataset == 'MNIST': dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=True, download=True, transform=tran_transform) test_dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist_test'), train=False, download=True, transform=test_transform) elif self.config.data.dataset == 'CELEBA': if self.config.data.random_flip: dataset = CelebA(root=os.path.join(self.args.run, 'datasets', 'celeba'), split='train', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(self.config.data.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]), download=True) else: dataset = CelebA(root=os.path.join(self.args.run, 'datasets', 'celeba'), split='train', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(self.config.data.image_size), transforms.ToTensor(), ]), download=True) test_dataset = CelebA(root=os.path.join(self.args.run, 'datasets', 'celeba_test'), split='test', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(self.config.data.image_size), transforms.ToTensor(), ]), download=True) elif self.config.data.dataset == 'SVHN': dataset = SVHN(os.path.join(self.args.run, 'datasets', 'svhn'), split='train', download=True, transform=tran_transform) test_dataset = SVHN(os.path.join(self.args.run, 'datasets', 'svhn_test'), split='test', download=True, transform=test_transform) elif self.config.data.dataset == 'NYUv2': if self.config.data.random_flip is False: nyu_train_transform = nyu_test_transform = transforms.Compose([ transforms.CenterCrop((400, 400)), transforms.Resize(32), transforms.ToTensor() ]) else: nyu_train_transform = transforms.Compose([ transforms.CenterCrop((400, 400)), transforms.Resize(32), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor() ]) nyu_test_transform = transforms.Compose([ transforms.CenterCrop((400, 400)), transforms.Resize(32), transforms.ToTensor() ]) dataset = NYUv2(os.path.join(self.args.run, 'datasets', 'nyuv2'), train=True, download=True, rgb_transform=nyu_train_transform, depth_transform=nyu_train_transform) test_dataset = NYUv2(os.path.join(self.args.run, 'datasets', 'nyuv2'), train=False, download=True, rgb_transform=nyu_test_transform, depth_transform=nyu_test_transform) dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=0) # changed num_workers from 4 to 0 test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=0, drop_last=True) # changed num_workers from 4 to 0 test_iter = iter(test_loader) self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels tb_path = os.path.join(self.args.run, 'tensorboard', self.args.doc) if os.path.exists(tb_path): shutil.rmtree(tb_path) tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path) score = CondRefineNetDilated(self.config).to(self.config.device) score = torch.nn.DataParallel(score) optimizer = self.get_optimizer(score.parameters()) if self.args.resume_training: states = torch.load(os.path.join(self.args.log, 'checkpoint.pth')) score.load_state_dict(states[0]) optimizer.load_state_dict(states[1]) step = 0 sigmas = torch.tensor( np.exp(np.linspace(np.log(self.config.model.sigma_begin), np.log(self.config.model.sigma_end), self.config.model.num_classes))).float().to(self.config.device) for epoch in range(self.config.training.n_epochs): for i, (X, y) in enumerate(dataloader): step += 1 score.train() X = X.to(self.config.device) X = X / 256. * 255. + torch.rand_like(X) / 256. if self.config.data.logit_transform: X = self.logit_transform(X) if self.config.data.dataset == 'NYUv2': # concatenate depth map with image y = y[0] # code to see resized depth map # input_gt_depth_image = y[0][0].data.cpu().numpy().astype(np.float32) # plot.imsave('gt_depth_map_{}.png'.format(i), input_gt_depth_image, # cmap="viridis") y = y.to(self.config.device) X = torch.cat((X, y), 1) labels = torch.randint(0, len(sigmas), (X.shape[0],), device=X.device) if self.config.training.algo == 'dsm': loss = anneal_dsm_score_estimation(score, X, labels, sigmas, self.config.training.anneal_power) elif self.config.training.algo == 'ssm': loss = anneal_sliced_score_estimation_vr(score, X, labels, sigmas, n_particles=self.config.training.n_particles) optimizer.zero_grad() loss.backward() optimizer.step() tb_logger.add_scalar('loss', loss, global_step=step) logging.info("step: {}, loss: {}".format(step, loss.item())) if step >= self.config.training.n_iters: return 0 if step % 100 == 0: score.eval() try: test_X, test_y = next(test_iter) except StopIteration: test_iter = iter(test_loader) test_X, test_y = next(test_iter) test_X = test_X.to(self.config.device) test_X = test_X / 256. * 255. + torch.rand_like(test_X) / 256. if self.config.data.logit_transform: test_X = self.logit_transform(test_X) if self.config.data.dataset == 'NYUv2': test_y = test_y[0] test_y = test_y.to(self.config.device) test_X = torch.cat((test_X, test_y), 1) test_labels = torch.randint(0, len(sigmas), (test_X.shape[0],), device=test_X.device) with torch.no_grad(): test_dsm_loss = anneal_dsm_score_estimation(score, test_X, test_labels, sigmas, self.config.training.anneal_power) tb_logger.add_scalar('test_dsm_loss', test_dsm_loss, global_step=step) if step % self.config.training.snapshot_freq == 0: states = [ score.state_dict(), optimizer.state_dict(), ] torch.save(states, os.path.join(self.args.log, 'checkpoint_{}.pth'.format(step))) torch.save(states, os.path.join(self.args.log, 'checkpoint.pth'))
def train(self): if self.config.data.random_flip is False: tran_transform = test_transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.ToTensor() ]) else: tran_transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor() ]) test_transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.ToTensor() ]) if self.config.data.dataset == 'CIFAR10': dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=True, download=True, transform=tran_transform) test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10_test'), train=False, download=True, transform=test_transform) elif self.config.data.dataset == 'MNIST': dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=True, download=True, transform=tran_transform) test_dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist_test'), train=False, download=True, transform=test_transform) elif self.config.data.dataset == 'CELEBA': if self.config.data.random_flip: dataset = CelebA( root=os.path.join(self.args.run, 'datasets', 'celeba'), split='train', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(self.config.data.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]), download=False) else: dataset = CelebA( root=os.path.join(self.args.run, 'datasets', 'celeba'), split='train', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(self.config.data.image_size), transforms.ToTensor(), ]), download=False) test_dataset = CelebA( root=os.path.join(self.args.run, 'datasets', 'celeba_test'), split='test', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(self.config.data.image_size), transforms.ToTensor(), ]), download=False) elif self.config.data.dataset == 'SVHN': dataset = SVHN(os.path.join(self.args.run, 'datasets', 'svhn'), split='train', download=True, transform=tran_transform) test_dataset = SVHN(os.path.join(self.args.run, 'datasets', 'svhn_test'), split='test', download=True, transform=test_transform) dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=4) test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=4, drop_last=True) test_iter = iter(test_loader) self.config.input_dim = self.config.data.image_size**2 * self.config.data.channels tb_path = os.path.join(self.args.run, 'tensorboard', self.args.doc) if os.path.exists(tb_path): shutil.rmtree(tb_path) tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path) score = CondRefineNetDilated(self.config).to(self.config.device) score = torch.nn.DataParallel(score) optimizer = self.get_optimizer(score.parameters()) if self.args.resume_training: states = torch.load(os.path.join(self.args.log, 'checkpoint.pth')) score.load_state_dict(states[0]) optimizer.load_state_dict(states[1]) step = 0 sigmas = torch.tensor( np.exp( np.linspace(np.log(self.config.model.sigma_begin), np.log(self.config.model.sigma_end), self.config.model.num_classes))).float().to( self.config.device) time_record = [] for epoch in range(self.config.training.n_epochs): for i, (X, y) in enumerate(dataloader): step += 1 score.train() X = X.to(self.config.device) X = X / 256. * 255. + torch.rand_like(X) / 256. if self.config.data.logit_transform: X = self.logit_transform(X) labels = torch.randint(0, len(sigmas), (X.shape[0], ), device=X.device) if self.config.training.algo == 'dsm': t = time.time() loss = anneal_dsm_score_estimation( score, X, labels, sigmas, self.config.training.anneal_power) elif self.config.training.algo == 'dsm_tracetrick': t = time.time() loss = anneal_dsm_score_estimation_TraceTrick( score, X, labels, sigmas, self.config.training.anneal_power) elif self.config.training.algo == 'ssm': t = time.time() loss = anneal_sliced_score_estimation_vr( score, X, labels, sigmas, n_particles=self.config.training.n_particles) elif self.config.training.algo == 'esm_scorenet': t = time.time() loss = anneal_ESM_scorenet( score, X, labels, sigmas, n_particles=self.config.training.n_particles) elif self.config.training.algo == 'esm_scorenet_VR': t = time.time() loss = anneal_ESM_scorenet_VR( score, X, labels, sigmas, n_particles=self.config.training.n_particles) optimizer.zero_grad() loss.backward() optimizer.step() t = time.time() - t time_record.append(t) if step >= self.config.training.n_iters: return 0 if step % 100 == 0: tb_logger.add_scalar('loss', loss, global_step=step) logging.info( "step: {}, loss: {}, time per step: {:.3f} +- {:.3f} ms" .format(step, loss.item(), np.mean(time_record) * 1e3, np.std(time_record) * 1e3)) # if step % 2000 == 0: # score.eval() # try: # test_X, test_y = next(test_iter) # except StopIteration: # test_iter = iter(test_loader) # test_X, test_y = next(test_iter) # test_X = test_X.to(self.config.device) # test_X = test_X / 256. * 255. + torch.rand_like(test_X) / 256. # if self.config.data.logit_transform: # test_X = self.logit_transform(test_X) # test_labels = torch.randint(0, len(sigmas), (test_X.shape[0],), device=test_X.device) # #if self.config.training.algo == 'dsm': # with torch.no_grad(): # test_dsm_loss = anneal_dsm_score_estimation(score, test_X, test_labels, sigmas, # self.config.training.anneal_power) # tb_logger.add_scalar('test_dsm_loss', test_dsm_loss, global_step=step) # logging.info("step: {}, test dsm loss: {}".format(step, test_dsm_loss.item())) # elif self.config.training.algo == 'ssm': # test_ssm_loss = anneal_sliced_score_estimation_vr(score, test_X, test_labels, sigmas, # n_particles=self.config.training.n_particles) # tb_logger.add_scalar('test_ssm_loss', test_ssm_loss, global_step=step) # logging.info("step: {}, test ssm loss: {}".format(step, test_ssm_loss.item())) if step >= 140000 and step % self.config.training.snapshot_freq == 0: states = [ score.state_dict(), optimizer.state_dict(), ] torch.save( states, os.path.join(self.args.log, 'checkpoint_{}.pth'.format(step))) torch.save(states, os.path.join(self.args.log, 'checkpoint.pth'))
def train(self): if self.config.data.random_flip is False: tran_transform = test_transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.ToTensor() ]) else: tran_transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor() ]) test_transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.ToTensor() ]) dataset = LoadDataset('./ground turth',tran_transform) dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=4) self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels tb_path = os.path.join(self.args.run, 'tensorboard', self.args.doc) if os.path.exists(tb_path): shutil.rmtree(tb_path) tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path) score = CondRefineNetDilated(self.config).to(self.config.device) score = torch.nn.DataParallel(score) optimizer = self.get_optimizer(score.parameters()) if self.args.resume_training: states = torch.load(os.path.join(self.args.log, 'checkpoint.pth')) score.load_state_dict(states[0]) optimizer.load_state_dict(states[1]) step = 0 sigmas = torch.tensor( np.exp(np.linspace(np.log(self.config.model.sigma_begin), np.log(self.config.model.sigma_end), self.config.model.num_classes))).float().to(self.config.device) for epoch in range(self.config.training.n_epochs): for i, X in enumerate(dataloader): X = torch.tensor(X,dtype=torch.float32) step += 1 score.train() X = X.to(self.config.device) X = X / 256. * 255. + torch.rand_like(X) / 256. if self.config.data.logit_transform: X = self.logit_transform(X) labels = torch.randint(0, len(sigmas), (X.shape[0],), device=X.device) if self.config.training.algo == 'dsm': loss = anneal_dsm_score_estimation(score, X, labels, sigmas, self.config.training.anneal_power) elif self.config.training.algo == 'ssm': loss = anneal_sliced_score_estimation_vr(score, X, labels, sigmas, n_particles=self.config.training.n_particles) optimizer.zero_grad() loss.backward() optimizer.step() tb_logger.add_scalar('loss', loss, global_step=step) logging.info("step: {}, loss: {}".format(step, loss.item())) if step >= self.config.training.n_iters: return 0 if step % self.config.training.snapshot_freq == 0: states = [ score.state_dict(), optimizer.state_dict(), ] torch.save(states, os.path.join(self.args.log, 'checkpoint_{}.pth'.format(step))) torch.save(states, os.path.join(self.args.log, 'checkpoint.pth'))
def test(self): # Load the score network states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device) scorenet = CondRefineNetDilated(self.config).to(self.config.device) scorenet = torch.nn.DataParallel(scorenet) scorenet.load_state_dict(states[0]) scorenet.eval() # Grab the first two samples from MNIST dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=False, download=True) data = dataset.data.transpose(0, 3, 1, 2) for iteration in range(500): print("Iteration {}".format(iteration)) curr_dir = os.path.join(SAVE_DIR, "{:07d}".format(iteration)) if not os.path.exists(curr_dir): os.makedirs(curr_dir) rand_idx = np.random.randint(0, data.shape[0] - 1, BATCH_SIZE) image = torch.tensor(data[rand_idx, :].astype(np.float) / 255.).float() # GT color images image_grid = make_grid(image, nrow=GRID_SIZE) save_image(image_grid, os.path.join(curr_dir, "gt.png")) # Grayscale image image_gray = image.mean(1).view(BATCH_SIZE, 1, 32, 32) image_grid = make_grid(image_gray, nrow=GRID_SIZE) save_image(image_grid, os.path.join(curr_dir, "grayscale.png")) image_gray = image_gray.cuda() x = nn.Parameter(torch.Tensor(BATCH_SIZE, 3, 32, 32).uniform_()).cuda() step_lr = 0.00002 # Noise amounts sigmas = np.array([ 1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637, 0.04641589, 0.02782559, 0.01668101, 0.01 ]) n_steps_each = 100 # Weight to put on reconstruction error vs p(x) for idx, sigma in enumerate(sigmas): lambda_recon = 1.5 / (sigma**2) # Not completely sure what this part is for labels = torch.ones(1, device=x.device) * idx labels = labels.long() step_size = step_lr * (sigma / sigmas[-1])**2 for step in range(n_steps_each): noise_x = torch.randn_like(x) * np.sqrt(step_size * 2) grad_x = scorenet(x, labels).detach() recon_loss = (torch.norm( torch.flatten( x.mean(1).view(BATCH_SIZE, 1, 32, 32) - image_gray))**2) print(recon_loss) recon_grads = torch.autograd.grad(recon_loss, [x]) x = x + (step_size * grad_x) + (-step_size * lambda_recon * recon_grads[0].detach()) + noise_x # Write x image_grid = make_grid(x, nrow=GRID_SIZE) save_image(image_grid, os.path.join(curr_dir, "x.png")) import pdb pdb.set_trace()