Esempio n. 1
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """

        target_inference = transforms.to_tensor(target)

        kspace = transforms.to_tensor(kspace)
        target = transforms.ifft2(kspace)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        if self.use_mask:
            mask = transforms.get_mask(kspace, self.mask_func, seed)
            masked_kspace = mask * kspace
        else:
            masked_kspace = kspace

        image = transforms.ifft2(masked_kspace)
        image_crop = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        _, mean, std = transforms.normalize_instance_complex(image_crop,
                                                             eps=1e-11)

        image_abs = transforms.complex_abs(image_crop)
        image_abs, mean_abs, std_abs = transforms.normalize_instance(image_abs,
                                                                     eps=1e-11)

        image = transforms.normalize(image, mean, std)

        target_image_complex_norm = transforms.normalize(target, mean, std)
        target_kspace_train = transforms.fft2(target_image_complex_norm)

        target = transforms.complex_center_crop(target, (320, 320))
        target = transforms.complex_abs(target)
        target_train = target

        if RENORM:
            target_train = transforms.normalize(target_train, mean_abs,
                                                std_abs)

        if CLAMP:
            image = image.clamp(-6, 6)
            target_train = target_train.clamp(-6, 6)

        return image, target_train, target_kspace_train, mean, std, mask, mean_abs, std_abs, target_inference, attrs[
            'max'], attrs['norm'].astype(np.float32)
Esempio n. 2
0
def generate(generator, data, device):

    input, _, mean, std, mask, _, _, _ = data
    input = input.to(device)
    mask = mask.to(device)

    output_network = generator(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)

    # Projection to consistent K-space
    output_consistent, target_kspace, output_kspace = project_to_consistent_subspace(
        output_network, input, mask)

    # Take loss on the cropped, real valued image (abs)
    mean = mean.to(device)
    std = std.to(device)
    mean = mean.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(device)
    std = std.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(device)
    output_consistent = transforms.unnormalize(output_consistent, mean, std)
    output_consistent = transforms.complex_center_crop(output_consistent,
                                                       (320, 320))
    output_consistent = transforms.complex_abs(output_consistent)

    output_network = transforms.unnormalize(output_network, mean, std)
    output_network = transforms.complex_center_crop(output_network, (320, 320))
    output_network = transforms.complex_abs(output_network)

    return output_consistent, output_network, target_kspace, output_kspace
Esempio n. 3
0
def train_step(model, data, device):
    input, target, mean, std, mean_image, std_image, mask = data
    input = input.to(device)
    mask = mask.to(device)
    target = target.to(device)
    output = model(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)

    # Projection to consistent K-space
    output = input * mask + (1-mask) * output
    
    # Consistent K-space loss (with the normalized output and target)
    loss_k_consistent = F.l1_loss(output, target) 

    mean = mean.to(device)
    std = std.to(device)

    target = transforms.unnormalize(target, mean, std)
    output = transforms.unnormalize(output, mean, std)

    output_image = transforms.ifft2(output)
    target_image = transforms.ifft2(target)

    output_image = transforms.complex_center_crop(output_image, (320, 320))
    output_image = transforms.complex_abs(output_image)
    target_image = transforms.complex_center_crop(target_image, (320, 320))
    target_image = transforms.complex_abs(target_image)
    mean_image = mean_image.unsqueeze(1).unsqueeze(2).to(device)
    std_image = std_image.unsqueeze(1).unsqueeze(2).to(device)
    output_image = transforms.normalize(output_image, mean_image, std_image)
    target_image = transforms.normalize(target_image, mean_image, std_image)
    target_image = target_image.clamp(-6, 6)
    # Consistent image loss (with the unnormalized output and target)
    loss_image = F.l1_loss(output_image, target_image)
    loss = loss_k_consistent + loss_image
    return loss
Esempio n. 4
0
 def __call__(self, k_space, mask, target, attrs, f_name, slice):
     """
     Args:
         kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
             data or (rows, cols, 2) for single coil data.
         mask (numpy.array): Mask from the test dataset
         target (numpy.array): Target image
         attrs (dict): Acquisition related information stored in the HDF5 object.
         fname (str): File name
         slice (int): Serial number of the slice.
     Returns:
         (tuple): tuple containing:
             k_space (torch.Tensor): k-space(resolution x resolution x 2)
             target (torch.Tensor): Target image converted to a torch Tensor.
             fname (str): File name
             slice (int): Serial number of the slice.
     """
     k_space = transforms.to_tensor(k_space)
     full_image = transforms.ifft2(k_space)
     cropped_image = transforms.complex_center_crop(
         full_image, (self.resolution, self.resolution))
     k_space = transforms.fft2(cropped_image)
     # Normalize input
     cropped_image, mean, std = transforms.normalize_instance(cropped_image,
                                                              eps=1e-11)
     cropped_image = cropped_image.clamp(-6, 6)
     # Normalize target
     target = transforms.to_tensor(target)
     target = transforms.center_crop(target,
                                     (self.resolution, self.resolution))
     target = transforms.normalize(target, mean, std, eps=1e-11)
     target = target.clamp(-6, 6)
     return k_space, target, f_name, slice
 def __call__(self, kspace, target, attrs, fname, slice):
     """
     Args:
         kspace (numpy.Array): k-space measurements
         target (numpy.Array): Target image
         attrs (dict): Acquisition related information stored in the HDF5 object
         fname (pathlib.Path): Path to the input file
         slice (int): Serial number of the slice
     Returns:
         (tuple): tuple containing:
             image (torch.Tensor): Normalized zero-filled input image
             mean (float): Mean of the zero-filled image
             std (float): Standard deviation of the zero-filled image
             fname (pathlib.Path): Path to the input file
             slice (int): Serial number of the slice
     """
     kspace = transforms.to_tensor(kspace)
     if self.mask_func is not None:
         seed = tuple(map(ord, fname))
         masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func, seed)
     else:
         masked_kspace = kspace
     # Inverse Fourier Transform to get zero filled solution
     image = transforms.ifft2(masked_kspace)
     # Crop input image
     image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
     # Absolute value
     image = transforms.complex_abs(image)
     # Apply Root-Sum-of-Squares if multicoil data
     if self.which_challenge == 'multicoil':
         image = transforms.root_sum_of_squares(image)
     # Normalize input
     image, mean, std = transforms.normalize_instance(image)
     image = image.clamp(-6, 6)
     return image, mean, std, fname, slice
Esempio n. 6
0
def test_complex_center_crop(shape, target_shape):
    shape = shape + [2]
    input = create_input(shape)
    out_torch = transforms.complex_center_crop(input, target_shape).numpy()
    assert list(out_torch.shape) == target_shape + [
        2,
    ]
Esempio n. 7
0
    def __call__(self, kspace, target, challenge, fname, slice_index):
        original_kspace = transforms.to_tensor(kspace)

        if self.reduce:
            original_kspace = reducedimension(original_kspace, self.resolution)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace, mask = transforms.apply_mask(original_kspace,
                                                    self.mask_func, seed)

        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)

        target = transforms.to_tensor(target)
        # Normalize target
        target = transforms.normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)

        if self.polar:
            original_kspace = cartesianToPolar(original_kspace)
            masked_kspace = cartesianToPolar(masked_kspace)

        return original_kspace, masked_kspace, mask, target, fname, slice_index
Esempio n. 8
0
    def data_consistency_kspace(self, prediction, k_space_slice, mask):
        """
        Args:
            - prediction: net (or block) predicted real image in complex domain
            - k_space_slice: initially sampled elements in k-space
            - mask: corresponding nonzero location in kspace
        Res:
            image in k space where:
                - masked entries of initial slice are replaced with entries predicted by output
                - non masked entries of initial slice stay the same
        """

        prediction = prediction[:, :, 0:320, 0:320]
        prediction = prediction.permute(
            0, 2, 3, 1)  # prediction from 1 x 2 x h x w to 1 x h x w x 2
        prediction = self.proper_padding(
            prediction, k_space_slice)  # pad prediction to be 640 x 372 x 2
        k_space_prediction = fastmri.fft2c(
            prediction)  # transform prediction to kspace domain

        k_space_out = (
            1 - mask) * k_space_prediction + mask * k_space_slice  # apply mask
        prediction = fastmri.ifft2c(k_space_out)  # back to cplx image
        prediction = transforms.complex_center_crop(
            prediction, (320, 320))  # crop image to 320 x 320
        prediction = prediction.permute(0, 3, 1, 2)  # back to 1 x 2 x h x w
        return prediction
Esempio n. 9
0
def resize(hparams, image, target):
    smallest_width = min(hparams.resolution, image.shape[-2])
    smallest_height = min(hparams.resolution, image.shape[-3])
    if target is not None:
        smallest_width = min(smallest_width, target.shape[-1])
        smallest_height = min(smallest_height, target.shape[-2])
    crop_size = (smallest_height, smallest_width)
    image = transforms.complex_center_crop(image, crop_size)
    # Absolute value
    image_abs = transforms.complex_abs(image)
    # Apply Root-Sum-of-Squares if multicoil data
    if hparams.challenge == "multicoil":
        image_abs = transforms.root_sum_of_squares(image_abs)
    # Normalize input
    image_abs, mean, std = transforms.normalize_instance(image_abs, eps=1e-11)
    image_abs = image_abs.clamp(-6, 6)

    # Normalize target
    if target is not None:
        target = transforms.to_tensor(target)
        target = transforms.center_crop(target, crop_size)
        target = transforms.normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)
    else:
        target = torch.Tensor([0])
    return image, image_abs, target, mean, std
Esempio n. 10
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.Array): k-space measurements
            target (numpy.Array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object
            fname (pathlib.Path): Path to the input file
            slice (int): Serial number of the slice
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Normalized zero-filled input image
                mean (float): Mean of the zero-filled image
                std (float): Standard deviation of the zero-filled image
                fname (pathlib.Path): Path to the input file
                slice (int): Serial number of the slice
        """
        kspace = transforms.to_tensor(kspace)
        image = transforms.ifft2(kspace)
        image = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)

        image = transforms.complex_abs(image)
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)
        kspace = transforms.rfft2(image)
        return kspace, mean, std, fname, slice
