Beispiel #1
0
    def training_step(self, batch, batch_idx):
        subsampled_kspace, _, _, _, _, _, mask1 = batch
        mask2 = -(mask1-1)

        subsampled_kspace1 = subsampled_kspace * mask1 + 0.0
        subsampled_kspace2 = subsampled_kspace * mask2 + 0.0

        image1 = fastmri.ifft2c(subsampled_kspace1)
        image1 = fastmri.complex_abs(image1)
        image2 = fastmri.ifft2c(subsampled_kspace2)
        image2 = fastmri.complex_abs(image2)

        image = torch.vstack((image1, image2))

        output_image = self(image)

        output_image1 = output_image[0,:,:]
        output_image2 = output_image[1,:,:]

        output_kspace1 = torch.fft.fft2(output_image1)
        output_kspace1 = torch.stack((output_kspace1.real, output_kspace1.imag), axis=-1)
        output_kspace1 = output_kspace1 * mask2 + 0.0
        output_kspace2 = torch.fft.fft2(output_image2)
        output_kspace2 = torch.stack((output_kspace2.real, output_kspace2.imag), axis=-1)
        output_kspace2 = output_kspace2 * mask1 + 0.0

        loss = l1_l2_loss(output_kspace1, subsampled_kspace2) \
            + l1_l2_loss(output_kspace2, subsampled_kspace1)

        self.log("loss", loss.detach())

        return loss
Beispiel #2
0
    def forward(self, kspace_pred: torch.Tensor, ref_kspace: torch.Tensor):
        """
            Compute data consistency loss in kspace and 
            total variation loss in image space.
    
            Inputs:
            - kspace_pred: PyTorch tensor of shape (N, H, W, 2) holding predicted kspace.
            - ref_kspace : input masked kspace 
    
            Returns:
            - loss: PyTorch Variable holding a scalar giving the total variation loss
              for img.
        """

        output = fastmri.complex_abs(fastmri.ifft2c(kspace_pred))
        gt = fastmri.complex_abs(fastmri.ifft2c(ref_kspace))
        # energy_loss = torch.abs(torch.sum(output) - torch.sum(gt))

        hdiff_pred = (output[:, :, :-1] - output[:, :, 1:]).view(-1)
        hdiff_gt = (gt[:, :, :-1] - gt[:, :, 1:]).view(-1)

        vdiff_pred = (output[:, :-1, :] - output[:, 1:, :]).view(-1)
        vdiff_gt = (gt[:, :-1, :] - gt[:, 1:, :]).view(-1)

        hdiff_var_loss = torch.sqrt(
            torch.var(hdiff_pred)) - 1.25 * torch.sqrt(torch.var(hdiff_gt))
        vdiff_var_loss = torch.sqrt(
            torch.var(vdiff_pred)) - 1.25 * torch.sqrt(torch.var(vdiff_gt))

        tv_loss = (torch.abs(hdiff_var_loss) + torch.abs(vdiff_var_loss))
        return self.tv_weight * tv_loss
