Пример #1
0
def train_step(model, data, device):
    input, target, mean, std, mean_image, std_image, mask = data
    input = input.to(device)
    mask = mask.to(device)
    target = target.to(device)
    output = model(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)

    # Projection to consistent K-space
    output = input * mask + (1-mask) * output
    
    # Consistent K-space loss (with the normalized output and target)
    loss_k_consistent = F.l1_loss(output, target) 

    mean = mean.to(device)
    std = std.to(device)

    target = transforms.unnormalize(target, mean, std)
    output = transforms.unnormalize(output, mean, std)

    output_image = transforms.ifft2(output)
    target_image = transforms.ifft2(target)

    output_image = transforms.complex_center_crop(output_image, (320, 320))
    output_image = transforms.complex_abs(output_image)
    target_image = transforms.complex_center_crop(target_image, (320, 320))
    target_image = transforms.complex_abs(target_image)
    mean_image = mean_image.unsqueeze(1).unsqueeze(2).to(device)
    std_image = std_image.unsqueeze(1).unsqueeze(2).to(device)
    output_image = transforms.normalize(output_image, mean_image, std_image)
    target_image = transforms.normalize(target_image, mean_image, std_image)
    target_image = target_image.clamp(-6, 6)
    # Consistent image loss (with the unnormalized output and target)
    loss_image = F.l1_loss(output_image, target_image)
    loss = loss_k_consistent + loss_image
    return loss
