def plot_result(result, gt, iteration): psnr = PSNR(result, gt) ssim = SSIM(result, gt) plt.imshow(result, cmap='gray') plt.axis('off') plt.title('%d: PSNR: %.2f SSIM: %.4f' % (iteration, psnr, ssim)) plt.show()
def callback_func(iteration, reconstruction, loss): global iter_history, loss_history, psnr_history, ssim_history, reco_history if iteration == 0: return iter_history.append(iteration) loss_history.append(loss) psnr_history.append(PSNR(reconstruction, gt)) ssim_history.append(SSIM(reconstruction, gt)) reco_history.append(reconstruction)
def fit_model(model, optimizer, scheduler, num_epochs, criterion, loaders, device, seed = 0): train_loader = loaders['train'] len_train_loader = len(train_loader) seed_everything(seed = seed) print('start training') for epoch in tqdm(range(num_epochs) ) : model.train() loss = 0 for i, (x, d) in enumerate(train_loader): x, d = x.cuda(device), d.cuda(device) # reset the gradients back to zero # PyTorch accumulates gradients on subsequent backward passes optimizer.zero_grad() # compute reconstructions outputs = model(x) # compute training reconstruction loss train_loss = criterion(outputs, d) # compute accumulated gradients train_loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) # perform parameter update based on current gradients optimizer.step() # add the mini-batch training loss to epoch loss loss += float(train_loss) if i % 100 == 0: print("iter : {}/{}, loss = {:.6f}".format(epoch * len_train_loader + i, len_train_loader * num_epochs, float(train_loss))) # compute the epoch training loss loss = float(loss) / len_train_loader # display the epoch training loss print("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, num_epochs, loss)) # update the step-size scheduler.step() psnrs = [] model.eval() print('Evaluating model') with torch.no_grad(): for obs, gt in loaders['validation']: reco = model(obs.to(device)).cpu() psnrs.append(PSNR(reco, gt)) print('mean psnr: {:f}'.format(np.mean(psnrs))) return model
def eval(self, test_data): self.model.eval() running_psnr = 0.0 with tqdm(test_data, desc='test ', disable=not self.show_pbar) as pbar: for obs, gt in pbar: rec = self.reconstruct(obs) running_psnr += PSNR(rec, gt) return running_psnr / len(test_data)
def plot_reconstructions(reconstructions, titles, ray_trafo, obs, gt, save_name=None, fig_size=(18, 4.5), cmap='pink'): """ Plots a ground-truth and several reconstructions :param reconstructors: List of Reconstructor objects to compute the reconstructions :param test_data: Data to apply the reconstruction methods """ psnrs = [PSNR(reco, gt) for reco in reconstructions] ssims = [SSIM(reco, gt) for reco in reconstructions] l2_error0 = np.sqrt( np.sum(np.power(ray_trafo(gt).asarray() - obs.asarray(), 2))) l2_error = [ np.sqrt(np.sum(np.power(ray_trafo(reco).asarray() - obs.asarray(), 2))) for reco in reconstructions ] # plot results im, ax = plot_images([ gt, ] + reconstructions, fig_size=fig_size, rect=(0.0, 0.0, 1.0, 1.0), xticks=[], yticks=[], vrange=(0.0, 0.9 * np.max(gt.asarray())), cbar=False, interpolation='none', cmap=cmap) # set labels ax[0].set_title('Ground Truth') for j in range(len(reconstructions)): ax[j + 1].set_title(titles[j]) ax[j + 1].set_xlabel( '$\ell_2$ data error: {:.4f}\nPSNR: {:.1f}, SSIM: {:.2f}'.format( l2_error[j], psnrs[j], ssims[j])) ax[0].set_xlabel('$\ell_2$ data error: {:.2f}'.format(l2_error0)) plt.tight_layout() plt.tight_layout() if save_name: plt.savefig('%s.pdf' % save_name) plt.show()
def train(self, dataset): if self.torch_manual_seed: torch.random.manual_seed(self.torch_manual_seed) self.init_transform(dataset=dataset) # create PyTorch datasets dataset_train = dataset.create_torch_dataset( part='train', reshape=((1, ) + dataset.space[0].shape, (1, ) + dataset.space[1].shape), transform=self._transform) dataset_validation = dataset.create_torch_dataset( part='validation', reshape=((1, ) + dataset.space[0].shape, (1, ) + dataset.space[1].shape)) # reset model before training self.init_model() criterion = torch.nn.MSELoss() self.init_optimizer(dataset_train=dataset_train) # create PyTorch dataloaders shuffle = (dataset.supports_random_access() if self.shuffle == 'auto' else self.shuffle) data_loaders = { 'train': DataLoader(dataset_train, batch_size=self.batch_size, num_workers=self.num_data_loader_workers, shuffle=shuffle, pin_memory=True, worker_init_fn=self.worker_init_fn), 'validation': DataLoader(dataset_validation, batch_size=self.batch_size, num_workers=self.num_data_loader_workers, shuffle=shuffle, pin_memory=True, worker_init_fn=self.worker_init_fn) } dataset_sizes = { 'train': len(dataset_train), 'validation': len(dataset_validation) } self.init_scheduler(dataset_train=dataset_train) if self._scheduler is not None: schedule_every_batch = isinstance(self._scheduler, (CyclicLR, OneCycleLR)) best_model_wts = deepcopy(self.model.state_dict()) best_psnr = 0 if self.log_dir is not None: if not TENSORBOARD_AVAILABLE: raise ImportError( 'Missing tensorboard. Please install it or disable ' 'logging by specifying `log_dir=None`.') writer = SummaryWriter(log_dir=self.log_dir, max_queue=0) validation_samples = dataset.get_data_pairs( 'validation', self.log_num_validation_samples) self.model.to(self.device) self.model.train() for epoch in range(self.epochs): # Each epoch has a training and validation phase for phase in ['train', 'validation']: if phase == 'train': self.model.train() # Set model to training mode else: self.model.eval() # Set model to evaluate mode running_psnr = 0.0 running_loss = 0.0 running_size = 0 with tqdm(data_loaders[phase], desc='epoch {:d}'.format(epoch + 1), disable=not self.show_pbar) as pbar: for inputs, labels in pbar: if self.normalize_by_opnorm: inputs = (1. / self.opnorm) * inputs inputs = inputs.to(self.device) labels = labels.to(self.device) # zero the parameter gradients self._optimizer.zero_grad() # forward # track gradients only if in train phase with torch.set_grad_enabled(phase == 'train'): outputs = self.model(inputs) loss = criterion(outputs, labels) # backward + optimize only if in training phase if phase == 'train': loss.backward() torch.nn.utils.clip_grad_norm_( self.model.parameters(), max_norm=1) self._optimizer.step() if (self._scheduler is not None and schedule_every_batch): self._scheduler.step() for i in range(outputs.shape[0]): labels_ = labels[i, 0].detach().cpu().numpy() outputs_ = outputs[i, 0].detach().cpu().numpy() running_psnr += PSNR(outputs_, labels_) # statistics running_loss += loss.item() * outputs.shape[0] running_size += outputs.shape[0] pbar.set_postfix({ 'phase': phase, 'loss': running_loss / running_size, 'psnr': running_psnr / running_size }) if self.log_dir is not None and phase == 'train': step = (epoch * ceil( dataset_sizes['train'] / self.batch_size) + ceil(running_size / self.batch_size)) writer.add_scalar( 'loss/{}'.format(phase), torch.tensor(running_loss / running_size), step) writer.add_scalar( 'psnr/{}'.format(phase), torch.tensor(running_psnr / running_size), step) if (self._scheduler is not None and not schedule_every_batch): self._scheduler.step() epoch_loss = running_loss / dataset_sizes[phase] epoch_psnr = running_psnr / dataset_sizes[phase] if self.log_dir is not None and phase == 'validation': step = (epoch + 1) * ceil( dataset_sizes['train'] / self.batch_size) writer.add_scalar('loss/{}'.format(phase), epoch_loss, step) writer.add_scalar('psnr/{}'.format(phase), epoch_psnr, step) # deep copy the model (if it is the best one seen so far) if phase == 'validation' and epoch_psnr > best_psnr: best_psnr = epoch_psnr best_model_wts = deepcopy(self.model.state_dict()) if self.save_best_learned_params_path is not None: self.save_learned_params( self.save_best_learned_params_path) if (phase == 'validation' and self.log_dir is not None and self.log_num_validation_samples > 0): with torch.no_grad(): val_images = [] for (y, x) in validation_samples: y = torch.from_numpy(np.asarray(y))[None, None].to( self.device) x = torch.from_numpy(np.asarray(x))[None, None].to( self.device) reco = self.model(y) reco -= torch.min(reco) reco /= torch.max(reco) val_images += [reco, x] writer.add_images( 'validation_samples', torch.cat(val_images), (epoch + 1) * (ceil(dataset_sizes['train'] / self.batch_size)), dataformats='NCWH') print('Best val psnr: {:4f}'.format(best_psnr)) self.model.load_state_dict(best_model_wts)
reconstructor.load_hyper_params(hyper_params_path) #%% expose FBP cache to reconstructor by assigning `fbp_dataset` attribute # uncomment the next line to generate the cache files (~20 GB) # generate_fbp_cache_files(dataset, ray_trafo, CACHE_FILES) cached_fbp_dataset = get_cached_fbp_dataset(dataset, ray_trafo, CACHE_FILES) dataset.fbp_dataset = cached_fbp_dataset #%% train # reduce the batch size here if the model does not fit into GPU memory # reconstructor.batch_size = 16 reconstructor.train(dataset) #%% evaluate recos = [] psnrs = [] for obs, gt in test_data: reco = reconstructor.reconstruct(obs) recos.append(reco) psnrs.append(PSNR(reco, gt)) print('mean psnr: {:f}'.format(np.mean(psnrs))) for i in range(3): _, ax = plot_images([recos[i], test_data.ground_truth[i]], fig_size=(10, 4)) ax[0].set_xlabel('PSNR: {:.2f}'.format(psnrs[i])) ax[0].set_title('FBPUNetReconstructor') ax[1].set_title('ground truth') ax[0].figure.suptitle('test sample {:d}'.format(i))
def callback_func(iteration, reconstruction, loss): _, ax = plot_images([reconstruction, gt], fig_size=(10, 4)) ax[0].set_xlabel('loss: {:f}'.format(loss)) ax[0].set_title('DIP iteration {:d}'.format(iteration)) ax[1].set_title('ground truth') ax[0].figure.suptitle('test sample {:d}'.format(TEST_SAMPLE)) plt.show() reconstructor = DeepImagePriorCTReconstructor( dataset.get_ray_trafo(impl=IMPL), callback_func=callback_func, callback_func_interval=100) #%% obtain reference hyper parameters if not check_for_params('diptv', 'lodopab'): download_params('diptv', 'lodopab') params_path = get_params_path('diptv', 'lodopab') reconstructor.load_params(params_path) #%% evaluate reco = reconstructor.reconstruct(obs) psnr = PSNR(reco, gt) print('psnr: {:f}'.format(psnr)) _, ax = plot_images([reco, gt], fig_size=(10, 4)) ax[0].set_xlabel('PSNR: {:.2f}'.format(psnr)) ax[0].set_title('DeepImagePriorCTReconstructor') ax[1].set_title('ground truth') ax[0].figure.suptitle('test sample {:d}'.format(TEST_SAMPLE))
def plot_reconstructors_tests(reconstructors, ray_trafo, test_data, save_name=None, fig_size=(18, 4.5), cmap='pink'): """ Plots a ground-truth and several reconstructions :param reconstructors: List of Reconstructor objects to compute the reconstructions :param test_data: Data to apply the reconstruction methods """ titles = [] for reconstructor in reconstructors: titles.append(reconstructor.name) for i in range(len(test_data)): y_delta, x = test_data[i] # compute reconstructions and psnr and ssim measures recos = [r.reconstruct(y_delta) for r in reconstructors] l2_error = [ np.sqrt( np.sum( np.power(ray_trafo(reco).asarray() - y_delta.asarray(), 2))) for reco in recos ] l2_error0 = np.sqrt( np.sum(np.power(ray_trafo(x).asarray() - y_delta.asarray(), 2))) psnrs = [PSNR(reco, x) for reco in recos] ssims = [SSIM(reco, x) for reco in recos] # plot results im, ax = plot_images([ x, ] + recos, fig_size=fig_size, rect=(0.0, 0.0, 1.0, 1.0), ncols=4, nrows=-1, xticks=[], yticks=[], vrange=(0.0, 0.9 * np.max(x.asarray())), cbar=False, interpolation='none', cmap=cmap) # set labels ax = ax.reshape(-1) ax[0].set_title('Ground Truth') ax[0].set_xlabel('$\ell_2$ data error: {:.2f}'.format(l2_error0)) for j in range(len(recos)): ax[j + 1].set_title(titles[j]) ax[j + 1].set_xlabel( '$\ell_2$ data error: {:.4f}\nPSNR: {:.1f}, SSIM: {:.2f}'. format(l2_error[j], psnrs[j], ssims[j])) plt.tight_layout() plt.tight_layout() if save_name: plt.savefig('%s-%d.pdf' % (save_name, i)) plt.show()
[0, 4, 4, 4, 4]] for skip in skip_channels: print('Skip Channels:') print(skip) params.channels = (ch,) * sc params.skip_channels = skip params.scales = sc reconstructor = DeepImagePriorReconstructor(ray_trafo=ray_trafo, hyper_params=params.dict, name='DIP') result = reconstructor.reconstruct(obs) results.append(result) fig = plt.figure(figsize=(9.1, 8.3)) for i in range(len(results)): ax = fig.add_subplot(3, 4, i+1) psnr = PSNR(results[i], gt) ssim = SSIM(results[i], gt) plot_image(results[i], ax=ax, xticks=[], yticks=[], cmap='pink') if i < 8: ax.set_title('Channels: %d, Scales: %d' % (channels[i], scales[i])) else: ax.set_title('Skip: {}'.format(skip_channels[i-8])) ax.set_xlabel('PSNR: %.2f, SSIM: %.4f' % (psnr, ssim)) plt.tight_layout() plt.tight_layout() plt.savefig('ellipses_architectures.pdf') plt.savefig('ellipses_architectures.pgf') plt.show()
def plot_result(result, gt, iteration): psnr = PSNR(result, gt) plt.imshow(result, cmap='bone') plt.title('%d: %.3f' % (iteration, psnr)) plt.axis('off') plt.show()
def main_worker(gpu, ngpus_per_node, args): args.gpu = gpu ngpus_per_node = torch.cuda.device_count() print("Use GPU: {} for training".format(args.gpu)) args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) print('==> Making model..') model_setting_dict = { 'kernel_size': args.KERNEL_SIZE, 'hidden_layer_width_list': args.HIDDEN_WIDTHS, 'n_classes': args.IMAGE_CHANNEL_NUM, 'ista_num_steps': args.ISTA_NUM_STEPS, 'lasso_lambda_scalar': args.LASSO_LAMBDA_SCALAR, 'uncouple_adjoint_bool': args.UNCOUPLE_ADJOINT_BOOL, 'relu_out_bool': args.RELU_OUT_BOOL } model = ista_unet(**model_setting_dict) torch.cuda.set_device(args.gpu) model.cuda(args.gpu) args.BATCH_SIZE = int(args.BATCH_SIZE / ngpus_per_node) args.num_workers = int(args.num_workers / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print('The number of parameters of model is', num_params) print('==> Preparing data..') dataset_setting_dict = { 'batch_size': args.BATCH_SIZE, 'num_workers': args.num_workers, 'distributed_bool': True } loaders = get_dataloaders_ellipses(**dataset_setting_dict) print(len(loaders['train'].dataset)) optimizer = optim.Adam(model.parameters(), lr=args.LEARNING_RATE) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.NUM_EPOCH, eta_min=2e-5) fit_setting_dict = { 'num_epochs': args.NUM_EPOCH, 'criterion': eval(args.LOSS_STR), 'optimizer': optimizer, 'scheduler': scheduler, 'device': args.gpu } config_dict = { **model_setting_dict, **dataset_setting_dict, **fit_setting_dict } trained_model = fit_model(model, loaders=loaders, **fit_setting_dict) if args.gpu == 0: guid = str(uuid.uuid4()) guid_dir = os.path.join(args.model_save_dir, args.model_name, guid) os.makedirs(guid_dir) config_dict['guid'] = guid with open(os.path.join(guid_dir, 'config_dict.pickle'), 'wb') as handle: pickle.dump(config_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) torch.save(trained_model.module.state_dict(), os.path.join(guid_dir, 'ista_unet.pt')) psnrs = [] trained_model.eval() trained_model.to(args.gpu) print('Evaluating model') with torch.no_grad(): for obs, gt in loaders['test']: reco = trained_model(obs.to(args.gpu)).cpu() psnrs.append(PSNR(reco, gt)) print('mean psnr: {:f}'.format(np.mean(psnrs)))