示例#1
0
    def _load_data(self, idx):
        if self.inverse_crime:
            #imgs, maps, masks = load_data_legacy(idx, self.data_file, self.gen_masks)
            imgs, maps, masks = load_data(idx, self.data_file, self.gen_masks)
        else:
            imgs, maps, masks, ksp = load_data_ksp(idx, self.data_file,
                                                   self.gen_masks)
        if self.scale_data:
            ## FIXME: batch mode
            assert not self.scale_data, 'SEE FIXME'
            sc = np.percentile(abs(imgs), 99, axis=(-1, -2))
            imgs = imgs / sc
            ksp = ksp / sc

        if self.fully_sampled:
            masks = np.ones(masks.shape)

        if self.inverse_crime:
            assert not self.noncart, 'FIXME: forward sim of NUFFT'
            out = self._sim_data(imgs, maps, masks)
        else:
            out = self._sim_data(imgs, maps, masks, ksp)

        if not self.noncart:
            maps = fftmod(maps)
        return imgs, maps, masks, out
示例#2
0
    def _sim_data(self, imgs, maps, masks, ksp=None):

        # N, nc, nx, ny
        if self.noncart:
            assert ksp is not None, 'FIXME: NUFFT forward sim'
            noise = np.random.randn(
                *ksp.shape) + 1j * np.random.randn(*ksp.shape)
        else:
            noise = np.random.randn(
                *maps.shape) + 1j * np.random.randn(*maps.shape)

        if self.inverse_crime and ksp is None:
            out = masks[:, None, :, :] * (fft2uc(imgs[:, None, :, :] * maps) +
                                          1 / np.sqrt(2) * self.stdev * noise)
        else:
            if self.noncart:
                out = ksp + 1 / np.sqrt(2) * self.stdev * noise
            else:
                out = masks[:, None, :, :] * (
                    ksp + 1 / np.sqrt(2) * self.stdev * noise)

        if self.adjoint:
            assert not self.noncart, 'FIXME: support NUFFT sim'
            out = np.sum(np.conj(maps) * ifft2uc(out), axis=1).squeeze()
        else:
            if not self.noncart:
                out = fftmod(out)

        return out