Beispiel #3
0
def hist_loss(current_kspace: torch.Tensor,
              masked_kspace: torch.Tensor,
              bins: int = 5):
    """    
        Inputs:
        - kspace_pred: PyTorch tensor of shape (N, H, W, 2) holding predicted kspace.
        - ref_kspace : input masked kspace 
        - mask: the subsampling mask
    
        Returns:
        - loss: PyTorch Variable holding a scalar giving the total variation loss
          for img.
    """
    output = fastmri.complex_abs(fastmri.ifft2c(current_kspace))
    gt = fastmri.complex_abs(fastmri.ifft2c(masked_kspace))

    hdiff_pred = (output[:, :, :-1] - output[:, :, 1:]).view(-1)
    hdiff_gt = (gt[:, :, :-1] - gt[:, :, 1:]).view(-1)
    hmin_pred, hmax_pred = hdiff_pred.min().item(), hdiff_pred.max().item()
    hmin_gt, hmax_gt = hdiff_gt.min().item(), hdiff_gt.max().item()
    hist_x = differentiable_histogram(hdiff_pred,
                                      bins=bins,
                                      min=hmin_pred,
                                      max=hmax_pred)
    hist_y = differentiable_histogram(hdiff_gt,
                                      bins=bins,
                                      min=hmin_gt,
                                      max=hmax_gt)
    hdiff_hist_loss = (hist_x - hist_y) / len(hdiff_pred)
    hdiff_hist_loss = torch.norm(hdiff_hist_loss)

    vdiff_pred = (output[:, :-1, :] - output[:, 1:, :]).view(-1)
    vdiff_gt = (gt[:, :-1, :] - gt[:, 1:, :]).view(-1)
    vmin_pred, vmax_pred = vdiff_pred.min().item(), vdiff_pred.max().item()
    vmin_gt, vmax_gt = vdiff_gt.min().item(), vdiff_gt.max().item()
    hist_x = differentiable_histogram(vdiff_pred,
                                      bins=bins,
                                      min=vmin_pred,
                                      max=vmax_pred)
    hist_y = differentiable_histogram(vdiff_gt,
                                      bins=bins,
                                      min=vmin_gt,
                                      max=vmax_gt)
    vdiff_hist_loss = (hist_x - hist_y) / len(vdiff_pred)
    vdiff_hist_loss = torch.norm(vdiff_hist_loss)

    output = output.view(-1)
    gt = gt.view(-1)
    gt_min, gt_max = gt.min().item(), gt.max().item()
    hist_x = differentiable_histogram(output,
                                      bins=bins,
                                      min=gt_min,
                                      max=gt_max)
    hist_y = differentiable_histogram(gt, bins=bins, min=gt_min, max=gt_max)
    gt_hist_loss = (hist_x - hist_y) / len(output)
    gt_hist_loss = torch.norm(gt_hist_loss)

    return (hdiff_hist_loss + vdiff_hist_loss + gt_hist_loss)
Beispiel #4
0
def load_data(file_dir_path):
    file_path = get_files(file_dir_path)
    file_num = len(file_path)
    total_image_list = []
    total_sampled_image_list = []
    for h5_num in range(file_num):
        total_kspace, slices_num = load_dataset(file_path[0])
        image_list = []
        slice_kspace_tensor_list = []
        for i in range(slices_num):
            slice_kspace = total_kspace[i]
            slice_kspace_tensor = T.to_tensor(
                slice_kspace)  # convert numpy to tensor
            slice_image = fastmri.ifft2c(
                slice_kspace_tensor)  # inverse fast FT
            slice_image_abs = fastmri.complex_abs(
                slice_image)  # compute the absolute value to get a real image
            image_list.append(slice_image_abs)
            slice_kspace_tensor_list.append(
                slice_kspace_tensor)  # 35* torch[640, 368])

        image_list_tensor = torch.stack(image_list,
                                        dim=0)  # torch.Size([35, 640, 368])
        total_image_list.append(image_list_tensor)
        mask_func = RandomMaskFunc(
            center_fractions=[0.08],
            accelerations=[4])  # create the mask function object
        sampled_image_list = []
        for i in range(slices_num):
            slice_kspace_tensor = slice_kspace_tensor_list[i]
            masked_kspace, mask = T.apply_mask(slice_kspace_tensor, mask_func)
            Ny, Nx, _ = slice_kspace_tensor.shape
            mask = mask.repeat(Ny, 1, 1).squeeze()
            # functions.show_slice(mask, cmap='gray')
            # functions.show_slice(image_list[10], cmap='gray')
            sampled_image = fastmri.ifft2c(
                masked_kspace)  # inverse fast FT to get the complex image
            sampled_image_abs = fastmri.complex_abs(sampled_image)
            sampled_image_list.append(sampled_image_abs)
        sampled_image_list_tensor = torch.stack(
            sampled_image_list, dim=0)  # torch.Size([35, 640, 368])
        total_sampled_image_list.append(sampled_image_list_tensor)
    # total_image_tensor = torch.cat(total_image_list, dim=0)                       # torch.Size([6965, 640, 368])
    # total_sampled_image_tensor = torch.cat(total_sampled_image_list, dim=0)       # torch.Size([6965, 640, 368])
    total_image_tensor = torch.stack(total_image_list,
                                     dim=0)  # torch.Size([199, 35, 640, 368])
    total_sampled_image_tensor = torch.stack(
        total_sampled_image_list, dim=0)  # torch.Size([199, 35, 640, 368])
    print(total_image_tensor.shape)
    print(total_sampled_image_tensor.shape)
    return total_image_tensor, total_sampled_image_tensor