Пример #2
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """

        target_inference = transforms.to_tensor(target)

        kspace = transforms.to_tensor(kspace)
        target = transforms.ifft2(kspace)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        if self.use_mask:
            mask = transforms.get_mask(kspace, self.mask_func, seed)
            masked_kspace = mask * kspace
        else:
            masked_kspace = kspace

        image = transforms.ifft2(masked_kspace)
        image_crop = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        _, mean, std = transforms.normalize_instance_complex(image_crop,
                                                             eps=1e-11)

        image_abs = transforms.complex_abs(image_crop)
        image_abs, mean_abs, std_abs = transforms.normalize_instance(image_abs,
                                                                     eps=1e-11)

        image = transforms.normalize(image, mean, std)

        target_image_complex_norm = transforms.normalize(target, mean, std)
        target_kspace_train = transforms.fft2(target_image_complex_norm)

        target = transforms.complex_center_crop(target, (320, 320))
        target = transforms.complex_abs(target)
        target_train = target

        if RENORM:
            target_train = transforms.normalize(target_train, mean_abs,
                                                std_abs)

        if CLAMP:
            image = image.clamp(-6, 6)
            target_train = target_train.clamp(-6, 6)

        return image, target_train, target_kspace_train, mean, std, mask, mean_abs, std_abs, target_inference, attrs[
            'max'], attrs['norm'].astype(np.float32)
Пример #3
0
 def forward(self, masked_kspace, mask):
     sens_maps = self.sens_net(masked_kspace, mask)
     kspace_pred = masked_kspace.clone()
     for cascade in self.cascades:
         kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps)
     return T.root_sum_of_squares(T.complex_abs(T.ifft2(kspace_pred)),
                                  dim=1)
Пример #4
0
 def __call__(self, k_space, mask, target, attrs, f_name, slice):
     """
     Args:
         kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
             data or (rows, cols, 2) for single coil data.
         mask (numpy.array): Mask from the test dataset
         target (numpy.array): Target image
         attrs (dict): Acquisition related information stored in the HDF5 object.
         fname (str): File name
         slice (int): Serial number of the slice.
     Returns:
         (tuple): tuple containing:
             k_space (torch.Tensor): k-space(resolution x resolution x 2)
             target (torch.Tensor): Target image converted to a torch Tensor.
             fname (str): File name
             slice (int): Serial number of the slice.
     """
     k_space = transforms.to_tensor(k_space)
     full_image = transforms.ifft2(k_space)
     cropped_image = transforms.complex_center_crop(
         full_image, (self.resolution, self.resolution))
     k_space = transforms.fft2(cropped_image)
     # Normalize input
     cropped_image, mean, std = transforms.normalize_instance(cropped_image,
                                                              eps=1e-11)
     cropped_image = cropped_image.clamp(-6, 6)
     # Normalize target
     target = transforms.to_tensor(target)
     target = transforms.center_crop(target,
                                     (self.resolution, self.resolution))
     target = transforms.normalize(target, mean, std, eps=1e-11)
     target = target.clamp(-6, 6)
     return k_space, target, f_name, slice
Пример #5
0
 def __call__(self, kspace, target, attrs, fname, slice):
     """
     Args:
         kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
             data or (rows, cols, 2) for single coil data.
         target (numpy.array): Target image
         attrs (dict): Acquisition related information stored in the HDF5 object.
         fname (str): File name
         slice (int): Serial number of the slice.
     Returns:
         (tuple): tuple containing:
             image (torch.Tensor): Zero-filled input image.
             target (torch.Tensor): Target image converted to a torch Tensor.
             mean (float): Mean value used for normalization.
             std (float): Standard deviation value used for normalization.
             norm (float): L2 norm of the entire volume.
     """
     kspace = transforms.to_tensor(kspace)
     # Apply mask
     seed = None if not self.use_seed else tuple(map(ord, fname))
     if self.use_mask:
         mask = transforms.get_mask(kspace, self.mask_func, seed)
         masked_kspace = mask * kspace
     else:
         masked_kspace = kspace
     image = transforms.ifft2(masked_kspace)
     _, mean_image, std_image = transforms.normalize_instance(image, eps=1e-11)
     masked_kspace, mean, std = transforms.normalize_instance_complex(masked_kspace)
     kspace = transforms.normalize(kspace, mean, std)
     
     return masked_kspace, kspace, mean, std, mean_image, std_image, mask
Пример #6
0
    def __call__(self, kspace, target, challenge, fname, slice_index):
        original_kspace = transforms.to_tensor(kspace)

        if self.reduce:
            original_kspace = reducedimension(original_kspace, self.resolution)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace, mask = transforms.apply_mask(original_kspace,
                                                    self.mask_func, seed)

        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)

        target = transforms.to_tensor(target)
        # Normalize target
        target = transforms.normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)

        if self.polar:
            original_kspace = cartesianToPolar(original_kspace)
            masked_kspace = cartesianToPolar(masked_kspace)

        return original_kspace, masked_kspace, mask, target, fname, slice_index
Пример #7
0
def nufft_adjoint(input, coord, oshape, oversamp=1.25, width=4.0, n=128, device='cuda'):
    ndim = coord.shape[-1]
    beta = numpy.pi * (((width / oversamp) * (oversamp - 0.5)) ** 2 - 0.8) ** 0.5
    oshape = list(oshape)

    os_shape = _get_oversamp_shape(oshape, ndim, oversamp)

    # Gridding
    coord = _scale_coord(coord, oshape, oversamp, device)
    kernel = _get_kaiser_bessel_kernel(n, width, beta, coord.dtype, device)
    output = interp.gridding(input, os_shape, width, kernel, coord, device)

    # IFFT
    output = output.permute(0, 2, 3, 1)

    # plt.figure()
    # plt.imshow(print_complex_kspace_tensor(output[0].detach().cpu()), cmap='gray')
    # plt.show()

    output = transforms.ifft2(output)

    # plt.figure()
    # plt.imshow(print_complex_image_tensor(output[0].detach().cpu()), cmap='gray')
    # plt.show()

    # Crop
    output = output.permute(0, 3, 1, 2)
    output = util.resize(output, oshape, device=device)
    output *= util.prod(os_shape[-ndim:]) / util.prod(oshape[-ndim:]) ** 0.5

    # Apodize
    output = _apodize(output, ndim, oversamp, width, beta, device)
    return output.permute(0, 2, 3, 1)
Пример #8
0
    def __call__(self, ksp, sens, mask, fname, slice):

        mask = torch.from_numpy(mask)
        mask = (torch.stack((mask, mask), dim=-1)).float()

        ksp_cmplx = ksp[:, :, ::2] + 1j * ksp[:, :, 1::2]
        sens_t = T.to_tensor(sens)
        ksp_t = T.to_tensor(ksp_cmplx)
        ksp_us = ksp_t.permute(2, 0, 1, 3)

        img_us = T.ifft2(ksp_us)
        img_us_sens = T.combine_all_coils(img_us, sens_t)

        pha_us = T.phase(img_us_sens)
        mag_us = T.complex_abs(img_us_sens)

        mag_us_pad = T.pad(mag_us, [256, 256])
        pha_us_pad = T.pad(pha_us, [256, 256])

        ksp_us_np = ksp
        ksp_us_np = ksp_us_np[:, :, ::2] + 1j * ksp_us_np[:, :, 1::2]

        img_us_np = T.zero_filled_reconstruction(ksp_us_np)

        return mag_us_pad / mag_us_pad.max(
        ), pha_us_pad, ksp_us / mag_us_pad.max(
        ), sens_t, mask, fname.name, slice, img_us_np.max()
Пример #9
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.Array): k-space measurements
            target (numpy.Array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object
            fname (pathlib.Path): Path to the input file
            slice (int): Serial number of the slice
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Normalized zero-filled input image
                mean (float): Mean of the zero-filled image
                std (float): Standard deviation of the zero-filled image
                fname (pathlib.Path): Path to the input file
                slice (int): Serial number of the slice
        """
        kspace = transforms.to_tensor(kspace)
        image = transforms.ifft2(kspace)
        image = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)

        image = transforms.complex_abs(image)
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)
        kspace = transforms.rfft2(image)
        return kspace, mean, std, fname, slice
 def __call__(self, kspace, target, attrs, fname, slice):
     """
     Args:
         kspace (numpy.Array): k-space measurements
         target (numpy.Array): Target image
         attrs (dict): Acquisition related information stored in the HDF5 object
         fname (pathlib.Path): Path to the input file
         slice (int): Serial number of the slice
     Returns:
         (tuple): tuple containing:
             image (torch.Tensor): Normalized zero-filled input image
             mean (float): Mean of the zero-filled image
             std (float): Standard deviation of the zero-filled image
             fname (pathlib.Path): Path to the input file
             slice (int): Serial number of the slice
     """
     kspace = transforms.to_tensor(kspace)
     if self.mask_func is not None:
         seed = tuple(map(ord, fname))
         masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func, seed)
     else:
         masked_kspace = kspace
     # Inverse Fourier Transform to get zero filled solution
     image = transforms.ifft2(masked_kspace)
     # Crop input image
     image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
     # Absolute value
     image = transforms.complex_abs(image)
     # Apply Root-Sum-of-Squares if multicoil data
     if self.which_challenge == 'multicoil':
         image = transforms.root_sum_of_squares(image)
     # Normalize input
     image, mean, std = transforms.normalize_instance(image)
     image = image.clamp(-6, 6)
     return image, mean, std, fname, slice
