def plot_result(result, gt, iteration): psnr = PSNR(result, gt) ssim = SSIM(result, gt) plt.imshow(result, cmap='gray') plt.axis('off') plt.title('%d: PSNR: %.2f SSIM: %.4f' % (iteration, psnr, ssim)) plt.show()
def callback_func(iteration, reconstruction, loss): global iter_history, loss_history, psnr_history, ssim_history, reco_history if iteration == 0: return iter_history.append(iteration) loss_history.append(loss) psnr_history.append(PSNR(reconstruction, gt)) ssim_history.append(SSIM(reconstruction, gt)) reco_history.append(reconstruction)
def plot_reconstructions(reconstructions, titles, ray_trafo, obs, gt, save_name=None, fig_size=(18, 4.5), cmap='pink'): """ Plots a ground-truth and several reconstructions :param reconstructors: List of Reconstructor objects to compute the reconstructions :param test_data: Data to apply the reconstruction methods """ psnrs = [PSNR(reco, gt) for reco in reconstructions] ssims = [SSIM(reco, gt) for reco in reconstructions] l2_error0 = np.sqrt( np.sum(np.power(ray_trafo(gt).asarray() - obs.asarray(), 2))) l2_error = [ np.sqrt(np.sum(np.power(ray_trafo(reco).asarray() - obs.asarray(), 2))) for reco in reconstructions ] # plot results im, ax = plot_images([ gt, ] + reconstructions, fig_size=fig_size, rect=(0.0, 0.0, 1.0, 1.0), xticks=[], yticks=[], vrange=(0.0, 0.9 * np.max(gt.asarray())), cbar=False, interpolation='none', cmap=cmap) # set labels ax[0].set_title('Ground Truth') for j in range(len(reconstructions)): ax[j + 1].set_title(titles[j]) ax[j + 1].set_xlabel( '$\ell_2$ data error: {:.4f}\nPSNR: {:.1f}, SSIM: {:.2f}'.format( l2_error[j], psnrs[j], ssims[j])) ax[0].set_xlabel('$\ell_2$ data error: {:.2f}'.format(l2_error0)) plt.tight_layout() plt.tight_layout() if save_name: plt.savefig('%s.pdf' % save_name) plt.show()
def plot_reconstructors_tests(reconstructors, ray_trafo, test_data, save_name=None, fig_size=(18, 4.5), cmap='pink'): """ Plots a ground-truth and several reconstructions :param reconstructors: List of Reconstructor objects to compute the reconstructions :param test_data: Data to apply the reconstruction methods """ titles = [] for reconstructor in reconstructors: titles.append(reconstructor.name) for i in range(len(test_data)): y_delta, x = test_data[i] # compute reconstructions and psnr and ssim measures recos = [r.reconstruct(y_delta) for r in reconstructors] l2_error = [ np.sqrt( np.sum( np.power(ray_trafo(reco).asarray() - y_delta.asarray(), 2))) for reco in recos ] l2_error0 = np.sqrt( np.sum(np.power(ray_trafo(x).asarray() - y_delta.asarray(), 2))) psnrs = [PSNR(reco, x) for reco in recos] ssims = [SSIM(reco, x) for reco in recos] # plot results im, ax = plot_images([ x, ] + recos, fig_size=fig_size, rect=(0.0, 0.0, 1.0, 1.0), ncols=4, nrows=-1, xticks=[], yticks=[], vrange=(0.0, 0.9 * np.max(x.asarray())), cbar=False, interpolation='none', cmap=cmap) # set labels ax = ax.reshape(-1) ax[0].set_title('Ground Truth') ax[0].set_xlabel('$\ell_2$ data error: {:.2f}'.format(l2_error0)) for j in range(len(recos)): ax[j + 1].set_title(titles[j]) ax[j + 1].set_xlabel( '$\ell_2$ data error: {:.4f}\nPSNR: {:.1f}, SSIM: {:.2f}'. format(l2_error[j], psnrs[j], ssims[j])) plt.tight_layout() plt.tight_layout() if save_name: plt.savefig('%s-%d.pdf' % (save_name, i)) plt.show()
[0, 4, 4, 4, 4]] for skip in skip_channels: print('Skip Channels:') print(skip) params.channels = (ch,) * sc params.skip_channels = skip params.scales = sc reconstructor = DeepImagePriorReconstructor(ray_trafo=ray_trafo, hyper_params=params.dict, name='DIP') result = reconstructor.reconstruct(obs) results.append(result) fig = plt.figure(figsize=(9.1, 8.3)) for i in range(len(results)): ax = fig.add_subplot(3, 4, i+1) psnr = PSNR(results[i], gt) ssim = SSIM(results[i], gt) plot_image(results[i], ax=ax, xticks=[], yticks=[], cmap='pink') if i < 8: ax.set_title('Channels: %d, Scales: %d' % (channels[i], scales[i])) else: ax.set_title('Skip: {}'.format(skip_channels[i-8])) ax.set_xlabel('PSNR: %.2f, SSIM: %.4f' % (psnr, ssim)) plt.tight_layout() plt.tight_layout() plt.savefig('ellipses_architectures.pdf') plt.savefig('ellipses_architectures.pgf') plt.show()
def train(self, dataset): if self.torch_manual_seed: torch.random.manual_seed(self.torch_manual_seed) self.init_transform(dataset=dataset) # create PyTorch datasets dataset_train = dataset.create_torch_dataset( part='train', reshape=((1, ) + dataset.space[0].shape, (1, ) + dataset.space[1].shape), transform=self._transform) dataset_validation = dataset.create_torch_dataset( part='validation', reshape=((1, ) + dataset.space[0].shape, (1, ) + dataset.space[1].shape)) # reset model before training criterion = torch.nn.MSELoss() #torch.nn.MSELoss() criterion_sum = torch.nn.MSELoss(reduction='sum') self.init_optimizer(dataset_train=dataset_train) # create PyTorch dataloaders shuffle = (dataset.supports_random_access() if self.shuffle == 'auto' else self.shuffle) data_loaders = { 'train': DataLoader(dataset_train, batch_size=self.batch_size, num_workers=self.num_data_loader_workers, shuffle=shuffle, pin_memory=True, worker_init_fn=self.worker_init_fn), 'validation': DataLoader(dataset_validation, batch_size=self.batch_size, num_workers=self.num_data_loader_workers, shuffle=shuffle, pin_memory=True, worker_init_fn=self.worker_init_fn) } dataset_sizes = { 'train': len(dataset_train), 'validation': len(dataset_validation) } self.init_scheduler(dataset_train=dataset_train) if self._scheduler is not None: schedule_every_batch = isinstance(self._scheduler, (CyclicLR, OneCycleLR)) best_model_wts = deepcopy(self.model.state_dict()) best_psnr = 0 if self.log_dir is not None: if not TENSORBOARD_AVAILABLE: raise ImportError( 'Missing tensorboard. Please install it or disable ' 'logging by specifying `log_dir=None`.') writer = SummaryWriter(log_dir=self.log_dir, max_queue=0) validation_samples = dataset.get_data_pairs( 'validation', self.log_num_validation_samples) self.model.to(self.device) self.model.train() for epoch in range(self.epochs): # Each epoch has a training and validation phase for phase in ['train', 'validation']: if phase == 'train': self.model.train() # Set model to training mode else: self.model.eval() # Set model to evaluate mode running_psnr = 0.0 running_ssim = 0.0 running_loss = 0.0 running_size = 0 num_iter = 0 with tqdm(data_loaders[phase], desc='epoch {:d}'.format(epoch + 1), disable=not self.show_pbar) as pbar: for inputs, labels in pbar: num_iter += 1 if self.normalize_by_opnorm: inputs = (1. / self.opnorm) * inputs masked_inputs, mask = self.masker.mask( inputs, num_iter % (self.masker.n_masks - 1)) #loss_ratio = torch.numel(mask) / mask.sum() #mask *= loss_ratio # fbp reconstruct masked_inputs_fbp = np.zeros(labels.shape, dtype=np.float32) for i in range(len(inputs)): masked_inputs_fbp[i, 0, :, :] = self.fbp_op( masked_inputs[i, 0].numpy()) masked_inputs_fbp = torch.from_numpy(masked_inputs_fbp) inputs_fbp = np.zeros(labels.shape, dtype=np.float32) for i in range(len(inputs)): inputs_fbp[i, 0, :, :] = self.fbp_op( inputs[i, 0].numpy()) inputs_fbp = torch.from_numpy(inputs_fbp) masked_inputs_fbp = masked_inputs_fbp.to(self.device) inputs_fbp = inputs_fbp.to(self.device) inputs = inputs.to(self.device) labels = labels.to(self.device) mask = mask.to(self.device) # zero the parameter gradients self._optimizer.zero_grad() # forward # track gradients only if in train phase with torch.set_grad_enabled(phase == 'train'): out_raw = self.model(inputs_fbp) out_masked = self.model(masked_inputs_fbp) proj_raw = self.ray_trafo_module(out_raw) proj_masked = self.ray_trafo_module(out_masked) l_rec = criterion(proj_raw, inputs) l_inv = criterion_sum(proj_raw * mask, proj_masked * mask) / mask.sum() # print(mask.sum(), l_rec.item(), l_inv.item()) loss = l_rec + 2 * torch.sqrt(l_inv) # backward + optimize only if in training phase if phase == 'train': loss.backward() torch.nn.utils.clip_grad_norm_( self.model.parameters(), max_norm=1) self._optimizer.step() if (self._scheduler is not None and schedule_every_batch): self._scheduler.step() for i in range(out_raw.shape[0]): labels_ = labels[i, 0].detach().cpu().numpy() outputs_ = out_raw[i, 0].detach().cpu().numpy() running_psnr += PSNR(outputs_, labels_) running_ssim += SSIM(outputs_, labels_) # statistics running_loss += loss.item() * out_raw.shape[0] running_size += out_raw.shape[0] pbar.set_postfix({ 'phase': phase, 'loss': running_loss / running_size, 'psnr': running_psnr / running_size, 'ssim': running_ssim / running_size }) if self.log_dir is not None and phase == 'train': step = (epoch * ceil( dataset_sizes['train'] / self.batch_size) + ceil(running_size / self.batch_size)) writer.add_scalar( 'loss/{}'.format(phase), torch.tensor(running_loss / running_size), step) writer.add_scalar( 'psnr/{}'.format(phase), torch.tensor(running_psnr / running_size), step) writer.add_scalar( 'ssim/{}'.format(phase), torch.tensor(running_ssim / running_size), step) if (self._scheduler is not None and not schedule_every_batch): self._scheduler.step() epoch_loss = running_loss / dataset_sizes[phase] epoch_psnr = running_psnr / dataset_sizes[phase] epoch_ssim = running_ssim / dataset_sizes[phase] if self.log_dir is not None and phase == 'validation': step = (epoch + 1) * ceil( dataset_sizes['train'] / self.batch_size) writer.add_scalar('loss/{}'.format(phase), epoch_loss, step) writer.add_scalar('psnr/{}'.format(phase), epoch_psnr, step) writer.add_scalar('ssim/{}'.format(phase), epoch_ssim, step) # deep copy the model (if it is the best one seen so far) if phase == 'validation' and epoch_psnr > best_psnr: best_psnr = epoch_psnr best_model_wts = deepcopy(self.model.state_dict()) if self.save_best_learned_params_path is not None: self.save_learned_params( self.save_best_learned_params_path) if (phase == 'validation' and self.log_dir is not None and self.log_num_validation_samples > 0): with torch.no_grad(): val_images = [] for (y, x) in validation_samples: y = torch.from_numpy(np.asarray(y))[None, None].to( self.device) x = torch.from_numpy(np.asarray(x))[None, None].to( self.device) reco = self.model(y) reco -= torch.min(reco) reco /= torch.max(reco) val_images += [reco, x] writer.add_images( 'validation_samples', torch.cat(val_images), (epoch + 1) * (ceil(dataset_sizes['train'] / self.batch_size)), dataformats='NCWH') print('Best val psnr: {:4f}'.format(best_psnr)) self.model.load_state_dict(best_model_wts)
from dival.measure import PSNR, SSIM from dival.reconstructors.odl_reconstructors import FBPReconstructor fbp_reconstructor = FBPReconstructor(dataset.ray_trafo, hyper_params={ 'filter_type': 'Hann', 'frequency_scaling': 1.0 }) for i, (obs, gt) in islice(enumerate(dataset.generator(part='train')), 3): reco = fbp_reconstructor.reconstruct(obs) reco = np.clip(reco, 0., 1.) _, ax = plot_images([reco, gt], fig_size=(10, 4)) ax[0].figure.suptitle('train sample {:d}'.format(i)) ax[0].set_title('FBP reconstruction') ax[1].set_title('Ground truth') psnr = PSNR(reco, gt) ssim = SSIM(reco, gt) ax[0].set_xlabel('PSNR: {:.2f}dB, SSIM: {:.3f}'.format(psnr, ssim)) print('metrics for FBP reconstruction on sample {:d}:'.format(i)) print('PSNR: {:.2f}dB, SSIM: {:.3f}'.format(psnr, ssim)) plt.show() # %% simulate and store fan beam observations SKIP_SIMULATION = False if not SKIP_SIMULATION: from dival.util.input import input_yes_no print('start simulating and storing fan beam observations for all lodopab ' 'ground truth samples? [y]/n') if not input_yes_no(): raise RuntimeError('cancelled by user')