Example #1
0
    def forward(self, image):
        """Forward method for the MultiDomainConvTranspose2d class."""
        kspace = [
            fft2(im,
                 centered=self.fft_centered,
                 normalization=self.fft_normalization,
                 spatial_dims=self.spatial_dims) for im in torch.split(
                     image.permute(0, 2, 3, 1).contiguous(), 2, -1)
        ]
        kspace = torch.cat(kspace, -1).permute(0, 3, 1, 2)
        kspace = self.kspace_conv(kspace)

        backward = [
            ifft2(
                ks.float(),
                centered=self.fft_centered,
                normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
            ).type(image.type()) for ks in torch.split(
                kspace.permute(0, 2, 3, 1).contiguous(), 2, -1)
        ]
        backward = torch.cat(backward, -1).permute(0, 3, 1, 2)

        image = self.image_conv(image)
        return torch.cat([image, backward], dim=self.coil_dim)
Example #2
0
def test_centered_ifft2_forward_normalization(shape):
    """
    Test centered 2D Inverse Fast Fourier Transform with forward normalization.

    Args:
        shape: shape of the input

    Returns:
        None
    """
    shape = shape + [2]
    x = create_input(shape)
    out_torch = ifft2(x,
                      centered=True,
                      normalization="forward",
                      spatial_dims=[-2, -1]).numpy()
    out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

    input_numpy = tensor_to_complex_np(x)
    input_numpy = np.fft.ifftshift(input_numpy, (-2, -1))
    out_numpy = np.fft.ifft2(input_numpy, norm="forward")
    out_numpy = np.fft.fftshift(out_numpy, (-2, -1))

    if not np.allclose(out_torch, out_numpy):
        raise AssertionError
Example #3
0
File: utils.py Project: wdika/mridc
def log_likelihood_gradient(
    eta: torch.Tensor,
    masked_kspace: torch.Tensor,
    sense: torch.Tensor,
    mask: torch.Tensor,
    sigma: float,
    fft_centered: bool,
    fft_normalization: str,
    spatial_dims: Sequence[int],
    coil_dim: int,
) -> torch.Tensor:
    """
    Computes the gradient of the log-likelihood function.

    Parameters
    ----------
    eta: Initial guess for the reconstruction.
    masked_kspace: Subsampled k-space data.
    sense: Sensing matrix.
    mask: Sampling mask.
    sigma: Noise level.
    fft_centered: Whether to center the FFT.
    fft_normalization: Whether to normalize the FFT.
    spatial_dims: Spatial dimensions of the data.
    coil_dim: Dimension of the coil.

    Returns
    -------
    Gradient of the log-likelihood function.
    """
    eta_real, eta_imag = map(lambda x: torch.unsqueeze(x, 0), eta.chunk(2, -1))
    sense_real, sense_imag = sense.chunk(2, -1)

    re_se = eta_real * sense_real - eta_imag * sense_imag
    im_se = eta_real * sense_imag + eta_imag * sense_real

    pred = ifft2(
        mask * (fft2(
            torch.cat((re_se, im_se), -1),
            centered=fft_centered,
            normalization=fft_normalization,
            spatial_dims=spatial_dims,
        ) - masked_kspace),
        centered=fft_centered,
        normalization=fft_normalization,
        spatial_dims=spatial_dims,
    )

    pred_real, pred_imag = pred.chunk(2, -1)

    re_out = torch.sum(pred_real * sense_real + pred_imag * sense_imag,
                       coil_dim) / (sigma**2.0)
    im_out = torch.sum(pred_imag * sense_real - pred_real * sense_imag,
                       coil_dim) / (sigma**2.0)

    eta_real = eta_real.squeeze(0)
    eta_imag = eta_imag.squeeze(0)

    return torch.cat((eta_real, eta_imag, re_out, im_out),
                     0).unsqueeze(0).squeeze(-1)
Example #4
0
File: crnn.py Project: wdika/mridc
    def process_intermediate_pred(self, pred, sensitivity_maps, target):
        """
        Process the intermediate prediction.

        Parameters
        ----------
        pred: Intermediate prediction.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        sensitivity_maps: Coil sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        target: Target data to crop to size.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]

        Returns
        -------
        pred: torch.Tensor, shape [batch_size, n_x, n_y, 2]
            Processed prediction.
        """
        pred = ifft2(
            pred, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims
        )
        pred = coil_combination(pred, sensitivity_maps, method=self.coil_combination_method, dim=self.coil_dim)
        pred = torch.view_as_complex(pred)
        _, pred = center_crop_to_smallest(target, pred)
        return pred
Example #5
0
 def AT(x):
     return torch.sum(
         complex_mul(
             ifft2(x * mask,
                   centered=fft_centered,
                   normalization=fft_normalization,
                   spatial_dims=spatial_dims),
             complex_conj(smaps),
         ),
         dim=(-5),
     )
Example #6
0
    def test_step(self, batch: Dict[float, torch.Tensor],
                  batch_idx: int) -> Tuple[str, int, torch.Tensor]:
        """
        Test step.

        Parameters
        ----------
        batch: Batch of data.
            Dict of torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        batch_idx: Batch index.
            int

        Returns
        -------
        name: Name of the volume.
            str
        slice_num: Slice number.
            int
        pred: Predicted data.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]
        """
        kspace, y, sensitivity_maps, mask, init_pred, target, fname, slice_num, _ = batch
        y, mask, _ = self.process_inputs(y, mask)

        if self.use_sens_net:
            sensitivity_maps = self.sens_net(kspace, mask)
            if self.coil_combination_method.upper() == "SENSE":
                target = sense(
                    ifft2(
                        kspace,
                        centered=self.fft_centered,
                        normalization=self.fft_normalization,
                        spatial_dims=self.spatial_dims,
                    ),
                    sensitivity_maps,
                    dim=self.coil_dim,
                )

        prediction = self.forward(y, sensitivity_maps, mask, target)

        slice_num = int(slice_num)
        name = str(fname[0])  # type: ignore
        key = f"{name}_images_idx_{slice_num}"  # type: ignore
        output = torch.abs(prediction).detach().cpu()
        target = torch.abs(target).detach().cpu()
        output = output / output.max()  # type: ignore
        target = target / target.max()  # type: ignore
        error = torch.abs(target - output)
        self.log_image(f"{key}/target", target)
        self.log_image(f"{key}/reconstruction", output)
        self.log_image(f"{key}/error", error)

        return name, slice_num, prediction.detach().cpu().numpy()
