def to_cropped_image(masked_kspace, target, attrs):
    # inverse Fourier transform to get zero filled solution
    image = fastmri.ifft2c(masked_kspace)

    # crop input to correct size
    if target is not None:
        crop_size = (target.shape[-2], target.shape[-1])
    else:
        crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])

    # check for FLAIR 203
    if image.shape[-2] < crop_size[1]:
        crop_size = (image.shape[-2], image.shape[-2])

    image = T.complex_center_crop(image, crop_size)

    # absolute value
    image = fastmri.complex_abs(image)

    # normalize input
    image, mean, std = T.normalize_instance(image, eps=1e-11)
    image = image.clamp(-6, 6)

    # normalize target
    if target is not None:
        if isinstance(target, np.ndarray):
            target = T.to_tensor(target)
        target = T.center_crop(target, crop_size)
        target = T.normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)
    else:
        target = torch.Tensor([0])

    return image, target
예제 #2
0
def load_data_from_pathlist(path):
    file_num = len(path)
    use_num = file_num // 3
    total_target_list = []
    total_sampled_image_list = []
    for h5_num in range(use_num):
        total_kspace, slices_num, target = load_dataset(path[h5_num])
        image_list = []
        slice_kspace_tensor_list = []
        target_image_list = []
        for i in range(slices_num):
            slice_kspace = total_kspace[i]
            #target_image = target[i]
            slice_kspace_tensor = T.to_tensor(
                slice_kspace)  # convert numpy to tensor
            slice_kspace_tensor = slice_kspace_tensor.float()
            #print(slice_kspace_tensor.shape)
            slice_kspace_tensor_list.append(
                slice_kspace_tensor)  # 35* torch[640, 368])
            #target = target_image_list.append(target_image)

        #image_list_tensor = torch.stack(image_list, dim=0)  # torch.Size([35, 640, 368])
        #total_image_list.append(image_list_tensor)
        mask_func = RandomMaskFunc(
            center_fractions=[0.08],
            accelerations=[4])  # create the mask function object
        sampled_image_list = []
        target_list = []
        for i in range(slices_num):
            slice_kspace_tensor = slice_kspace_tensor_list[i]
            masked_kspace, mask = T.apply_mask(slice_kspace_tensor, mask_func)
            Ny, Nx, _ = slice_kspace_tensor.shape
            mask = mask.repeat(Ny, 1, 1).squeeze()
            # functions.show_slice(mask, cmap='gray')
            # functions.show_slice(image_list[10], cmap='gray')
            sampled_image = fastmri.ifft2c(
                masked_kspace)  # inverse fast FT to get the complex image
            sampled_image = T.complex_center_crop(sampled_image, (320, 320))
            sampled_image_abs = fastmri.complex_abs(sampled_image)
            sampled_image_list.append(sampled_image_abs)
        sampled_image_list_tensor = torch.stack(
            sampled_image_list, dim=0)  # torch.Size([35, 640, 368])
        total_sampled_image_list.append(sampled_image_list_tensor)
        target = T.to_tensor(target)
        total_target_list.append(target)
    #target_image_tensor = torch.cat(target_image_list, dim=0)                       # torch.Size([6965, 640, 368])
    total_target = torch.cat(total_target_list, dim=0)
    total_sampled_image_tensor = torch.cat(
        total_sampled_image_list, dim=0)  # torch.Size([6965, 640, 368])
    total_sampled_image_tensor, mean, std = T.normalize_instance(
        total_sampled_image_tensor, eps=1e-11)
    total_sampled_image_tensor = total_sampled_image_tensor.clamp(-6, 6)
    target_image_tensor = T.normalize(total_target, mean, std, eps=1e-11)
    target_image_tensor = target_image_tensor.clamp(-6, 6)
    # total_image_tensor = torch.stack(total_image_list, dim=0)  # torch.Size([199, 35, 640, 368])
    # total_sampled_image_tensor = torch.stack(total_sampled_image_list, dim=0)  # torch.Size([199, 35, 640, 368])
    #print(target_image_tensor.shape)
    #print(total_sampled_image_tensor.shape)
    return target_image_tensor, total_sampled_image_tensor
