Пример #1
0
 def __call__(self, sample):
     # Crop input image
     for k in self.keys:
         x = sample[k]
         nx, ny = x.shape[-2], x.shape[-1]
         sample[k] = center_crop(x, (min(nx, self.resolution), min(ny, self.resolution)))
     return sample
Пример #2
0
def run_sn(args, data_loader, model):
    """ Run Sigmanet """
    model.eval()
    logging.info(f'Run Sigmanet reconstruction')
    logging.info(f'Arguments: {args}')
    reconstructions = defaultdict(list)
    # keys = ['input', 'kspace', 'smaps', 'mask', 'fg_mask']
    # if args.mask_bg:
    #     keys.append('input_rss_mean')
    # attr_keys = ['mean', 'cov', 'norm']

    with torch.no_grad():
        for ii, sample in enumerate(tqdm(iter(data_loader))):
            sample = data_batch._read_data(sample, device=args.device)

            rec_x = sample['attrs']['metadata']['rec_x']
            rec_y = sample['attrs']['metadata']['rec_y']

            x = model(sample['input'], sample['kspace'], sample['smaps'],
                      sample['mask'], sample['attrs'])

            recons = postprocess(x, (rec_x, rec_y))

            # mask background using background mean value
            if args.mask_bg:
                fg_mask = center_crop(
                    sample['fg_mask'],
                    (rec_x, rec_y),
                ).squeeze(1)
                if args.use_bg_noise_mean:
                    bg_mean = sample['input_rss_mean'].reshape(-1, 1, 1)
                    recons = recons * fg_mask + (1 - fg_mask) * bg_mean
                else:
                    recons = recons * fg_mask

            # renormalize
            norm = sample['attrs']['norm'].reshape(-1, 1, 1)
            recons = recons * norm

            recons = recons.to('cpu').numpy()

            if args.debug and ii % 10 == 0:
                plt.imsave(
                    'run_sn_progress.png',
                    np.hstack(recons),
                    cmap='gray',
                )

            for bidx in range(recons.shape[0]):
                reconstructions[sample['fname']].append(
                    (sample['slidx'][bidx], recons[bidx]))

    reconstructions = {
        fname: np.stack([pred for _, pred in sorted(slice_preds)])
        for fname, slice_preds in reconstructions.items()
    }

    save_reconstructions(reconstructions, args.out_dir)
Пример #3
0
def postprocess(tensor, shape):
    """Postprocess the tensor to be magnitude image and crop to the ROI,
    which is (min(nFE, shape[0]), min(nPE, shape[1]). The method expects either a
    tensor representing complex values
    (with shape [bs, nsmaps, nx, ny, 2])
    or a real-valued tensor
    (with shape [bs, nsmaps, nx, ny])

    """
    if tensor.shape[-1] == 2:
        tensor = mytorch.mri.root_sum_of_squares(tensor, dim=(1, -1), eps=1e-9)
    cropsize = (min(tensor.shape[-2], shape[0]), min(tensor.shape[-1], shape[1]))
    return center_crop(tensor, cropsize)
Пример #4
0
def run_zero_filled_sense(args, data_loader):
    """ Run Adjoint (zero-filled SENSE) reconstruction """
    logging.info('Run zero-filled SENSE reconstruction')
    logging.info(f'Arguments: {args}')
    reconstructions = defaultdict(list)

    with torch.no_grad():
        for sample in tqdm(iter(data_loader)):
            sample = data_batch._read_data(sample)

            rec_x = sample['attrs']['metadata']['rec_x']
            rec_y = sample['attrs']['metadata']['rec_y']

            x = sample['input']

            recons = postprocess(x, (rec_x, rec_y))

            # mask background using background mean value
            if args.mask_bg:
                fg_mask = center_crop(
                    sample['fg_mask'],
                    (rec_x, rec_y),
                ).squeeze(1)
                if args.use_bg_noise_mean:
                    bg_mean = sample['input_rss_mean'].reshape(-1, 1, 1)
                    recons = recons * fg_mask + (1 - fg_mask) * bg_mean
                else:
                    recons = recons * fg_mask

            # renormalize
            norm = sample['attrs']['norm'].numpy()[:, np.newaxis, np.newaxis]
            recons = recons.numpy() * norm

            for bidx in range(recons.shape[0]):
                reconstructions[sample['fname']].append(
                    (sample['slidx'][bidx], recons[bidx]))

    reconstructions = {
        fname: np.stack([pred for _, pred in sorted(slice_preds)])
        for fname, slice_preds in reconstructions.items()
    }

    save_reconstructions(reconstructions, args.out_dir)
Пример #5
0
 def unpad2d(self, tensor, shape):
     if tensor.shape == shape:
         return tensor
     else:
         return center_crop(tensor, shape)