def plot_result(result, gt, iteration):
    psnr = PSNR(result, gt)
    ssim = SSIM(result, gt)
    plt.imshow(result, cmap='gray')
    plt.axis('off')
    plt.title('%d: PSNR: %.2f SSIM: %.4f' % (iteration, psnr, ssim))
    plt.show()
 def callback_func(iteration, reconstruction, loss):
     global iter_history, loss_history, psnr_history, ssim_history, reco_history
     if iteration == 0:
         return
     iter_history.append(iteration)
     loss_history.append(loss)
     psnr_history.append(PSNR(reconstruction, gt))
     ssim_history.append(SSIM(reconstruction, gt))
     reco_history.append(reconstruction)
示例#3
0
def fit_model(model, optimizer, scheduler, num_epochs, criterion, loaders, device, seed = 0):

    train_loader = loaders['train']    
    len_train_loader = len(train_loader) 
    seed_everything(seed = seed)
    
    print('start training')
    for epoch in tqdm(range(num_epochs) ) :

        model.train()
        loss = 0
        for i, (x, d) in enumerate(train_loader):
            x, d = x.cuda(device), d.cuda(device)
                    
            # reset the gradients back to zero
            # PyTorch accumulates gradients on subsequent backward passes
            optimizer.zero_grad()
            
            # compute reconstructions
            outputs = model(x)
            
            # compute training reconstruction loss
            train_loss = criterion(outputs, d)
            # compute accumulated gradients
            train_loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)            

            # perform parameter update based on current gradients
            optimizer.step()
            
            # add the mini-batch training loss to epoch loss
            loss +=  float(train_loss) 
            
            if i % 100 == 0:
                print("iter : {}/{}, loss = {:.6f}".format(epoch * len_train_loader + i, len_train_loader * num_epochs, float(train_loss)))
         
        # compute the epoch training loss
        loss = float(loss) / len_train_loader   
        
        # display the epoch training loss
        print("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, num_epochs, loss))
        
        # update the step-size
        scheduler.step() 
                
        psnrs = []
        model.eval()
        print('Evaluating model')
        with torch.no_grad():
            for obs, gt in loaders['validation']:
                reco = model(obs.to(device)).cpu()
                psnrs.append(PSNR(reco, gt))
        print('mean psnr: {:f}'.format(np.mean(psnrs)))
        
    return model
示例#4
0
    def eval(self, test_data):
        self.model.eval()

        running_psnr = 0.0
        with tqdm(test_data, desc='test ', disable=not self.show_pbar) as pbar:
            for obs, gt in pbar:
                rec = self.reconstruct(obs)
                running_psnr += PSNR(rec, gt)

        return running_psnr / len(test_data)
示例#5
0
def plot_reconstructions(reconstructions,
                         titles,
                         ray_trafo,
                         obs,
                         gt,
                         save_name=None,
                         fig_size=(18, 4.5),
                         cmap='pink'):
    """
    Plots a ground-truth and several reconstructions
    :param reconstructors: List of Reconstructor objects to compute the reconstructions
    :param test_data: Data to apply the reconstruction methods
    """
    psnrs = [PSNR(reco, gt) for reco in reconstructions]
    ssims = [SSIM(reco, gt) for reco in reconstructions]

    l2_error0 = np.sqrt(
        np.sum(np.power(ray_trafo(gt).asarray() - obs.asarray(), 2)))
    l2_error = [
        np.sqrt(np.sum(np.power(ray_trafo(reco).asarray() - obs.asarray(), 2)))
        for reco in reconstructions
    ]

    # plot results
    im, ax = plot_images([
        gt,
    ] + reconstructions,
                         fig_size=fig_size,
                         rect=(0.0, 0.0, 1.0, 1.0),
                         xticks=[],
                         yticks=[],
                         vrange=(0.0, 0.9 * np.max(gt.asarray())),
                         cbar=False,
                         interpolation='none',
                         cmap=cmap)

    # set labels
    ax[0].set_title('Ground Truth')
    for j in range(len(reconstructions)):
        ax[j + 1].set_title(titles[j])
        ax[j + 1].set_xlabel(
            '$\ell_2$ data error: {:.4f}\nPSNR: {:.1f}, SSIM: {:.2f}'.format(
                l2_error[j], psnrs[j], ssims[j]))

    ax[0].set_xlabel('$\ell_2$ data error: {:.2f}'.format(l2_error0))

    plt.tight_layout()
    plt.tight_layout()

    if save_name:
        plt.savefig('%s.pdf' % save_name)
    plt.show()
示例#6
0
    def train(self, dataset):
        if self.torch_manual_seed:
            torch.random.manual_seed(self.torch_manual_seed)

        self.init_transform(dataset=dataset)

        # create PyTorch datasets
        dataset_train = dataset.create_torch_dataset(
            part='train',
            reshape=((1, ) + dataset.space[0].shape,
                     (1, ) + dataset.space[1].shape),
            transform=self._transform)

        dataset_validation = dataset.create_torch_dataset(
            part='validation',
            reshape=((1, ) + dataset.space[0].shape,
                     (1, ) + dataset.space[1].shape))

        # reset model before training
        self.init_model()

        criterion = torch.nn.MSELoss()
        self.init_optimizer(dataset_train=dataset_train)

        # create PyTorch dataloaders
        shuffle = (dataset.supports_random_access()
                   if self.shuffle == 'auto' else self.shuffle)
        data_loaders = {
            'train':
            DataLoader(dataset_train,
                       batch_size=self.batch_size,
                       num_workers=self.num_data_loader_workers,
                       shuffle=shuffle,
                       pin_memory=True,
                       worker_init_fn=self.worker_init_fn),
            'validation':
            DataLoader(dataset_validation,
                       batch_size=self.batch_size,
                       num_workers=self.num_data_loader_workers,
                       shuffle=shuffle,
                       pin_memory=True,
                       worker_init_fn=self.worker_init_fn)
        }

        dataset_sizes = {
            'train': len(dataset_train),
            'validation': len(dataset_validation)
        }

        self.init_scheduler(dataset_train=dataset_train)
        if self._scheduler is not None:
            schedule_every_batch = isinstance(self._scheduler,
                                              (CyclicLR, OneCycleLR))

        best_model_wts = deepcopy(self.model.state_dict())
        best_psnr = 0

        if self.log_dir is not None:
            if not TENSORBOARD_AVAILABLE:
                raise ImportError(
                    'Missing tensorboard. Please install it or disable '
                    'logging by specifying `log_dir=None`.')
            writer = SummaryWriter(log_dir=self.log_dir, max_queue=0)
            validation_samples = dataset.get_data_pairs(
                'validation', self.log_num_validation_samples)

        self.model.to(self.device)
        self.model.train()

        for epoch in range(self.epochs):
            # Each epoch has a training and validation phase
            for phase in ['train', 'validation']:
                if phase == 'train':
                    self.model.train()  # Set model to training mode
                else:
                    self.model.eval()  # Set model to evaluate mode

                running_psnr = 0.0
                running_loss = 0.0
                running_size = 0
                with tqdm(data_loaders[phase],
                          desc='epoch {:d}'.format(epoch + 1),
                          disable=not self.show_pbar) as pbar:
                    for inputs, labels in pbar:
                        if self.normalize_by_opnorm:
                            inputs = (1. / self.opnorm) * inputs
                        inputs = inputs.to(self.device)
                        labels = labels.to(self.device)

                        # zero the parameter gradients
                        self._optimizer.zero_grad()

                        # forward
                        # track gradients only if in train phase
                        with torch.set_grad_enabled(phase == 'train'):
                            outputs = self.model(inputs)
                            loss = criterion(outputs, labels)

                            # backward + optimize only if in training phase
                            if phase == 'train':
                                loss.backward()
                                torch.nn.utils.clip_grad_norm_(
                                    self.model.parameters(), max_norm=1)
                                self._optimizer.step()
                                if (self._scheduler is not None
                                        and schedule_every_batch):
                                    self._scheduler.step()

                        for i in range(outputs.shape[0]):
                            labels_ = labels[i, 0].detach().cpu().numpy()
                            outputs_ = outputs[i, 0].detach().cpu().numpy()
                            running_psnr += PSNR(outputs_, labels_)

                        # statistics
                        running_loss += loss.item() * outputs.shape[0]
                        running_size += outputs.shape[0]

                        pbar.set_postfix({
                            'phase': phase,
                            'loss': running_loss / running_size,
                            'psnr': running_psnr / running_size
                        })
                        if self.log_dir is not None and phase == 'train':
                            step = (epoch * ceil(
                                dataset_sizes['train'] / self.batch_size) +
                                    ceil(running_size / self.batch_size))
                            writer.add_scalar(
                                'loss/{}'.format(phase),
                                torch.tensor(running_loss / running_size),
                                step)
                            writer.add_scalar(
                                'psnr/{}'.format(phase),
                                torch.tensor(running_psnr / running_size),
                                step)

                    if (self._scheduler is not None
                            and not schedule_every_batch):
                        self._scheduler.step()

                    epoch_loss = running_loss / dataset_sizes[phase]
                    epoch_psnr = running_psnr / dataset_sizes[phase]

                    if self.log_dir is not None and phase == 'validation':
                        step = (epoch + 1) * ceil(
                            dataset_sizes['train'] / self.batch_size)
                        writer.add_scalar('loss/{}'.format(phase), epoch_loss,
                                          step)
                        writer.add_scalar('psnr/{}'.format(phase), epoch_psnr,
                                          step)

                    # deep copy the model (if it is the best one seen so far)
                    if phase == 'validation' and epoch_psnr > best_psnr:
                        best_psnr = epoch_psnr
                        best_model_wts = deepcopy(self.model.state_dict())
                        if self.save_best_learned_params_path is not None:
                            self.save_learned_params(
                                self.save_best_learned_params_path)

                if (phase == 'validation' and self.log_dir is not None
                        and self.log_num_validation_samples > 0):
                    with torch.no_grad():
                        val_images = []
                        for (y, x) in validation_samples:
                            y = torch.from_numpy(np.asarray(y))[None, None].to(
                                self.device)
                            x = torch.from_numpy(np.asarray(x))[None, None].to(
                                self.device)
                            reco = self.model(y)
                            reco -= torch.min(reco)
                            reco /= torch.max(reco)
                            val_images += [reco, x]
                        writer.add_images(
                            'validation_samples',
                            torch.cat(val_images), (epoch + 1) *
                            (ceil(dataset_sizes['train'] / self.batch_size)),
                            dataformats='NCWH')

        print('Best val psnr: {:4f}'.format(best_psnr))
        self.model.load_state_dict(best_model_wts)
示例#7
0
reconstructor.load_hyper_params(hyper_params_path)

#%% expose FBP cache to reconstructor by assigning `fbp_dataset` attribute
# uncomment the next line to generate the cache files (~20 GB)
# generate_fbp_cache_files(dataset, ray_trafo, CACHE_FILES)
cached_fbp_dataset = get_cached_fbp_dataset(dataset, ray_trafo, CACHE_FILES)
dataset.fbp_dataset = cached_fbp_dataset

#%% train
# reduce the batch size here if the model does not fit into GPU memory
# reconstructor.batch_size = 16
reconstructor.train(dataset)

#%% evaluate
recos = []
psnrs = []
for obs, gt in test_data:
    reco = reconstructor.reconstruct(obs)
    recos.append(reco)
    psnrs.append(PSNR(reco, gt))

print('mean psnr: {:f}'.format(np.mean(psnrs)))

for i in range(3):
    _, ax = plot_images([recos[i], test_data.ground_truth[i]],
                        fig_size=(10, 4))
    ax[0].set_xlabel('PSNR: {:.2f}'.format(psnrs[i]))
    ax[0].set_title('FBPUNetReconstructor')
    ax[1].set_title('ground truth')
    ax[0].figure.suptitle('test sample {:d}'.format(i))
示例#8
0
def callback_func(iteration, reconstruction, loss):
    _, ax = plot_images([reconstruction, gt],
                        fig_size=(10, 4))
    ax[0].set_xlabel('loss: {:f}'.format(loss))
    ax[0].set_title('DIP iteration {:d}'.format(iteration))
    ax[1].set_title('ground truth')
    ax[0].figure.suptitle('test sample {:d}'.format(TEST_SAMPLE))
    plt.show()

reconstructor = DeepImagePriorCTReconstructor(
    dataset.get_ray_trafo(impl=IMPL),
    callback_func=callback_func, callback_func_interval=100)

#%% obtain reference hyper parameters
if not check_for_params('diptv', 'lodopab'):
    download_params('diptv', 'lodopab')
params_path = get_params_path('diptv', 'lodopab')
reconstructor.load_params(params_path)

#%% evaluate
reco = reconstructor.reconstruct(obs)
psnr = PSNR(reco, gt)

print('psnr: {:f}'.format(psnr))
_, ax = plot_images([reco, gt],
                    fig_size=(10, 4))
ax[0].set_xlabel('PSNR: {:.2f}'.format(psnr))
ax[0].set_title('DeepImagePriorCTReconstructor')
ax[1].set_title('ground truth')
ax[0].figure.suptitle('test sample {:d}'.format(TEST_SAMPLE))
示例#9
0
def plot_reconstructors_tests(reconstructors,
                              ray_trafo,
                              test_data,
                              save_name=None,
                              fig_size=(18, 4.5),
                              cmap='pink'):
    """
    Plots a ground-truth and several reconstructions
    :param reconstructors: List of Reconstructor objects to compute the reconstructions
    :param test_data: Data to apply the reconstruction methods
    """
    titles = []
    for reconstructor in reconstructors:
        titles.append(reconstructor.name)

    for i in range(len(test_data)):
        y_delta, x = test_data[i]

        # compute reconstructions and psnr and ssim measures
        recos = [r.reconstruct(y_delta) for r in reconstructors]
        l2_error = [
            np.sqrt(
                np.sum(
                    np.power(ray_trafo(reco).asarray() - y_delta.asarray(),
                             2))) for reco in recos
        ]
        l2_error0 = np.sqrt(
            np.sum(np.power(ray_trafo(x).asarray() - y_delta.asarray(), 2)))

        psnrs = [PSNR(reco, x) for reco in recos]
        ssims = [SSIM(reco, x) for reco in recos]

        # plot results
        im, ax = plot_images([
            x,
        ] + recos,
                             fig_size=fig_size,
                             rect=(0.0, 0.0, 1.0, 1.0),
                             ncols=4,
                             nrows=-1,
                             xticks=[],
                             yticks=[],
                             vrange=(0.0, 0.9 * np.max(x.asarray())),
                             cbar=False,
                             interpolation='none',
                             cmap=cmap)

        # set labels
        ax = ax.reshape(-1)
        ax[0].set_title('Ground Truth')
        ax[0].set_xlabel('$\ell_2$ data error: {:.2f}'.format(l2_error0))
        for j in range(len(recos)):
            ax[j + 1].set_title(titles[j])
            ax[j + 1].set_xlabel(
                '$\ell_2$ data error: {:.4f}\nPSNR: {:.1f}, SSIM: {:.2f}'.
                format(l2_error[j], psnrs[j], ssims[j]))

        plt.tight_layout()
        plt.tight_layout()

        if save_name:
            plt.savefig('%s-%d.pdf' % (save_name, i))
        plt.show()
示例#10
0
                 [0, 4, 4, 4, 4]]

