def test_sense_tseg_off_res_model(self): img_shape = [16, 16] mps_shape = [8, 16, 16] img = sp.randn(img_shape, dtype=np.complex) mps = sp.randn(mps_shape, dtype=np.complex) y, x = np.mgrid[:16, :16] coord = np.stack([np.ravel(y - 8), np.ravel(x - 8)], axis=1) coord = coord.astype(np.float) d = np.sqrt(x * x + y * y) sigma, mu, a = 2, 0.25, 400 b0 = a * np.exp(-((d - mu) ** 2 / (2.0 * sigma ** 2))) tseg = {"b0": b0, "dt": 4e-6, "lseg": 1, "n_bins": 10} F = sp.linop.NUFFT(mps_shape, coord) b, ct = sp.mri.util.tseg_off_res_b_ct(b0=b0, bins=10, lseg=1, dt=4e-6, T=coord.shape[0] * 4e-6) B1 = sp.linop.Multiply(F.oshape, b.T) Ct1 = sp.linop.Multiply(img_shape, ct.reshape(img_shape)) S = sp.linop.Multiply(img_shape, mps) A = linop.Sense(mps, coord=coord, tseg=tseg) check_linop_adjoint(A, dtype=np.complex) npt.assert_allclose(B1 * F * S * Ct1 * img, A * img)
def test_kspace_precond_cart(self): nc = 4 n = 10 shape = (nc, n) mps = sp.randn(shape, dtype=np.complex) mps /= np.linalg.norm(mps, axis=0, keepdims=True) weights = sp.randn([n]) >= 0 A = sp.linop.Multiply(shape, weights**0.5) * linop.Sense(mps) AAH = np.zeros((nc, n, nc, n), np.complex) for d in range(nc): for j in range(n): x = np.zeros((nc, n), np.complex) x[d, j] = 1.0 AAHx = A(A.H(x)) for c in range(nc): for i in range(n): AAH[c, i, d, j] = AAHx[c, i] p_expected = np.ones((nc, n), np.complex) for c in range(nc): for i in range(n): if weights[i]: p_expected_inv_ic = 0 for d in range(nc): for j in range(n): p_expected_inv_ic += abs(AAH[c, i, d, j])**2 / abs( AAH[c, i, c, i]) p_expected[c, i] = 1 / p_expected_inv_ic p = precond.kspace_precond(mps, weights=weights) npt.assert_allclose(p[:, weights == 1], p_expected[:, weights == 1])
def test_kspace_precond_noncart(self): n = 10 nc = 3 shape = [nc, n] mps = sp.randn(shape, dtype=np.complex) mps /= np.linalg.norm(mps, axis=0, keepdims=True) coord = sp.randn([n, 1], dtype=np.float) A = linop.Sense(mps, coord=coord) AAH = np.zeros((nc, n, nc, n), np.complex) for d in range(nc): for j in range(n): x = np.zeros(shape, np.complex) x[d, j] = 1.0 AAHx = A(A.H(x)) for c in range(nc): for i in range(n): AAH[c, i, d, j] = AAHx[c, i] p_expected = np.zeros([nc, n], np.complex) for c in range(nc): for i in range(n): p_expected_inv_ic = 0 for d in range(nc): for j in range(n): p_expected_inv_ic += abs(AAH[c, i, d, j] )**2 / abs(AAH[c, i, c, i]) p_expected[c, i] = 1 / p_expected_inv_ic p = precond.kspace_precond(mps, coord=coord) npt.assert_allclose(p, p_expected, atol=1e-2, rtol=1e-2)
def test_sense_model_batch(self): img_shape = [16, 16] mps_shape = [8, 16, 16] img = sp.randn(img_shape, dtype=np.complex) mps = sp.randn(mps_shape, dtype=np.complex) A = linop.Sense(mps, coil_batch_size=1) check_linop_adjoint(A, dtype=np.complex) npt.assert_allclose(sp.fft(img * mps, axes=[-1, -2]), A * img)
def check_linop_adjoint(A, dtype=np.float, device=sp.cpu_device): device = sp.Device(device) x = sp.randn(A.ishape, dtype=dtype, device=device) y = sp.randn(A.oshape, dtype=dtype, device=device) xp = device.xp with device: lhs = xp.vdot(A * x, y) rhs = xp.vdot(x, A.H * y) xp.testing.assert_allclose(lhs, rhs, atol=1e-5, rtol=1e-5)
def test_noncart_sense_model_batch(self): img_shape = [16, 16] mps_shape = [8, 16, 16] img = sp.randn(img_shape, dtype=np.complex) mps = sp.randn(mps_shape, dtype=np.complex) y, x = np.mgrid[:16, :16] coord = np.stack([np.ravel(y - 8), np.ravel(x - 8)], axis=1) coord = coord.astype(np.float) A = linop.Sense(mps, coord=coord, coil_batch_size=1) check_linop_adjoint(A, dtype=np.complex) npt.assert_allclose(sp.fft(img * mps, axes=[-1, -2]).ravel(), (A * img).ravel(), atol=0.1, rtol=0.1)
def test_sense_model_with_comm(self): img_shape = [16, 16] mps_shape = [8, 16, 16] comm = sp.Communicator() img = sp.randn(img_shape, dtype=np.complex) mps = sp.randn(mps_shape, dtype=np.complex) comm.allreduce(img) comm.allreduce(mps) ksp = sp.fft(img * mps, axes=[-1, -2]) A = linop.Sense(mps[comm.rank::comm.size], comm=comm) npt.assert_allclose(A.H(ksp[comm.rank::comm.size]), np.sum( sp.ifft(ksp, axes=[-1, -2]) * mps.conjugate(), 0))
def test_sense_model(self): img_shape = [16, 16] mps_shape = [8, 16, 16] img = sp.randn(img_shape, dtype=np.complex) mps = sp.randn(mps_shape, dtype=np.complex) mask = np.zeros(img_shape) mask[::2, ::2] = 1.0 A = linop.Sense(mps) check_linop_adjoint(A, dtype=np.complex) npt.assert_allclose(sp.fft(img * mps, axes=[-1, -2]), A * img)
def _get_vars(self): self.t_idx = sp.ShuffledNumbers(self.num_batches) xp = self.device.xp with self.device: self.y_t = xp.empty((self.batch_size, ) + self.y.shape[1:], dtype=self.dtype) self.L = sp.randn(self.L_shape, dtype=self.dtype, device=self.device) if self.multi_channel: self.L /= xp.sum(xp.abs(self.L)**2, axis=(0, ) + tuple(range(-self.data_ndim, 0)), keepdims=True)**0.5 else: self.L /= xp.sum(xp.abs(self.L)**2, axis=tuple(range(-self.data_ndim, 0)), keepdims=True)**0.5 self.L_old = xp.empty(self.L_shape, dtype=self.dtype) self.R = ConvSparseCoefficients(self.y, self.L, lamda=self.lamda, multi_channel=self.multi_channel, mode=self.mode, max_iter=self.max_inner_iter, max_power_iter=self.max_power_iter)
def test_circulant_precond_cart(self): nc = 4 n = 10 shape = (nc, n) mps = sp.randn(shape, dtype=np.complex) mps /= np.linalg.norm(mps, axis=0, keepdims=True) weights = sp.randn([n]) >= 0 A = sp.linop.Multiply(shape, weights**0.5) * linop.Sense(mps) F = sp.linop.FFT([n]) p_expected = np.zeros(n, np.complex) for i in range(n): if weights[i]: x = np.zeros(n, np.complex) x[i] = 1.0 p_expected[i] = 1 / F(A.H(A(F.H(x))))[i] p = precond.circulant_precond(mps, weights=weights) npt.assert_allclose(p[weights == 1], p_expected[weights == 1])
def _init_vars(self): self.L = [] self.R = [] for j in range(self.J): L_j_shape = self.B[j].ishape L_j = sp.randn(L_j_shape, dtype=self.dtype, device=self.device) L_j_norm = self.xp.sum(self.xp.abs(L_j)**2, axis=range(-self.D, 0), keepdims=True)**0.5 L_j /= L_j_norm R_j_shape = (self.T, ) + L_j_norm.shape R_j = self.xp.zeros(R_j_shape, dtype=self.dtype) self.L.append(L_j) self.R.append(R_j)
def test_circulant_precond_noncart(self): nc = 4 n = 10 shape = [nc, n] mps = np.ones(shape, dtype=np.complex) mps /= np.linalg.norm(mps, axis=0, keepdims=True) coord = sp.randn([n, 1], dtype=np.float) A = linop.Sense(mps, coord=coord) F = sp.linop.FFT([n]) p_expected = np.zeros(n, np.complex) for i in range(n): x = np.zeros(n, np.complex) x[i] = 1.0 p_expected[i] = 1 / F(A.H(A(F.H(x))))[i] p = precond.circulant_precond(mps, coord=coord) npt.assert_allclose(p, p_expected, atol=1e-1, rtol=1e-1)
def test_spatial_explicit_model(self): dim = 3 img_shape = [dim, dim, dim] mps_shape = [8, dim, dim, dim] dt = 4e-6 k = sp.mri.spiral(fov=dim / 2, N=dim, f_sampling=1, R=1, ninterleaves=1, alpha=1, gm=0.03, sm=200) k = rf.stack_of(k, 3, 0.1) mps = sp.randn(mps_shape, dtype=np.complex64) A = linop.PtxSpatialExplicit(mps, k, dt, img_shape) check_linop_adjoint(A, dtype=np.complex64)