Beispiel #1
0
def log_likelihood_gradient(
    eta: torch.Tensor,
    masked_kspace: torch.Tensor,
    sense: torch.Tensor,
    mask: torch.Tensor,
    sigma: float,
    fft_centered: bool,
    fft_normalization: str,
    spatial_dims: Sequence[int],
    coil_dim: int,
) -> torch.Tensor:
    """
    Computes the gradient of the log-likelihood function.

    Parameters
    ----------
    eta: Initial guess for the reconstruction.
    masked_kspace: Subsampled k-space data.
    sense: Sensing matrix.
    mask: Sampling mask.
    sigma: Noise level.
    fft_centered: Whether to center the FFT.
    fft_normalization: Whether to normalize the FFT.
    spatial_dims: Spatial dimensions of the data.
    coil_dim: Dimension of the coil.

    Returns
    -------
    Gradient of the log-likelihood function.
    """
    eta_real, eta_imag = map(lambda x: torch.unsqueeze(x, 0), eta.chunk(2, -1))
    sense_real, sense_imag = sense.chunk(2, -1)

    re_se = eta_real * sense_real - eta_imag * sense_imag
    im_se = eta_real * sense_imag + eta_imag * sense_real

    pred = ifft2(
        mask * (fft2(
            torch.cat((re_se, im_se), -1),
            centered=fft_centered,
            normalization=fft_normalization,
            spatial_dims=spatial_dims,
        ) - masked_kspace),
        centered=fft_centered,
        normalization=fft_normalization,
        spatial_dims=spatial_dims,
    )

    pred_real, pred_imag = pred.chunk(2, -1)

    re_out = torch.sum(pred_real * sense_real + pred_imag * sense_imag,
                       coil_dim) / (sigma**2.0)
    im_out = torch.sum(pred_imag * sense_real - pred_real * sense_imag,
                       coil_dim) / (sigma**2.0)

    eta_real = eta_real.squeeze(0)
    eta_imag = eta_imag.squeeze(0)

    return torch.cat((eta_real, eta_imag, re_out, im_out),
                     0).unsqueeze(0).squeeze(-1)
Beispiel #2
0
def test_centered_fft2_forward_normalization(shape):
    """
    Test centered 2D Fast Fourier Transform with forward normalization.

    Args:
        shape: shape of the input

    Returns:
        None
    """
    shape = shape + [2]
    x = create_input(shape)
    out_torch = fft2(x,
                     centered=True,
                     normalization="forward",
                     spatial_dims=[-2, -1]).numpy()
    out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

    input_numpy = tensor_to_complex_np(x)
    input_numpy = np.fft.ifftshift(input_numpy, (-2, -1))
    out_numpy = np.fft.fft2(input_numpy, norm="forward")
    out_numpy = np.fft.fftshift(out_numpy, (-2, -1))

    if not np.allclose(out_torch, out_numpy):
        raise AssertionError
Beispiel #3
0
    def forward(self, image):
        """Forward method for the MultiDomainConvTranspose2d class."""
        kspace = [
            fft2(im,
                 centered=self.fft_centered,
                 normalization=self.fft_normalization,
                 spatial_dims=self.spatial_dims) for im in torch.split(
                     image.permute(0, 2, 3, 1).contiguous(), 2, -1)
        ]
        kspace = torch.cat(kspace, -1).permute(0, 3, 1, 2)
        kspace = self.kspace_conv(kspace)

        backward = [
            ifft2(
                ks.float(),
                centered=self.fft_centered,
                normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
            ).type(image.type()) for ks in torch.split(
                kspace.permute(0, 2, 3, 1).contiguous(), 2, -1)
        ]
        backward = torch.cat(backward, -1).permute(0, 3, 1, 2)

        image = self.image_conv(image)
        return torch.cat([image, backward], dim=self.coil_dim)
Beispiel #4
0
 def A(x):
     x = (fft2(
         complex_mul(x.expand_as(smaps), smaps),
         centered=ctx.fft_centered,
         normalization=ctx.fft_normalization,
         spatial_dims=ctx.spatial_dims,
     ) * mask)
     return torch.sum(x, dim=-4, keepdim=True)
Beispiel #5
0
 def _forward_operator(self, image, sampling_mask, sensitivity_map):
     """Forward operator."""
     return torch.where(
         sampling_mask == 0,
         torch.tensor([0.0], dtype=image.dtype).to(image.device),
         fft2(
             complex_mul(image.unsqueeze(self.coil_dim), sensitivity_map),
             centered=self.fft_centered,
             normalization=self.fft_normalization,
             spatial_dims=self.spatial_dims,
         ).type(image.type()),
     )