Esempio n. 11
0
    def __call__(self, kspace, mask, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            mask (numpy.array): Mask from the test dataset
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
        """

        kspace = transforms.to_tensor(kspace)

        # Apply mask
        if self.mask_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            masked_kspace, mask = transforms.apply_mask(
                kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace

        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image to given resolution if larger
        smallest_width = min(self.resolution, image.shape[-2])
        smallest_height = min(self.resolution, image.shape[-3])
        if target is not None:
            smallest_width = min(smallest_width, target.shape[-1])
            smallest_height = min(smallest_height, target.shape[-2])

        crop_size = (smallest_height, smallest_width)
        image = transforms.complex_center_crop(image, crop_size)
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)
        # Normalize target
        if target is not None:
            target = transforms.to_tensor(target)
            target = transforms.center_crop(target, crop_size)
            target = transforms.normalize(target, mean, std, eps=1e-11)
            target = target.clamp(-6, 6)
        else:
            target = torch.Tensor([0])
        return image, target, mean, std, fname, slice
Esempio n. 12
0
    def __call__(self, kspace, target, attrs, fname, slice):
        kspace_rect = transforms.to_tensor(kspace)  ##rectangular kspace

        image_rect = transforms.ifft2(kspace_rect)  ##rectangular FS image
        image_square = transforms.complex_center_crop(
            image_rect,
            (self.resolution, self.resolution))  ##cropped to FS square image
        kspace_square = self.c3object.apply(
            transforms.fft2(image_square))  #* 10000  ##kspace of square iamge

        if self.augmentation:
            kspace_square = self.augmentation.apply(kspace_square)

        image_square = ifft_c3(kspace_square)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace_square, mask = transforms.apply_mask(
            kspace_square, self.mask_func, seed)  ##ZF square kspace

        # Inverse Fourier Transform to get zero filled solution
        # image = transforms.ifft2(masked_kspace)
        image_square_us = ifft_c3(
            masked_kspace_square)  ## US square complex image

        # Crop input image
        # image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        # image = transforms.complex_abs(image)
        image_square_abs = transforms.complex_abs(
            image_square_us)  ## US square real image

        # Apply Root-Sum-of-Squares if multicoil data
        # if self.which_challenge == 'multicoil':
        #     image = transforms.root_sum_of_squares(image)
        # Normalize input
        # image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        _, mean, std = transforms.normalize_instance(image_square_abs,
                                                     eps=1e-11)
        # image = image.clamp(-6, 6)

        # target = transforms.to_tensor(target)
        target = image_square.permute(2, 0, 1)
        # Normalize target
        # target = transforms.normalize(target, mean, std, eps=1e-11)
        # target = target.clamp(-6, 6)
        # return image, target, mean, std, attrs['norm'].astype(np.float32)

        # return masked_kspace_square.permute((2,0,1)), image, image_square.permute(2,0,1), mean, std, attrs['norm'].astype(np.float32)

        # ksp, zf, target, me, st, nor
        return masked_kspace_square.permute((2,0,1)), image_square_us.permute((2,0,1)), \
            target,  \
            mean, std, attrs['norm'].astype(np.float32)
Esempio n. 13
0
    def __call__(self, kspace, target, attrs, fname, slice):
        kspace = transforms.to_tensor(kspace)
        image = transforms.ifft2_regular(kspace)
        image = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        # image = transforms.complex_abs(image)
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)
        kspace = transforms.fft2(image)

        return image, mean, std, fname, slice
Esempio n. 14
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        kspace = transforms.to_tensor(kspace)
        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        if self.use_mask:
            mask = transforms.get_mask(kspace, self.mask_func, seed)
            masked_kspace = mask * kspace
        else:
            masked_kspace = kspace

        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)

        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        if CLAMP:
            image = image.clamp(-6, 6)

        # Normalize target
        target = transforms.to_tensor(target)
        target_train = transforms.normalize(target, mean, std, eps=1e-11)
        if CLAMP:
            target_train = target_train.clamp(
                -6,
                6)  # Return target (for viz) and target_clamped (for training)

        return image, target_train, mean, std, attrs['norm'].astype(
            np.float32), target
Esempio n. 15
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.Array): k-space measurements
            target (numpy.Array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object
            fname (pathlib.Path): Path to the input file
            slice (int): Serial number of the slice
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Normalized zero-filled input image
                mean (float): Mean of the zero-filled image
                std (float): Standard deviation of the zero-filled image
                fname (pathlib.Path): Path to the input file
                slice (int): Serial number of the slice
        """
        kspace = transforms.to_tensor(kspace)
        if self.mask_func is not None:
            seed = tuple(map(ord, fname))
            masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace
        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image)
        image = image.clamp(-6, 6)

        # difference between kspace actual and target dim
        extra = int(masked_kspace.shape[1] - self.kspace_x)

        # clip kspace at input dim
        if extra > 0:
            masked_kspace = masked_kspace[:, (extra//2):-(extra//2), :]

        # zero pad if necessary
        elif extra < 0:
            empty_kspace = torch.zeros((masked_kspace.shape[0], self.kspace_x, masked_kspace.shape[2]))
            empty_kspace[:, -(extra//2):(extra//2), :] = masked_kspace
            masked_kspace = empty_kspace

        #TODO return mask as well for exclusive updates
        return masked_kspace, image, mean, std, fname, slice
Esempio n. 16
0
    def to_spatial(self, kspace, resolution):
        '''
        k space: pytorch tensor post enchancement
        '''
        # Inverse Fourier Transform to get interpolated solution
        image = transforms.ifft2(kspace)
        # Crop input image
        image = transforms.complex_center_crop(image, (resolution, resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        return image
Esempio n. 17
0
def inference(model, data, device):
    with torch.no_grad():
        input, target, mean, std, _, _, mask = data
        input = input.to(device)
        mask = mask.to(device)
        output = model(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        output = input * mask + (1-mask) * output
        target = target.to(device)

        mean = mean.to(device)
        std = std.to(device)

        output = transforms.unnormalize(output, mean, std)
        target = transforms.unnormalize(target, mean, std)

        output = transforms.ifft2(output)
        target = transforms.ifft2(target)

        output = transforms.complex_center_crop(output, (320, 320))
        output = transforms.complex_abs(output)
        target = transforms.complex_center_crop(target, (320, 320))
        target = transforms.complex_abs(target)

        return output, target
Esempio n. 18
0
def generate(generator, data, device):

    input, _, mean, std, mask, _, _, _ = data
    input = input.to(device)
    mask = mask.to(device)

    output_network = generator(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)

    # Take loss on the cropped, real valued image (abs)
    mean = mean.to(device)
    std = std.to(device)
    output_network = transforms.unnormalize(output_network, mean, std)
    output_network = transforms.complex_center_crop(output_network, (320, 320))
    output_network = transforms.complex_abs(output_network)

    return output_network
Esempio n. 19
0
    def __call__(self, kspace, target, attrs, fname, slice):
        kspace = transforms.to_tensor(kspace)
        image = transforms.ifft2_regular(kspace)
        image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # image = transforms.complex_abs(image)
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        # image, mean, std = transforms.normalize_instance_per_channel(image, eps=1e-11)
        # image = image.clamp(-6, 6)
        # kspace = transforms.fft2(image)

        target = transforms.to_tensor(target)
        target, mean, std = transforms.normalize_instance(target, eps=1e-11)
        # # target = transforms.normalize(target, mean, std)
        # target = target.clamp(-6, 6)
        mean = std = 0
        return image, target, mean, std, attrs['norm'].astype(np.float32)
Esempio n. 20
0
def kspacetoimage(kspace, args):
    # Inverse Fourier Transform to get zero filled solution
    image = transforms.ifft2(kspace)
    # Crop input image
    image = transforms.complex_center_crop(image,
                                           (args.resolution, args.resolution))
    # Absolute value
    image = transforms.complex_abs(image)
    # Apply Root-Sum-of-Squares if multicoil data
    if args.challenge == 'multicoil':
        image = transforms.root_sum_of_squares(image)
        # Normalize input
    image, mean, std = transforms.normalize_instance(image, eps=1e-11)
    image = image.clamp(-6, 6)

    return image
Esempio n. 21
0
def nkspacetoimage(args, kspace_fni, mean, std, eps=1e-11):
    #nkspace to image
    assert kspace_fni.size(-1) == 2
    image = transforms.ifftshift(kspace_fni, dim=(-3, -2))
    image = torch.ifft(image, 2)
    image = transforms.fftshift(image, dim=(-3, -2))
    #denormalizing the nimage
    image = (image * std) + mean
    image = image[0]

    image = transforms.complex_center_crop(image,
                                           (args.resolution, args.resolution))
    # Absolute value
    image = transforms.complex_abs(image)
    # Normalize input
    image, mean, std = transforms.normalize_instance(image, eps=1e-11)
    image = image.clamp(-6, 6)
    return image
Esempio n. 22
0
def generate(generator, data, device):

    input, _, mean, std, mask, _, _, _ = data
    input = input.to(device)
    mask = mask.to(device)

    # Use network to predict residual
    residual = generator(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)

    # Projection to consistent K-space
    if PROJECT:
        output = project_to_consistent_subspace(residual, input, mask)

    # Take loss on the cropped, real valued image (abs)
    mean = mean.to(device)
    std = std.to(device)
    output = transforms.unnormalize(output, mean, std)
    output = transforms.complex_center_crop(output, (320, 320))
    output = transforms.complex_abs(output)

    return output
Esempio n. 23
0
def save_zero_filled(data_dir, out_dir, which_challenge, resolution):
    reconstructions = {}

    for file in data_dir.iterdir():
        print("file:{}".format(file))
        with h5py.File(file, "r") as hf:
            masked_kspace = transforms.to_tensor(hf['kspace'][()])
            # Inverse Fourier Transform to get zero filled solution
            image = transforms.ifft2(masked_kspace)
            # Crop input image
            smallest_width = min(resolution, image.shape[-2])
            smallest_height = min(resolution, image.shape[-3])
            image = transforms.complex_center_crop(image, (smallest_height, smallest_width))
            # Absolute value
            image = transforms.complex_abs(image)
            # Apply Root-Sum-of-Squares if multicoil data
            if which_challenge == 'multicoil':
                image = transforms.root_sum_of_squares(image, dim=1)

            reconstructions[file.name] = image
    save_reconstructions(reconstructions, out_dir)
Esempio n. 24
0
def data_transform(kspace, mask_function, target, data_attributes, filename,
                   slice_num):
    """
    Perform preprocessing of the kspace image, in order to get a proper input for the net. Should be invoked from
    the SliceData class.
    Args:
        - kspace: complete sampled kspace image
        - mask_func: masking function to apply mask to kspace (TODO not working: we are passing from outside)
        - target: the target image to be reconstructed from the kspace
        - data_attributes: attributes of the whole HDF5 file

    Returns:
        - normalized_masked_image: original kspace with mask applied and cropped to 320 x 320
        - mask: mask generated by masking function
        - normalized_target: normalized target
        - max_value: highest entry in target tensor (for SSIM loss)
    """

    kspace_t = transforms.to_tensor(kspace)
    kspace_t = transforms.normalize_instance(kspace_t)[0]

    masked_kspace, mask = transforms.apply_mask(
        data=kspace_t, mask_func=mask_func
    )  # apply mask: returns masked space and generated mask
    masked_image = fastmri.ifft2c(
        masked_kspace
    )  # Apply Inverse Fourier Transform to get the complex image
    masked_image = transforms.complex_center_crop(
        masked_image, (320, 320))  # center crop masked image
    masked_image = masked_image.permute(
        2, 0, 1)  # permuting the masked image fot pytorch n x c x h x w format
    masked_image = transforms.normalize_instance(masked_image)[0]  # normalize

    target = transforms.to_tensor(target)
    target = transforms.normalize_instance(target)[0]  # normalize
    target = torch.unsqueeze(target, 0)  # add dimension

    return kspace_t, masked_image, target, mask, data_attributes[
        'max'], slice_num
Esempio n. 25
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        kspace = transforms.to_tensor(kspace)
        gt = transforms.ifft2(kspace)
        gt = transforms.complex_center_crop(gt, (self.resolution, self.resolution))
        kspace = transforms.fft2(gt)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace, mask = transforms.apply_mask(kspace, self.mask_func, seed)
        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        masked_kspace = transforms.fft2_nshift(image)
        # Crop input image
        image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        image_mod = transforms.complex_abs(image).max()
        image_r = image[:, :, 0]*6.0/image_mod
        image_i = image[:, :, 1]*6.0/image_mod
        # image_r = image[:, :, 0]
        # image_i = image[:, :, 1]
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input

        image = np.stack((image_r, image_i), axis=-1)
        image = image.transpose((2, 0, 1))
        image = transforms.to_tensor(image)

        target = transforms.ifft2(kspace)
        target = transforms.complex_center_crop(target, (self.resolution, self.resolution))
        # Normalize target
        target_r = target[:, :, 0]*6.0/image_mod
        target_i = target[:, :, 1]*6.0/image_mod
        # target_r = target[:, :, 0]
        # target_i = target[:, :, 1]

        target = np.stack((target_r, target_i), axis=-1)
        target = target.transpose((2, 0, 1))
        target = transforms.to_tensor(target)

        image_mod = np.stack((image_mod, image_mod), axis=0)
        image_mod = transforms.to_tensor(image_mod)

        norm = attrs['norm'].astype(np.float32)
        norm = np.stack((norm, norm), axis=-1)
        norm = transforms.to_tensor(norm)

        mask = mask.expand(kspace.shape)
        mask = mask.transpose(0, 2).transpose(1, 2)
        mask = transforms.ifftshift(mask)

        masked_kspace = masked_kspace.transpose(0, 2).transpose(1, 2)

        return image, target
Esempio n. 26
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        kspace = transforms.to_tensor(kspace)
        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        if self.use_mask:
            mask = transforms.get_mask(kspace, self.mask_func, seed)
            masked_kspace = mask * kspace
        else:
            masked_kspace = kspace

        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)

        # Normalize input
        if self.normalize:
            image, mean, std = transforms.normalize_instance(image, eps=1e-11)
            if CLAMP:
                image = image.clamp(-6, 6)
        else:
            mean = -1.0
            std = -1.0

        # Normalize target
        if target is not None:
            target = transforms.to_tensor(target)
            target_train = target
            if self.normalize:
                target_train = transforms.normalize(target,
                                                    mean,
                                                    std,
                                                    eps=1e-11)
                if CLAMP:
                    target_train = target_train.clamp(
                        -6, 6
                    )  # Return target (for viz) and target_clamped (for training)
            norm = attrs['norm'].astype(np.float32)
        else:
            target_train = []
            target = []
            norm = -1.0
        image_updated = []
        if os.path.exists(
                '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_train/'
                + fname):
            updated_fname = '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_train/' + fname
            with h5py.File(updated_fname, 'r') as data:
                image_updated = data['reconstruction'][slice]
                image_updated = transforms.to_tensor(image_updated)
        elif os.path.exists(
                '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_val/'
                + fname):
            updated_fname = '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_val/' + fname
            with h5py.File(updated_fname, 'r') as data:
                image_updated = data['reconstruction'][slice]
                image_updated = transforms.to_tensor(image_updated)
        elif os.path.exists(
                '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_test/'
                + fname):
            updated_fname = '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_test/' + fname
            with h5py.File(updated_fname, 'r') as data:
                image_updated = data['reconstruction'][slice]
                image_updated = transforms.to_tensor(image_updated)

        return image, target_train, mean, std, norm, target, image_updated
Esempio n. 27
0
def croppedimage(kspace, resolution):
    image = transforms.ifft2(kspace)
    image = transforms.complex_center_crop(image, (resolution, resolution))
    return image
Esempio n. 28
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        kspace_rect = transforms.to_tensor(kspace)  ##rectangular kspace

        image_rect = transforms.ifft2(kspace_rect)  ##rectangular FS image
        image_square = transforms.complex_center_crop(
            image_rect,
            (self.resolution, self.resolution))  ##cropped to FS square image
        kspace_square = transforms.fft2(image_square)  ##kspace of square iamge

        if self.augmentation:
            kspace_square = self.augmentation.apply(kspace_square)
            image_square = transforms.ifft2(kspace_square)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace_square, mask = transforms.apply_mask(
            kspace_square, self.mask_func, seed)  ##ZF square kspace

        # Inverse Fourier Transform to get zero filled solution
        # image = transforms.ifft2(masked_kspace)
        image_square_us = transforms.ifft2(
            masked_kspace_square)  ## US square complex image

        # Crop input image
        # image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        # image = transforms.complex_abs(image)
        image_square_abs = transforms.complex_abs(
            image_square_us)  ## US square real image

        # Apply Root-Sum-of-Squares if multicoil data
        # if self.which_challenge == 'multicoil':
        #     image = transforms.root_sum_of_squares(image)
        # Normalize input
        # image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        _, mean, std = transforms.normalize_instance(image_square_abs,
                                                     eps=1e-11)
        # image = image.clamp(-6, 6)

        # target = transforms.to_tensor(target)
        target = image_square.permute(2, 0, 1)
        # Normalize target
        # target = transforms.normalize(target, mean, std, eps=1e-11)
        # target = target.clamp(-6, 6)
        # return image, target, mean, std, attrs['norm'].astype(np.float32)

        # return masked_kspace_square.permute((2,0,1)), image, image_square.permute(2,0,1), mean, std, attrs['norm'].astype(np.float32)

        # ksp, zf, target, me, st, nor
        return masked_kspace_square.permute((2,0,1)), image_square_us.permute((2,0,1)), \
            target,  \
            mean, std, attrs['norm'].astype(np.float32)
Esempio n. 29
0
    def __call__(self, kspace, target, attrs, fname, slice):
        kspace_rect = transforms.to_tensor(kspace)  ##rectangular kspace

        image_rect = transforms.ifft2(kspace_rect)  ##rectangular FS image
        image_square = transforms.complex_center_crop(
            image_rect,
            (self.resolution, self.resolution))  ##cropped to FS square image

        kspace_square = self.c3object.apply(
            transforms.fft2(image_square)) * 10000  ##kspace of square iamge
        image_square2 = ifft_c3(kspace_square)  ##for training domain_transform

        if self.augmentation:
            kspace_square = self.augmentation.apply(kspace_square)

        # image_square = ifft_c3(kspace_square)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace_square, mask = transforms.apply_mask(
            kspace_square, self.mask_func, seed)  ##ZF square kspace

        # Inverse Fourier Transform to get zero filled solution
        # image = transforms.ifft2(masked_kspace)
        us_image_square = ifft_c3(
            masked_kspace_square)  ## US square complex image

        # Crop input image
        # image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        # image = transforms.complex_abs(image)
        us_image_square_abs = transforms.complex_abs(
            us_image_square)  ## US square real image
        us_image_square_rss = transforms.root_sum_of_squares(
            us_image_square_abs, dim=0)

        stacked_kspace_square = []
        for i in (range(len(kspace_square[:, 0, 0, 0]))):
            stacked_kspace_square.append(kspace_square[i, :, :, 0])
            stacked_kspace_square.append(kspace_square[i, :, :, 1])

        stacked_kspace_square = torch.stack(stacked_kspace_square)

        stacked_masked_kspace_square = []
        # masked_kspace_square = transforms.to_tensor(masked_kspace_square)
        # for i in range(len(masked_kspace_square[:,0,0,0])):
        # stacked_masked_kspace_square.stack(masked_kspace_square[i,:,:,0],masked_kspace_square[i,:,:,1])

        for i in (range(len(masked_kspace_square[:, 0, 0, 0]))):
            stacked_masked_kspace_square.append(masked_kspace_square[i, :, :,
                                                                     0])
            stacked_masked_kspace_square.append(masked_kspace_square[i, :, :,
                                                                     1])

        stacked_masked_kspace_square = torch.stack(
            stacked_masked_kspace_square)

        stacked_image_square = []
        for i in (range(len(image_square[:, 0, 0, 0]))):
            stacked_image_square.append(image_square2[i, :, :, 0])
            stacked_image_square.append(image_square2[i, :, :, 1])

        stacked_image_square = torch.stack(stacked_image_square)




        return stacked_kspace_square,stacked_masked_kspace_square , stacked_image_square , \
            us_image_square_rss ,   \
            target *10000 \
            #mean, std, attrs['norm'].astype(np.float32)
        '''
Esempio n. 30
0
def tosquare(ksp,shp):
    rec = T.ifft2(ksp)
    sz = rec.shape
    
    return c3m * T.fft2(T.complex_center_crop(rec,shp)) * 100000