def iagan_recover(x, gen, forward_model, optimizer_type, mode='clamped_normal', limit=1, z_lr1=1e-4, z_lr2=1e-5, model_lr=1e-5, z_steps1=1600, z_steps2=3000, restarts=1, run_dir=None, run_name=None, disable_tqdm=False, **kwargs): best_psnr = -float("inf") best_return_val = None for i in trange(restarts, desc='Restarts', leave=False, disable=disable_tqdm): if run_name is not None: current_run_name = f'{run_name}_{i}' else: current_run_name = None return_val = _iagan_recover(x=x, gen=gen, forward_model=forward_model, optimizer_type=optimizer_type, mode=mode, limit=limit, z_lr1=z_lr1, z_lr2=z_lr2, model_lr=model_lr, z_steps1=z_steps1, z_steps2=z_steps2, run_dir=run_dir, run_name=current_run_name, disable_tqdm=disable_tqdm, **kwargs) p = psnr_from_mse(return_val[2]) if p > best_psnr: best_psnr = p best_return_val = return_val return best_return_val
def recover(x, gen, optimizer_type, n_cuts, forward_model, mode='clamped_normal', limit=1, z_lr=0.5, n_steps=2000, restarts=1, run_dir=None, run_name=None, disable_tqdm=False, return_z1_z2=False, **kwargs): best_psnr = -float("inf") best_return_val = None for i in trange(restarts, desc='Restarts', leave=False, disable=disable_tqdm): if run_name is not None: current_run_name = f'{run_name}_{i}' else: current_run_name = None return_val = _recover(x=x, gen=gen, optimizer_type=optimizer_type, n_cuts=n_cuts, forward_model=forward_model, mode=mode, limit=limit, z_lr=z_lr, n_steps=n_steps, run_dir=run_dir, run_name=current_run_name, disable_tqdm=disable_tqdm, return_z1_z2=return_z1_z2, **kwargs) p = psnr_from_mse(return_val[2]) if p > best_psnr: best_psnr = p best_return_val = return_val return best_return_val
def deep_decoder_recover( x, forward_model, optimizer='lbfgs', num_filters=64, depth=6, # TODO lr=1, img_size=64, steps=50, restarts=1, run_dir=None, run_name=None, disable_tqdm=False, **kwargs): best_psnr = -float("inf") best_return_val = None for i in trange(restarts, desc='Restarts', leave=False, disable=disable_tqdm): if run_name is not None: current_run_name = f'{run_name}_{i}' else: current_run_name = None return_val = _deep_decoder_recover(x=x, forward_model=forward_model, optimizer=optimizer, num_filters=num_filters, depth=depth, lr=lr, img_size=img_size, steps=steps, run_dir=run_dir, run_name=current_run_name, disable_tqdm=disable_tqdm, **kwargs) p = psnr_from_mse(return_val[2]) if p > best_psnr: best_psnr = p best_return_val = return_val return best_return_val
def _recover(x, gen, optimizer_type, n_cuts, forward_model, mode='clamped_normal', limit=1, z_lr=0.5, n_steps=2000, run_dir=None, run_name=None, disable_tqdm=False, return_z1_z2=False, **kwargs): """ Args: x - input image, torch tensor (C x H x W) gen - generator, already loaded with checkpoint weights forward_model - corrupts the image n_steps - number of optimization steps during recovery run_name - use None for no logging """ # Keep batch_size = 1 batch_size = 1 z1_dim, z2_dim = gen.input_shapes[n_cuts] if (isinstance(forward_model, GaussianCompressiveSensing)): n_pixel_bora = 64 * 64 * 3 n_pixel = np.prod(x.shape) noise = torch.randn(batch_size, forward_model.n_measure, device=x.device) noise *= 0.1 * torch.sqrt( torch.tensor(n_pixel / forward_model.n_measure / n_pixel_bora)) if mode == 'lasso_inverse' and isinstance(forward_model, GaussianCompressiveSensing): lasso_x_hat = recover_dct(x.cpu().numpy().transpose([1, 2, 0]), forward_model.n_measure, 0.01, 128, A=forward_model.A.cpu().numpy(), noise=noise.cpu().numpy()) _, _, _, z1_z2_dict = recover(torch.tensor( lasso_x_hat.transpose([2, 0, 1]), dtype=torch.float).to(DEVICE), gen, optimizer_type=optimizer_type, n_cuts=n_cuts, forward_model=forward_model, mode='clamped_normal', limit=limit, z_lr=z_lr, n_steps=n_steps, restarts=1, return_z1_z2=True) z1 = torch.nn.Parameter(z1_z2_dict['z1']) params = [z1] if len(z2_dim) > 0: z2 = torch.nn.Parameter(z1_z2_dict['z2']) params.append(z2) else: z2 = None else: z1 = torch.nn.Parameter( get_z_vector((batch_size, *z1_dim), mode=mode, limit=limit, device=x.device)) # print('z1: ', z1.min(), z1.max()) params = [z1] if len(z2_dim) > 0: z2 = torch.nn.Parameter( get_z_vector((batch_size, *z2_dim), mode=mode, limit=limit, device=x.device)) # print('z2: ', z2.min(), z2.max()) params.append(z2) else: z2 = None if optimizer_type == 'adamw': optimizer_z = torch.optim.AdamW(params, lr=z_lr, betas=(0.5, 0.999), weight_decay=0) scheduler_z = None save_img_every_n = 50 elif optimizer_type == 'lbfgs': optimizer_z = torch.optim.LBFGS(params, lr=z_lr) scheduler_z = None save_img_every_n = 2 else: raise NotImplementedError() if run_name is not None: logdir = os.path.join('recovery_tensorboard_logs', run_dir, run_name) if os.path.exists(logdir): print("Overwriting pre-existing logs!") shutil.rmtree(logdir) writer = SummaryWriter(logdir) # Save original and distorted image if run_name is not None: writer.add_image("Original/Clamp", x.clamp(0, 1)) if forward_model.viewable: writer.add_image( "Distorted/Clamp", forward_model(x.unsqueeze(0).clamp(0, 1)).squeeze(0)) # Recover image under forward model x = x.expand(batch_size, *x.shape) y_observed = forward_model(x) if (isinstance(forward_model, GaussianCompressiveSensing)): y_observed += noise for j in trange(n_steps, leave=False, desc='Recovery', disable=disable_tqdm): def closure(): optimizer_z.zero_grad() x_hats = gen.forward(z1, z2, n_cuts=n_cuts, **kwargs) if gen.rescale: x_hats = (x_hats + 1) / 2 train_mses = F.mse_loss(forward_model(x_hats), y_observed, reduction='none') train_mses = train_mses.view(batch_size, -1).mean(1) train_mse = train_mses.sum() train_mse.backward() return train_mse # Step first, then identify the current "best" and "worst" optimizer_z.step(closure) with torch.no_grad(): x_hats = gen.forward(z1, z2, n_cuts=n_cuts, **kwargs) if gen.rescale: x_hats = (x_hats + 1) / 2 train_mses = F.mse_loss(forward_model(x_hats), y_observed, reduction='none') train_mses = train_mses.view(batch_size, -1).mean(1) train_mse = train_mses.sum() train_mses_clamped = F.mse_loss(forward_model(x_hats.detach().clamp( 0, 1)), y_observed, reduction='none').view(batch_size, -1).mean(1) orig_mses_clamped = F.mse_loss(x_hats.detach().clamp(0, 1), x, reduction='none').view(batch_size, -1).mean(1) # batch_size = 1, so best and worst are meaningless. # Restarts is handled in outer function best_train_mse, best_idx = train_mses_clamped.min(0) worst_train_mse, worst_idx = train_mses_clamped.max(0) best_orig_mse = orig_mses_clamped[best_idx] worst_orig_mse = orig_mses_clamped[worst_idx] if run_name is not None and j == 0: writer.add_image('Start', x_hats[best_idx].clamp(0, 1)) if run_name is not None: writer.add_scalar('TRAIN_MSE/best', best_train_mse, j + 1) writer.add_scalar('TRAIN_MSE/worst', worst_train_mse, j + 1) writer.add_scalar('TRAIN_MSE/sum', train_mse, j + 1) writer.add_scalar('ORIG_MSE/best', best_orig_mse, j + 1) writer.add_scalar('ORIG_MSE/worst', worst_orig_mse, j + 1) writer.add_scalar('ORIG_PSNR/best', psnr_from_mse(best_orig_mse), j + 1) writer.add_scalar('ORIG_PSNR/worst', psnr_from_mse(worst_orig_mse), j + 1) if j % save_img_every_n == 0: writer.add_image('Recovered/Best', x_hats[best_idx].clamp(0, 1), j + 1) if scheduler_z is not None: scheduler_z.step() if run_name is not None: writer.add_image('Final', x_hats[best_idx].clamp(0, 1)) if return_z1_z2: return x_hats[best_idx], forward_model(x)[0], best_train_mse, { 'z1': z1, 'z2': z2 } else: return x_hats[best_idx], forward_model(x)[0], best_train_mse
def _deep_decoder_recover( x, forward_model, optimizer, num_filters, depth, lr, img_size, steps, run_dir, run_name, disable_tqdm, **kwargs, ): # Keep batch_size = 1 batch_size = 1 if (isinstance(forward_model, GaussianCompressiveSensing)): n_pixel_bora = 64 * 64 * 3 n_pixel = np.prod(x.shape) noise = torch.randn(batch_size, forward_model.n_measure, device=x.device) noise *= 0.1 * torch.sqrt( torch.tensor(n_pixel / forward_model.n_measure / n_pixel_bora)) # z is a fixed latent vector start_imsize = int(np.log2(img_size)) - depth + 1 z = torch.randn(batch_size, num_filters, start_imsize, start_imsize, device=x.device) # make a fresh DD model for every run model = DeepDecoder(num_filters=num_filters, img_size=img_size, depth=depth).to(x.device) if optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=lr) save_img_every_n = 50 elif optimizer == 'lbfgs': optimizer = torch.optim.LBFGS(model.parameters(), lr=lr) save_img_every_n = 2 else: raise NotImplementedError() if run_name is not None: logdir = os.path.join('recovery_tensorboard_logs', run_dir, run_name) if os.path.exists(logdir): print("Overwriting pre-existing logs!") shutil.rmtree(logdir) writer = SummaryWriter(logdir) else: writer = None # Save original and distorted image if run_name is not None: writer.add_image("Original/Clamp", x.clamp(0, 1)) if forward_model.viewable: writer.add_image( "Distorted/Clamp", forward_model(x.unsqueeze(0).clamp(0, 1)).squeeze(0)) # Make noisy gaussian measurements x = x.expand(batch_size, *x.shape) y_observed = forward_model(x) if (isinstance(forward_model, GaussianCompressiveSensing)): y_observed += noise def closure(): optimizer.zero_grad() x_hat = model.forward(z) loss = F.mse_loss(forward_model(x_hat), y_observed) loss.backward() return loss for j in trange(steps, desc='Fit', leave=False): optimizer.step(closure) with torch.no_grad(): x_hat = model.forward(z) train_mse_clamped = F.mse_loss( forward_model(x_hat.detach().clamp(0, 1)), y_observed) if writer is not None: writer.add_scalar('TRAIN_MSE', train_mse_clamped, j + 1) writer.add_scalar('TRAIN_PSNR', psnr_from_mse(train_mse_clamped), j + 1) orig_mse_clamped = F.mse_loss(x_hat.detach().clamp(0, 1), x) writer.add_scalar('ORIG_MSE', orig_mse_clamped, j + 1) writer.add_scalar('ORIG_PSNR', psnr_from_mse(orig_mse_clamped), j + 1) if j % save_img_every_n == 0: writer.add_image('Recovered', x_hat.squeeze().clamp(0, 1), j + 1) if writer is not None: writer.add_image('Final', x_hat.squeeze().clamp(0, 1)) return x_hat.squeeze(), forward_model(x).squeeze(), train_mse_clamped
def _iagan_recover( x, gen, forward_model, optimizer_type='adam', mode='clamped_normal', limit=1, z_lr1=1e-4, z_lr2=1e-5, model_lr=1e-5, z_steps1=1600, z_steps2=3000, run_dir=None, # IAGAN run_name=None, # datetime or config disable_tqdm=False, **kwargs): # Keep batch_size = 1 batch_size = 1 z1_dim, z2_dim = gen.input_shapes[0] # n_cuts = 0 if (isinstance(forward_model, GaussianCompressiveSensing)): n_pixel_bora = 64 * 64 * 3 n_pixel = np.prod(x.shape) noise = torch.randn(batch_size, forward_model.n_measure, device=x.device) noise *= 0.1 * torch.sqrt( torch.tensor(n_pixel / forward_model.n_measure / n_pixel_bora)) # z1 is the actual latent code. # z2 is the additional input for n_cuts logic (not used here) z1 = torch.nn.Parameter( get_z_vector((batch_size, *z1_dim), mode=mode, limit=limit, device=x.device)) params = [z1] if len(z2_dim) > 0: z2 = torch.nn.Parameter( get_z_vector((batch_size, *z2_dim), mode=mode, limit=limit, device=x.device)) params.append(z2) else: z2 = None if optimizer_type == 'adam': optimizer_z = torch.optim.Adam([z1], lr=z_lr1) optimizer_model = torch.optim.Adam(gen.parameters(), lr=model_lr) else: raise NotImplementedError() if run_name is not None: logdir = os.path.join('recovery_tensorboard_logs', run_dir, run_name) if os.path.exists(logdir): print("Overwriting pre-existing logs!") shutil.rmtree(logdir) writer = SummaryWriter(logdir) # Save original and distorted image if run_name is not None: writer.add_image("Original/Clamp", x.clamp(0, 1)) if forward_model.viewable: writer.add_image( "Distorted/Clamp", forward_model(x.unsqueeze(0).clamp(0, 1)).squeeze(0)) # Make noisy gaussian measurements x = x.expand(batch_size, *x.shape) y_observed = forward_model(x) if (isinstance(forward_model, GaussianCompressiveSensing)): y_observed += noise # Stage 1: optimize latent code only save_img_every_n = 50 for j in trange(z_steps1, desc='Stage1', leave=False): optimizer_z.zero_grad() x_hat = gen.forward(z1, z2, n_cuts=0, **kwargs) if gen.rescale: x_hat = (x_hat + 1) / 2 train_mse = F.mse_loss(forward_model(x_hat), y_observed) train_mse.backward() optimizer_z.step() train_mse_clamped = F.mse_loss( forward_model(x_hat.detach().clamp(0, 1)), y_observed) orig_mse_clamped = F.mse_loss(x_hat.detach().clamp(0, 1), x) if run_name is not None and j == 0: writer.add_image('Stage1/Start', x_hat.squeeze().clamp(0, 1)) if run_name is not None: writer.add_scalar('Stage1/TRAIN_MSE', train_mse_clamped, j + 1) writer.add_scalar('Stage1/ORIG_MSE', orig_mse_clamped, j + 1) writer.add_scalar('Stage1/ORIG_PSNR', psnr_from_mse(orig_mse_clamped), j + 1) if j % save_img_every_n == 0: writer.add_image('Stage1/Recovered', x_hat.squeeze().clamp(0, 1), j + 1) if run_name is not None: writer.add_image('Stage1_Final', x_hat.squeeze().clamp(0, 1)) # Stage 2: optimize latent code and model save_img_every_n = 20 optimizer_z = torch.optim.Adam([z1], lr=z_lr2) for j in trange(z_steps2, desc='Stage2', leave=False): optimizer_z.zero_grad() optimizer_model.zero_grad() x_hat = gen.forward(z1, z2, n_cuts=0, **kwargs) if gen.rescale: x_hat = (x_hat + 1) / 2 train_mse = F.mse_loss(forward_model(x_hat), y_observed) train_mse.backward() optimizer_z.step() optimizer_model.step() train_mse_clamped = F.mse_loss( forward_model(x_hat.detach().clamp(0, 1)), y_observed) orig_mse_clamped = F.mse_loss(x_hat.detach().clamp(0, 1), x) if run_name is not None and j == 0: writer.add_image('Stage2/Start', x_hat.squeeze().clamp(0, 1)) if run_name is not None: writer.add_scalar('Stage2/TRAIN_MSE', train_mse_clamped, j + 1) writer.add_scalar('Stage2/ORIG_MSE', orig_mse_clamped, j + 1) writer.add_scalar('Stage2/ORIG_PSNR', psnr_from_mse(orig_mse_clamped), j + 1) if j % save_img_every_n == 0: writer.add_image('Stage2/Recovered', x_hat.squeeze().clamp(0, 1), j + 1) if run_name is not None: writer.add_image('Stage2_Final', x_hat.squeeze().clamp(0, 1)) return x_hat.squeeze(), forward_model(x).squeeze(), train_mse_clamped
def _mgan_recover(x, gen, n_cuts, forward_model, optimizer_type='sgd', mode='zero', limit=1, z_lr=1, n_steps=2000, z_number=20, run_dir=None, run_name=None, disable_tqdm=False, **kwargs): """ Args: x - input image, torch tensor (C x H x W) gen - generator, already loaded with checkpoint weights forward_model - corrupts the image n_cuts - the intermediate layer to combine z vectors n_steps - number of optimization steps during recovery run_name - use None for no logging """ z1_dim, _ = gen.input_shapes[0] _, z2_dim = gen.input_shapes[n_cuts] if (isinstance(forward_model, GaussianCompressiveSensing)): n_pixel_bora = 64 * 64 * 3 n_pixel = np.prod(x.shape) noise = torch.randn(1, forward_model.n_measure, device=x.device) noise *= 0.1 * torch.sqrt(torch.tensor(n_pixel / forward_model.n_measure / n_pixel_bora)) z1 = torch.nn.Parameter(get_z_vector((z_number, *z1_dim), mode=mode, limit=limit, device=x.device)) alpha = torch.nn.Parameter( get_z_vector((z_number, gen.input_shapes[n_cuts][0][0]), mode=mode, limit=limit, device=x.device)) params = [z1, alpha] if len(z2_dim) > 0: z2 = torch.nn.Parameter(get_z_vector((1, *z2_dim), mode=mode, limit=limit, device=x.device)) params.append(z2) else: z2 = None if optimizer_type == 'sgd': optimizer_z = torch.optim.SGD(params, lr=z_lr) scheduler_z = None save_img_every_n = 50 elif optimizer_type == 'adam': optimizer_z = torch.optim.Adam(params, lr=z_lr) scheduler_z = None # scheduler_z = torch.optim.lr_scheduler.CosineAnnealingLR( # optimizer_z, n_steps, 0.05 * z_lr) save_img_every_n = 50 else: raise NotImplementedError() if run_name is not None: logdir = os.path.join('recovery_tensorboard_logs', run_dir, run_name) if os.path.exists(logdir): print("Overwriting pre-existing logs!") shutil.rmtree(logdir) writer = SummaryWriter(logdir) # Save original and distorted image if run_name is not None: writer.add_image("Original/Clamp", x.clamp(0, 1)) if forward_model.viewable: writer.add_image("Distorted/Clamp", forward_model(x.unsqueeze(0).clamp(0, 1)).squeeze(0)) # Recover image under forward model x = x.expand(1, *x.shape) y_observed = forward_model(x) if (isinstance(forward_model, GaussianCompressiveSensing)): y_observed += noise for j in trange(n_steps, leave=False, desc='Recovery', disable=disable_tqdm): optimizer_z.zero_grad() F_l = gen.forward(z1, None, n_cuts=0, end=n_cuts, **kwargs) F_l_2 = (F_l * alpha[:, :, None, None]).sum(0, keepdim=True) x_hats = gen.forward(F_l_2, z2, n_cuts=n_cuts, end=None, **kwargs) if gen.rescale: x_hats = (x_hats + 1) / 2 train_mse = F.mse_loss(forward_model(x_hats), y_observed) train_mse.backward() optimizer_z.step() train_mse_clamped = F.mse_loss(forward_model(x_hats.detach().clamp(0, 1)), y_observed) orig_mse_clamped = F.mse_loss(x_hats.detach().clamp(0, 1), x) if run_name is not None and j == 0: writer.add_image('Start', x_hats.clamp(0, 1).squeeze(0)) if run_name is not None: writer.add_scalar('TRAIN_MSE', train_mse_clamped, j + 1) writer.add_scalar('ORIG_MSE', orig_mse_clamped, j + 1) writer.add_scalar('ORIG_PSNR', psnr_from_mse(orig_mse_clamped), j + 1) if j % save_img_every_n == 0: writer.add_image('Recovered', x_hats.clamp(0, 1).squeeze(0), j + 1) if scheduler_z is not None: scheduler_z.step() if run_name is not None: writer.add_image('Final', x_hats.clamp(0, 1).squeeze(0)) return x_hats.squeeze(0), forward_model(x)[0], train_mse_clamped