for skip in skip_channels:
    print('Skip Channels:')
    print(skip)
    params.channels = (ch,) * sc
    params.skip_channels = skip
    params.scales = sc
    reconstructor = DeepImagePriorReconstructor(ray_trafo=ray_trafo, hyper_params=params.dict, name='DIP')
    result = reconstructor.reconstruct(obs)
    results.append(result)

fig = plt.figure(figsize=(9.1, 8.3))
for i in range(len(results)):
    ax = fig.add_subplot(3, 4, i+1)
    psnr = PSNR(results[i], gt)
    ssim = SSIM(results[i], gt)
    plot_image(results[i], ax=ax, xticks=[], yticks=[], cmap='pink')
    if i < 8:
        ax.set_title('Channels: %d, Scales: %d' % (channels[i], scales[i]))
    else:
        ax.set_title('Skip: {}'.format(skip_channels[i-8]))
    ax.set_xlabel('PSNR: %.2f, SSIM: %.4f' % (psnr, ssim))

plt.tight_layout()
plt.tight_layout()

plt.savefig('ellipses_architectures.pdf')
plt.savefig('ellipses_architectures.pgf')
plt.show()
def plot_result(result, gt, iteration):
    psnr = PSNR(result, gt)
    plt.imshow(result, cmap='bone')
    plt.title('%d: %.3f' % (iteration, psnr))
    plt.axis('off')
    plt.show()