Example #7
0
 def _backward_operator(self, kspace, sampling_mask, sensitivity_map):
     """Backward operator."""
     kspace = torch.where(
         sampling_mask == 0,
         torch.tensor([0.0], dtype=kspace.dtype).to(kspace.device), kspace)
     return (complex_mul(
         ifft2(
             kspace.float(),
             centered=self.fft_centered,
             normalization=self.fft_normalization,
             spatial_dims=self.spatial_dims,
         ),
         complex_conj(sensitivity_map),
     ).sum(self.coil_dim).type(kspace.type()))
Example #8
0
    def forward(
        self,
        y: torch.Tensor,
        sensitivity_maps: torch.Tensor,
        mask: torch.Tensor,
        init_pred: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        """
        Forward pass of the network.

        Parameters
        ----------
        y: Subsampled k-space data.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        sensitivity_maps: Coil sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        mask: Sampling mask.
            torch.Tensor, shape [1, 1, n_x, n_y, 1]
        init_pred: Initial prediction.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]
        target: Target data to compute the loss.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]

        Returns
        -------
        pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or  torch.Tensor, shape [batch_size, n_x, n_y, 2]
             If self.accumulate_loss is True, returns a list of all intermediate estimates.
             If False, returns the final estimate.
        """
        estimation = y.clone()

        for cascade in self.cascades:
            # Forward pass through the cascades
            estimation = cascade(estimation, y, sensitivity_maps, mask)

        estimation = ifft2(
            estimation,
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
        estimation = coil_combination(estimation,
                                      sensitivity_maps,
                                      method=self.coil_combination_method,
                                      dim=self.coil_dim)
        estimation = torch.view_as_complex(estimation)
        _, estimation = center_crop_to_smallest(target, estimation)
        return estimation
Example #9
0
    def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
        """
        Reduce the sensitivity maps.

        Parameters
        ----------
        x: Input data.
        sens_maps: Coil Sensitivity maps.

        Returns
        -------
        SENSE coil-combined reconstruction.
        """
        x = ifft2(x, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims)
        return complex_mul(x, complex_conj(sens_maps)).sum(self.coil_dim)
Example #10
0
    def forward(
        self,
        y: torch.Tensor,
        sensitivity_maps: torch.Tensor,
        mask: torch.Tensor,
        init_pred: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        """
        Forward pass of the network.

        Parameters
        ----------
        y: Subsampled k-space data.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        sensitivity_maps: Coil sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        mask: Sampling mask.
            torch.Tensor, shape [1, 1, n_x, n_y, 1]
        init_pred: Initial prediction.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]
        target: Target data to compute the loss.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]

        Returns
        -------
        pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or  torch.Tensor, shape [batch_size, n_x, n_y, 2]
             If self.accumulate_loss is True, returns a list of all intermediate estimates.
             If False, returns the final estimate.
        """
        image = ifft2(y,
                      centered=self.fft_centered,
                      normalization=self.fft_normalization,
                      spatial_dims=self.spatial_dims)

        if hasattr(self, "standardization"):
            image = self.standardization(image, sensitivity_maps)

        output_image = self._compute_model_per_coil(
            self.unet, image.permute(0, 1, 4, 2, 3)).permute(0, 1, 3, 4, 2)
        output_image = coil_combination(output_image,
                                        sensitivity_maps,
                                        method=self.coil_combination_method,
                                        dim=self.coil_dim)
        output_image = torch.view_as_complex(output_image)
        _, output_image = center_crop_to_smallest(target, output_image)
        return output_image
Example #11
0
    def forward(
        self,
        y: torch.Tensor,
        sensitivity_maps: torch.Tensor,
        mask: torch.Tensor,
        init_pred: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        """
        Forward pass of the network.

        Parameters
        ----------
        y: Subsampled k-space data.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        sensitivity_maps: Coil sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        mask: Sampling mask.
            torch.Tensor, shape [1, 1, n_x, n_y, 1]
        init_pred: Initial prediction.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]
        target: Target data to compute the loss.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]

        Returns
        -------
        pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or  torch.Tensor, shape [batch_size, n_x, n_y, 2]
             If self.accumulate_loss is True, returns a list of all intermediate estimates.
             If False, returns the final estimate.
        """
        sensitivity_maps = self.sens_net(
            y, mask) if self.use_sens_net else sensitivity_maps
        image = self.model(y, sensitivity_maps, mask)
        image = torch.view_as_complex(
            coil_combination(
                ifft2(
                    image,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ),
                sensitivity_maps,
                method=self.coil_combination_method,
                dim=self.coil_dim,
            ))
        _, image = center_crop_to_smallest(target, image)
        return image
Example #12
0
File: base.py Project: wdika/mridc
    def forward(
        self,
        masked_kspace: torch.Tensor,
        mask: torch.Tensor,
        num_low_frequencies: Optional[int] = None,
    ) -> torch.Tensor:
        """
        Forward pass of the model.

        Parameters
        ----------
        masked_kspace: Subsampled k-space data.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        mask: Sampling mask.
            torch.Tensor, shape [batch_size, 1, n_x, n_y, 1]
        num_low_frequencies: Number of low frequencies to keep.
            int

        Returns
        -------
        Normalized UNet output tensor.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        """
        if self.mask_center:
            pad, num_low_freqs = self.get_pad_and_num_low_freqs(
                mask, num_low_frequencies)
            masked_kspace = batched_mask_center(masked_kspace,
                                                pad,
                                                pad + num_low_freqs,
                                                mask_type=self.mask_type)

        # convert to image space
        images, batches = self.chans_to_batch_dim(
            ifft2(
                masked_kspace,
                centered=self.fft_centered,
                normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
            ))

        # estimate sensitivities
        images = self.batch_chans_to_chan_dim(self.norm_unet(images), batches)
        if self.normalize:
            images = self.divide_root_sum_of_squares(images, self.coil_dim)
        return images
Example #13
0
    def forward(
        self,
        y: torch.Tensor,
        sensitivity_maps: torch.Tensor,
        mask: torch.Tensor,
        init_pred: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        """
        Forward pass of the network.

        Parameters
        ----------
        y: Subsampled k-space data.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        sensitivity_maps: Coil sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        mask: Sampling mask.
            torch.Tensor, shape [1, 1, n_x, n_y, 1]
        init_pred: Initial prediction.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]
        target: Target data to compute the loss.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]

        Returns
        -------
        pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or  torch.Tensor, shape [batch_size, n_x, n_y, 2]
             If self.accumulate_loss is True, returns a list of all intermediate estimates.
             If False, returns the final estimate.
        """
        DC_sens = self.sens_net(y, mask)
        sensitivity_maps = DC_sens.clone()
        image = complex_mul(
            ifft2(y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims),
            complex_conj(sensitivity_maps),
        ).sum(self.coil_dim)
        for idx in range(self.num_iter):
            sensitivity_maps = self.update_C(idx, DC_sens, sensitivity_maps, image, y, mask)
            image = self.update_X(idx, image, sensitivity_maps, y, mask)
        image = torch.view_as_complex(image)
        _, image = center_crop_to_smallest(target, image)
        return image
Example #14
0
    def forward(
        self,
        y: torch.Tensor,
        sensitivity_maps: torch.Tensor,
        mask: torch.Tensor,
        target: torch.Tensor = None,
    ) -> Union[list, Any]:
        """
        Forward pass of the zero-filled method.

        Parameters
        ----------
        y: Subsampled k-space data.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        sensitivity_maps: Coil sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        mask: Sampling mask.
            torch.Tensor, shape [1, 1, n_x, n_y, 1]
        init_pred: Initial prediction.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]
        target: Target data to compute the loss.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]

        Returns
        -------
        pred: torch.Tensor, shape [batch_size, n_x, n_y, 2]
            Predicted data.
        """
        pred = coil_combination(
            ifft2(y,
                  centered=self.fft_centered,
                  normalization=self.fft_normalization,
                  spatial_dims=self.spatial_dims),
            sensitivity_maps,
            method=self.coil_combination_method.upper(),
            dim=self.coil_dim,
        )
        pred = check_stacked_complex(pred)
        _, pred = center_crop_to_smallest(target, pred)
        return pred
Example #15
0
    def forward(self, x, y, smaps, mask):
        """
        Forward pass of the data-consistency block.

        Parameters
        ----------
        x: Input image.
        y: Subsampled k-space data.
        smaps: Coil sensitivity maps.
        mask: Sampling mask.

        Returns
        -------
        Output image.
        """
        A_x = torch.sum(
            fft2(
                complex_mul(x.unsqueeze(-5).expand_as(smaps), smaps),
                centered=self.fft_centered,
                normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
            ),
            -4,
            keepdim=True,
        )
        k_dc = (1 - mask) * A_x + mask * (self.alpha * A_x +
                                          (1 - self.alpha) * y)
        x_dc = torch.sum(
            complex_mul(
                ifft2(
                    k_dc,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ),
                complex_conj(smaps),
            ),
            dim=(-5),
        )
        return self.beta * x + (1 - self.beta) * x_dc
Example #16
0
def test_non_centered_ifft2(shape):
    """
    Test non-centered 2D Inverse Fast Fourier Transform.

    Args:
        shape: shape of the input

    Returns:
        None
    """
    shape = shape + [2]
    x = create_input(shape)
    out_torch = ifft2(x,
                      centered=False,
                      normalization="ortho",
                      spatial_dims=[-2, -1]).numpy()
    out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

    input_numpy = tensor_to_complex_np(x)
    out_numpy = np.fft.ifft2(input_numpy, norm="ortho")

    if not np.allclose(out_torch, out_numpy):
        raise AssertionError
Example #17
0
    def sens_reduce(self, x: torch.Tensor,
                    sens_maps: torch.Tensor) -> torch.Tensor:
        """
        Reduce the sensitivity maps to the same size as the input.

        Parameters
        ----------
        x: Input data.
            torch.Tensor, shape [batch_size, n_coils, height, width, 2]
        sens_maps: Sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, height, width, 2]

        Returns
        -------
        SENSE reconstruction.
            torch.Tensor, shape [batch_size, height, width, 2]
        """
        x = ifft2(x,
                  centered=self.fft_centered,
                  normalization=self.fft_normalization,
                  spatial_dims=self.spatial_dims)
        return complex_mul(x, complex_conj(sens_maps)).sum(dim=self.coil_dim,
                                                           keepdim=True)
Example #18
0
    def forward(self, x, y, mask):
        """
        Forward pass of the data-consistency block.

        Parameters
        ----------
        x: Input image.
        y: Subsampled k-space data.
        mask: Sampling mask.

        Returns
        -------
        Output image.
        """
        A_x = fft2(x,
                   centered=self.fft_centered,
                   normalization=self.fft_normalization,
                   spatial_dims=self.spatial_dims)
        k_dc = (1 - mask) * A_x + mask * (self.lambda_ * A_x +
                                          (1 - self.lambda_) * y)
        return ifft2(k_dc,
                     centered=self.fft_centered,
                     normalization=self.fft_normalization,
                     spatial_dims=self.spatial_dims)
Example #19
0
    def forward(
        self,
        pred: torch.Tensor,
        masked_kspace: torch.Tensor,
        sense: torch.Tensor,
        mask: torch.Tensor,
        eta: torch.Tensor = None,
        hx: torch.Tensor = None,
        sigma: float = 1.0,
        keep_eta: bool = False,
    ) -> Tuple[Any, Union[list, torch.Tensor, None]]:
        """
        Forward pass of the RIMBlock.

        Parameters
        ----------
        pred: Predicted k-space.
        masked_kspace: Subsampled k-space.
        sense: Coil sensitivity maps.
        mask: Sample mask.
        eta: Initial guess for the eta.
        hx: Initial guess for the hidden state.
        sigma: Noise level.
        keep_eta: Whether to keep the eta.

        Returns
        -------
        Reconstructed image and hidden states.
        """
        if hx is None:
            hx = [
                masked_kspace.new_zeros(
                    (masked_kspace.size(0), f, *masked_kspace.size()[2:-1]))
                for f in self.recurrent_filters if f != 0
            ]

        if isinstance(pred, list):
            pred = pred[-1].detach()

        if eta is None or eta.ndim < 3:
            eta = (pred if keep_eta else torch.sum(
                complex_mul(
                    ifft2(
                        pred,
                        centered=self.fft_centered,
                        normalization=self.fft_normalization,
                        spatial_dims=self.spatial_dims,
                    ),
                    complex_conj(sense),
                ),
                self.coil_dim,
            ))

        etas = []
        for _ in range(self.time_steps):
            grad_eta = log_likelihood_gradient(
                eta,
                masked_kspace,
                sense,
                mask,
                sigma=sigma,
                fft_centered=self.fft_centered,
                fft_normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
                coil_dim=self.coil_dim,
            ).contiguous()

            for h, convrnn in enumerate(self.layers):
                hx[h] = convrnn(grad_eta, hx[h])
                grad_eta = hx[h]

            eta = eta + self.final_layer(grad_eta).permute(0, 2, 3, 1)
            etas.append(eta)

        eta = etas

        if self.no_dc:
            return eta, None

        soft_dc = torch.where(mask, pred - masked_kspace,
                              self.zero.to(masked_kspace)) * self.dc_weight
        current_kspace = [
            masked_kspace - soft_dc - fft2(
                complex_mul(e.unsqueeze(self.coil_dim), sense),
                centered=self.fft_centered,
                normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
            ) for e in eta
        ]

        return current_kspace, None
Example #20
0
    def update_X(self, idx, image, sensitivity_maps, y, mask):
        """
        Update the image.

        .. math::
            x_{k} = (1 - 2 * \lamdba_{{k}_{I}} * mi_{k} - 2 * \lamdba_{{k}_{F}} * mi_{k}) * x_{k}

            x_{k} = 2 * mi_{k} * (\lambda_{{k}_{I}} * D_I(x_{k}) + \lambda_{{k}_{F}} * F^-1(D_F(f)))

            A(x{k} - b) = M * F * (C * x{k}) - b

            x_{k} = 2 * mi_{k} * A^* * (A(x{k} - b))

        Parameters
        ----------
        idx: int
            The current iteration index.
        image: torch.Tensor [batch_size, num_coils, num_rows, num_cols]
            The predicted image.
        sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols]
            The coil sensitivity maps.
        y: torch.Tensor [batch_size, num_coils, num_rows, num_cols]
            The subsampled k-space data.
        mask: torch.Tensor [batch_size, 1, num_rows, num_cols]
            The subsampled mask.

        Returns
        -------
        image: torch.Tensor [batch_size, num_coils, num_rows, num_cols]
            The updated image.
        """
        # (1 - 2 * lamdba_{k}_{I} * mi_{k} - 2 * lamdba_{k}_{F} * mi_{k}) * x_{k}
        image_term_1 = (
            1 - 2 * self.reg_param_I[idx] * self.lr_image[idx] - 2 * self.reg_param_F[idx] * self.lr_image[idx]
        ) * image
        # D_I(x_{k})
        image_term_2_DI = self.image_model(image.unsqueeze(self.coil_dim)).squeeze(self.coil_dim).contiguous()
        image_term_2_DF = ifft2(
            self.kspace_model(
                fft2(
                    image,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ).unsqueeze(self.coil_dim)
            )
            .squeeze(self.coil_dim)
            .contiguous(),
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
        # 2 * mi_{k} * (lambda_{k}_{I} * D_I(x_{k}) + lambda_{k}_{F} * F^-1(D_F(f)))
        image_term_2 = (
            2
            * self.lr_image[idx]
            * (self.reg_param_I[idx] * image_term_2_DI + self.reg_param_F[idx] * image_term_2_DF)
        )
        # A(x{k}) - b) = M * F * (C * x{k}) - b
        image_term_3_A = fft2(
            complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps),
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
        image_term_3_A = torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), image_term_3_A) - y
        # 2 * mi_{k} * A^* * (A(x{k}) - b))
        image_term_3_Aconj = complex_mul(
            ifft2(
                image_term_3_A,
                centered=self.fft_centered,
                normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
            ),
            complex_conj(sensitivity_maps),
        ).sum(self.coil_dim)
        image_term_3 = 2 * self.lr_image[idx] * image_term_3_Aconj
        image = image_term_1 + image_term_2 - image_term_3
        return image
Example #21
0
    def update_C(self, idx, DC_sens, sensitivity_maps, image, y, mask) -> torch.Tensor:
        """
        Update the coil sensitivity maps.

        .. math::
            C = (1 - 2 * \lambda_{k}^{C} * ni_{k}) * C_{k}

            C = 2 * \lambda_{k}^{C} * ni_{k} * D_{C}(F^-1(b))

            A(x_{k}) = M * F * (C * x_{k})

            C = 2 * ni_{k} * F^-1(M.T * (M * F * (C * x_{k}) - b)) * x_{k}^*

        Parameters
        ----------
        idx: int
            The current iteration index.
        DC_sens: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols]
            The initial coil sensitivity maps.
        sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols]
            The coil sensitivity maps.
        image: torch.Tensor [batch_size, num_coils, num_rows, num_cols]
            The predicted image.
        y: torch.Tensor [batch_size, num_coils, num_rows, num_cols]
            The subsampled k-space data.
        mask: torch.Tensor [batch_size, 1, num_rows, num_cols]
            The subsampled mask.

        Returns
        -------
        sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols]
            The updated coil sensitivity maps.
        """
        # (1 - 2 * lambda_{k}^{C} * ni_{k}) * C_{k}
        sense_term_1 = (1 - 2 * self.reg_param_C[idx] * self.lr_sens[idx]) * sensitivity_maps
        # 2 * lambda_{k}^{C} * ni_{k} * D_{C}(F^-1(b))
        sense_term_2 = 2 * self.reg_param_C[idx] * self.lr_sens[idx] * DC_sens
        # A(x_{k}) = M * F * (C * x_{k})
        sense_term_3_A = fft2(
            complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps),
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
        sense_term_3_A = torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), sense_term_3_A)
        # 2 * ni_{k} * F^-1(M.T * (M * F * (C * x_{k}) - b)) * x_{k}^*
        sense_term_3_mask = torch.where(
            mask == 1,
            torch.tensor([0.0], dtype=y.dtype).to(y.device),
            sense_term_3_A - y,
        )

        sense_term_3_backward = ifft2(
            sense_term_3_mask,
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
        sense_term_3 = 2 * self.lr_sens[idx] * sense_term_3_backward * complex_conj(image).unsqueeze(self.coil_dim)
        sensitivity_maps = sense_term_1 + sense_term_2 - sense_term_3
        return sensitivity_maps
Example #22
0
File: base.py Project: wdika/mridc
    def test_step(self, batch: Dict[float, torch.Tensor],
                  batch_idx: int) -> Tuple[str, int, torch.Tensor]:
        """
        Performs a test step.

        Parameters
        ----------
        batch: Batch of data. Dict[str, torch.Tensor], with keys,
            'y': subsampled kspace,
                torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
            'sensitivity_maps': sensitivity_maps,
                torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
            'mask': sampling mask,
                torch.Tensor, shape [1, 1, n_x, n_y, 1]
            'init_pred': initial prediction. For example zero-filled or PICS.
                torch.Tensor, shape [batch_size, n_x, n_y, 2]
            'target': target data,
                torch.Tensor, shape [batch_size, n_x, n_y, 2]
            'fname': filename,
                str, shape [batch_size]
            'slice_idx': slice_idx,
                torch.Tensor, shape [batch_size]
            'acc': acceleration factor,
                torch.Tensor, shape [batch_size]
            'max_value': maximum value of the magnitude image space,
                torch.Tensor, shape [batch_size]
            'crop_size': crop size,
                torch.Tensor, shape [n_x, n_y]
        batch_idx: Batch index.
            int

        Returns
        -------
        name: Name of the volume.
            str
        slice_num: Slice number.
            int
        pred: Predicted data.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]
        """
        kspace, y, sensitivity_maps, mask, init_pred, target, fname, slice_num, _ = batch
        y, mask, init_pred, r = self.process_inputs(y, mask, init_pred)

        if self.use_sens_net:
            sensitivity_maps = self.sens_net(kspace, mask)
            if self.coil_combination_method.upper() == "SENSE":
                target = sense(
                    ifft2(
                        kspace,
                        centered=self.fft_centered,
                        normalization=self.fft_normalization,
                        spatial_dims=self.spatial_dims,
                    ),
                    sensitivity_maps,
                    dim=self.coil_dim,
                )

        preds = self.forward(y, sensitivity_maps, mask, init_pred, target)

        if self.accumulate_estimates:
            try:
                preds = next(preds)
            except StopIteration:
                pass

        # Cascades
        if isinstance(preds, list):
            preds = preds[-1]

        # Time-steps
        if isinstance(preds, list):
            preds = preds[-1]

        slice_num = int(slice_num)
        name = str(fname[0])  # type: ignore
        key = f"{name}_images_idx_{slice_num}"  # type: ignore

        output = torch.abs(preds).detach().cpu()
        output = output / output.max()  # type: ignore

        target = torch.abs(target).detach().cpu()
        target = target / target.max()  # type: ignore

        error = torch.abs(target - output)

        self.log_image(f"{key}/target", target)
        self.log_image(f"{key}/reconstruction", output)
        self.log_image(f"{key}/error", error)

        return name, slice_num, preds.detach().cpu().numpy()
Example #23
0
File: base.py Project: wdika/mridc
    def validation_step(self, batch: Dict[float, torch.Tensor],
                        batch_idx: int) -> Dict:
        """
        Performs a validation step.

        Parameters
        ----------
        batch: Batch of data. Dict[str, torch.Tensor], with keys,
            'y': subsampled kspace,
                torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
            'sensitivity_maps': sensitivity_maps,
                torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
            'mask': sampling mask,
                torch.Tensor, shape [1, 1, n_x, n_y, 1]
            'init_pred': initial prediction. For example zero-filled or PICS.
                torch.Tensor, shape [batch_size, n_x, n_y, 2]
            'target': target data,
                torch.Tensor, shape [batch_size, n_x, n_y, 2]
            'fname': filename,
                str, shape [batch_size]
            'slice_idx': slice_idx,
                torch.Tensor, shape [batch_size]
            'acc': acceleration factor,
                torch.Tensor, shape [batch_size]
            'max_value': maximum value of the magnitude image space,
                torch.Tensor, shape [batch_size]
            'crop_size': crop size,
                torch.Tensor, shape [n_x, n_y]
        batch_idx: Batch index.
            int

        Returns
        -------
        Dict[str, torch.Tensor], with keys,
        'loss': loss,
            torch.Tensor, shape [1]
        'log': log,
            dict, shape [1]
        """
        kspace, y, sensitivity_maps, mask, init_pred, target, fname, slice_num, _ = batch
        y, mask, init_pred, r = self.process_inputs(y, mask, init_pred)

        if self.use_sens_net:
            sensitivity_maps = self.sens_net(kspace, mask)
            if self.coil_combination_method.upper() == "SENSE":
                target = sense(
                    ifft2(
                        kspace,
                        centered=self.fft_centered,
                        normalization=self.fft_normalization,
                        spatial_dims=self.spatial_dims,
                    ),
                    sensitivity_maps,
                    dim=self.coil_dim,
                )

        preds = self.forward(y, sensitivity_maps, mask, init_pred, target)

        if self.accumulate_estimates:
            try:
                preds = next(preds)
            except StopIteration:
                pass

            val_loss = sum(
                self.process_loss(target, preds, _loss_fn=self.eval_loss_fn))
        else:
            val_loss = self.process_loss(target,
                                         preds,
                                         _loss_fn=self.eval_loss_fn)

        # Cascades
        if isinstance(preds, list):
            preds = preds[-1]

        # Time-steps
        if isinstance(preds, list):
            preds = preds[-1]

        key = f"{fname[0]}_images_idx_{int(slice_num)}"  # type: ignore
        output = torch.abs(preds).detach().cpu()
        target = torch.abs(target).detach().cpu()
        output = output / output.max()  # type: ignore
        target = target / target.max()  # type: ignore
        error = torch.abs(target - output)
        self.log_image(f"{key}/target", target)
        self.log_image(f"{key}/reconstruction", output)
        self.log_image(f"{key}/error", error)

        target = target.numpy()  # type: ignore
        output = output.numpy()  # type: ignore
        self.mse_vals[fname][slice_num] = torch.tensor(mse(target,
                                                           output)).view(1)
        self.nmse_vals[fname][slice_num] = torch.tensor(nmse(target,
                                                             output)).view(1)
        self.ssim_vals[fname][slice_num] = torch.tensor(
            ssim(target, output, maxval=output.max() - output.min())).view(1)
        self.psnr_vals[fname][slice_num] = torch.tensor(
            psnr(target, output, maxval=output.max() - output.min())).view(1)

        return {"val_loss": val_loss}
Example #24
0
File: base.py Project: wdika/mridc
    def training_step(self, batch: Dict[float, torch.Tensor],
                      batch_idx: int) -> Dict[str, torch.Tensor]:
        """
        Performs a training step.

        Parameters
        ----------
        batch: Batch of data.
            Dict[str, torch.Tensor], with keys,

            'y': subsampled kspace,
                torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
            'sensitivity_maps': sensitivity_maps,
                torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
            'mask': sampling mask,
                torch.Tensor, shape [1, 1, n_x, n_y, 1]
            'init_pred': initial prediction. For example zero-filled or PICS.
                torch.Tensor, shape [batch_size, n_x, n_y, 2]
            'target': target data,
                torch.Tensor, shape [batch_size, n_x, n_y, 2]
            'fname': filename,
                str, shape [batch_size]
            'slice_idx': slice_idx,
                torch.Tensor, shape [batch_size]
            'acc': acceleration factor,
                torch.Tensor, shape [batch_size]
            'max_value': maximum value of the magnitude image space,
                torch.Tensor, shape [batch_size]
            'crop_size': crop size,
                torch.Tensor, shape [n_x, n_y]
        batch_idx: Batch index.
            int

        Returns
        -------
        Dict[str, torch.Tensor], with keys,
        'loss': loss,
            torch.Tensor, shape [1]
        'log': log,
            dict, shape [1]
        """
        kspace, y, sensitivity_maps, mask, init_pred, target, _, _, acc = batch
        y, mask, init_pred, r = self.process_inputs(y, mask, init_pred)

        if self.use_sens_net:
            sensitivity_maps = self.sens_net(kspace, mask)
            if self.coil_combination_method.upper() == "SENSE":
                target = sense(
                    ifft2(
                        kspace,
                        centered=self.fft_centered,
                        normalization=self.fft_normalization,
                        spatial_dims=self.spatial_dims,
                    ),
                    sensitivity_maps,
                    dim=self.coil_dim,
                )

        preds = self.forward(y, sensitivity_maps, mask, init_pred, target)

        if self.accumulate_estimates:
            try:
                preds = next(preds)
            except StopIteration:
                pass

            train_loss = sum(
                self.process_loss(target, preds, _loss_fn=self.train_loss_fn))
        else:
            train_loss = self.process_loss(target,
                                           preds,
                                           _loss_fn=self.train_loss_fn)

        acc = r if r != 0 else acc
        tensorboard_logs = {
            f"train_loss_{acc}x": train_loss.item(),  # type: ignore
            "lr": self._optimizer.param_groups[0]["lr"],  # type: ignore
        }
        return {"loss": train_loss, "log": tensorboard_logs}
Example #25
0
def evaluate(
    arguments,
    reconstruction_key,
    mask_background,
    output_path,
    method,
    acc,
    no_params,
    slice_start,
    slice_end,
    coil_dim,
):
    """
    Evaluates the reconstructions.

    Parameters
    ----------
    arguments: The CLI arguments.
    reconstruction_key: The key of the reconstruction to evaluate.
    mask_background: The background mask.
    output_path: The output path.
    method: The reconstruction method.
    acc: The acceleration factor.
    no_params: The number of parameters.
    slice_start: The start slice. (optional)
    slice_end: The end slice. (optional)
    coil_dim: The coil dimension. (optional)

    Returns
    -------
    dict: A dict where the keys are metric names and the values are the mean of the metric.
    """
    _metrics = Metrics(METRIC_FUNCS, output_path,
                       method) if arguments.type == "mean_std" else {}

    for tgt_file in tqdm(arguments.target_path.iterdir()):
        if exists(arguments.predictions_path / tgt_file.name):
            with h5py.File(tgt_file, "r") as target, h5py.File(
                    arguments.predictions_path / tgt_file.name, "r") as recons:
                kspace = target["kspace"][()]

                if arguments.sense_path is not None:
                    sense = h5py.File(arguments.sense_path / tgt_file.name,
                                      "r")["sensitivity_map"][()]
                elif "sensitivity_map" in target:
                    sense = target["sensitivity_map"][()]

                sense = sense.squeeze().astype(np.complex64)

                if sense.shape != kspace.shape:
                    sense = np.transpose(sense, (0, 3, 1, 2))

                target = np.abs(
                    tensor_to_complex_np(
                        torch.sum(
                            complex_mul(
                                ifft2(to_tensor(kspace),
                                      centered="fastmri"
                                      in str(arguments.sense_path).lower()),
                                complex_conj(to_tensor(sense)),
                            ),
                            coil_dim,
                        )))

                recons = recons[reconstruction_key][()]

                if recons.ndim == 4:
                    recons = recons.squeeze(coil_dim)

                if arguments.crop_size is not None:
                    crop_size = arguments.crop_size
                    crop_size[0] = min(target.shape[-2], int(crop_size[0]))
                    crop_size[1] = min(target.shape[-1], int(crop_size[1]))
                    crop_size[0] = min(recons.shape[-2], int(crop_size[0]))
                    crop_size[1] = min(recons.shape[-1], int(crop_size[1]))

                    target = center_crop(target, crop_size)
                    recons = center_crop(recons, crop_size)

                if mask_background:
                    for sl in range(target.shape[0]):
                        mask = convex_hull_image(
                            np.where(
                                np.abs(target[sl]) > threshold_otsu(
                                    np.abs(target[sl])), 1, 0)  # type: ignore
                        )
                        target[sl] = target[sl] * mask
                        recons[sl] = recons[sl] * mask

                if slice_start is not None:
                    target = target[slice_start:]
                    recons = recons[slice_start:]

                if slice_end is not None:
                    target = target[:slice_end]
                    recons = recons[:slice_end]

                for sl in range(target.shape[0]):
                    target[sl] = target[sl] / np.max(np.abs(target[sl]))
                    recons[sl] = recons[sl] / np.max(np.abs(recons[sl]))

                target = np.abs(target)
                recons = np.abs(recons)

                if arguments.type == "mean_std":
                    _metrics.push(target, recons)
                else:
                    _target = np.expand_dims(target, coil_dim)
                    _recons = np.expand_dims(recons, coil_dim)
                    for sl in range(target.shape[0]):
                        _metrics["FNAME"] = tgt_file.name
                        _metrics["SLICE"] = sl
                        _metrics["ACC"] = acc
                        _metrics["METHOD"] = method
                        _metrics["MSE"] = [mse(target[sl], recons[sl])]
                        _metrics["NMSE"] = [nmse(target[sl], recons[sl])]
                        _metrics["PSNR"] = [psnr(target[sl], recons[sl])]
                        _metrics["SSIM"] = [ssim(_target[sl], _recons[sl])]
                        _metrics["PARAMS"] = no_params

                        if not exists(arguments.output_path):
                            pd.DataFrame(columns=_metrics.keys()).to_csv(
                                arguments.output_path, index=False, mode="w")
                        pd.DataFrame(_metrics).to_csv(arguments.output_path,
                                                      index=False,
                                                      header=False,
                                                      mode="a")

    return _metrics
Example #26
0
File: lpd.py Project: wdika/mridc
    def forward(
        self,
        y: torch.Tensor,
        sensitivity_maps: torch.Tensor,
        mask: torch.Tensor,
        init_pred: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        """
        Forward pass of the network.

        Parameters
        ----------
        y: Subsampled k-space data.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        sensitivity_maps: Coil sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        mask: Sampling mask.
            torch.Tensor, shape [1, 1, n_x, n_y, 1]
        init_pred: Initial prediction.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]
        target: Target data to compute the loss.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]

        Returns
        -------
        pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or  torch.Tensor, shape [batch_size, n_x, n_y, 2]
             If self.accumulate_loss is True, returns a list of all intermediate estimates.
             If False, returns the final estimate.
        """
        input_image = complex_mul(
            ifft2(
                torch.where(mask == 0,
                            torch.tensor([0.0], dtype=y.dtype).to(y.device),
                            y),
                centered=self.fft_centered,
                normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
            ),
            complex_conj(sensitivity_maps),
        ).sum(self.coil_dim)
        dual_buffer = torch.cat([y] * self.num_dual, -1).to(y.device)
        primal_buffer = torch.cat([input_image] * self.num_primal,
                                  -1).to(y.device)

        for idx in range(self.num_iter):
            # Dual
            f_2 = primal_buffer[..., 2:4].clone()
            f_2 = torch.where(
                mask == 0,
                torch.tensor([0.0], dtype=f_2.dtype).to(f_2.device),
                fft2(
                    complex_mul(f_2.unsqueeze(self.coil_dim),
                                sensitivity_maps),
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ).type(f_2.type()),
            )
            dual_buffer = self.dual_net[idx](dual_buffer, f_2, y)

            # Primal
            h_1 = dual_buffer[..., 0:2].clone()
            h_1 = complex_mul(
                ifft2(
                    torch.where(
                        mask == 0,
                        torch.tensor([0.0], dtype=h_1.dtype).to(h_1.device),
                        h_1),
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ),
                complex_conj(sensitivity_maps),
            ).sum(self.coil_dim)
            primal_buffer = self.primal_net[idx](primal_buffer, h_1)

        output = primal_buffer[..., 0:2]
        output = (output**2).sum(-1).sqrt()
        _, output = center_crop_to_smallest(target, output)
        return output
Example #27
0
File: rvn.py Project: wdika/mridc
    def forward(
        self,
        y: torch.Tensor,
        sensitivity_maps: torch.Tensor,
        mask: torch.Tensor,
        init_pred: torch.Tensor,
        target: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        """
        Forward pass of the network.

        Parameters
        ----------
        y: Subsampled k-space data.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        sensitivity_maps: Coil sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        mask: Sampling mask.
            torch.Tensor, shape [1, 1, n_x, n_y, 1]
        init_pred: Initial prediction.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]
        target: Target data to compute the loss.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]

        Returns
        -------
        pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or  torch.Tensor, shape [batch_size, n_x, n_y, 2]
             If self.accumulate_loss is True, returns a list of all intermediate estimates.
             If False, returns the final estimate.
        """
        previous_state: Optional[torch.Tensor] = None

        if self.initializer is not None:
            if self.initializer_initialization == "sense":
                initializer_input_image = (complex_mul(
                    ifft2(
                        y,
                        centered=self.fft_centered,
                        normalization=self.fft_normalization,
                        spatial_dims=self.spatial_dims,
                    ),
                    complex_conj(sensitivity_maps),
                ).sum(self.coil_dim).unsqueeze(self.coil_dim))
            elif self.initializer_initialization == "input_image":
                if "initial_image" not in kwargs:
                    raise ValueError(
                        "`'initial_image` is required as input if initializer_initialization "
                        f"is {self.initializer_initialization}.")
                initializer_input_image = kwargs["initial_image"].unsqueeze(
                    self.coil_dim)
            elif self.initializer_initialization == "zero_filled":
                initializer_input_image = ifft2(
                    y,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                )

            previous_state = self.initializer(
                fft2(
                    initializer_input_image,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ).sum(1).permute(0, 3, 1, 2))

        kspace_prediction = y.clone()

        for step in range(self.num_steps):
            block = self.block_list[
                step] if self.no_parameter_sharing else self.block_list[0]
            kspace_prediction, previous_state = block(
                kspace_prediction,
                y,
                mask,
                sensitivity_maps,
                previous_state,
            )

        eta = ifft2(
            kspace_prediction,
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
        eta = coil_combination(eta,
                               sensitivity_maps,
                               method=self.coil_combination_method,
                               dim=self.coil_dim)
        eta = torch.view_as_complex(eta)
        _, eta = center_crop_to_smallest(target, eta)
        return eta
Example #28
0
    def forward(
        self,
        y: torch.Tensor,
        sensitivity_maps: torch.Tensor,
        mask: torch.Tensor,
        init_pred: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        """
        Forward pass of the network.

        Parameters
        ----------
        y: Subsampled k-space data.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        sensitivity_maps: Coil sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        mask: Sampling mask.
            torch.Tensor, shape [1, 1, n_x, n_y, 1]
        init_pred: Initial prediction.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]
        target: Target data to compute the loss.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]

        Returns
        -------
        pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or  torch.Tensor, shape [batch_size, n_x, n_y, 2]
             If self.accumulate_loss is True, returns a list of all intermediate estimates.
             If False, returns the final estimate.
        """
        kspace = y.clone()
        zero = torch.zeros(1, 1, 1, 1, 1).to(kspace)

        for idx in range(self.num_iter):
            soft_dc = torch.where(mask.bool(), kspace - y, zero) * self.dc_weight

            kspace = self.kspace_model_list[idx](kspace)
            if kspace.shape[-1] != 2:
                kspace = kspace.permute(0, 1, 3, 4, 2).to(target)
                kspace = torch.view_as_real(kspace[..., 0] + 1j * kspace[..., 1])  # this is necessary, but why?

            image = complex_mul(
                ifft2(
                    kspace,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ),
                complex_conj(sensitivity_maps),
            ).sum(self.coil_dim)
            image = self.image_model_list[idx](image.unsqueeze(self.coil_dim)).squeeze(self.coil_dim)

            if not self.no_dc:
                image = fft2(
                    complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps),
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ).type(image.type())
                image = kspace - soft_dc - image
                image = complex_mul(
                    ifft2(
                        image,
                        centered=self.fft_centered,
                        normalization=self.fft_normalization,
                        spatial_dims=self.spatial_dims,
                    ),
                    complex_conj(sensitivity_maps),
                ).sum(self.coil_dim)

            if idx < self.num_iter - 1:
                kspace = fft2(
                    complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps),
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ).type(image.type())

        image = torch.view_as_complex(image)
        _, image = center_crop_to_smallest(target, image)
        return image
Example #29
0
    def forward(
        self,
        current_kspace: torch.Tensor,
        masked_kspace: torch.Tensor,
        sampling_mask: torch.Tensor,
        sensitivity_map: torch.Tensor,
        hidden_state: Union[None, torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes forward pass of RecurrentVarNetBlock.

        Parameters
        ----------
        current_kspace: Current k-space prediction.
            torch.Tensor, shape [batch_size, n_coil, height, width, 2]
        masked_kspace: Subsampled k-space.
            torch.Tensor, shape [batch_size, n_coil, height, width, 2]
        sampling_mask: Sampling mask.
            torch.Tensor, shape [batch_size, 1, height, width, 1]
        sensitivity_map: Coil sensitivities.
            torch.Tensor, shape [batch_size, n_coil, height, width, 2]
        hidden_state: ConvGRU hidden state.
            None or torch.Tensor, shape [batch_size, n_l, height, width, hidden_channels]

        Returns
        -------
        new_kspace: New k-space prediction.
            torch.Tensor, shape [batch_size, n_coil, height, width, 2]
        hidden_state: Next hidden state.
            list of torch.Tensor, shape [batch_size, hidden_channels, height, width, num_layers]
        """
        kspace_error = torch.where(
            sampling_mask == 0,
            torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device),
            current_kspace - masked_kspace,
        )

        recurrent_term = torch.cat(
            [
                complex_mul(
                    ifft2(
                        kspace,
                        centered=self.fft_centered,
                        normalization=self.fft_normalization,
                        spatial_dims=self.spatial_dims,
                    ),
                    complex_conj(sensitivity_map),
                ).sum(self.coil_dim)
                for kspace in torch.split(current_kspace, 2, -1)
            ],
            dim=-1,
        ).permute(0, 3, 1, 2)

        recurrent_term, hidden_state = self.regularizer(recurrent_term, hidden_state)  # :math:`w_t`, :math:`h_{t+1}`
        recurrent_term = recurrent_term.permute(0, 2, 3, 1)

        recurrent_term = torch.cat(
            [
                fft2(
                    complex_mul(image.unsqueeze(self.coil_dim), sensitivity_map),
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                )
                for image in torch.split(recurrent_term, 2, -1)
            ],
            dim=-1,
        )

        new_kspace = current_kspace - self.learning_rate * kspace_error + recurrent_term

        return new_kspace, hidden_state
Example #30
0
    def test_step(self, batch: Dict[float, torch.Tensor],
                  batch_idx: int) -> Tuple[str, int, torch.Tensor]:
        """
        Test step.

        Parameters
        ----------
        batch: Batch of data.
            Dict of torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        batch_idx: Batch index.
            int

        Returns
        -------
        name: Name of the volume.
            str
        slice_num: Slice number.
            int
        pred: Predicted data.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]
        """
        kspace, y, sensitivity_maps, mask, _, target, fname, slice_num, _ = batch
        y, mask, _ = self.process_inputs(y, mask)

        if self.use_sens_net:
            sensitivity_maps = self.sens_net(kspace, mask)
            if self.coil_combination_method.upper() == "SENSE":
                target = sense(
                    ifft2(
                        kspace,
                        centered=self.fft_centered,
                        normalization=self.fft_normalization,
                        spatial_dims=self.spatial_dims,
                    ),
                    sensitivity_maps,
                    dim=self.coil_dim,
                )

        y = torch.view_as_complex(y).permute(0, 2, 3, 1).detach().cpu().numpy()

        if sensitivity_maps is None and not self.sens_net:
            raise ValueError(
                "Sensitivity maps are required for PICS. "
                "Please set use_sens_net to True if you precomputed sensitivity maps are not available."
            )

        sensitivity_maps = torch.view_as_complex(sensitivity_maps)
        if self.fft_type != "orthogonal":
            sensitivity_maps = torch.fft.fftshift(sensitivity_maps,
                                                  dim=(-2, -1))
        sensitivity_maps = sensitivity_maps.permute(
            0, 2, 3, 1).detach().cpu().numpy()  # type: ignore

        prediction = torch.from_numpy(
            self.forward(y, sensitivity_maps, mask, target)).unsqueeze(0)
        if self.fft_type != "orthogonal":
            prediction = torch.fft.fftshift(prediction, dim=(-2, -1))

        slice_num = int(slice_num)
        name = str(fname[0])  # type: ignore
        key = f"{name}_images_idx_{slice_num}"  # type: ignore
        output = torch.abs(prediction).detach().cpu()
        target = torch.abs(target).detach().cpu()
        output = output / output.max()  # type: ignore
        target = target / target.max()  # type: ignore
        error = torch.abs(target - output)
        self.log_image(f"{key}/target", target)
        self.log_image(f"{key}/reconstruction", output)
        self.log_image(f"{key}/error", error)

        return name, slice_num, prediction.detach().cpu().numpy()