예제 #3
0
def test_normalize_instance(shape):
    input = create_input(shape)
    output, mean, stddev = transforms.normalize_instance(input)
    output = output.numpy()
    assert np.isclose(input.numpy().mean(), mean, rtol=1e-2)
    assert np.isclose(input.numpy().std(), stddev, rtol=1e-2)
    assert np.isclose(output.mean(), 0, rtol=1e-2, atol=1e-3)
    assert np.isclose(output.std(), 1, rtol=1e-2, atol=1e-3)
예제 #4
0
def _base_fastmri_unet_transform(
    kspace,
    mask,
    ground_truth,
    attrs,
    which_challenge="singlecoil",
):
    kspace = fastmri_transforms.to_tensor(kspace)

    mask = mask[..., :kspace.shape[-2]]  # accounting for variable size masks
    masked_kspace = kspace * mask.unsqueeze(-1) + 0.0

    # inverse Fourier transform to get zero filled solution
    image = fastmri.ifft2c(masked_kspace)

    # crop input to correct size
    if ground_truth is not None:
        crop_size = (ground_truth.shape[-2], ground_truth.shape[-1])
    else:
        crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])

    # check for FLAIR 203
    if image.shape[-2] < crop_size[1]:
        crop_size = (image.shape[-2], image.shape[-2])

    # noinspection PyTypeChecker
    image = fastmri_transforms.complex_center_crop(image, crop_size)

    # absolute value
    image = fastmri.complex_abs(image)

    # apply Root-Sum-of-Squares if multicoil data
    if which_challenge == "multicoil":
        image = fastmri.rss(image)

    # normalize input
    image, mean, std = fastmri_transforms.normalize_instance(image, eps=1e-11)
    image = image.clamp(-6, 6)

    return image.unsqueeze(0), mean, std
def visualize_kspace(kspace, dim=None, crop=False, output_filepath=None):
    kspace = fastmri.ifft2c(kspace)
    if crop:
        crop_size = (kspace.shape[-2], kspace.shape[-2])
        kspace = T.complex_center_crop(kspace, crop_size)
        kspace = fastmri.complex_abs(kspace)
        kspace, _, _ = T.normalize_instance(kspace, eps=1e-11)
        kspace = kspace.clamp(-6, 6)
    else:
        # Compute absolute value to get a real image
        kspace = fastmri.complex_abs(kspace)
    if dim is not None:
        kspace = fastmri.rss(kspace, dim=dim)
    img = np.abs(kspace.numpy())
    if output_filepath is not None:
        if not output_filepath.parent.exists():
            output_filepath.parent.mkdir(parents=True)
        plt.imshow(img, cmap='gray')
        plt.axis("off")
        plt.savefig(output_filepath, bbox_inches="tight", pad_inches=0)
    else:
        plt.imshow(img, cmap='gray')
        plt.show()
