예제 #1
0
def cs_total_variation(args, kspace, acquisition, acceleration, num_low_freqs):
    """
    Run ESPIRIT coil sensitivity estimation and Total Variation Minimization based
    reconstruction algorithm using the BART toolkit.
    """

    if acquisition not in REG_PARAM[args.challenge]:
        raise ValueError(f'Invalid acquisition protocol: {acquisition}')
    if acceleration not in {4, 8}:
        raise ValueError(f'Invalid acceleration factor: {acceleration}')

    if args.challenge == 'singlecoil':
        kspace = kspace.unsqueeze(0)
    kspace = kspace.permute(1, 2, 0, 3).unsqueeze(0)
    kspace = tensor_to_complex_np(kspace)

    # Estimate sensitivity maps
    sens_maps = bart.bart(1, f'ecalib -d0 -m1 -r {num_low_freqs}', kspace)

    # Use Total Variation Minimization to reconstruct the image
    reg_wt = REG_PARAM[args.challenge][acquisition][acceleration]
    pred = bart.bart(1, f'pics -d0 -S -R T:7:0:{reg_wt} -i {args.num_iters}',
                     kspace, sens_maps)
    pred = torch.from_numpy(np.abs(pred[0]))

    # Crop the predicted image to selected resolution if bigger
    smallest_width = min(args.resolution, pred.shape[-1])
    smallest_height = min(args.resolution, pred.shape[-2])
    return transforms.center_crop(pred, (smallest_height, smallest_width))
예제 #2
0
def resize(hparams, image, target):
    smallest_width = min(hparams.resolution, image.shape[-2])
    smallest_height = min(hparams.resolution, image.shape[-3])
    if target is not None:
        smallest_width = min(smallest_width, target.shape[-1])
        smallest_height = min(smallest_height, target.shape[-2])
    crop_size = (smallest_height, smallest_width)
    image = transforms.complex_center_crop(image, crop_size)
    # Absolute value
    image_abs = transforms.complex_abs(image)
    # Apply Root-Sum-of-Squares if multicoil data
    if hparams.challenge == "multicoil":
        image_abs = transforms.root_sum_of_squares(image_abs)
    # Normalize input
    image_abs, mean, std = transforms.normalize_instance(image_abs, eps=1e-11)
    image_abs = image_abs.clamp(-6, 6)

    # Normalize target
    if target is not None:
        target = transforms.to_tensor(target)
        target = transforms.center_crop(target, crop_size)
        target = transforms.normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)
    else:
        target = torch.Tensor([0])
    return image, image_abs, target, mean, std
예제 #3
0
    def __call__(self, image):
        """
        Args:
            image (numpy.array): DICOM image
        Returns:
            image (torch.Tensor): Zero-filled input image.
        """

        # image = np.rot90(image, axes=(0, 1)).copy()
        image = np.flip(image, 0)

        # if image.shape[0] < self.resolution or image.shape[1] < self.resolution:
        #     return None
        # # Crop center
        # image = transforms.center_crop(image, (self.resolution, self.resolution))

        res_crop = min(image.shape[0], image.shape[1])
        image = transforms.center_crop(image, (res_crop, res_crop))
        image = cv2.resize(image, dsize=(self.resolution, self.resolution), interpolation=cv2.INTER_CUBIC)
        
        # Normalize input
        image = transforms.to_tensor(image)
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        return image  
예제 #4
0
 def __call__(self, k_space, mask, target, attrs, f_name, slice):
     """
     Args:
         kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
             data or (rows, cols, 2) for single coil data.
         mask (numpy.array): Mask from the test dataset
         target (numpy.array): Target image
         attrs (dict): Acquisition related information stored in the HDF5 object.
         fname (str): File name
         slice (int): Serial number of the slice.
     Returns:
         (tuple): tuple containing:
             k_space (torch.Tensor): k-space(resolution x resolution x 2)
             target (torch.Tensor): Target image converted to a torch Tensor.
             fname (str): File name
             slice (int): Serial number of the slice.
     """
     k_space = transforms.to_tensor(k_space)
     full_image = transforms.ifft2(k_space)
     cropped_image = transforms.complex_center_crop(
         full_image, (self.resolution, self.resolution))
     k_space = transforms.fft2(cropped_image)
     # Normalize input
     cropped_image, mean, std = transforms.normalize_instance(cropped_image,
                                                              eps=1e-11)
     cropped_image = cropped_image.clamp(-6, 6)
     # Normalize target
     target = transforms.to_tensor(target)
     target = transforms.center_crop(target,
                                     (self.resolution, self.resolution))
     target = transforms.normalize(target, mean, std, eps=1e-11)
     target = target.clamp(-6, 6)
     return k_space, target, f_name, slice