def to_cropped_image(masked_kspace, target, attrs):
    # inverse Fourier transform to get zero filled solution
    image = fastmri.ifft2c(masked_kspace)

    # crop input to correct size
    if target is not None:
        crop_size = (target.shape[-2], target.shape[-1])
    else:
        crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])

    # check for FLAIR 203
    if image.shape[-2] < crop_size[1]:
        crop_size = (image.shape[-2], image.shape[-2])

    image = T.complex_center_crop(image, crop_size)

    # absolute value
    image = fastmri.complex_abs(image)

    # normalize input
    image, mean, std = T.normalize_instance(image, eps=1e-11)
    image = image.clamp(-6, 6)

    # normalize target
    if target is not None:
        if isinstance(target, np.ndarray):
            target = T.to_tensor(target)
        target = T.center_crop(target, crop_size)
        target = T.normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)
    else:
        target = torch.Tensor([0])

    return image, target
Beispiel #6
0
    def data_consistency_kspace(self, prediction, k_space_slice, mask):
        """
        Args:
            - prediction: net (or block) predicted real image in complex domain
            - k_space_slice: initially sampled elements in k-space
            - mask: corresponding nonzero location in kspace
        Res:
            image in k space where:
                - masked entries of initial slice are replaced with entries predicted by output
                - non masked entries of initial slice stay the same
        """

        prediction = prediction[:, :, 0:320, 0:320]
        prediction = prediction.permute(
            0, 2, 3, 1)  # prediction from 1 x 2 x h x w to 1 x h x w x 2
        prediction = self.proper_padding(
            prediction, k_space_slice)  # pad prediction to be 640 x 372 x 2
        k_space_prediction = fastmri.fft2c(
            prediction)  # transform prediction to kspace domain

        k_space_out = (
            1 - mask) * k_space_prediction + mask * k_space_slice  # apply mask
        prediction = fastmri.ifft2c(k_space_out)  # back to cplx image
        prediction = transforms.complex_center_crop(
            prediction, (320, 320))  # crop image to 320 x 320
        prediction = prediction.permute(0, 3, 1, 2)  # back to 1 x 2 x h x w
        return prediction
Beispiel #7
0
def save_zero_filled(data_dir, out_dir, which_challenge):
    reconstructions = {}

    for f in data_dir.iterdir():
        with h5py.File(f, "r") as hf:
            enc = ismrmrd.xsd.CreateFromDocument(hf["ismrmrd_header"][()]).encoding[0]
            masked_kspace = transforms.to_tensor(hf["kspace"][()])

            # extract target image width, height from ismrmrd header
            crop_size = (enc.reconSpace.matrixSize.x, enc.reconSpace.matrixSize.y)

            # inverse Fourier Transform to get zero filled solution
            image = fastmri.ifft2c(masked_kspace)

            # check for FLAIR 203
            if image.shape[-2] < crop_size[1]:
                crop_size = (image.shape[-2], image.shape[-2])

            # crop input image
            image = transforms.complex_center_crop(image, crop_size)

            # absolute value
            image = fastmri.complex_abs(image)

            # apply Root-Sum-of-Squares if multicoil data
            if which_challenge == "multicoil":
                image = fastmri.rss(image, dim=1)

            reconstructions[f.name] = image

    fastmri.save_reconstructions(reconstructions, out_dir)
def save_zero_filled(data_dir, out_dir, which_challenge):
    reconstructions = {}

    for fname in tqdm(list(data_dir.glob("*.h5"))):
        with h5py.File(fname, "r") as hf:
            et_root = etree.fromstring(hf["ismrmrd_header"][()])
            masked_kspace = transforms.to_tensor(hf["kspace"][()])

            # extract target image width, height from ismrmrd header
            enc = ["encoding", "encodedSpace", "matrixSize"]
            crop_size = (
                int(et_query(et_root, enc + ["x"])),
                int(et_query(et_root, enc + ["y"])),
            )

            # inverse Fourier Transform to get zero filled solution
            image = fastmri.ifft2c(masked_kspace)

            # check for FLAIR 203
            if image.shape[-2] < crop_size[1]:
                crop_size = (image.shape[-2], image.shape[-2])

            # crop input image
            image = transforms.complex_center_crop(image, crop_size)

            # absolute value
            image = fastmri.complex_abs(image)

            # apply Root-Sum-of-Squares if multicoil data
            if which_challenge == "multicoil":
                image = fastmri.rss(image, dim=1)

            reconstructions[fname.name] = image

    fastmri.save_reconstructions(reconstructions, out_dir)