Beispiel #6
0
    def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
        """
        Expand the sensitivity maps to the same size as the input.

        Parameters
        ----------
        x: Input data.
        sens_maps: Coil Sensitivity maps.

        Returns
        -------
        SENSE reconstruction expanded to the same size as the input sens_maps.
        """
        return fft2(
            complex_mul(x, sens_maps),
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
Beispiel #7
0
    def forward(self, x, y, smaps, mask):
        """
        Forward pass of the data-consistency block.

        Parameters
        ----------
        x: Input image.
        y: Subsampled k-space data.
        smaps: Coil sensitivity maps.
        mask: Sampling mask.

        Returns
        -------
        Output image.
        """
        A_x = torch.sum(
            fft2(
                complex_mul(x.unsqueeze(-5).expand_as(smaps), smaps),
                centered=self.fft_centered,
                normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
            ),
            -4,
            keepdim=True,
        )
        k_dc = (1 - mask) * A_x + mask * (self.alpha * A_x +
                                          (1 - self.alpha) * y)
        x_dc = torch.sum(
            complex_mul(
                ifft2(
                    k_dc,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ),
                complex_conj(smaps),
            ),
            dim=(-5),
        )
        return self.beta * x + (1 - self.beta) * x_dc
Beispiel #8
0
def test_non_centered_fft2(shape):
    """
    Test non-centered 2D Fast Fourier Transform.

    Args:
        shape: shape of the input

    Returns:
        None
    """
    shape = shape + [2]
    x = create_input(shape)
    out_torch = fft2(x,
                     centered=False,
                     normalization="ortho",
                     spatial_dims=[-2, -1]).numpy()
    out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

    input_numpy = tensor_to_complex_np(x)
    out_numpy = np.fft.fft2(input_numpy, norm="ortho")

    if not np.allclose(out_torch, out_numpy):
        raise AssertionError
Beispiel #9
0
    def sens_expand(self, x: torch.Tensor,
                    sens_maps: torch.Tensor) -> torch.Tensor:
        """
        Expand the sensitivity maps to the same size as the input.

        Parameters
        ----------
        x: Input data.
            torch.Tensor, shape [batch_size, n_coils, height, width, 2]
        sens_maps: Sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, height, width, 2]

        Returns
        -------
        SENSE reconstruction expanded to the same size as the input.
            torch.Tensor, shape [batch_size, n_coils, height, width, 2]
        """
        return fft2(
            complex_mul(x, sens_maps),
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
Beispiel #10
0
    def forward(self, x, y, smaps, mask):
        """

        Parameters
        ----------
        x: Input image.
        y: Subsampled k-space data.
        smaps: Coil sensitivity maps.
        mask: Sampling mask.

        Returns
        -------
        data_loss: Data term loss.
        """
        A_x_y = (torch.sum(
            fft2(
                complex_mul(x.unsqueeze(-5).expand_as(smaps), smaps),
                centered=self.fft_centered,
                normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
            ) * mask,
            -4,
            keepdim=True,
        ) - y)
        gradD_x = torch.sum(
            complex_mul(
                ifft2c(
                    A_x_y * mask,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ),
                complex_conj(smaps),
            ),
            dim=(-5),
        )
        return x - self.data_weight * gradD_x
Beispiel #11
0
    def forward(self, x, y, mask):
        """
        Forward pass of the data-consistency block.

        Parameters
        ----------
        x: Input image.
        y: Subsampled k-space data.
        mask: Sampling mask.

        Returns
        -------
        Output image.
        """
        A_x = fft2(x,
                   centered=self.fft_centered,
                   normalization=self.fft_normalization,
                   spatial_dims=self.spatial_dims)
        k_dc = (1 - mask) * A_x + mask * (self.lambda_ * A_x +
                                          (1 - self.lambda_) * y)
        return ifft2(k_dc,
                     centered=self.fft_centered,
                     normalization=self.fft_normalization,
                     spatial_dims=self.spatial_dims)
Beispiel #12
0
    def forward(
        self,
        y: torch.Tensor,
        sensitivity_maps: torch.Tensor,
        mask: torch.Tensor,
        init_pred: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        """
        Forward pass of the network.

        Parameters
        ----------
        y: Subsampled k-space data.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        sensitivity_maps: Coil sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        mask: Sampling mask.
            torch.Tensor, shape [1, 1, n_x, n_y, 1]
        init_pred: Initial prediction.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]
        target: Target data to compute the loss.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]

        Returns
        -------
        pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or  torch.Tensor, shape [batch_size, n_x, n_y, 2]
             If self.accumulate_loss is True, returns a list of all intermediate estimates.
             If False, returns the final estimate.
        """
        kspace = y.clone()
        zero = torch.zeros(1, 1, 1, 1, 1).to(kspace)

        for idx in range(self.num_iter):
            soft_dc = torch.where(mask.bool(), kspace - y, zero) * self.dc_weight

            kspace = self.kspace_model_list[idx](kspace)
            if kspace.shape[-1] != 2:
                kspace = kspace.permute(0, 1, 3, 4, 2).to(target)
                kspace = torch.view_as_real(kspace[..., 0] + 1j * kspace[..., 1])  # this is necessary, but why?

            image = complex_mul(
                ifft2(
                    kspace,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ),
                complex_conj(sensitivity_maps),
            ).sum(self.coil_dim)
            image = self.image_model_list[idx](image.unsqueeze(self.coil_dim)).squeeze(self.coil_dim)

            if not self.no_dc:
                image = fft2(
                    complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps),
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ).type(image.type())
                image = kspace - soft_dc - image
                image = complex_mul(
                    ifft2(
                        image,
                        centered=self.fft_centered,
                        normalization=self.fft_normalization,
                        spatial_dims=self.spatial_dims,
                    ),
                    complex_conj(sensitivity_maps),
                ).sum(self.coil_dim)

            if idx < self.num_iter - 1:
                kspace = fft2(
                    complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps),
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ).type(image.type())

        image = torch.view_as_complex(image)
        _, image = center_crop_to_smallest(target, image)
        return image
Beispiel #13
0
    def forward(
        self,
        current_kspace: torch.Tensor,
        masked_kspace: torch.Tensor,
        sampling_mask: torch.Tensor,
        sensitivity_map: torch.Tensor,
        hidden_state: Union[None, torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes forward pass of RecurrentVarNetBlock.

        Parameters
        ----------
        current_kspace: Current k-space prediction.
            torch.Tensor, shape [batch_size, n_coil, height, width, 2]
        masked_kspace: Subsampled k-space.
            torch.Tensor, shape [batch_size, n_coil, height, width, 2]
        sampling_mask: Sampling mask.
            torch.Tensor, shape [batch_size, 1, height, width, 1]
        sensitivity_map: Coil sensitivities.
            torch.Tensor, shape [batch_size, n_coil, height, width, 2]
        hidden_state: ConvGRU hidden state.
            None or torch.Tensor, shape [batch_size, n_l, height, width, hidden_channels]

        Returns
        -------
        new_kspace: New k-space prediction.
            torch.Tensor, shape [batch_size, n_coil, height, width, 2]
        hidden_state: Next hidden state.
            list of torch.Tensor, shape [batch_size, hidden_channels, height, width, num_layers]
        """
        kspace_error = torch.where(
            sampling_mask == 0,
            torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device),
            current_kspace - masked_kspace,
        )

        recurrent_term = torch.cat(
            [
                complex_mul(
                    ifft2(
                        kspace,
                        centered=self.fft_centered,
                        normalization=self.fft_normalization,
                        spatial_dims=self.spatial_dims,
                    ),
                    complex_conj(sensitivity_map),
                ).sum(self.coil_dim)
                for kspace in torch.split(current_kspace, 2, -1)
            ],
            dim=-1,
        ).permute(0, 3, 1, 2)

        recurrent_term, hidden_state = self.regularizer(recurrent_term, hidden_state)  # :math:`w_t`, :math:`h_{t+1}`
        recurrent_term = recurrent_term.permute(0, 2, 3, 1)

        recurrent_term = torch.cat(
            [
                fft2(
                    complex_mul(image.unsqueeze(self.coil_dim), sensitivity_map),
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                )
                for image in torch.split(recurrent_term, 2, -1)
            ],
            dim=-1,
        )

        new_kspace = current_kspace - self.learning_rate * kspace_error + recurrent_term

        return new_kspace, hidden_state
Beispiel #14
0
    def update_X(self, idx, image, sensitivity_maps, y, mask):
        """
        Update the image.

        .. math::
            x_{k} = (1 - 2 * \lamdba_{{k}_{I}} * mi_{k} - 2 * \lamdba_{{k}_{F}} * mi_{k}) * x_{k}

            x_{k} = 2 * mi_{k} * (\lambda_{{k}_{I}} * D_I(x_{k}) + \lambda_{{k}_{F}} * F^-1(D_F(f)))

            A(x{k} - b) = M * F * (C * x{k}) - b

            x_{k} = 2 * mi_{k} * A^* * (A(x{k} - b))

        Parameters
        ----------
        idx: int
            The current iteration index.
        image: torch.Tensor [batch_size, num_coils, num_rows, num_cols]
            The predicted image.
        sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols]
            The coil sensitivity maps.
        y: torch.Tensor [batch_size, num_coils, num_rows, num_cols]
            The subsampled k-space data.
        mask: torch.Tensor [batch_size, 1, num_rows, num_cols]
            The subsampled mask.

        Returns
        -------
        image: torch.Tensor [batch_size, num_coils, num_rows, num_cols]
            The updated image.
        """
        # (1 - 2 * lamdba_{k}_{I} * mi_{k} - 2 * lamdba_{k}_{F} * mi_{k}) * x_{k}
        image_term_1 = (
            1 - 2 * self.reg_param_I[idx] * self.lr_image[idx] - 2 * self.reg_param_F[idx] * self.lr_image[idx]
        ) * image
        # D_I(x_{k})
        image_term_2_DI = self.image_model(image.unsqueeze(self.coil_dim)).squeeze(self.coil_dim).contiguous()
        image_term_2_DF = ifft2(
            self.kspace_model(
                fft2(
                    image,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ).unsqueeze(self.coil_dim)
            )
            .squeeze(self.coil_dim)
            .contiguous(),
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
        # 2 * mi_{k} * (lambda_{k}_{I} * D_I(x_{k}) + lambda_{k}_{F} * F^-1(D_F(f)))
        image_term_2 = (
            2
            * self.lr_image[idx]
            * (self.reg_param_I[idx] * image_term_2_DI + self.reg_param_F[idx] * image_term_2_DF)
        )
        # A(x{k}) - b) = M * F * (C * x{k}) - b
        image_term_3_A = fft2(
            complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps),
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
        image_term_3_A = torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), image_term_3_A) - y
        # 2 * mi_{k} * A^* * (A(x{k}) - b))
        image_term_3_Aconj = complex_mul(
            ifft2(
                image_term_3_A,
                centered=self.fft_centered,
                normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
            ),
            complex_conj(sensitivity_maps),
        ).sum(self.coil_dim)
        image_term_3 = 2 * self.lr_image[idx] * image_term_3_Aconj
        image = image_term_1 + image_term_2 - image_term_3
        return image
Beispiel #15
0
    def forward(
        self,
        y: torch.Tensor,
        sensitivity_maps: torch.Tensor,
        mask: torch.Tensor,
        init_pred: torch.Tensor,
        target: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        """
        Forward pass of the network.

        Parameters
        ----------
        y: Subsampled k-space data.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        sensitivity_maps: Coil sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        mask: Sampling mask.
            torch.Tensor, shape [1, 1, n_x, n_y, 1]
        init_pred: Initial prediction.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]
        target: Target data to compute the loss.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]

        Returns
        -------
        pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or  torch.Tensor, shape [batch_size, n_x, n_y, 2]
             If self.accumulate_loss is True, returns a list of all intermediate estimates.
             If False, returns the final estimate.
        """
        previous_state: Optional[torch.Tensor] = None

        if self.initializer is not None:
            if self.initializer_initialization == "sense":
                initializer_input_image = (complex_mul(
                    ifft2(
                        y,
                        centered=self.fft_centered,
                        normalization=self.fft_normalization,
                        spatial_dims=self.spatial_dims,
                    ),
                    complex_conj(sensitivity_maps),
                ).sum(self.coil_dim).unsqueeze(self.coil_dim))
            elif self.initializer_initialization == "input_image":
                if "initial_image" not in kwargs:
                    raise ValueError(
                        "`'initial_image` is required as input if initializer_initialization "
                        f"is {self.initializer_initialization}.")
                initializer_input_image = kwargs["initial_image"].unsqueeze(
                    self.coil_dim)
            elif self.initializer_initialization == "zero_filled":
                initializer_input_image = ifft2(
                    y,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                )

            previous_state = self.initializer(
                fft2(
                    initializer_input_image,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ).sum(1).permute(0, 3, 1, 2))

        kspace_prediction = y.clone()

        for step in range(self.num_steps):
            block = self.block_list[
                step] if self.no_parameter_sharing else self.block_list[0]
            kspace_prediction, previous_state = block(
                kspace_prediction,
                y,
                mask,
                sensitivity_maps,
                previous_state,
            )

        eta = ifft2(
            kspace_prediction,
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
        eta = coil_combination(eta,
                               sensitivity_maps,
                               method=self.coil_combination_method,
                               dim=self.coil_dim)
        eta = torch.view_as_complex(eta)
        _, eta = center_crop_to_smallest(target, eta)
        return eta
Beispiel #16
0
    def forward(
        self,
        y: torch.Tensor,
        sensitivity_maps: torch.Tensor,
        mask: torch.Tensor,
        init_pred: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        """
        Forward pass of the network.

        Parameters
        ----------
        y: Subsampled k-space data.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        sensitivity_maps: Coil sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        mask: Sampling mask.
            torch.Tensor, shape [1, 1, n_x, n_y, 1]
        init_pred: Initial prediction.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]
        target: Target data to compute the loss.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]

        Returns
        -------
        pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or  torch.Tensor, shape [batch_size, n_x, n_y, 2]
             If self.accumulate_loss is True, returns a list of all intermediate estimates.
             If False, returns the final estimate.
        """
        input_image = complex_mul(
            ifft2(
                torch.where(mask == 0,
                            torch.tensor([0.0], dtype=y.dtype).to(y.device),
                            y),
                centered=self.fft_centered,
                normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
            ),
            complex_conj(sensitivity_maps),
        ).sum(self.coil_dim)
        dual_buffer = torch.cat([y] * self.num_dual, -1).to(y.device)
        primal_buffer = torch.cat([input_image] * self.num_primal,
                                  -1).to(y.device)

        for idx in range(self.num_iter):
            # Dual
            f_2 = primal_buffer[..., 2:4].clone()
            f_2 = torch.where(
                mask == 0,
                torch.tensor([0.0], dtype=f_2.dtype).to(f_2.device),
                fft2(
                    complex_mul(f_2.unsqueeze(self.coil_dim),
                                sensitivity_maps),
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ).type(f_2.type()),
            )
            dual_buffer = self.dual_net[idx](dual_buffer, f_2, y)

            # Primal
            h_1 = dual_buffer[..., 0:2].clone()
            h_1 = complex_mul(
                ifft2(
                    torch.where(
                        mask == 0,
                        torch.tensor([0.0], dtype=h_1.dtype).to(h_1.device),
                        h_1),
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ),
                complex_conj(sensitivity_maps),
            ).sum(self.coil_dim)
            primal_buffer = self.primal_net[idx](primal_buffer, h_1)

        output = primal_buffer[..., 0:2]
        output = (output**2).sum(-1).sqrt()
        _, output = center_crop_to_smallest(target, output)
        return output
Beispiel #17
0
    def update_C(self, idx, DC_sens, sensitivity_maps, image, y, mask) -> torch.Tensor:
        """
        Update the coil sensitivity maps.

        .. math::
            C = (1 - 2 * \lambda_{k}^{C} * ni_{k}) * C_{k}

            C = 2 * \lambda_{k}^{C} * ni_{k} * D_{C}(F^-1(b))

            A(x_{k}) = M * F * (C * x_{k})

            C = 2 * ni_{k} * F^-1(M.T * (M * F * (C * x_{k}) - b)) * x_{k}^*

        Parameters
        ----------
        idx: int
            The current iteration index.
        DC_sens: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols]
            The initial coil sensitivity maps.
        sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols]
            The coil sensitivity maps.
        image: torch.Tensor [batch_size, num_coils, num_rows, num_cols]
            The predicted image.
        y: torch.Tensor [batch_size, num_coils, num_rows, num_cols]
            The subsampled k-space data.
        mask: torch.Tensor [batch_size, 1, num_rows, num_cols]
            The subsampled mask.

        Returns
        -------
        sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols]
            The updated coil sensitivity maps.
        """
        # (1 - 2 * lambda_{k}^{C} * ni_{k}) * C_{k}
        sense_term_1 = (1 - 2 * self.reg_param_C[idx] * self.lr_sens[idx]) * sensitivity_maps
        # 2 * lambda_{k}^{C} * ni_{k} * D_{C}(F^-1(b))
        sense_term_2 = 2 * self.reg_param_C[idx] * self.lr_sens[idx] * DC_sens
        # A(x_{k}) = M * F * (C * x_{k})
        sense_term_3_A = fft2(
            complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps),
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
        sense_term_3_A = torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), sense_term_3_A)
        # 2 * ni_{k} * F^-1(M.T * (M * F * (C * x_{k}) - b)) * x_{k}^*
        sense_term_3_mask = torch.where(
            mask == 1,
            torch.tensor([0.0], dtype=y.dtype).to(y.device),
            sense_term_3_A - y,
        )

        sense_term_3_backward = ifft2(
            sense_term_3_mask,
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
        sense_term_3 = 2 * self.lr_sens[idx] * sense_term_3_backward * complex_conj(image).unsqueeze(self.coil_dim)
        sensitivity_maps = sense_term_1 + sense_term_2 - sense_term_3
        return sensitivity_maps
Beispiel #18
0
    def forward(
        self,
        pred: torch.Tensor,
        masked_kspace: torch.Tensor,
        sense: torch.Tensor,
        mask: torch.Tensor,
        eta: torch.Tensor = None,
        hx: torch.Tensor = None,
        sigma: float = 1.0,
        keep_eta: bool = False,
    ) -> Tuple[Any, Union[list, torch.Tensor, None]]:
        """
        Forward pass of the RIMBlock.

        Parameters
        ----------
        pred: Predicted k-space.
        masked_kspace: Subsampled k-space.
        sense: Coil sensitivity maps.
        mask: Sample mask.
        eta: Initial guess for the eta.
        hx: Initial guess for the hidden state.
        sigma: Noise level.
        keep_eta: Whether to keep the eta.

        Returns
        -------
        Reconstructed image and hidden states.
        """
        if hx is None:
            hx = [
                masked_kspace.new_zeros(
                    (masked_kspace.size(0), f, *masked_kspace.size()[2:-1]))
                for f in self.recurrent_filters if f != 0
            ]

        if isinstance(pred, list):
            pred = pred[-1].detach()

        if eta is None or eta.ndim < 3:
            eta = (pred if keep_eta else torch.sum(
                complex_mul(
                    ifft2(
                        pred,
                        centered=self.fft_centered,
                        normalization=self.fft_normalization,
                        spatial_dims=self.spatial_dims,
                    ),
                    complex_conj(sense),
                ),
                self.coil_dim,
            ))

        etas = []
        for _ in range(self.time_steps):
            grad_eta = log_likelihood_gradient(
                eta,
                masked_kspace,
                sense,
                mask,
                sigma=sigma,
                fft_centered=self.fft_centered,
                fft_normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
                coil_dim=self.coil_dim,
            ).contiguous()

            for h, convrnn in enumerate(self.layers):
                hx[h] = convrnn(grad_eta, hx[h])
                grad_eta = hx[h]

            eta = eta + self.final_layer(grad_eta).permute(0, 2, 3, 1)
            etas.append(eta)

        eta = etas

        if self.no_dc:
            return eta, None

        soft_dc = torch.where(mask, pred - masked_kspace,
                              self.zero.to(masked_kspace)) * self.dc_weight
        current_kspace = [
            masked_kspace - soft_dc - fft2(
                complex_mul(e.unsqueeze(self.coil_dim), sense),
                centered=self.fft_centered,
                normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
            ) for e in eta
        ]

        return current_kspace, None
Beispiel #19
0
    def __call__(
        self,
        kspace: np.ndarray,
        sensitivity_map: np.ndarray,
        mask: np.ndarray,
        eta: np.ndarray,
        target: np.ndarray,
        attrs: Dict,
        fname: str,
        slice_idx: int,
    ) -> Tuple[torch.Tensor, Union[Union[List, torch.Tensor], torch.Tensor],
               Union[Optional[torch.Tensor], Any], Union[List, Any], Union[
                   Optional[torch.Tensor], Any], Union[torch.Tensor, Any], str,
               int, Union[Union[List, torch.Tensor], Any], ]:
        """
        Apply the data transform.

        Parameters
        ----------
        kspace: The kspace.
        sensitivity_map: The sensitivity map.
        mask: The mask.
        eta: The initial estimation.
        target: The target.
        attrs: The attributes.
        fname: The file name.
        slice_idx: The slice number.

        Returns
        -------
        The transformed data.
        """
        kspace = to_tensor(kspace)

        # This condition is necessary in case of auto estimation of sense maps.
        if sensitivity_map is not None and sensitivity_map.size != 0:
            sensitivity_map = to_tensor(sensitivity_map)

        # Apply zero-filling on kspace
        if self.kspace_zero_filling_size is not None and self.kspace_zero_filling_size not in (
                "", "None"):
            padding_top = np.floor_divide(
                abs(int(self.kspace_zero_filling_size[0]) - kspace.shape[1]),
                2)
            padding_bottom = padding_top
            padding_left = np.floor_divide(
                abs(int(self.kspace_zero_filling_size[1]) - kspace.shape[2]),
                2)
            padding_right = padding_left

            kspace = torch.view_as_complex(kspace)
            kspace = torch.nn.functional.pad(kspace,
                                             pad=(padding_left, padding_right,
                                                  padding_top, padding_bottom),
                                             mode="constant",
                                             value=0)
            kspace = torch.view_as_real(kspace)

            sensitivity_map = fft2(
                sensitivity_map,
                centered=self.fft_centered,
                normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
            )
            sensitivity_map = torch.view_as_complex(sensitivity_map)
            sensitivity_map = torch.nn.functional.pad(
                sensitivity_map,
                pad=(padding_left, padding_right, padding_top, padding_bottom),
                mode="constant",
                value=0,
            )
            sensitivity_map = torch.view_as_real(sensitivity_map)
            sensitivity_map = ifft2(
                sensitivity_map,
                centered=self.fft_centered,
                normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
            )

        # Initial estimation
        eta = to_tensor(
            eta) if eta is not None and eta.size != 0 else torch.tensor([])

        # If the target is not given, we need to compute it.
        if self.coil_combination_method.upper() == "RSS":
            target = rss(
                ifft2(
                    kspace,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ),
                dim=self.coil_dim,
            )
        elif self.coil_combination_method.upper() == "SENSE":
            if sensitivity_map is not None and sensitivity_map.size != 0:
                target = sense(
                    ifft2(
                        kspace,
                        centered=self.fft_centered,
                        normalization=self.fft_normalization,
                        spatial_dims=self.spatial_dims,
                    ),
                    sensitivity_map,
                    dim=self.coil_dim,
                )
        elif target is not None and target.size != 0:
            target = to_tensor(target)
        elif "target" in attrs or "target_rss" in attrs:
            target = torch.tensor(attrs["target"])
        else:
            raise ValueError("No target found")

        target = torch.view_as_complex(target)
        target = torch.abs(target / torch.max(torch.abs(target)))

        seed = tuple(map(ord, fname)) if self.use_seed else None
        acq_start = attrs["padding_left"] if "padding_left" in attrs else 0
        acq_end = attrs["padding_right"] if "padding_left" in attrs else 0

        # This should be outside the condition because it needs to be returned in the end, even if cropping is off.
        # crop_size = torch.tensor([attrs["recon_size"][0], attrs["recon_size"][1]])
        crop_size = target.shape
        if self.crop_size is not None and self.crop_size not in ("", "None"):
            # Check for smallest size against the target shape.
            h = min(int(self.crop_size[0]), target.shape[0])
            w = min(int(self.crop_size[1]), target.shape[1])

            # Check for smallest size against the stored recon shape in metadata.
            if crop_size[0] != 0:
                h = h if h <= crop_size[0] else crop_size[0]
            if crop_size[1] != 0:
                w = w if w <= crop_size[1] else crop_size[1]

            self.crop_size = (int(h), int(w))

            target = center_crop(target, self.crop_size)
            if sensitivity_map is not None and sensitivity_map.size != 0:
                sensitivity_map = (ifft2(
                    complex_center_crop(
                        fft2(
                            sensitivity_map,
                            centered=self.fft_centered,
                            normalization=self.fft_normalization,
                            spatial_dims=self.spatial_dims,
                        ),
                        self.crop_size,
                    ),
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ) if self.kspace_crop else complex_center_crop(
                    sensitivity_map, self.crop_size))

            if eta is not None and eta.ndim > 2:
                eta = (ifft2(
                    complex_center_crop(
                        fft2(
                            eta,
                            centered=self.fft_centered,
                            normalization=self.fft_normalization,
                            spatial_dims=self.spatial_dims,
                        ),
                        self.crop_size,
                    ),
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ) if self.kspace_crop else complex_center_crop(
                    eta, self.crop_size))

        # Cropping before masking will maintain the shape of original kspace intact for masking.
        if self.crop_size is not None and self.crop_size not in (
                "", "None") and self.crop_before_masking:
            kspace = (complex_center_crop(kspace, self.crop_size)
                      if self.kspace_crop else fft2(
                          complex_center_crop(
                              ifft2(
                                  kspace,
                                  centered=self.fft_centered,
                                  normalization=self.fft_normalization,
                                  spatial_dims=self.spatial_dims,
                              ),
                              self.crop_size,
                          ),
                          centered=self.fft_centered,
                          normalization=self.fft_normalization,
                          spatial_dims=self.spatial_dims,
                      ))

        # Undersample kspace if undersampling is enabled.
        if self.mask_func is None:
            masked_kspace = kspace
            acc = torch.tensor([np.around(mask.size / mask.sum())
                                ]) if mask is not None else torch.tensor([1])

            if mask is None:
                mask = torch.ones(
                    [masked_kspace.shape[-3], masked_kspace.shape[-2]],
                    dtype=torch.float32  # type: ignore
                )
            else:
                mask = torch.from_numpy(mask)
                if mask.shape[0] == masked_kspace.shape[2]:  # type: ignore
                    mask = mask.permute(1, 0)
                elif mask.shape[0] != masked_kspace.shape[1]:  # type: ignore
                    mask = torch.ones(
                        [masked_kspace.shape[-3], masked_kspace.shape[-2]],
                        dtype=torch.float32  # type: ignore
                    )

            if mask.ndim == 1:
                mask = np.expand_dims(mask, axis=0)

            if mask.shape[-2] == 1:  # 1D mask
                mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(-1)
            else:  # 2D mask
                # Crop loaded mask.
                if self.crop_size is not None and self.crop_size not in (
                        "", "None"):
                    mask = center_crop(mask, self.crop_size)

                mask = mask.unsqueeze(0).unsqueeze(-1)

            if self.shift_mask:
                mask = torch.fft.fftshift(mask, dim=[-3, -2])

            masked_kspace = masked_kspace * mask
            mask = mask.byte()
        elif isinstance(self.mask_func, list):
            masked_kspaces = []
            masks = []
            accs = []
            for m in self.mask_func:
                _masked_kspace, _mask, _acc = apply_mask(
                    kspace,
                    m,
                    seed,
                    (acq_start, acq_end),
                    shift=self.shift_mask,
                    half_scan_percentage=self.half_scan_percentage,
                    center_scale=self.mask_center_scale,
                )
                masked_kspaces.append(_masked_kspace)
                masks.append(_mask.byte())
                accs.append(_acc)
            masked_kspace = masked_kspaces
            mask = masks
            acc = accs
        else:
            masked_kspace, mask, acc = apply_mask(
                kspace,
                self.mask_func[0],  # type: ignore
                seed,
                (acq_start, acq_end),
                shift=self.shift_mask,
                half_scan_percentage=self.half_scan_percentage,
                center_scale=self.mask_center_scale,
            )
            mask = mask.byte()

        # Cropping after masking.
        if self.crop_size is not None and self.crop_size not in (
                "", "None") and not self.crop_before_masking:
            masked_kspace = (complex_center_crop(masked_kspace, self.crop_size)
                             if self.kspace_crop else fft2(
                                 complex_center_crop(
                                     ifft2(
                                         masked_kspace,
                                         centered=self.fft_centered,
                                         normalization=self.fft_normalization,
                                         spatial_dims=self.spatial_dims,
                                     ),
                                     self.crop_size,
                                 ),
                                 centered=self.fft_centered,
                                 normalization=self.fft_normalization,
                                 spatial_dims=self.spatial_dims,
                             ))

            mask = center_crop(mask.squeeze(-1), self.crop_size).unsqueeze(-1)

        # Normalize by the max value.
        if self.normalize_inputs:
            if isinstance(self.mask_func, list):
                masked_kspaces = []
                for y in masked_kspace:
                    if self.fft_normalization in ("orthogonal",
                                                  "orthogonal_norm_only",
                                                  "ortho"):
                        imspace = ifft2(
                            y,
                            centered=self.fft_centered,
                            normalization=self.fft_normalization,
                            spatial_dims=self.spatial_dims,
                        )
                        imspace = imspace / torch.max(torch.abs(imspace))
                        masked_kspaces.append(
                            fft2(
                                imspace,
                                centered=self.fft_centered,
                                normalization=self.fft_normalization,
                                spatial_dims=self.spatial_dims,
                            ))
                    elif self.fft_normalization == "fft_norm":
                        imspace = ifft2(
                            y,
                            centered=self.fft_centered,
                            normalization=self.fft_normalization,
                            spatial_dims=self.spatial_dims,
                        )
                        masked_kspaces.append(
                            fft2(
                                imspace,
                                centered=self.fft_centered,
                                normalization=self.fft_normalization,
                                spatial_dims=self.spatial_dims,
                            ))
                    elif self.fft_normalization == "backward":
                        imspace = ifft2(y,
                                        centered=self.fft_centered,
                                        normalization="backward",
                                        spatial_dims=self.spatial_dims)
                        masked_kspaces.append(
                            fft2(
                                imspace,
                                centered=self.fft_centered,
                                normalization="backward",
                                spatial_dims=self.spatial_dims,
                            ))
                    else:
                        imspace = torch.fft.ifftn(torch.view_as_complex(y),
                                                  dim=[-2, -1],
                                                  norm=None)
                        imspace = imspace / torch.max(torch.abs(imspace))
                        masked_kspaces.append(
                            torch.view_as_real(
                                torch.fft.fftn(imspace,
                                               dim=[-2, -1],
                                               norm=None)))
                masked_kspace = masked_kspaces
            elif self.fft_normalization in ("orthogonal",
                                            "orthogonal_norm_only", "ortho"):
                imspace = ifft2(
                    masked_kspace,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                )
                imspace = imspace / torch.max(torch.abs(imspace))
                masked_kspace = fft2(
                    imspace,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                )
            elif self.fft_normalization == "fft_norm":
                masked_kspace = fft2(
                    ifft2(
                        masked_kspace,
                        centered=self.fft_centered,
                        normalization=self.fft_normalization,
                        spatial_dims=self.spatial_dims,
                    ),
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                )
            elif self.fft_normalization == "backward":
                masked_kspace = fft2(
                    ifft2(
                        masked_kspace,
                        centered=self.fft_centered,
                        normalization="backward",
                        spatial_dims=self.spatial_dims,
                    ),
                    centered=self.fft_centered,
                    normalization="backward",
                    spatial_dims=self.spatial_dims,
                )
            else:
                imspace = torch.fft.ifftn(torch.view_as_complex(masked_kspace),
                                          dim=[-2, -1],
                                          norm=None)
                imspace = imspace / torch.max(torch.abs(imspace))
                masked_kspace = torch.view_as_real(
                    torch.fft.fftn(imspace, dim=[-2, -1], norm=None))

            if sensitivity_map.size != 0:
                sensitivity_map = sensitivity_map / torch.max(
                    torch.abs(sensitivity_map))

            if eta.size != 0 and eta.ndim > 2:
                eta = eta / torch.max(torch.abs(eta))

            target = target / torch.max(torch.abs(target))

        return kspace, masked_kspace, sensitivity_map, mask, eta, target, fname, slice_idx, acc