Ejemplo n.º 1
0
    def __call__(self, target, fname, slice_num=0, attrs=None, seed=None):
        # Preprocess the data here
        # target shape: [H, W, 1] or [H, W, 3]
        img = target
        if target.shape[2] != 2:
            img = np.concatenate((target, np.zeros_like(target)), axis=2)
        assert img.shape[-1] == 2
        img = to_tensor(img)
        kspace = fastmri.fft2c(img)

        center_kspace, _ = apply_mask(kspace,
                                      self.mask_func,
                                      hamming=True,
                                      seed=seed)
        img_LF = fastmri.complex_abs(fastmri.ifft2c(center_kspace))
        img_LF = img_LF.unsqueeze(0)
        image, mean, std = normalize_instance(img_LF, eps=1e-11)
        image = image.clamp(-6, 6)
        # img_LF tensor should have shape [H, W, ?]
        target = to_tensor(np.transpose(target,
                                        (2, 0, 1)))  # target shape [1, H, W]
        target = normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)
        target = target.squeeze(0)
        # check for max value
        max_value = 0.0
        # print('traget shape', target.shape)
        # print('image shape', image.shape)
        return image, target, mean, std, fname, slice_num, max_value
Ejemplo n.º 2
0
    def data_consistency_kspace(self, prediction, k_space_slice, mask):
        """
        Args:
            - prediction: net (or block) predicted real image in complex domain
            - k_space_slice: initially sampled elements in k-space
            - mask: corresponding nonzero location in kspace
        Res:
            image in k space where:
                - masked entries of initial slice are replaced with entries predicted by output
                - non masked entries of initial slice stay the same
        """

        prediction = prediction[:, :, 0:320, 0:320]
        prediction = prediction.permute(
            0, 2, 3, 1)  # prediction from 1 x 2 x h x w to 1 x h x w x 2
        prediction = self.proper_padding(
            prediction, k_space_slice)  # pad prediction to be 640 x 372 x 2
        k_space_prediction = fastmri.fft2c(
            prediction)  # transform prediction to kspace domain

        k_space_out = (
            1 - mask) * k_space_prediction + mask * k_space_slice  # apply mask
        prediction = fastmri.ifft2c(k_space_out)  # back to cplx image
        prediction = transforms.complex_center_crop(
            prediction, (320, 320))  # crop image to 320 x 320
        prediction = prediction.permute(0, 3, 1, 2)  # back to 1 x 2 x h x w
        return prediction