Beispiel #9
0
    def __call__(self, target, fname, slice_num=0, attrs=None, seed=None):
        # Preprocess the data here
        # target shape: [H, W, 1] or [H, W, 3]
        img = target
        if target.shape[2] != 2:
            img = np.concatenate((target, np.zeros_like(target)), axis=2)
        assert img.shape[-1] == 2
        img = to_tensor(img)
        kspace = fastmri.fft2c(img)

        center_kspace, _ = apply_mask(kspace,
                                      self.mask_func,
                                      hamming=True,
                                      seed=seed)
        img_LF = fastmri.complex_abs(fastmri.ifft2c(center_kspace))
        img_LF = img_LF.unsqueeze(0)
        image, mean, std = normalize_instance(img_LF, eps=1e-11)
        image = image.clamp(-6, 6)
        # img_LF tensor should have shape [H, W, ?]
        target = to_tensor(np.transpose(target,
                                        (2, 0, 1)))  # target shape [1, H, W]
        target = normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)
        target = target.squeeze(0)
        # check for max value
        max_value = 0.0
        # print('traget shape', target.shape)
        # print('image shape', image.shape)
        return image, target, mean, std, fname, slice_num, max_value
Beispiel #10
0
def load_data_from_pathlist(path):
    file_num = len(path)
    use_num = file_num // 3
    total_target_list = []
    total_sampled_image_list = []
    for h5_num in range(use_num):
        total_kspace, slices_num, target = load_dataset(path[h5_num])
        image_list = []
        slice_kspace_tensor_list = []
        target_image_list = []
        for i in range(slices_num):
            slice_kspace = total_kspace[i]
            #target_image = target[i]
            slice_kspace_tensor = T.to_tensor(
                slice_kspace)  # convert numpy to tensor
            slice_kspace_tensor = slice_kspace_tensor.float()
            #print(slice_kspace_tensor.shape)
            slice_kspace_tensor_list.append(
                slice_kspace_tensor)  # 35* torch[640, 368])
            #target = target_image_list.append(target_image)

        #image_list_tensor = torch.stack(image_list, dim=0)  # torch.Size([35, 640, 368])
        #total_image_list.append(image_list_tensor)
        mask_func = RandomMaskFunc(
            center_fractions=[0.08],
            accelerations=[4])  # create the mask function object
        sampled_image_list = []
        target_list = []
        for i in range(slices_num):
            slice_kspace_tensor = slice_kspace_tensor_list[i]
            masked_kspace, mask = T.apply_mask(slice_kspace_tensor, mask_func)
            Ny, Nx, _ = slice_kspace_tensor.shape
            mask = mask.repeat(Ny, 1, 1).squeeze()
            # functions.show_slice(mask, cmap='gray')
            # functions.show_slice(image_list[10], cmap='gray')
            sampled_image = fastmri.ifft2c(
                masked_kspace)  # inverse fast FT to get the complex image
            sampled_image = T.complex_center_crop(sampled_image, (320, 320))
            sampled_image_abs = fastmri.complex_abs(sampled_image)
            sampled_image_list.append(sampled_image_abs)
        sampled_image_list_tensor = torch.stack(
            sampled_image_list, dim=0)  # torch.Size([35, 640, 368])
        total_sampled_image_list.append(sampled_image_list_tensor)
        target = T.to_tensor(target)
        total_target_list.append(target)
    #target_image_tensor = torch.cat(target_image_list, dim=0)                       # torch.Size([6965, 640, 368])
    total_target = torch.cat(total_target_list, dim=0)
    total_sampled_image_tensor = torch.cat(
        total_sampled_image_list, dim=0)  # torch.Size([6965, 640, 368])
    total_sampled_image_tensor, mean, std = T.normalize_instance(
        total_sampled_image_tensor, eps=1e-11)
    total_sampled_image_tensor = total_sampled_image_tensor.clamp(-6, 6)
    target_image_tensor = T.normalize(total_target, mean, std, eps=1e-11)
    target_image_tensor = target_image_tensor.clamp(-6, 6)
    # total_image_tensor = torch.stack(total_image_list, dim=0)  # torch.Size([199, 35, 640, 368])
    # total_sampled_image_tensor = torch.stack(total_sampled_image_list, dim=0)  # torch.Size([199, 35, 640, 368])
    #print(target_image_tensor.shape)
    #print(total_sampled_image_tensor.shape)
    return target_image_tensor, total_sampled_image_tensor
