Example #1
0
 def A(x):
     x = (fft2(
         complex_mul(x.expand_as(smaps), smaps),
         centered=ctx.fft_centered,
         normalization=ctx.fft_normalization,
         spatial_dims=ctx.spatial_dims,
     ) * mask)
     return torch.sum(x, dim=-4, keepdim=True)
Example #2
0
    def forward(
        self,
        y: torch.Tensor,
        sensitivity_maps: torch.Tensor,
        mask: torch.Tensor,
        init_pred: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        """
        Forward pass of the network.

        Parameters
        ----------
        y: Subsampled k-space data.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        sensitivity_maps: Coil sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        mask: Sampling mask.
            torch.Tensor, shape [1, 1, n_x, n_y, 1]
        init_pred: Initial prediction.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]
        target: Target data to compute the loss.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]

        Returns
        -------
        pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or  torch.Tensor, shape [batch_size, n_x, n_y, 2]
             If self.accumulate_loss is True, returns a list of all intermediate estimates.
             If False, returns the final estimate.
        """
        init_pred = torch.sum(
            complex_mul(
                ifft2(y,
                      centered=self.fft_centered,
                      normalization=self.fft_normalization,
                      spatial_dims=self.spatial_dims),
                complex_conj(sensitivity_maps),
            ),
            self.coil_dim,
        )
        image = self.model(init_pred, y, sensitivity_maps, mask)
        image = torch.sum(complex_mul(image, complex_conj(sensitivity_maps)),
                          self.coil_dim)
        image = torch.view_as_complex(image)
        _, image = center_crop_to_smallest(target, image)
        return image
Example #3
0
 def complexDot(data1, data2):
     """Complex dot product of two tensors."""
     nBatch = data1.shape[0]
     mult = complex_mul(data1, complex_conj(data2))
     re, im = torch.unbind(mult, dim=-1)
     return torch.stack([
         torch.sum(re.view(nBatch, -1), dim=-1),
         torch.sum(im.view(nBatch, -1), dim=-1)
     ], -1)
Example #4
0
 def forward(self, coil_images: torch.Tensor,
             sensitivity_map: torch.Tensor) -> torch.Tensor:
     """Forward pass."""
     combined_image = complex_mul(
         coil_images, complex_conj(sensitivity_map)).sum(self.coil_dim)
     residual_image = combined_image.unsqueeze(self.coil_dim) - complex_mul(
         combined_image.unsqueeze(self.coil_dim), sensitivity_map)
     return torch.cat(
         [
             torch.cat(
                 [
                     combined_image,
                     residual_image.select(self.coil_dim, idx)
                 ],
                 self.channel_dim,
             ).unsqueeze(self.coil_dim)
             for idx in range(coil_images.size(self.coil_dim))
         ],
         self.coil_dim,
     )
Example #5
0
 def AT(x):
     return torch.sum(
         complex_mul(
             ifft2(x * mask,
                   centered=fft_centered,
                   normalization=fft_normalization,
                   spatial_dims=spatial_dims),
             complex_conj(smaps),
         ),
         dim=(-5),
     )
Example #6
0
 def _forward_operator(self, image, sampling_mask, sensitivity_map):
     """Forward operator."""
     return torch.where(
         sampling_mask == 0,
         torch.tensor([0.0], dtype=image.dtype).to(image.device),
         fft2(
             complex_mul(image.unsqueeze(self.coil_dim), sensitivity_map),
             centered=self.fft_centered,
             normalization=self.fft_normalization,
             spatial_dims=self.spatial_dims,
         ).type(image.type()),
     )
Example #7
0
    def forward(self, x, y, smaps, mask):
        """
        Forward pass of the data-consistency block.

        Parameters
        ----------
        x: Input image.
        y: Subsampled k-space data.
        smaps: Coil sensitivity maps.
        mask: Sampling mask.

        Returns
        -------
        Output image.
        """
        A_x = torch.sum(
            fft2(
                complex_mul(x.unsqueeze(-5).expand_as(smaps), smaps),
                centered=self.fft_centered,
                normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
            ),
            -4,
            keepdim=True,
        )
        k_dc = (1 - mask) * A_x + mask * (self.alpha * A_x +
                                          (1 - self.alpha) * y)
        x_dc = torch.sum(
            complex_mul(
                ifft2(
                    k_dc,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ),
                complex_conj(smaps),
            ),
            dim=(-5),
        )
        return self.beta * x + (1 - self.beta) * x_dc
Example #8
0
 def _backward_operator(self, kspace, sampling_mask, sensitivity_map):
     """Backward operator."""
     kspace = torch.where(
         sampling_mask == 0,
         torch.tensor([0.0], dtype=kspace.dtype).to(kspace.device), kspace)
     return (complex_mul(
         ifft2(
             kspace.float(),
             centered=self.fft_centered,
             normalization=self.fft_normalization,
             spatial_dims=self.spatial_dims,
         ),
         complex_conj(sensitivity_map),
     ).sum(self.coil_dim).type(kspace.type()))
Example #9
0
    def solve(x0, M, tol, max_iter):
        """Solve the linear system Mx=b using conjugate gradient."""
        nBatch = x0.shape[0]
        x = torch.zeros(x0.shape).to(x0.device)
        r = x0.clone()
        p = x0.clone()
        x0x0 = (x0.pow(2)).view(nBatch, -1).sum(-1)
        rr = torch.stack([(r.pow(2)).view(nBatch, -1).sum(-1),
                          torch.zeros(nBatch).to(x0.device)],
                         dim=-1)

        it = 0
        while torch.min(rr[..., 0] / x0x0) > tol and it < max_iter:
            it += 1
            q = M(p)

            data1 = rr
            data2 = ConjugateGradient.complexDot(p, q)

            re1, im1 = torch.unbind(data1, -1)
            re2, im2 = torch.unbind(data2, -1)
            alpha = torch.stack([re1 * re2 + im1 * im2, im1 * re2 - re1 * im2],
                                -1) / complex_abs(data2)**2

            x += complex_mul(alpha.reshape(nBatch, 1, 1, 1, -1), p.clone())
            r -= complex_mul(alpha.reshape(nBatch, 1, 1, 1, -1), q.clone())
            rr_new = torch.stack([(r.pow(2)).view(nBatch, -1).sum(-1),
                                  torch.zeros(nBatch).to(x0.device)],
                                 dim=-1)
            beta = torch.stack([
                rr_new[..., 0] / rr[..., 0],
                torch.zeros(nBatch).to(x0.device)
            ],
                               dim=-1)
            p = r.clone() + complex_mul(beta.reshape(nBatch, 1, 1, 1, -1), p)
            rr = rr_new.clone()
        return x
Example #10
0
    def forward(self, x, y, smaps, mask):
        """

        Parameters
        ----------
        x: Input image.
        y: Subsampled k-space data.
        smaps: Coil sensitivity maps.
        mask: Sampling mask.

        Returns
        -------
        data_loss: Data term loss.
        """
        A_x_y = (torch.sum(
            fft2(
                complex_mul(x.unsqueeze(-5).expand_as(smaps), smaps),
                centered=self.fft_centered,
                normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
            ) * mask,
            -4,
            keepdim=True,
        ) - y)
        gradD_x = torch.sum(
            complex_mul(
                ifft2c(
                    A_x_y * mask,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ),
                complex_conj(smaps),
            ),
            dim=(-5),
        )
        return x - self.data_weight * gradD_x
Example #11
0
    def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
        """
        Reduce the sensitivity maps.

        Parameters
        ----------
        x: Input data.
        sens_maps: Coil Sensitivity maps.

        Returns
        -------
        SENSE coil-combined reconstruction.
        """
        x = ifft2(x, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims)
        return complex_mul(x, complex_conj(sens_maps)).sum(self.coil_dim)
Example #12
0
    def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
        """
        Expand the sensitivity maps to the same size as the input.

        Parameters
        ----------
        x: Input data.
        sens_maps: Coil Sensitivity maps.

        Returns
        -------
        SENSE reconstruction expanded to the same size as the input sens_maps.
        """
        return fft2(
            complex_mul(x, sens_maps),
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
Example #13
0
    def sens_reduce(self, x: torch.Tensor,
                    sens_maps: torch.Tensor) -> torch.Tensor:
        """
        Reduce the sensitivity maps to the same size as the input.

        Parameters
        ----------
        x: Input data.
            torch.Tensor, shape [batch_size, n_coils, height, width, 2]
        sens_maps: Sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, height, width, 2]

        Returns
        -------
        SENSE reconstruction.
            torch.Tensor, shape [batch_size, height, width, 2]
        """
        x = ifft2(x,
                  centered=self.fft_centered,
                  normalization=self.fft_normalization,
                  spatial_dims=self.spatial_dims)
        return complex_mul(x, complex_conj(sens_maps)).sum(dim=self.coil_dim,
                                                           keepdim=True)
Example #14
0
    def sens_expand(self, x: torch.Tensor,
                    sens_maps: torch.Tensor) -> torch.Tensor:
        """
        Expand the sensitivity maps to the same size as the input.

        Parameters
        ----------
        x: Input data.
            torch.Tensor, shape [batch_size, n_coils, height, width, 2]
        sens_maps: Sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, height, width, 2]

        Returns
        -------
        SENSE reconstruction expanded to the same size as the input.
            torch.Tensor, shape [batch_size, n_coils, height, width, 2]
        """
        return fft2(
            complex_mul(x, sens_maps),
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
Example #15
0
    def update_C(self, idx, DC_sens, sensitivity_maps, image, y, mask) -> torch.Tensor:
        """
        Update the coil sensitivity maps.

        .. math::
            C = (1 - 2 * \lambda_{k}^{C} * ni_{k}) * C_{k}

            C = 2 * \lambda_{k}^{C} * ni_{k} * D_{C}(F^-1(b))

            A(x_{k}) = M * F * (C * x_{k})

            C = 2 * ni_{k} * F^-1(M.T * (M * F * (C * x_{k}) - b)) * x_{k}^*

        Parameters
        ----------
        idx: int
            The current iteration index.
        DC_sens: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols]
            The initial coil sensitivity maps.
        sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols]
            The coil sensitivity maps.
        image: torch.Tensor [batch_size, num_coils, num_rows, num_cols]
            The predicted image.
        y: torch.Tensor [batch_size, num_coils, num_rows, num_cols]
            The subsampled k-space data.
        mask: torch.Tensor [batch_size, 1, num_rows, num_cols]
            The subsampled mask.

        Returns
        -------
        sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols]
            The updated coil sensitivity maps.
        """
        # (1 - 2 * lambda_{k}^{C} * ni_{k}) * C_{k}
        sense_term_1 = (1 - 2 * self.reg_param_C[idx] * self.lr_sens[idx]) * sensitivity_maps
        # 2 * lambda_{k}^{C} * ni_{k} * D_{C}(F^-1(b))
        sense_term_2 = 2 * self.reg_param_C[idx] * self.lr_sens[idx] * DC_sens
        # A(x_{k}) = M * F * (C * x_{k})
        sense_term_3_A = fft2(
            complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps),
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
        sense_term_3_A = torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), sense_term_3_A)
        # 2 * ni_{k} * F^-1(M.T * (M * F * (C * x_{k}) - b)) * x_{k}^*
        sense_term_3_mask = torch.where(
            mask == 1,
            torch.tensor([0.0], dtype=y.dtype).to(y.device),
            sense_term_3_A - y,
        )

        sense_term_3_backward = ifft2(
            sense_term_3_mask,
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
        sense_term_3 = 2 * self.lr_sens[idx] * sense_term_3_backward * complex_conj(image).unsqueeze(self.coil_dim)
        sensitivity_maps = sense_term_1 + sense_term_2 - sense_term_3
        return sensitivity_maps
Example #16
0
File: lpd.py Project: wdika/mridc
    def forward(
        self,
        y: torch.Tensor,
        sensitivity_maps: torch.Tensor,
        mask: torch.Tensor,
        init_pred: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        """
        Forward pass of the network.

        Parameters
        ----------
        y: Subsampled k-space data.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        sensitivity_maps: Coil sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        mask: Sampling mask.
            torch.Tensor, shape [1, 1, n_x, n_y, 1]
        init_pred: Initial prediction.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]
        target: Target data to compute the loss.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]

        Returns
        -------
        pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or  torch.Tensor, shape [batch_size, n_x, n_y, 2]
             If self.accumulate_loss is True, returns a list of all intermediate estimates.
             If False, returns the final estimate.
        """
        input_image = complex_mul(
            ifft2(
                torch.where(mask == 0,
                            torch.tensor([0.0], dtype=y.dtype).to(y.device),
                            y),
                centered=self.fft_centered,
                normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
            ),
            complex_conj(sensitivity_maps),
        ).sum(self.coil_dim)
        dual_buffer = torch.cat([y] * self.num_dual, -1).to(y.device)
        primal_buffer = torch.cat([input_image] * self.num_primal,
                                  -1).to(y.device)

        for idx in range(self.num_iter):
            # Dual
            f_2 = primal_buffer[..., 2:4].clone()
            f_2 = torch.where(
                mask == 0,
                torch.tensor([0.0], dtype=f_2.dtype).to(f_2.device),
                fft2(
                    complex_mul(f_2.unsqueeze(self.coil_dim),
                                sensitivity_maps),
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ).type(f_2.type()),
            )
            dual_buffer = self.dual_net[idx](dual_buffer, f_2, y)

            # Primal
            h_1 = dual_buffer[..., 0:2].clone()
            h_1 = complex_mul(
                ifft2(
                    torch.where(
                        mask == 0,
                        torch.tensor([0.0], dtype=h_1.dtype).to(h_1.device),
                        h_1),
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ),
                complex_conj(sensitivity_maps),
            ).sum(self.coil_dim)
            primal_buffer = self.primal_net[idx](primal_buffer, h_1)

        output = primal_buffer[..., 0:2]
        output = (output**2).sum(-1).sqrt()
        _, output = center_crop_to_smallest(target, output)
        return output
Example #17
0
File: rvn.py Project: wdika/mridc
    def forward(
        self,
        y: torch.Tensor,
        sensitivity_maps: torch.Tensor,
        mask: torch.Tensor,
        init_pred: torch.Tensor,
        target: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        """
        Forward pass of the network.

        Parameters
        ----------
        y: Subsampled k-space data.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        sensitivity_maps: Coil sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        mask: Sampling mask.
            torch.Tensor, shape [1, 1, n_x, n_y, 1]
        init_pred: Initial prediction.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]
        target: Target data to compute the loss.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]

        Returns
        -------
        pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or  torch.Tensor, shape [batch_size, n_x, n_y, 2]
             If self.accumulate_loss is True, returns a list of all intermediate estimates.
             If False, returns the final estimate.
        """
        previous_state: Optional[torch.Tensor] = None

        if self.initializer is not None:
            if self.initializer_initialization == "sense":
                initializer_input_image = (complex_mul(
                    ifft2(
                        y,
                        centered=self.fft_centered,
                        normalization=self.fft_normalization,
                        spatial_dims=self.spatial_dims,
                    ),
                    complex_conj(sensitivity_maps),
                ).sum(self.coil_dim).unsqueeze(self.coil_dim))
            elif self.initializer_initialization == "input_image":
                if "initial_image" not in kwargs:
                    raise ValueError(
                        "`'initial_image` is required as input if initializer_initialization "
                        f"is {self.initializer_initialization}.")
                initializer_input_image = kwargs["initial_image"].unsqueeze(
                    self.coil_dim)
            elif self.initializer_initialization == "zero_filled":
                initializer_input_image = ifft2(
                    y,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                )

            previous_state = self.initializer(
                fft2(
                    initializer_input_image,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ).sum(1).permute(0, 3, 1, 2))

        kspace_prediction = y.clone()

        for step in range(self.num_steps):
            block = self.block_list[
                step] if self.no_parameter_sharing else self.block_list[0]
            kspace_prediction, previous_state = block(
                kspace_prediction,
                y,
                mask,
                sensitivity_maps,
                previous_state,
            )

        eta = ifft2(
            kspace_prediction,
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
        eta = coil_combination(eta,
                               sensitivity_maps,
                               method=self.coil_combination_method,
                               dim=self.coil_dim)
        eta = torch.view_as_complex(eta)
        _, eta = center_crop_to_smallest(target, eta)
        return eta
Example #18
0
    def forward(
        self,
        y: torch.Tensor,
        sensitivity_maps: torch.Tensor,
        mask: torch.Tensor,
        init_pred: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        """
        Forward pass of the network.

        Parameters
        ----------
        y: Subsampled k-space data.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        sensitivity_maps: Coil sensitivity maps.
            torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
        mask: Sampling mask.
            torch.Tensor, shape [1, 1, n_x, n_y, 1]
        init_pred: Initial prediction.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]
        target: Target data to compute the loss.
            torch.Tensor, shape [batch_size, n_x, n_y, 2]

        Returns
        -------
        pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or  torch.Tensor, shape [batch_size, n_x, n_y, 2]
             If self.accumulate_loss is True, returns a list of all intermediate estimates.
             If False, returns the final estimate.
        """
        kspace = y.clone()
        zero = torch.zeros(1, 1, 1, 1, 1).to(kspace)

        for idx in range(self.num_iter):
            soft_dc = torch.where(mask.bool(), kspace - y, zero) * self.dc_weight

            kspace = self.kspace_model_list[idx](kspace)
            if kspace.shape[-1] != 2:
                kspace = kspace.permute(0, 1, 3, 4, 2).to(target)
                kspace = torch.view_as_real(kspace[..., 0] + 1j * kspace[..., 1])  # this is necessary, but why?

            image = complex_mul(
                ifft2(
                    kspace,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ),
                complex_conj(sensitivity_maps),
            ).sum(self.coil_dim)
            image = self.image_model_list[idx](image.unsqueeze(self.coil_dim)).squeeze(self.coil_dim)

            if not self.no_dc:
                image = fft2(
                    complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps),
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ).type(image.type())
                image = kspace - soft_dc - image
                image = complex_mul(
                    ifft2(
                        image,
                        centered=self.fft_centered,
                        normalization=self.fft_normalization,
                        spatial_dims=self.spatial_dims,
                    ),
                    complex_conj(sensitivity_maps),
                ).sum(self.coil_dim)

            if idx < self.num_iter - 1:
                kspace = fft2(
                    complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps),
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ).type(image.type())

        image = torch.view_as_complex(image)
        _, image = center_crop_to_smallest(target, image)
        return image
Example #19
0
    def forward(
        self,
        current_kspace: torch.Tensor,
        masked_kspace: torch.Tensor,
        sampling_mask: torch.Tensor,
        sensitivity_map: torch.Tensor,
        hidden_state: Union[None, torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes forward pass of RecurrentVarNetBlock.

        Parameters
        ----------
        current_kspace: Current k-space prediction.
            torch.Tensor, shape [batch_size, n_coil, height, width, 2]
        masked_kspace: Subsampled k-space.
            torch.Tensor, shape [batch_size, n_coil, height, width, 2]
        sampling_mask: Sampling mask.
            torch.Tensor, shape [batch_size, 1, height, width, 1]
        sensitivity_map: Coil sensitivities.
            torch.Tensor, shape [batch_size, n_coil, height, width, 2]
        hidden_state: ConvGRU hidden state.
            None or torch.Tensor, shape [batch_size, n_l, height, width, hidden_channels]

        Returns
        -------
        new_kspace: New k-space prediction.
            torch.Tensor, shape [batch_size, n_coil, height, width, 2]
        hidden_state: Next hidden state.
            list of torch.Tensor, shape [batch_size, hidden_channels, height, width, num_layers]
        """
        kspace_error = torch.where(
            sampling_mask == 0,
            torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device),
            current_kspace - masked_kspace,
        )

        recurrent_term = torch.cat(
            [
                complex_mul(
                    ifft2(
                        kspace,
                        centered=self.fft_centered,
                        normalization=self.fft_normalization,
                        spatial_dims=self.spatial_dims,
                    ),
                    complex_conj(sensitivity_map),
                ).sum(self.coil_dim)
                for kspace in torch.split(current_kspace, 2, -1)
            ],
            dim=-1,
        ).permute(0, 3, 1, 2)

        recurrent_term, hidden_state = self.regularizer(recurrent_term, hidden_state)  # :math:`w_t`, :math:`h_{t+1}`
        recurrent_term = recurrent_term.permute(0, 2, 3, 1)

        recurrent_term = torch.cat(
            [
                fft2(
                    complex_mul(image.unsqueeze(self.coil_dim), sensitivity_map),
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                )
                for image in torch.split(recurrent_term, 2, -1)
            ],
            dim=-1,
        )

        new_kspace = current_kspace - self.learning_rate * kspace_error + recurrent_term

        return new_kspace, hidden_state
Example #20
0
    def backward(ctx, grad_x):
        """
        Backward pass of the conjugate gradient solver.

        Parameters
        ----------
        ctx: Context object.
        grad_x: Gradient of the output image.

        Returns
        -------
        grad_z: Gradient of the input image.
        """
        ATy, rhs, smaps, mask, lambdaa = ctx.saved_tensors

        def A(x):
            x = (fft2(
                complex_mul(x.expand_as(smaps), smaps),
                centered=ctx.fft_centered,
                normalization=ctx.fft_normalization,
                spatial_dims=ctx.spatial_dims,
            ) * mask)
            return torch.sum(x, dim=-4, keepdim=True)

        def AT(x):
            return torch.sum(
                complex_mul(
                    ifft2(
                        x * mask,
                        centered=ctx.fft_centered,
                        normalization=ctx.fft_normalization,
                        spatial_dims=ctx.spatial_dimso,
                    ),
                    complex_conj(smaps),
                ),
                dim=(-5),
            )

        def M(p):
            return lambdaa * AT(A(p)) + p

        Qe = ConjugateGradient.solve(grad_x, M, ctx.tol, ctx.max_iter)
        QQe = ConjugateGradient.solve(Qe, M, ctx.tol, ctx.max_iter)

        grad_z = Qe

        grad_lambdaa = (complex_mul(
            ifft2(Qe,
                  centered=ctx.fft_centered,
                  normalization=ctx.fft_normalization,
                  spatial_dims=ctx.spatial_dims),
            complex_conj(ATy),
        ).sum() - complex_mul(
            ifft2(QQe,
                  centered=ctx.fft_centered,
                  normalization=ctx.fft_normalization,
                  spatial_dims=ctx.spatial_dims),
            complex_conj(rhs),
        ).sum())

        return grad_z, grad_lambdaa, None, None, None, None, None, None
Example #21
0
    def update_X(self, idx, image, sensitivity_maps, y, mask):
        """
        Update the image.

        .. math::
            x_{k} = (1 - 2 * \lamdba_{{k}_{I}} * mi_{k} - 2 * \lamdba_{{k}_{F}} * mi_{k}) * x_{k}

            x_{k} = 2 * mi_{k} * (\lambda_{{k}_{I}} * D_I(x_{k}) + \lambda_{{k}_{F}} * F^-1(D_F(f)))

            A(x{k} - b) = M * F * (C * x{k}) - b

            x_{k} = 2 * mi_{k} * A^* * (A(x{k} - b))

        Parameters
        ----------
        idx: int
            The current iteration index.
        image: torch.Tensor [batch_size, num_coils, num_rows, num_cols]
            The predicted image.
        sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols]
            The coil sensitivity maps.
        y: torch.Tensor [batch_size, num_coils, num_rows, num_cols]
            The subsampled k-space data.
        mask: torch.Tensor [batch_size, 1, num_rows, num_cols]
            The subsampled mask.

        Returns
        -------
        image: torch.Tensor [batch_size, num_coils, num_rows, num_cols]
            The updated image.
        """
        # (1 - 2 * lamdba_{k}_{I} * mi_{k} - 2 * lamdba_{k}_{F} * mi_{k}) * x_{k}
        image_term_1 = (
            1 - 2 * self.reg_param_I[idx] * self.lr_image[idx] - 2 * self.reg_param_F[idx] * self.lr_image[idx]
        ) * image
        # D_I(x_{k})
        image_term_2_DI = self.image_model(image.unsqueeze(self.coil_dim)).squeeze(self.coil_dim).contiguous()
        image_term_2_DF = ifft2(
            self.kspace_model(
                fft2(
                    image,
                    centered=self.fft_centered,
                    normalization=self.fft_normalization,
                    spatial_dims=self.spatial_dims,
                ).unsqueeze(self.coil_dim)
            )
            .squeeze(self.coil_dim)
            .contiguous(),
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
        # 2 * mi_{k} * (lambda_{k}_{I} * D_I(x_{k}) + lambda_{k}_{F} * F^-1(D_F(f)))
        image_term_2 = (
            2
            * self.lr_image[idx]
            * (self.reg_param_I[idx] * image_term_2_DI + self.reg_param_F[idx] * image_term_2_DF)
        )
        # A(x{k}) - b) = M * F * (C * x{k}) - b
        image_term_3_A = fft2(
            complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps),
            centered=self.fft_centered,
            normalization=self.fft_normalization,
            spatial_dims=self.spatial_dims,
        )
        image_term_3_A = torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), image_term_3_A) - y
        # 2 * mi_{k} * A^* * (A(x{k}) - b))
        image_term_3_Aconj = complex_mul(
            ifft2(
                image_term_3_A,
                centered=self.fft_centered,
                normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
            ),
            complex_conj(sensitivity_maps),
        ).sum(self.coil_dim)
        image_term_3 = 2 * self.lr_image[idx] * image_term_3_Aconj
        image = image_term_1 + image_term_2 - image_term_3
        return image
Example #22
0
    def forward(
        self,
        pred: torch.Tensor,
        masked_kspace: torch.Tensor,
        sense: torch.Tensor,
        mask: torch.Tensor,
        eta: torch.Tensor = None,
        hx: torch.Tensor = None,
        sigma: float = 1.0,
        keep_eta: bool = False,
    ) -> Tuple[Any, Union[list, torch.Tensor, None]]:
        """
        Forward pass of the RIMBlock.

        Parameters
        ----------
        pred: Predicted k-space.
        masked_kspace: Subsampled k-space.
        sense: Coil sensitivity maps.
        mask: Sample mask.
        eta: Initial guess for the eta.
        hx: Initial guess for the hidden state.
        sigma: Noise level.
        keep_eta: Whether to keep the eta.

        Returns
        -------
        Reconstructed image and hidden states.
        """
        if hx is None:
            hx = [
                masked_kspace.new_zeros(
                    (masked_kspace.size(0), f, *masked_kspace.size()[2:-1]))
                for f in self.recurrent_filters if f != 0
            ]

        if isinstance(pred, list):
            pred = pred[-1].detach()

        if eta is None or eta.ndim < 3:
            eta = (pred if keep_eta else torch.sum(
                complex_mul(
                    ifft2(
                        pred,
                        centered=self.fft_centered,
                        normalization=self.fft_normalization,
                        spatial_dims=self.spatial_dims,
                    ),
                    complex_conj(sense),
                ),
                self.coil_dim,
            ))

        etas = []
        for _ in range(self.time_steps):
            grad_eta = log_likelihood_gradient(
                eta,
                masked_kspace,
                sense,
                mask,
                sigma=sigma,
                fft_centered=self.fft_centered,
                fft_normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
                coil_dim=self.coil_dim,
            ).contiguous()

            for h, convrnn in enumerate(self.layers):
                hx[h] = convrnn(grad_eta, hx[h])
                grad_eta = hx[h]

            eta = eta + self.final_layer(grad_eta).permute(0, 2, 3, 1)
            etas.append(eta)

        eta = etas

        if self.no_dc:
            return eta, None

        soft_dc = torch.where(mask, pred - masked_kspace,
                              self.zero.to(masked_kspace)) * self.dc_weight
        current_kspace = [
            masked_kspace - soft_dc - fft2(
                complex_mul(e.unsqueeze(self.coil_dim), sense),
                centered=self.fft_centered,
                normalization=self.fft_normalization,
                spatial_dims=self.spatial_dims,
            ) for e in eta
        ]

        return current_kspace, None
Example #23
0
def evaluate(
    arguments,
    reconstruction_key,
    mask_background,
    output_path,
    method,
    acc,
    no_params,
    slice_start,
    slice_end,
    coil_dim,
):
    """
    Evaluates the reconstructions.

    Parameters
    ----------
    arguments: The CLI arguments.
    reconstruction_key: The key of the reconstruction to evaluate.
    mask_background: The background mask.
    output_path: The output path.
    method: The reconstruction method.
    acc: The acceleration factor.
    no_params: The number of parameters.
    slice_start: The start slice. (optional)
    slice_end: The end slice. (optional)
    coil_dim: The coil dimension. (optional)

    Returns
    -------
    dict: A dict where the keys are metric names and the values are the mean of the metric.
    """
    _metrics = Metrics(METRIC_FUNCS, output_path,
                       method) if arguments.type == "mean_std" else {}

    for tgt_file in tqdm(arguments.target_path.iterdir()):
        if exists(arguments.predictions_path / tgt_file.name):
            with h5py.File(tgt_file, "r") as target, h5py.File(
                    arguments.predictions_path / tgt_file.name, "r") as recons:
                kspace = target["kspace"][()]

                if arguments.sense_path is not None:
                    sense = h5py.File(arguments.sense_path / tgt_file.name,
                                      "r")["sensitivity_map"][()]
                elif "sensitivity_map" in target:
                    sense = target["sensitivity_map"][()]

                sense = sense.squeeze().astype(np.complex64)

                if sense.shape != kspace.shape:
                    sense = np.transpose(sense, (0, 3, 1, 2))

                target = np.abs(
                    tensor_to_complex_np(
                        torch.sum(
                            complex_mul(
                                ifft2(to_tensor(kspace),
                                      centered="fastmri"
                                      in str(arguments.sense_path).lower()),
                                complex_conj(to_tensor(sense)),
                            ),
                            coil_dim,
                        )))

                recons = recons[reconstruction_key][()]

                if recons.ndim == 4:
                    recons = recons.squeeze(coil_dim)

                if arguments.crop_size is not None:
                    crop_size = arguments.crop_size
                    crop_size[0] = min(target.shape[-2], int(crop_size[0]))
                    crop_size[1] = min(target.shape[-1], int(crop_size[1]))
                    crop_size[0] = min(recons.shape[-2], int(crop_size[0]))
                    crop_size[1] = min(recons.shape[-1], int(crop_size[1]))

                    target = center_crop(target, crop_size)
                    recons = center_crop(recons, crop_size)

                if mask_background:
                    for sl in range(target.shape[0]):
                        mask = convex_hull_image(
                            np.where(
                                np.abs(target[sl]) > threshold_otsu(
                                    np.abs(target[sl])), 1, 0)  # type: ignore
                        )
                        target[sl] = target[sl] * mask
                        recons[sl] = recons[sl] * mask

                if slice_start is not None:
                    target = target[slice_start:]
                    recons = recons[slice_start:]

                if slice_end is not None:
                    target = target[:slice_end]
                    recons = recons[:slice_end]

                for sl in range(target.shape[0]):
                    target[sl] = target[sl] / np.max(np.abs(target[sl]))
                    recons[sl] = recons[sl] / np.max(np.abs(recons[sl]))

                target = np.abs(target)
                recons = np.abs(recons)

                if arguments.type == "mean_std":
                    _metrics.push(target, recons)
                else:
                    _target = np.expand_dims(target, coil_dim)
                    _recons = np.expand_dims(recons, coil_dim)
                    for sl in range(target.shape[0]):
                        _metrics["FNAME"] = tgt_file.name
                        _metrics["SLICE"] = sl
                        _metrics["ACC"] = acc
                        _metrics["METHOD"] = method
                        _metrics["MSE"] = [mse(target[sl], recons[sl])]
                        _metrics["NMSE"] = [nmse(target[sl], recons[sl])]
                        _metrics["PSNR"] = [psnr(target[sl], recons[sl])]
                        _metrics["SSIM"] = [ssim(_target[sl], _recons[sl])]
                        _metrics["PARAMS"] = no_params

                        if not exists(arguments.output_path):
                            pd.DataFrame(columns=_metrics.keys()).to_csv(
                                arguments.output_path, index=False, mode="w")
                        pd.DataFrame(_metrics).to_csv(arguments.output_path,
                                                      index=False,
                                                      header=False,
                                                      mode="a")

    return _metrics