Exemplo n.º 1
0
    def tikh(self, rhs, kernel, rho):
        """ Tikhonov regularized inversion.

        Solves the normal equation

            (F*F + rho W*W) x = F*y

        or more generally

            (F*F + rho W*W) x = z

        for a Tikhonov regularized least squares fit, assuming that the
        regularization W*W can be diagonalied by FFTs, i.e.

            W*W = F*D*F

        for some diagonal matrix D.

        Parameters
        ----------
        rhs : torch.Tensor
            The right hand side tensor z, often F*y for some y.
        kernel : torch.Tensor
            The Fourier kernel of W, containing the diagonal elements D.
        rho : float
            The regularization parameter.

        """
        assert rhs.ndim >= 3 and rhs.shape[-3] == 2  # assert complex images
        fft_rhs = transforms.fft2(prep_fft_channel(rhs))
        combined_kernel = prep_fft_channel(
            to_complex(self.mask.unsqueeze(0).to(
                rhs.device))) + rho * kernel.to(rhs.device)
        fft_div = div_complex(fft_rhs, combined_kernel)
        return unprep_fft_channel(transforms.ifft2(fft_div))
Exemplo n.º 2
0
    def __call__(self, inputs):
        kspace, mask, target = inputs

        # pad if necessary
        p1 = max(0, self.shape[0] - target.shape[-2])
        p2 = max(0, self.shape[1] - target.shape[-1])
        target_padded = torch.nn.functional.pad(
            target, (p2 // 2, -(-p2 // 2), p1 // 2, -(-p1 // 2)),
        )

        # crop if necessary
        target_cropped = transforms.center_crop(target_padded, self.shape)

        # resimulate
        kspace_cropped = transforms.fft2(prep_fft_channel(target_cropped))
        new_mask = mask(kspace_cropped.shape).expand_as(kspace_cropped)
        new_kspace = unprep_fft_channel(kspace_cropped * new_mask)
        new_mask = unprep_fft_channel(new_mask)

        tcs = target_cropped.shape[-3]
        if not tcs == 2:
            target_cropped = target_cropped[
                ..., ((tcs // 2) // 2) * 2 : ((tcs // 2) // 2) * 2 + 2, :, :
            ]

        return new_kspace, new_mask, target_cropped
Exemplo n.º 3
0
 def dot(self, x):
     """ Subsampled Fourier transform. """
     full_fft = unprep_fft_channel(transforms.fft2(prep_fft_channel(x)))
     return im2vec(full_fft)[..., im2vec(self.mask)]