Beispiel #11
0
    def forward(self, masked_kspace, mask):
        sens_maps = self.sens_net(masked_kspace, mask)
        kspace_pred = masked_kspace.clone()

        for cascade in self.cascades:
            kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps)

        return fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(kspace_pred)),
                           dim=1)
def add_ge_oversampling(masked_kspace):
    # call all of these "masked_kspace" to preserve memory
    readout_len = masked_kspace.shape[-3]
    pad_size = (0, 0, 0, 0, readout_len // 2, readout_len // 2)
    masked_kspace = fastmri.ifft2c(masked_kspace)
    masked_kspace = F.pad(masked_kspace, pad_size)
    masked_kspace = fastmri.fft2c(masked_kspace)

    return masked_kspace
Beispiel #13
0
def test_ifft2(shape):
    shape = shape + [2]
    x = create_input(shape)
    out_torch = fastmri.ifft2c(x).numpy()
    out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

    input_numpy = transforms.tensor_to_complex_np(x)
    input_numpy = np.fft.ifftshift(input_numpy, (-2, -1))
    out_numpy = np.fft.ifft2(input_numpy, norm="ortho")
    out_numpy = np.fft.fftshift(out_numpy, (-2, -1))

    assert np.allclose(out_torch, out_numpy)
    def forward(
        self,
        current_kspace: torch.Tensor,
        ref_kspace: torch.Tensor,
        mask: torch.Tensor,
    ) -> torch.Tensor:
        zero = torch.zeros(1, 1, 1, 1).to(current_kspace)
        soft_dc = torch.where(mask, current_kspace - ref_kspace,
                              zero) * self.dc_weight
        model_term = fastmri.fft2c(self.model(fastmri.ifft2c(current_kspace)))

        return current_kspace - soft_dc - model_term
Beispiel #15
0
    def forward(
        self,
        masked_kspace: torch.Tensor,
        mask: torch.Tensor,
        num_low_frequencies: Optional[int] = None,
    ) -> torch.Tensor:
        sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies)
        kspace_pred = masked_kspace.clone()

        for cascade in self.cascades:
            kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps)

        return fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(kspace_pred)),
                           dim=1)
Beispiel #16
0
    def test_step(self, batch, batch_idx):
        masked_kspace, _, fname, slice_num, _, attrs, _ = batch = batch

        # image, _ = to_cropped_image(masked_kspace, None, attrs)

        image = fastmri.ifft2c(masked_kspace)
        image = fastmri.complex_abs(image)

        output = self.forward(image)

        return {
            "fname": fname,
            "slice": slice_num,
            "output": output.cpu().numpy(),
        }
Beispiel #17
0
    def forward(self, masked_kspace: torch.Tensor,
                mask: torch.Tensor) -> torch.Tensor:
        if self.mask_center:
            pad, num_low_freqs = self.get_pad_and_num_low_freqs(
                mask, self.num_sense_lines)
            masked_kspace = transforms.batched_mask_center(
                masked_kspace, pad, pad + num_low_freqs)

        # convert to image space
        images, batches = self.chans_to_batch_dim(
            fastmri.ifft2c(masked_kspace))

        # estimate sensitivities
        return self.divide_root_sum_of_squares(
            self.batch_chans_to_chan_dim(self.norm_unet(images), batches))