def add_ge_oversampling(masked_kspace):
    # call all of these "masked_kspace" to preserve memory
    readout_len = masked_kspace.shape[-3]
    pad_size = (0, 0, 0, 0, readout_len // 2, readout_len // 2)
    masked_kspace = fastmri.ifft2c(masked_kspace)
    masked_kspace = F.pad(masked_kspace, pad_size)
    masked_kspace = fastmri.fft2c(masked_kspace)

    return masked_kspace
Ejemplo n.º 4
0
def test_fft2(shape):
    shape = shape + [2]
    x = create_input(shape)
    out_torch = fastmri.fft2c(x).numpy()
    out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

    input_numpy = transforms.tensor_to_complex_np(x)
    input_numpy = np.fft.ifftshift(input_numpy, (-2, -1))
    out_numpy = np.fft.fft2(input_numpy, norm="ortho")
    out_numpy = np.fft.fftshift(out_numpy, (-2, -1))

    assert np.allclose(out_torch, out_numpy)
Ejemplo n.º 5
0
    def forward(
        self,
        current_kspace: torch.Tensor,
        ref_kspace: torch.Tensor,
        mask: torch.Tensor,
    ) -> torch.Tensor:
        zero = torch.zeros(1, 1, 1, 1).to(current_kspace)
        soft_dc = torch.where(mask, current_kspace - ref_kspace,
                              zero) * self.dc_weight
        model_term = fastmri.fft2c(self.model(fastmri.ifft2c(current_kspace)))

        return current_kspace - soft_dc - model_term
Ejemplo n.º 6
0
    def forward(
        self,
        current_kspace: torch.Tensor,
        ref_kspace: torch.Tensor,
        mask: torch.Tensor,
    ) -> torch.Tensor:
        zero = torch.zeros(1, 1, 1, 1).to(current_kspace)
        soft_dc = torch.where(mask, current_kspace - ref_kspace,
                              zero) * self.dc_weight
        with torch.enable_grad():
            X = current_kspace.clone().detach().requires_grad_(True)
            histl = hist_loss(X, ref_kspace.clone().detach())
            histl.backward()
            soft_histc = X.grad * self.hist_weight

        model_term = fastmri.fft2c(self.model(fastmri.ifft2c(current_kspace)))

        return current_kspace - soft_dc - soft_histc - model_term
Ejemplo n.º 7
0
    def forward(self, mr_img, mk_space, mask):
        # Loop over num = cascades number
        for i in range(self.n_cascade):
            # The laplacian decomposition will produce different images composing the gaussian/laplacian pyramids
            # The gaussians are the downsampled images from the original one
            # The laplacian represents the 'errors' within the images related to the gaussian images
            gaussian_3, lap_1, lap_2 = self.lapl_dec(mr_img.cpu())

            # Resize along with the input/output channels
            mr_img = self.shuffle_down_4(mr_img)

            # The convBlock1 is here performed to extract shallow features
            mr_img = self.convBlock1(mr_img)

            # ---- 3 branches ----
            branch_1 = self.shuffle_up_4(mr_img)
            branch_1 = self.branch1(branch_1)

            branch_2 = self.shuffle_up_2(mr_img)
            branch_2 = self.branch2(branch_2)

            branch_3 = self.branch3(mr_img)

            branch_out1 = torch.stack((self.pixel_conv(branch_1[...,0], lap_1[...,0]),
                                   self.pixel_conv(branch_1[...,1], lap_1[...,1])), dim=-1)
            branch_out2 = torch.stack((self.pixel_conv(branch_2[...,0], lap_2[...,0]),
                                   self.pixel_conv(branch_2[...,1], lap_2[...,1])), dim=-1)
            branch_out3 = torch.stack((self.pixel_conv(branch_3[...,0], gaussian_3[...,0]),
                                   self.pixel_conv(branch_3[...,1], gaussian_3[...,1])), dim=-1)

			# Performing the 2x linear upsample in order to add the different branches correctly
            output = self.lapl_rec(branch_out3, branch_out2)
            output = self.lapl_rec(output, branch_out1)

            # The DataConsistency Layer step is done here
            mr_img = fastmri.ifft2c((1.0 - mask) * fastmri.fft2c(output) + mk_space)
        return mr_img
Ejemplo n.º 8
0
 def sens_expand(self, x: torch.Tensor,
                 sens_maps: torch.Tensor) -> torch.Tensor:
     return fastmri.fft2c(fastmri.complex_mul(x, sens_maps))
Ejemplo n.º 9
0
 def sens_expand(x):
     return fastmri.fft2c(fastmri.complex_mul(x, sens_maps))
Ejemplo n.º 10
0
    def __call__(self, kspace, mask, target, attrs, fname, slice_num):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows,
                cols, 2) for multi-coil data or (rows, cols, 2) for single coil
                data.
            mask (numpy.array): Mask from the test dataset.
            target (numpy.array): Target image.
            attrs (dict): Acquisition related information stored in the HDF5
                object.
            fname (str): File name.
            slice_num (int): Serial number of the slice.

        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch
                    Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                fname (str): File name.
                slice_num (int): Serial number of the slice.
        """
        kspace = transforms.to_tensor(kspace)


        image = fastmri.ifft2c(kspace)

        # crop input to correct size
        if target is not None:
            crop_size = (target.shape[-2], target.shape[-1])
        else:
            crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])

        # check for sFLAIR 203
        if image.shape[-2] < crop_size[1]:
            crop_size = (image.shape[-2], image.shape[-2])
   
        image = transforms.complex_center_crop(image, crop_size)

        #getLR
        imgfft = fastmri.fft2c(image)
        imgfft = transforms.complex_center_crop(imgfft,(160,160))
        LR_image = fastmri.ifft2c(imgfft)

        # absolute value
        LR_image = fastmri.complex_abs(LR_image)

        # normalize input
        LR_image, mean, std = transforms.normalize_instance(LR_image, eps=1e-11)
        LR_image = LR_image.clamp(-6, 6)

        # normalize target
        if target is not None:
            target = transforms.to_tensor(target)
            target = transforms.center_crop(target, crop_size)
            target = transforms.normalize(target, mean, std, eps=1e-11)
            target = target.clamp(-6, 6)
        else:
            target = torch.Tensor([0])

        return LR_image, target, mean, std, fname, slice_num