Пример #1
0
    def __getitem__(self, idx):

        # cumulative number of slices in the data set volumes
        vols = self.df["kspace_shape"].apply(lambda s: s[0]).cumsum() - 1

        # get index of volume and slice within volume
        vol_idx = vols.searchsorted(idx)
        sl_idx = idx if vol_idx == 0 else idx - (vols.iloc[vol_idx - 1] + 1)

        # select slices for multi-slice mode
        sl_from = sl_idx - self.num_sym_slices
        sl_to = sl_idx + self.num_sym_slices + 1

        # load data
        fname = self.df["fname"].iloc[vol_idx]
        data = h5py.File(fname, "r")

        # read out slices and pad if necessary
        sl_num = data["kspace"].shape[0]
        kspace_vol = transforms.to_tensor(
            np.asarray(
                data["kspace"][max(0, sl_from) : min(sl_to, sl_num), ...]
            )
        )
        kspace_vol_padded = torch.nn.functional.pad(
            kspace_vol,
            (0, 0, 0, 0, 0, 0, max(0, -sl_from), max(0, sl_to - sl_num)),
        )

        if self.simulate_gt:
            if self.multi_slice_gt:
                gt = transforms.ifft2(kspace_vol_padded)
            else:
                nss = self.num_sym_slices
                gt = transforms.ifft2(kspace_vol_padded[nss : nss + 1, ...])
        else:
            if self.multi_slice_gt:
                raise NotImplementedError(
                    "Multi slice currently only works for simulated targets."
                )
            gt = (
                transforms.to_tensor(
                    np.asarray(
                        data["reconstruction_esc"][sl_idx : sl_idx + 1, ...]
                    )
                )
                if "reconstruction_esc" in data
                else None
            )

        out = self._process_data(kspace_vol_padded, gt,)

        return self.transform(out) if self.transform is not None else out
Пример #2
0
    def tikh(self, rhs, kernel, rho):
        """ Tikhonov regularized inversion.

        Solves the normal equation

            (F*F + rho W*W) x = F*y

        or more generally

            (F*F + rho W*W) x = z

        for a Tikhonov regularized least squares fit, assuming that the
        regularization W*W can be diagonalied by FFTs, i.e.

            W*W = F*D*F

        for some diagonal matrix D.

        Parameters
        ----------
        rhs : torch.Tensor
            The right hand side tensor z, often F*y for some y.
        kernel : torch.Tensor
            The Fourier kernel of W, containing the diagonal elements D.
        rho : float
            The regularization parameter.

        """
        assert rhs.ndim >= 3 and rhs.shape[-3] == 2  # assert complex images
        fft_rhs = transforms.fft2(prep_fft_channel(rhs))
        combined_kernel = prep_fft_channel(
            to_complex(self.mask.unsqueeze(0).to(
                rhs.device))) + rho * kernel.to(rhs.device)
        fft_div = div_complex(fft_rhs, combined_kernel)
        return unprep_fft_channel(transforms.ifft2(fft_div))
Пример #3
0
 def adj(self, y):
     """ Adjoint is the zeor-filled inverse Fourier transform. """
     masked_fft = torch.zeros(*y.shape[:-1],
                              self.n[0] * self.n[1],
                              device=y.device)
     masked_fft[..., im2vec(self.mask)] = y
     return unprep_fft_channel(
         transforms.ifft2(prep_fft_channel(vec2im(masked_fft, self.n))))
Пример #4
0
 def __call__(self, inputs):
     if self.use_target:
         tar = inputs[-1]
     else:
         tar = unprep_fft_channel(
             transforms.ifft2(prep_fft_channel(inputs[0])))
     norm = torch.norm(tar, p=self.p)
     if self.reduction == "mean" and not self.p == "inf":
         norm /= np.prod(tar.shape)**(1 / self.p)
     if len(inputs) == 2:
         return inputs[0] / norm, inputs[1] / norm
     else:
         return inputs[0] / norm, inputs[1], inputs[2] / norm
Пример #5
0
 def __call__(self, inputs):
     kspace, mask, target = inputs
     inv = unprep_fft_channel(transforms.ifft2(prep_fft_channel(kspace)))
     return inv, target