Beispiel #18
0
    def validation_step(self, batch, batch_idx):
        masked_kspace, mask, target, fname, slice_num, max_value, _ = batch

        kspace_pred = self(masked_kspace, mask)
        output = fastmri.complex_abs(fastmri.ifft2c(kspace_pred))
        target, output = transforms.center_crop_to_smallest(target, output)

        return {
            "batch_idx": batch_idx,
            "fname": fname,
            "slice_num": slice_num,
            "max_value": max_value,
            "output": output,
            "target": target,
            "val_loss": self.loss(kspace_pred, masked_kspace),
        }
Beispiel #19
0
    def forward(
        self,
        current_kspace: torch.Tensor,
        ref_kspace: torch.Tensor,
        mask: torch.Tensor,
    ) -> torch.Tensor:
        zero = torch.zeros(1, 1, 1, 1).to(current_kspace)
        soft_dc = torch.where(mask, current_kspace - ref_kspace,
                              zero) * self.dc_weight
        with torch.enable_grad():
            X = current_kspace.clone().detach().requires_grad_(True)
            histl = hist_loss(X, ref_kspace.clone().detach())
            histl.backward()
            soft_histc = X.grad * self.hist_weight

        model_term = fastmri.fft2c(self.model(fastmri.ifft2c(current_kspace)))

        return current_kspace - soft_dc - soft_histc - model_term
Beispiel #20
0
    def test_step(self, batch, batch_idx):
        masked_kspace, mask, _, fname, slice_num, _, crop_size = batch
        crop_size = crop_size[0]  # always have a batch size of 1 for varnet

        kspace_pred = self(masked_kspace, mask)
        output = fastmri.complex_abs(fastmri.ifft2c(kspace_pred))

        # check for FLAIR 203
        if output.shape[-1] < crop_size[1]:
            crop_size = (output.shape[-1], output.shape[-1])

        output = transforms.center_crop(output, crop_size)

        return {
            "fname": fname,
            "slice": slice_num,
            "output": output.cpu().numpy(),
        }
Beispiel #21
0
    def forward(self, masked_kspace: torch.Tensor,
                mask: torch.Tensor) -> torch.Tensor:
        # get low frequency line locations and mask them out
        cent = mask.shape[-2] // 2
        left = torch.nonzero(mask.squeeze()[:cent] == 0)[-1]
        right = torch.nonzero(mask.squeeze()[cent:] == 0)[0] + cent
        num_low_freqs = right - left
        pad = (mask.shape[-2] - num_low_freqs + 1) // 2

        x = transforms.mask_center(masked_kspace, pad, pad + num_low_freqs)

        # convert to image space
        x = fastmri.ifft2c(x)
        x, b = self.chans_to_batch_dim(x)

        # estimate sensitivities
        x = self.norm_unet(x)
        x = self.batch_chans_to_chan_dim(x, b)
        x = self.divide_root_sum_of_squares(x)

        return x
Beispiel #22
0
def _base_fastmri_unet_transform(
    kspace,
    mask,
    ground_truth,
    attrs,
    which_challenge="singlecoil",
):
    kspace = fastmri_transforms.to_tensor(kspace)

    mask = mask[..., :kspace.shape[-2]]  # accounting for variable size masks
    masked_kspace = kspace * mask.unsqueeze(-1) + 0.0

    # inverse Fourier transform to get zero filled solution
    image = fastmri.ifft2c(masked_kspace)

    # crop input to correct size
    if ground_truth is not None:
        crop_size = (ground_truth.shape[-2], ground_truth.shape[-1])
    else:
        crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])

    # check for FLAIR 203
    if image.shape[-2] < crop_size[1]:
        crop_size = (image.shape[-2], image.shape[-2])

    # noinspection PyTypeChecker
    image = fastmri_transforms.complex_center_crop(image, crop_size)

    # absolute value
    image = fastmri.complex_abs(image)

    # apply Root-Sum-of-Squares if multicoil data
    if which_challenge == "multicoil":
        image = fastmri.rss(image)

    # normalize input
    image, mean, std = fastmri_transforms.normalize_instance(image, eps=1e-11)
    image = image.clamp(-6, 6)

    return image.unsqueeze(0), mean, std