Пример #11
0
def data_for_training(rawdata, sensitivity, mask_func, norm=True):
    ''' normalize each slice using complex absolute max value'''

    rawdata = T.to_tensor(np.complex64(rawdata.transpose(2, 0, 1)))

    sensitivity = T.to_tensor(sensitivity.transpose(2, 0, 1))

    coils, Ny, Nx, ps = rawdata.shape

    # shift data
    shift_kspace = rawdata
    x, y = np.meshgrid(np.arange(1, Nx + 1), np.arange(1, Ny + 1))
    adjust = (-1)**(x + y)
    shift_kspace = T.ifftshift(shift_kspace, dim=(
        -3, -2)) * torch.from_numpy(adjust).view(1, Ny, Nx, 1).float()

    # apply masks
    shape = np.array(shift_kspace.shape)
    shape[:-3] = 1
    mask = mask_func(shape)
    mask = T.ifftshift(mask)  # shift mask

    # undersample
    masked_kspace = torch.where(mask == 0, torch.Tensor([0]), shift_kspace)
    masks = mask.repeat(coils, Ny, 1, ps)

    img_gt, img_und = T.ifft2(shift_kspace), T.ifft2(masked_kspace)

    if norm:
        # perform k space raw data normalization
        # during inference there is no ground truth image so use the zero-filled recon to normalize
        norm = T.complex_abs(img_und).max()
        if norm < 1e-6: norm = 1e-6
        # normalized recon
    else:
        norm = 1

    # normalize data to learn more effectively
    img_gt, img_und = img_gt / norm, img_und / norm

    rawdata_und = masked_kspace / norm  # faster

    sense_gt = cobmine_all_coils(img_gt, sensitivity)

    sense_und = cobmine_all_coils(img_und, sensitivity)

    return sense_und, sense_gt, rawdata_und, masks, sensitivity
