def _load_data(self, idx): if self.inverse_crime: #imgs, maps, masks = load_data_legacy(idx, self.data_file, self.gen_masks) imgs, maps, masks = load_data(idx, self.data_file, self.gen_masks) else: imgs, maps, masks, ksp = load_data_ksp(idx, self.data_file, self.gen_masks) if self.scale_data: ## FIXME: batch mode assert not self.scale_data, 'SEE FIXME' sc = np.percentile(abs(imgs), 99, axis=(-1, -2)) imgs = imgs / sc ksp = ksp / sc if self.fully_sampled: masks = np.ones(masks.shape) if self.inverse_crime: assert not self.noncart, 'FIXME: forward sim of NUFFT' out = self._sim_data(imgs, maps, masks) else: out = self._sim_data(imgs, maps, masks, ksp) if not self.noncart: maps = fftmod(maps) return imgs, maps, masks, out
def _sim_data(self, imgs, maps, masks, ksp=None): # N, nc, nx, ny if self.noncart: assert ksp is not None, 'FIXME: NUFFT forward sim' noise = np.random.randn( *ksp.shape) + 1j * np.random.randn(*ksp.shape) else: noise = np.random.randn( *maps.shape) + 1j * np.random.randn(*maps.shape) if self.inverse_crime and ksp is None: out = masks[:, None, :, :] * (fft2uc(imgs[:, None, :, :] * maps) + 1 / np.sqrt(2) * self.stdev * noise) else: if self.noncart: out = ksp + 1 / np.sqrt(2) * self.stdev * noise else: out = masks[:, None, :, :] * ( ksp + 1 / np.sqrt(2) * self.stdev * noise) if self.adjoint: assert not self.noncart, 'FIXME: support NUFFT sim' out = np.sum(np.conj(maps) * ifft2uc(out), axis=1).squeeze() else: if not self.noncart: out = fftmod(out) return out