Beispiel #23
0
    def training_step(self, batch, batch_idx):
        subsampled_kspace, _, _, _, _, _, mask_loss = batch
        mask_train = -(mask_loss-1)

        subsampled_kspace_train = subsampled_kspace * mask_train + 0.0
        subsampled_kspace_loss = subsampled_kspace * mask_loss + 0.0

        image_train = fastmri.ifft2c(subsampled_kspace_train)
        image_train = fastmri.complex_abs(image_train)

        output_image = self(image_train)

        output_kspace = torch.fft.fft2(output_image)
        output_kspace = torch.stack((output_kspace.real, output_kspace.imag), axis=-1)

        output_kspace_loss = output_kspace * mask_loss + 0.0

        loss = l1_l2_loss(output_kspace_loss, subsampled_kspace_loss)

        self.log("loss", loss.detach())

        return loss
Beispiel #24
0
def data_transform(kspace, mask_function, target, data_attributes, filename,
                   slice_num):
    """
    Perform preprocessing of the kspace image, in order to get a proper input for the net. Should be invoked from
    the SliceData class.
    Args:
        - kspace: complete sampled kspace image
        - mask_func: masking function to apply mask to kspace (TODO not working: we are passing from outside)
        - target: the target image to be reconstructed from the kspace
        - data_attributes: attributes of the whole HDF5 file

    Returns:
        - normalized_masked_image: original kspace with mask applied and cropped to 320 x 320
        - mask: mask generated by masking function
        - normalized_target: normalized target
        - max_value: highest entry in target tensor (for SSIM loss)
    """

    kspace_t = transforms.to_tensor(kspace)
    kspace_t = transforms.normalize_instance(kspace_t)[0]

    masked_kspace, mask = transforms.apply_mask(
        data=kspace_t, mask_func=mask_func
    )  # apply mask: returns masked space and generated mask
    masked_image = fastmri.ifft2c(
        masked_kspace
    )  # Apply Inverse Fourier Transform to get the complex image
    masked_image = transforms.complex_center_crop(
        masked_image, (320, 320))  # center crop masked image
    masked_image = masked_image.permute(
        2, 0, 1)  # permuting the masked image fot pytorch n x c x h x w format
    masked_image = transforms.normalize_instance(masked_image)[0]  # normalize

    target = transforms.to_tensor(target)
    target = transforms.normalize_instance(target)[0]  # normalize
    target = torch.unsqueeze(target, 0)  # add dimension

    return kspace_t, masked_image, target, mask, data_attributes[
        'max'], slice_num
Beispiel #25
0
    def forward(self, masked_kspace, mask):
        def get_low_frequency_lines(mask):
            l = r = mask.shape[-2] // 2
            while mask[..., r, :]:
                r += 1

            while mask[..., l, :]:
                l -= 1

            return l + 1, r

        l, r = get_low_frequency_lines(mask)
        num_low_freqs = r - l
        pad = (mask.shape[-2] - num_low_freqs + 1) // 2
        x = T.mask_center(masked_kspace, pad, pad + num_low_freqs)
        x = fastmri.ifft2c(x)
        x, b = self.chans_to_batch_dim(x)
        x = self.norm_unet(x)
        x = self.batch_chans_to_chan_dim(x, b)
        x = self.divide_root_sum_of_squares(x)

        return x
def visualize_kspace(kspace, dim=None, crop=False, output_filepath=None):
    kspace = fastmri.ifft2c(kspace)
    if crop:
        crop_size = (kspace.shape[-2], kspace.shape[-2])
        kspace = T.complex_center_crop(kspace, crop_size)
        kspace = fastmri.complex_abs(kspace)
        kspace, _, _ = T.normalize_instance(kspace, eps=1e-11)
        kspace = kspace.clamp(-6, 6)
    else:
        # Compute absolute value to get a real image
        kspace = fastmri.complex_abs(kspace)
    if dim is not None:
        kspace = fastmri.rss(kspace, dim=dim)
    img = np.abs(kspace.numpy())
    if output_filepath is not None:
        if not output_filepath.parent.exists():
            output_filepath.parent.mkdir(parents=True)
        plt.imshow(img, cmap='gray')
        plt.axis("off")
        plt.savefig(output_filepath, bbox_inches="tight", pad_inches=0)
    else:
        plt.imshow(img, cmap='gray')
        plt.show()