예제 #6
0
    def __call__(self, kspace, mask, target, attrs, fname, slice_num):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows,
                cols, 2) for multi-coil data or (rows, cols, 2) for single coil
                data.
            mask (numpy.array): Mask from the test dataset.
            target (numpy.array): Target image.
            attrs (dict): Acquisition related information stored in the HDF5
                object.
            fname (str): File name.
            slice_num (int): Serial number of the slice.

        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch
                    Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                fname (str): File name.
                slice_num (int): Serial number of the slice.
        """
        kspace = transforms.to_tensor(kspace)


        image = fastmri.ifft2c(kspace)

        # crop input to correct size
        if target is not None:
            crop_size = (target.shape[-2], target.shape[-1])
        else:
            crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])

        # check for sFLAIR 203
        if image.shape[-2] < crop_size[1]:
            crop_size = (image.shape[-2], image.shape[-2])
   
        image = transforms.complex_center_crop(image, crop_size)

        #getLR
        imgfft = fastmri.fft2c(image)
        imgfft = transforms.complex_center_crop(imgfft,(160,160))
        LR_image = fastmri.ifft2c(imgfft)

        # absolute value
        LR_image = fastmri.complex_abs(LR_image)

        # normalize input
        LR_image, mean, std = transforms.normalize_instance(LR_image, eps=1e-11)
        LR_image = LR_image.clamp(-6, 6)

        # normalize target
        if target is not None:
            target = transforms.to_tensor(target)
            target = transforms.center_crop(target, crop_size)
            target = transforms.normalize(target, mean, std, eps=1e-11)
            target = target.clamp(-6, 6)
        else:
            target = torch.Tensor([0])

        return LR_image, target, mean, std, fname, slice_num
예제 #7
0
    def __call__(self, kspace, mask, target, attrs, fname, slice_num):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows,
                cols, 2) for multi-coil data or (rows, cols, 2) for single coil
                data.
            mask (numpy.array): Mask from the test dataset.
            target (numpy.array): Target image.
            attrs (dict): Acquisition related information stored in the HDF5
                object.
            fname (str): File name.
            slice_num (int): Serial number of the slice.

        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch
                    Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                fname (str): File name.
                slice_num (int): Serial number of the slice.
        """
        kspace = transforms.to_tensor(kspace)

        # apply mask
        if self.mask_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            masked_kspace, mask = transforms.apply_mask(
                kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace

        # inverse Fourier transform to get zero filled solution
        image = fastmri.ifft2c(masked_kspace)

        # crop input to correct size
        if target is not None:
            crop_size = (target.shape[-2], target.shape[-1])
        else:
            crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])

        # check for FLAIR 203
        if image.shape[-2] < crop_size[1]:
            crop_size = (image.shape[-2], image.shape[-2])

        image = transforms.complex_center_crop(image, crop_size)

        # absolute value
        image = fastmri.complex_abs(image)

        # apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == "multicoil":
            image = fastmri.rss(image)

        # normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        # normalize target
        if target is not None:
            target = transforms.to_tensor(target)
            target = transforms.center_crop(target, crop_size)
            target = transforms.normalize(target, mean, std, eps=1e-11)
            target = target.clamp(-6, 6)
        else:
            target = torch.Tensor([0])

        return image, target, mean, std, fname, slice_num
    def __call__(
        self,
        kspace: np.ndarray,
        mask: np.ndarray,
        target: np.ndarray,
        attrs: Dict,
        fname: str,
        slice_num: int,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str,
               int, float]:
        """
        Args:
            kspace: Input k-space of shape (num_coils, rows, cols) for
                multi-coil data or (rows, cols) for single coil data.
            mask: Mask from the test dataset.
            target: Target image.
            attrs: Acquisition related information stored in the HDF5 object.
            fname: File name.
            slice_num: Serial number of the slice.

        Returns:
            tuple containing:
                image: Zero-filled input image.
                target: Target image converted to a torch.Tensor.
                mean: Mean value used for normalization.
                std: Standard deviation value used for normalization.
                fname: File name.
                slice_num: Serial number of the slice.
        """
        kspace = T.to_tensor(kspace)

        # check for max value
        max_value = attrs["max"] if "max" in attrs.keys() else 0.0

        # apply mask
        if self.mask_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            masked_kspace, mask = T.apply_mask(kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace

        # inverse Fourier transform to get zero filled solution
        image = fastmri.ifft2c(masked_kspace)

        if not self.test_mode:
            # crop input to correct size
            if target is not None:
                crop_size = (target.shape[-2], target.shape[-1])
            else:
                crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])

        # check for FLAIR 203
        if self.test_mode or image.shape[-2] < crop_size[1]:
            crop_size = (image.shape[-2], image.shape[-2])

        image = T.complex_center_crop(image, crop_size)

        # absolute value
        image = fastmri.complex_abs(image)

        # apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == "multicoil":
            image = fastmri.rss(image)

        # normalize input
        image, mean, std = T.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        # normalize target
        if not self.test_mode and target is not None:
            target = T.to_tensor(target)
            target = T.center_crop(target, crop_size)
            target = T.normalize(target, mean, std, eps=1e-11)
            target = target.clamp(-6, 6)
        else:
            target = torch.Tensor([0])

        return image, target, mean, std, fname, slice_num, max_value