Пример #12
0
def forward_adjoint_helper(device, hparams, mask_func, kspace, target=None):
    masked_kspace, _ = apply_mask(device, mask_func, kspace, hparams.seed)
    if not torch.is_tensor(masked_kspace):
        kspace_tensor = transforms.to_tensor(masked_kspace)
    else:
        kspace_tensor = masked_kspace
    image = transforms.ifft2(kspace_tensor)
    image, image_abs, _, _, _ = resize(hparams, image, target)
    return image, image_abs
Пример #13
0
def k_space_to_image_with_mask(kspace, mask_func=None, seed=None):
    #use_seed = False
    #seed = None if not use_seed else tuple(map(ord, fname))
    #seed = 42
    #print(fname)
    #kspace = transforms.to_tensor(kspace)
    if mask_func:
        masked_kspace, mask = transforms.apply_mask(kspace, mask_func, seed)
        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
    else:
        image = transforms.ifft2(kspace)
    image = transforms.complex_abs(image)
    image = transforms.center_crop(image, (320, 320))
    # Normalize input
    image, mean, std = transforms.normalize_instance(image, eps=1e-11)
    image = image.clamp(-6, 6)
    return image
Пример #14
0
    def __call__(self, kspace, mask, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            mask (numpy.array): Mask from the test dataset
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
        """

        kspace = transforms.to_tensor(kspace)

        # Apply mask
        if self.mask_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            masked_kspace, mask = transforms.apply_mask(
                kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace

        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image to given resolution if larger
        smallest_width = min(self.resolution, image.shape[-2])
        smallest_height = min(self.resolution, image.shape[-3])
        if target is not None:
            smallest_width = min(smallest_width, target.shape[-1])
            smallest_height = min(smallest_height, target.shape[-2])

        crop_size = (smallest_height, smallest_width)
        image = transforms.complex_center_crop(image, crop_size)
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)
        # Normalize target
        if target is not None:
            target = transforms.to_tensor(target)
            target = transforms.center_crop(target, crop_size)
            target = transforms.normalize(target, mean, std, eps=1e-11)
            target = target.clamp(-6, 6)
        else:
            target = torch.Tensor([0])
        return image, target, mean, std, fname, slice
Пример #15
0
def data_for_training(rawdata, sensitivity, mask, norm=True):
    ''' normalize each slice using complex absolute max value'''

    coils, Ny, Nx, ps = rawdata.shape

    # shift data
    shift_kspace = rawdata
    x, y = np.meshgrid(np.arange(1, Nx + 1), np.arange(1, Ny + 1))
    adjust = (-1)**(x + y)
    shift_kspace = T.ifftshift(shift_kspace, dim=(
        -3, -2)) * torch.from_numpy(adjust).view(1, Ny, Nx, 1).float()

    #masked_kspace = torch.where(mask == 0, torch.Tensor([0]), shift_kspace)
    mask = T.ifftshift(mask)
    mask = mask.unsqueeze(0).unsqueeze(-1).float()
    mask = mask.repeat(coils, 1, 1, ps)

    masked_kspace = shift_kspace * mask

    img_gt, img_und = T.ifft2(shift_kspace), T.ifft2(masked_kspace)

    if norm:
        # perform k space raw data normalization
        # during inference there is no ground truth image so use the zero-filled recon to normalize
        norm = T.complex_abs(img_und).max()
        if norm < 1e-6: norm = 1e-6
        # normalized recon
    else:
        norm = 1

    # normalize data to learn more effectively
    img_gt, img_und = img_gt / norm, img_und / norm

    rawdata_und = masked_kspace / norm  # faster

    sense_gt = cobmine_all_coils(img_gt, sensitivity)

    sense_und = cobmine_all_coils(img_und, sensitivity)

    sense_und_kspace = T.fft2(sense_und)

    return sense_und, sense_gt, sense_und_kspace, rawdata_und, mask, sensitivity
Пример #16
0
def imagenormalize(data, divisor=None):
    """kspace generated by normalizing image space"""
    #getting image from masked data
    image = transforms.ifft2(data)
    #normalizing the image
    nimage, divisor = normalize(image, divisor)
    #getting kspace data from normalized image
    data = transforms.ifftshift(image, dim=(-3, -2))
    data = torch.fft(data, 2)
    data = transforms.fftshift(data, dim=(-3, -2))
    return data, divisor
Пример #17
0
    def __call__(self, kspace, target, attrs, fname, slice):
        kspace_rect = transforms.to_tensor(kspace)  ##rectangular kspace

        image_rect = transforms.ifft2(kspace_rect)  ##rectangular FS image
        image_square = transforms.complex_center_crop(
            image_rect,
            (self.resolution, self.resolution))  ##cropped to FS square image
        kspace_square = self.c3object.apply(
            transforms.fft2(image_square))  #* 10000  ##kspace of square iamge

        if self.augmentation:
            kspace_square = self.augmentation.apply(kspace_square)

        image_square = ifft_c3(kspace_square)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace_square, mask = transforms.apply_mask(
            kspace_square, self.mask_func, seed)  ##ZF square kspace

        # Inverse Fourier Transform to get zero filled solution
        # image = transforms.ifft2(masked_kspace)
        image_square_us = ifft_c3(
            masked_kspace_square)  ## US square complex image

        # Crop input image
        # image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        # image = transforms.complex_abs(image)
        image_square_abs = transforms.complex_abs(
            image_square_us)  ## US square real image

        # Apply Root-Sum-of-Squares if multicoil data
        # if self.which_challenge == 'multicoil':
        #     image = transforms.root_sum_of_squares(image)
        # Normalize input
        # image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        _, mean, std = transforms.normalize_instance(image_square_abs,
                                                     eps=1e-11)
        # image = image.clamp(-6, 6)

        # target = transforms.to_tensor(target)
        target = image_square.permute(2, 0, 1)
        # Normalize target
        # target = transforms.normalize(target, mean, std, eps=1e-11)
        # target = target.clamp(-6, 6)
        # return image, target, mean, std, attrs['norm'].astype(np.float32)

        # return masked_kspace_square.permute((2,0,1)), image, image_square.permute(2,0,1), mean, std, attrs['norm'].astype(np.float32)

        # ksp, zf, target, me, st, nor
        return masked_kspace_square.permute((2,0,1)), image_square_us.permute((2,0,1)), \
            target,  \
            mean, std, attrs['norm'].astype(np.float32)
Пример #18
0
def test_ifft2(shape):
    shape = shape + [2]
    input = create_input(shape)
    out_torch = transforms.ifft2(input).numpy()
    out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

    input_numpy = utils.tensor_to_complex_np(input)
    input_numpy = np.fft.ifftshift(input_numpy, (-2, -1))
    out_numpy = np.fft.ifft2(input_numpy, norm='ortho')
    out_numpy = np.fft.fftshift(out_numpy, (-2, -1))
    assert np.allclose(out_torch, out_numpy)
Пример #19
0
def mnormalize(masked_kspace):
    #getting image from masked data
    image = transforms.ifft2(masked_kspace)
    #normalizing the image
    nimage, mean, std = transforms.normalize_instance(image, eps=1e-11)
    #getting kspace data from normalized image
    maksed_kspace_fni = transforms.ifftshift(nimage, dim=(-3, -2))
    maksed_kspace_fni = torch.fft(maksed_kspace_fni, 2)
    maksed_kspace_fni = transforms.fftshift(maksed_kspace_fni, dim=(-3, -2))
    maksed_kspace_fni, mean, std = transforms.normalize_instance(masked_kspace,
                                                                 eps=1e-11)
    return maksed_kspace_fni, mean, std
Пример #20
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        kspace = transforms.to_tensor(kspace)
        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        if self.use_mask:
            mask = transforms.get_mask(kspace, self.mask_func, seed)
            masked_kspace = mask * kspace
        else:
            masked_kspace = kspace

        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)

        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        if CLAMP:
            image = image.clamp(-6, 6)

        # Normalize target
        target = transforms.to_tensor(target)
        target_train = transforms.normalize(target, mean, std, eps=1e-11)
        if CLAMP:
            target_train = target_train.clamp(
                -6,
                6)  # Return target (for viz) and target_clamped (for training)

        return image, target_train, mean, std, attrs['norm'].astype(
            np.float32), target
Пример #21
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.Array): k-space measurements
            target (numpy.Array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object
            fname (pathlib.Path): Path to the input file
            slice (int): Serial number of the slice
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Normalized zero-filled input image
                mean (float): Mean of the zero-filled image
                std (float): Standard deviation of the zero-filled image
                fname (pathlib.Path): Path to the input file
                slice (int): Serial number of the slice
        """
        kspace = transforms.to_tensor(kspace)
        if self.mask_func is not None:
            seed = tuple(map(ord, fname))
            masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace
        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image)
        image = image.clamp(-6, 6)

        # difference between kspace actual and target dim
        extra = int(masked_kspace.shape[1] - self.kspace_x)

        # clip kspace at input dim
        if extra > 0:
            masked_kspace = masked_kspace[:, (extra//2):-(extra//2), :]

        # zero pad if necessary
        elif extra < 0:
            empty_kspace = torch.zeros((masked_kspace.shape[0], self.kspace_x, masked_kspace.shape[2]))
            empty_kspace[:, -(extra//2):(extra//2), :] = masked_kspace
            masked_kspace = empty_kspace

        #TODO return mask as well for exclusive updates
        return masked_kspace, image, mean, std, fname, slice
Пример #22
0
    def to_spatial(self, kspace, resolution):
        '''
        k space: pytorch tensor post enchancement
        '''
        # Inverse Fourier Transform to get interpolated solution
        image = transforms.ifft2(kspace)
        # Crop input image
        image = transforms.complex_center_crop(image, (resolution, resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        return image
Пример #23
0
def onormalize(original_kspace, mean, std, eps=1e-11):
    #getting image from masked data
    image = transforms.ifft2(original_kspace)
    #normalizing the image
    nimage = transforms.normalize(image, mean, std, eps=1e-11)
    #getting kspace data from normalized image
    original_kspace_fni = transforms.ifftshift(nimage, dim=(-3, -2))
    original_kspace_fni = torch.fft(original_kspace_fni, 2)
    original_kspace_fni = transforms.fftshift(original_kspace_fni,
                                              dim=(-3, -2))
    original_kspace_fni = transforms.normalize(original_kspace,
                                               mean,
                                               std,
                                               eps=1e-11)
    return original_kspace_fni
Пример #24
0
def inference(model, data, device):
    with torch.no_grad():
        input, target, mean, std, _, _, mask = data
        input = input.to(device)
        mask = mask.to(device)
        output = model(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        output = input * mask + (1-mask) * output
        target = target.to(device)

        mean = mean.to(device)
        std = std.to(device)

        output = transforms.unnormalize(output, mean, std)
        target = transforms.unnormalize(target, mean, std)

        output = transforms.ifft2(output)
        target = transforms.ifft2(target)

        output = transforms.complex_center_crop(output, (320, 320))
        output = transforms.complex_abs(output)
        target = transforms.complex_center_crop(target, (320, 320))
        target = transforms.complex_abs(target)

        return output, target
Пример #25
0
def kspacetoimage(kspace, args):
    # Inverse Fourier Transform to get zero filled solution
    image = transforms.ifft2(kspace)
    # Crop input image
    image = transforms.complex_center_crop(image,
                                           (args.resolution, args.resolution))
    # Absolute value
    image = transforms.complex_abs(image)
    # Apply Root-Sum-of-Squares if multicoil data
    if args.challenge == 'multicoil':
        image = transforms.root_sum_of_squares(image)
        # Normalize input
    image, mean, std = transforms.normalize_instance(image, eps=1e-11)
    image = image.clamp(-6, 6)

    return image
Пример #26
0
def main():
    torch.manual_seed(0)
    args = Args().parse_args()
    args.data_path = "../" + args.data_path
    index = 3
    train_data_loader, val_data_loader, display_data_loader = load_data(args)
    for k_space, target, f_name, slice in display_data_loader:
        sampling_vector = [[[i, j] for i in range(k_space.shape[1])]
                           for j in range(k_space.shape[2])]
        sampling_vector = torch.tensor(sampling_vector).float()
        sampling_vector = sampling_vector - 0.5 * k_space.shape[1]
        sampling_vector = sampling_vector.reshape(-1, 2)
        sampling_vector = sampling_vector.expand(k_space.shape[0], -1, -1)
        images = sample_vector(k_space, sampling_vector)
        break

    for i in range(images.shape[0]):
        show(ifft2(images[i]))
Пример #27
0
def kspaceto2dimage(kspace, polar, cropping=False, resolution=None):
    if polar:
        kspace = polarToCartesian(kspace)

    if cropping:
        if not resolution:
            raise Exception(
                "If cropping = True, pass the value for resolution for the function: kspaceto2dimage"
            )
        image = croppedimage(kspace, resolution)
    else:
        image = transforms.ifft2(kspace)
    # Absolute value
    image = transforms.complex_abs(image)
    # Normalize input
    image, mean, std = transforms.normalize_instance(image, eps=1e-11)
    image = image.clamp(-6, 6)

    return image
Пример #28
0
def save_zero_filled(data_dir, out_dir, which_challenge, resolution):
    reconstructions = {}

    for file in data_dir.iterdir():
        print("file:{}".format(file))
        with h5py.File(file, "r") as hf:
            masked_kspace = transforms.to_tensor(hf['kspace'][()])
            # Inverse Fourier Transform to get zero filled solution
            image = transforms.ifft2(masked_kspace)
            # Crop input image
            smallest_width = min(resolution, image.shape[-2])
            smallest_height = min(resolution, image.shape[-3])
            image = transforms.complex_center_crop(image, (smallest_height, smallest_width))
            # Absolute value
            image = transforms.complex_abs(image)
            # Apply Root-Sum-of-Squares if multicoil data
            if which_challenge == 'multicoil':
                image = transforms.root_sum_of_squares(image, dim=1)

            reconstructions[file.name] = image
    save_reconstructions(reconstructions, out_dir)
Пример #29
0
    def forward(self, masked_kspace, mask):
        def get_low_frequency_lines(mask):
            l = r = mask.shape[-2] // 2
            while mask[..., r, :]:
                r += 1

            while mask[..., l, :]:
                l -= 1

            return l + 1, r

        l, r = get_low_frequency_lines(mask)
        num_low_freqs = r - l
        pad = (mask.shape[-2] - num_low_freqs + 1) // 2
        x = T.mask_center(masked_kspace, pad, pad + num_low_freqs)
        x = T.ifft2(x)
        x, b = self.chans_to_batch_dim(x)
        x = self.norm_unet(x)
        x = self.batch_chans_to_chan_dim(x, b)
        x = self.divide_root_sum_of_squares(x)
        return x
Пример #30
0
 def __call__(self, kspace, target, attrs, fname, slice_info):
     """
     Args:
         kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
             data or (rows, cols, 2) for single coil data.
         target (numpy.array): Target image
         attrs (dict): Acquisition related information stored in the HDF5 object.
         fname (str): File name
         slice (int): Serial number of the slice.
     Returns:
         (tuple): tuple containing:
             image (torch.Tensor): Zero-filled input image.
             target (torch.Tensor): Target image converted to a torch Tensor.
             mean (float): Mean value used for normalization.
             std (float): Standard deviation value used for normalization.
     """
     kspace = transforms.to_tensor(kspace)
     # Inverse Fourier Transform to get zero filled solution
     image = transforms.ifft2(kspace)
     # Crop input image to given resolution if larger
     image, _, target, mean, std = resize(self.hparams, image, target)
     return image, target, kspace, mean, std, fname, slice_info