예제 #5
0
    def __call__(self, img, fname, slice):
        image = transforms.to_tensor(img)
        image = transforms.center_crop(image.permute(2,0,1), (self.resolution, self.resolution))
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        return image, mean, std, fname
예제 #6
0
 def test_step(self, batch, batch_idx):
     masked_kspace, mask, _, fname, slice, _ = batch
     output = self.forward(masked_kspace, mask)
     output = T.center_crop(output,(self.hparams.resolution,self.hparams.resolution))
     return {
         'fname': fname,
         'slice': slice,
         'output': output.cpu().numpy(),
     }
예제 #7
0
    def __call__(self, kspace, mask, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            mask (numpy.array): Mask from the test dataset
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
        """

        kspace = transforms.to_tensor(kspace)

        # Apply mask
        if self.mask_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            masked_kspace, mask = transforms.apply_mask(
                kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace

        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image to given resolution if larger
        smallest_width = min(self.resolution, image.shape[-2])
        smallest_height = min(self.resolution, image.shape[-3])
        if target is not None:
            smallest_width = min(smallest_width, target.shape[-1])
            smallest_height = min(smallest_height, target.shape[-2])

        crop_size = (smallest_height, smallest_width)
        image = transforms.complex_center_crop(image, crop_size)
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)
        # Normalize target
        if target is not None:
            target = transforms.to_tensor(target)
            target = transforms.center_crop(target, crop_size)
            target = transforms.normalize(target, mean, std, eps=1e-11)
            target = target.clamp(-6, 6)
        else:
            target = torch.Tensor([0])
        return image, target, mean, std, fname, slice
예제 #8
0
def evaluate(args, recons_key):
    metrics = Metrics(METRIC_FUNCS)

    for tgt_file in args.target_path.iterdir():
        with h5py.File(tgt_file, 'r') as target, h5py.File(
                args.predictions_path / tgt_file.name, 'r') as recons:
            if args.acquisition and args.acquisition != target.attrs[
                    'acquisition']:
                continue

            if args.acceleration and target.attrs[
                    'acceleration'] != args.acceleration:
                continue

            target = target[recons_key][()]
            recons = recons['reconstruction'][()]
            target = transforms.center_crop(
                target, (target.shape[-1], target.shape[-1]))
            recons = transforms.center_crop(
                recons, (target.shape[-1], target.shape[-1]))
            metrics.push(target, recons)
    return metrics
예제 #9
0
def k_space_to_image_with_mask(kspace, mask_func=None, seed=None):
    #use_seed = False
    #seed = None if not use_seed else tuple(map(ord, fname))
    #seed = 42
    #print(fname)
    #kspace = transforms.to_tensor(kspace)
    if mask_func:
        masked_kspace, mask = transforms.apply_mask(kspace, mask_func, seed)
        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
    else:
        image = transforms.ifft2(kspace)
    image = transforms.complex_abs(image)
    image = transforms.center_crop(image, (320, 320))
    # Normalize input
    image, mean, std = transforms.normalize_instance(image, eps=1e-11)
    image = image.clamp(-6, 6)
    return image
예제 #10
0
def cs_total_variation(args, kspace):
    """
    Run ESPIRIT coil sensitivity estimation and Total Variation Minimization based
    reconstruction algorithm using the BART toolkit.
    """

    if args.challenge == 'singlecoil':
        kspace = kspace.unsqueeze(0)
    kspace = kspace.permute(1, 2, 0, 3).unsqueeze(0)
    kspace = tensor_to_complex_np(kspace)

    # Estimate sensitivity maps
    sens_maps = bart.bart(1, f'ecalib -d0 -m1', kspace)

    # Use Total Variation Minimization to reconstruct the image
    pred = bart.bart(
        1, f'pics -d0 -S -R T:7:0:{args.reg_wt} -i {args.num_iters}', kspace,
        sens_maps)
    pred = torch.from_numpy(np.abs(pred[0]))

    # Crop the predicted image to the correct size
    return transforms.center_crop(pred, (args.resolution, args.resolution))
예제 #11
0
def evaluate():
    args = create_arg_parser().parse_args()
    args.target_path = f'{args.data_path}/{args.data_split}'
    args.predictions_path = f'summary/{args.test_name}/rec'
    metrics = Metrics(METRIC_FUNCS)
    i=0

    for rcn_file in pathlib.Path(args.predictions_path).iterdir():
        with h5py.File(rcn_file) as recons, h5py.File(
          args.target_path +'/'+ rcn_file.name) as target:
            target = target['data'][()]
            target = transforms.to_tensor(target)
            target = transforms.center_crop(target.permute(2, 0, 1), (144, 144))
            target, mean, std = transforms.normalize_instance(target, eps=1e-11)
            target = target.clamp(-6, 6)
            recons = recons['reconstruction'][()]
            if target.max() != 0:
                target -= target.min()
                target /= target.max()
                recons -= recons.min()
                recons /= recons.max()
                metrics.push(target.numpy(), recons.squeeze())
    return metrics
예제 #12
0
def test_center_crop(shape, target_shape):
    input = create_input(shape)
    out_torch = transforms.center_crop(input, target_shape).numpy()
    assert list(out_torch.shape) == target_shape