Beispiel #27
0
    def forward(self, mr_img, mk_space, mask):
        # Loop over num = cascades number
        for i in range(self.n_cascade):
            # The laplacian decomposition will produce different images composing the gaussian/laplacian pyramids
            # The gaussians are the downsampled images from the original one
            # The laplacian represents the 'errors' within the images related to the gaussian images
            gaussian_3, lap_1, lap_2 = self.lapl_dec(mr_img.cpu())

            # Resize along with the input/output channels
            mr_img = self.shuffle_down_4(mr_img)

            # The convBlock1 is here performed to extract shallow features
            mr_img = self.convBlock1(mr_img)

            # ---- 3 branches ----
            branch_1 = self.shuffle_up_4(mr_img)
            branch_1 = self.branch1(branch_1)

            branch_2 = self.shuffle_up_2(mr_img)
            branch_2 = self.branch2(branch_2)

            branch_3 = self.branch3(mr_img)

            branch_out1 = torch.stack((self.pixel_conv(branch_1[...,0], lap_1[...,0]),
                                   self.pixel_conv(branch_1[...,1], lap_1[...,1])), dim=-1)
            branch_out2 = torch.stack((self.pixel_conv(branch_2[...,0], lap_2[...,0]),
                                   self.pixel_conv(branch_2[...,1], lap_2[...,1])), dim=-1)
            branch_out3 = torch.stack((self.pixel_conv(branch_3[...,0], gaussian_3[...,0]),
                                   self.pixel_conv(branch_3[...,1], gaussian_3[...,1])), dim=-1)

			# Performing the 2x linear upsample in order to add the different branches correctly
            output = self.lapl_rec(branch_out3, branch_out2)
            output = self.lapl_rec(output, branch_out1)

            # The DataConsistency Layer step is done here
            mr_img = fastmri.ifft2c((1.0 - mask) * fastmri.fft2c(output) + mk_space)
        return mr_img
Beispiel #28
0
    def __call__(
        self,
        kspace: np.ndarray,
        mask: np.ndarray,
        target: np.ndarray,
        attrs: Dict,
        fname: str,
        slice_num: int,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str,
               int, float]:
        """
        Args:
            kspace: Input k-space of shape (num_coils, rows, cols) for
                multi-coil data or (rows, cols) for single coil data.
            mask: Mask from the test dataset.
            target: Target image.
            attrs: Acquisition related information stored in the HDF5 object.
            fname: File name.
            slice_num: Serial number of the slice.

        Returns:
            tuple containing:
                image: Zero-filled input image.
                target: Target image converted to a torch.Tensor.
                mean: Mean value used for normalization.
                std: Standard deviation value used for normalization.
                fname: File name.
                slice_num: Serial number of the slice.
        """
        kspace = to_tensor(kspace)

        # check for max value
        max_value = attrs["max"] if "max" in attrs.keys() else 0.0

        # apply mask
        if self.mask_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            masked_kspace, mask = apply_mask(kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace

        # inverse Fourier transform to get zero filled solution
        image = fastmri.ifft2c(masked_kspace)

        # crop input to correct size
        if target is not None:
            crop_size = (target.shape[-2], target.shape[-1])
        else:
            crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])

        # check for FLAIR 203
        if image.shape[-2] < crop_size[1]:
            crop_size = (image.shape[-2], image.shape[-2])

        image = complex_center_crop(image, crop_size)

        # absolute value
        image = fastmri.complex_abs(image)

        # apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == "multicoil":
            image = fastmri.rss(image)

        # normalize input
        image, mean, std = normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        # normalize target
        if target is not None:
            target = to_tensor(target)
            target = center_crop(target, crop_size)
            target = normalize(target, mean, std, eps=1e-11)
            target = target.clamp(-6, 6)
        else:
            target = torch.Tensor([0])

        return image, target, mean, std, fname, slice_num, max_value
Beispiel #29
0
 def sens_reduce(self, x: torch.Tensor,
                 sens_maps: torch.Tensor) -> torch.Tensor:
     return fastmri.complex_mul(fastmri.ifft2c(x),
                                fastmri.complex_conj(sens_maps)).sum(
                                    dim=1, keepdim=True)
Beispiel #30
0
 def sens_reduce(x):
     x = fastmri.ifft2c(x)
     return fastmri.complex_mul(x, fastmri.complex_conj(sens_maps)).sum(
         dim=1, keepdim=True)