示例#12
0
def main_worker(gpu, ngpus_per_node, args):

    args.gpu = gpu
    ngpus_per_node = torch.cuda.device_count()
    print("Use GPU: {} for training".format(args.gpu))

    args.rank = args.rank * ngpus_per_node + gpu
    dist.init_process_group(backend=args.dist_backend,
                            init_method=args.dist_url,
                            world_size=args.world_size,
                            rank=args.rank)

    print('==> Making model..')

    model_setting_dict = {
        'kernel_size': args.KERNEL_SIZE,
        'hidden_layer_width_list': args.HIDDEN_WIDTHS,
        'n_classes': args.IMAGE_CHANNEL_NUM,
        'ista_num_steps': args.ISTA_NUM_STEPS,
        'lasso_lambda_scalar': args.LASSO_LAMBDA_SCALAR,
        'uncouple_adjoint_bool': args.UNCOUPLE_ADJOINT_BOOL,
        'relu_out_bool': args.RELU_OUT_BOOL
    }

    model = ista_unet(**model_setting_dict)
    torch.cuda.set_device(args.gpu)
    model.cuda(args.gpu)
    args.BATCH_SIZE = int(args.BATCH_SIZE / ngpus_per_node)
    args.num_workers = int(args.num_workers / ngpus_per_node)
    model = torch.nn.parallel.DistributedDataParallel(model,
                                                      device_ids=[args.gpu])
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print('The number of parameters of model is', num_params)

    print('==> Preparing data..')
    dataset_setting_dict = {
        'batch_size': args.BATCH_SIZE,
        'num_workers': args.num_workers,
        'distributed_bool': True
    }

    loaders = get_dataloaders_ellipses(**dataset_setting_dict)

    print(len(loaders['train'].dataset))

    optimizer = optim.Adam(model.parameters(), lr=args.LEARNING_RATE)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                     T_max=args.NUM_EPOCH,
                                                     eta_min=2e-5)

    fit_setting_dict = {
        'num_epochs': args.NUM_EPOCH,
        'criterion': eval(args.LOSS_STR),
        'optimizer': optimizer,
        'scheduler': scheduler,
        'device': args.gpu
    }

    config_dict = {
        **model_setting_dict,
        **dataset_setting_dict,
        **fit_setting_dict
    }

    trained_model = fit_model(model, loaders=loaders, **fit_setting_dict)

    if args.gpu == 0:
        guid = str(uuid.uuid4())
        guid_dir = os.path.join(args.model_save_dir, args.model_name, guid)
        os.makedirs(guid_dir)

        config_dict['guid'] = guid
        with open(os.path.join(guid_dir, 'config_dict.pickle'),
                  'wb') as handle:
            pickle.dump(config_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

        torch.save(trained_model.module.state_dict(),
                   os.path.join(guid_dir, 'ista_unet.pt'))

        psnrs = []
        trained_model.eval()
        trained_model.to(args.gpu)

        print('Evaluating model')
        with torch.no_grad():
            for obs, gt in loaders['test']:
                reco = trained_model(obs.to(args.gpu)).cpu()
                psnrs.append(PSNR(reco, gt))
        print('mean psnr: {:f}'.format(np.mean(psnrs)))