def test_kspace_precond_cart(self): nc = 4 n = 10 shape = (nc, n) mps = sp.randn(shape, dtype=np.complex) mps /= np.linalg.norm(mps, axis=0, keepdims=True) weights = sp.randn([n]) >= 0 A = sp.linop.Multiply(shape, weights**0.5) * linop.Sense(mps) AAH = np.zeros((nc, n, nc, n), np.complex) for d in range(nc): for j in range(n): x = np.zeros((nc, n), np.complex) x[d, j] = 1.0 AAHx = A(A.H(x)) for c in range(nc): for i in range(n): AAH[c, i, d, j] = AAHx[c, i] p_expected = np.ones((nc, n), np.complex) for c in range(nc): for i in range(n): if weights[i]: p_expected_inv_ic = 0 for d in range(nc): for j in range(n): p_expected_inv_ic += abs(AAH[c, i, d, j])**2 / abs( AAH[c, i, c, i]) p_expected[c, i] = 1 / p_expected_inv_ic p = precond.kspace_precond(mps, weights=weights) npt.assert_allclose(p[:, weights == 1], p_expected[:, weights == 1])
def __init__(self, y, mps, lamda, weights=None, coord=None, device=sp.cpu_device, coil_batch_size=None, **kwargs): weights = _estimate_weights(y, weights, coord) if weights is not None: y = sp.to_device(y * weights**0.5, device=device) else: y = sp.to_device(y, device=device) A = linop.Sense(mps, coord=coord, weights=weights, coil_batch_size=coil_batch_size) G = sp.linop.Gradient(A.ishape) proxg = sp.prox.L1Reg(G.oshape, lamda) def g(x): device = sp.get_device(x) xp = device.xp with device: return lamda * xp.sum(xp.abs(x)) super().__init__(A, y, proxg=proxg, g=g, G=G, **kwargs)
def __init__(self, y, mps, lamda=0, weights=None, tseg=None, coord=None, device=sp.cpu_device, coil_batch_size=None, comm=None, show_pbar=True, transp_nufft=False, **kwargs): weights = _estimate_weights(y, weights, coord) if weights is not None: y = sp.to_device(y * weights**0.5, device=device) else: y = sp.to_device(y, device=device) A = linop.Sense(mps, coord=coord, weights=weights, tseg=tseg, coil_batch_size=coil_batch_size, comm=comm, transp_nufft=transp_nufft) if comm is not None: show_pbar = show_pbar and comm.rank == 0 super().__init__(A, y, lamda=lamda, show_pbar=show_pbar, **kwargs)
def __init__(self, y, mps, eps, weights=None, coord=None, device=sp.cpu_device, coil_batch_size=None, comm=None, show_pbar=True, **kwargs): weights = _estimate_weights(y, weights, coord) if weights is not None: y = sp.to_device(y * weights**0.5, device=device) else: y = sp.to_device(y, device=device) A = linop.Sense(mps, coord=coord, weights=weights, comm=comm, coil_batch_size=coil_batch_size) G = sp.linop.FiniteDifference(A.ishape) proxg = sp.prox.L1Reg(G.oshape, 1) if comm is not None: show_pbar = show_pbar and comm.rank == 0 super().__init__(A, y, proxg, eps, G=G, show_pbar=show_pbar, **kwargs)
def __init__(self, y, mps, lamda, weights=None, coord=None, wave_name='db4', device=sp.cpu_device, coil_batch_size=None, comm=None, show_pbar=True, **kwargs): weights = _estimate_weights(y, weights, coord) if weights is not None: y = sp.to_device(y * weights**0.5, device=device) else: y = sp.to_device(y, device=device) A = linop.Sense(mps, coord=coord, weights=weights, comm=comm, coil_batch_size=coil_batch_size) img_shape = mps.shape[1:] W = sp.linop.Wavelet(img_shape, wave_name=wave_name) proxg = sp.prox.UnitaryTransform(sp.prox.L1Reg(W.oshape, lamda), W) def g(input): device = sp.get_device(input) xp = device.xp with device: return lamda * xp.sum(xp.abs(W(input))) if comm is not None: show_pbar = show_pbar and comm.rank == 0 super().__init__(A, y, proxg=proxg, g=g, show_pbar=show_pbar, **kwargs)
def __init__(self, y, mps, lamda, weights=None, coord=None, device=sp.cpu_device, coil_batch_size=None, comm=None, show_pbar=True, **kwargs): weights = _estimate_weights(y, weights, coord) if weights is not None: y = sp.to_device(y * weights**0.5, device=device) else: y = sp.to_device(y, device=device) A = linop.Sense(mps, coord=coord, weights=weights, comm=comm, coil_batch_size=coil_batch_size) G = sp.linop.FiniteDifference(A.ishape) proxg = sp.prox.L1Reg(G.oshape, lamda) def g(x): device = sp.get_device(x) xp = device.xp with device: return lamda * xp.sum(xp.abs(x)) if comm is not None: show_pbar = show_pbar and comm.rank == 0 super().__init__(A, y, proxg=proxg, g=g, G=G, show_pbar=show_pbar, **kwargs)
def test_sense_tseg_off_res_model(self): img_shape = [16, 16] mps_shape = [8, 16, 16] img = sp.randn(img_shape, dtype=np.complex) mps = sp.randn(mps_shape, dtype=np.complex) y, x = np.mgrid[:16, :16] coord = np.stack([np.ravel(y - 8), np.ravel(x - 8)], axis=1) coord = coord.astype(np.float) d = np.sqrt(x * x + y * y) sigma, mu, a = 2, 0.25, 400 b0 = a * np.exp(-((d - mu) ** 2 / (2.0 * sigma ** 2))) tseg = {"b0": b0, "dt": 4e-6, "lseg": 1, "n_bins": 10} F = sp.linop.NUFFT(mps_shape, coord) b, ct = sp.mri.util.tseg_off_res_b_ct(b0=b0, bins=10, lseg=1, dt=4e-6, T=coord.shape[0] * 4e-6) B1 = sp.linop.Multiply(F.oshape, b.T) Ct1 = sp.linop.Multiply(img_shape, ct.reshape(img_shape)) S = sp.linop.Multiply(img_shape, mps) A = linop.Sense(mps, coord=coord, tseg=tseg) check_linop_adjoint(A, dtype=np.complex) npt.assert_allclose(B1 * F * S * Ct1 * img, A * img)
def __init__(self, y, mps, eps, wave_name='db4', weights=None, coord=None, device=sp.cpu_device, coil_batch_size=None, **kwargs): weights = _estimate_weights(y, weights, coord) if weights is not None: y = sp.to_device(y * weights**0.5, device=device) else: y = sp.to_device(y, device=device) A = linop.Sense(mps, coord=coord, weights=weights, coil_batch_size=coil_batch_size) img_shape = mps.shape[1:] W = sp.linop.Wavelet(img_shape, wave_name=wave_name) proxg = sp.prox.UnitaryTransform(sp.prox.L1Reg(W.oshape, 1), W) super().__init__(A, y, proxg, eps, **kwargs)
def test_kspace_precond_noncart(self): n = 10 nc = 3 shape = [nc, n] mps = sp.randn(shape, dtype=np.complex) mps /= np.linalg.norm(mps, axis=0, keepdims=True) coord = sp.randn([n, 1], dtype=np.float) A = linop.Sense(mps, coord=coord) AAH = np.zeros((nc, n, nc, n), np.complex) for d in range(nc): for j in range(n): x = np.zeros(shape, np.complex) x[d, j] = 1.0 AAHx = A(A.H(x)) for c in range(nc): for i in range(n): AAH[c, i, d, j] = AAHx[c, i] p_expected = np.zeros([nc, n], np.complex) for c in range(nc): for i in range(n): p_expected_inv_ic = 0 for d in range(nc): for j in range(n): p_expected_inv_ic += abs(AAH[c, i, d, j] )**2 / abs(AAH[c, i, c, i]) p_expected[c, i] = 1 / p_expected_inv_ic p = precond.kspace_precond(mps, coord=coord) npt.assert_allclose(p, p_expected, atol=1e-2, rtol=1e-2)
def test_stspa_spiral(self): target, sens = self.problem_2d(8) fov = 0.55 gts = 6.4e-6 gslew = 190 gamp = 40 R = 1 dx = 0.025 # in m # construct a trajectory g, k, t, s = rf.spiral_arch(fov / R, dx, gts, gslew, gamp) A = linop.Sense(sens, coord=k, ishape=target.shape).H pulses = rf.stspa(target, sens, k, dt=4e-6, alpha=1, b0=None, st=None, explicit=False, max_iter=100, tol=1E-4) npt.assert_array_almost_equal(A * pulses, target, 1E-3)
def test_stspa_radial(self): target, sens = self.problem_2d(8) # makes dim*dim*2 trajectory traj = sp.mri.radial((sens.shape[1], sens.shape[1], 2), target.shape, golden=True, dtype=np.float) # reshape to be Nt*2 trajectory traj = np.reshape(traj, [traj.shape[0] * traj.shape[1], 2]) A = linop.Sense(sens, coord=traj, weights=None, ishape=target.shape).H pulses = rf.stspa(target, sens, traj, dt=4e-6, alpha=1, b0=None, st=None, explicit=False, max_iter=100, tol=1E-4) npt.assert_array_almost_equal(A * pulses, target, 1E-3)
def test_sense_model_batch(self): img_shape = [16, 16] mps_shape = [8, 16, 16] img = sp.randn(img_shape, dtype=np.complex) mps = sp.randn(mps_shape, dtype=np.complex) A = linop.Sense(mps, coil_batch_size=1) check_linop_adjoint(A, dtype=np.complex) npt.assert_allclose(sp.fft(img * mps, axes=[-1, -2]), A * img)
def test_noncart_sense_model_batch(self): img_shape = [16, 16] mps_shape = [8, 16, 16] img = sp.randn(img_shape, dtype=np.complex) mps = sp.randn(mps_shape, dtype=np.complex) y, x = np.mgrid[:16, :16] coord = np.stack([np.ravel(y - 8), np.ravel(x - 8)], axis=1) coord = coord.astype(np.float) A = linop.Sense(mps, coord=coord, coil_batch_size=1) check_linop_adjoint(A, dtype=np.complex) npt.assert_allclose(sp.fft(img * mps, axes=[-1, -2]).ravel(), (A * img).ravel(), atol=0.1, rtol=0.1)
def test_sense_model_with_comm(self): img_shape = [16, 16] mps_shape = [8, 16, 16] comm = sp.Communicator() img = sp.randn(img_shape, dtype=np.complex) mps = sp.randn(mps_shape, dtype=np.complex) comm.allreduce(img) comm.allreduce(mps) ksp = sp.fft(img * mps, axes=[-1, -2]) A = linop.Sense(mps[comm.rank::comm.size], comm=comm) npt.assert_allclose(A.H(ksp[comm.rank::comm.size]), np.sum( sp.ifft(ksp, axes=[-1, -2]) * mps.conjugate(), 0))
def test_sense_model(self): img_shape = [16, 16] mps_shape = [8, 16, 16] img = sp.randn(img_shape, dtype=np.complex) mps = sp.randn(mps_shape, dtype=np.complex) mask = np.zeros(img_shape) mask[::2, ::2] = 1.0 A = linop.Sense(mps) check_linop_adjoint(A, dtype=np.complex) npt.assert_allclose(sp.fft(img * mps, axes=[-1, -2]), A * img)
def test_circulant_precond_noncart(self): nc = 4 n = 10 shape = [nc, n] mps = np.ones(shape, dtype=np.complex) mps /= np.linalg.norm(mps, axis=0, keepdims=True) coord = sp.randn([n, 1], dtype=np.float) A = linop.Sense(mps, coord=coord) F = sp.linop.FFT([n]) p_expected = np.zeros(n, np.complex) for i in range(n): x = np.zeros(n, np.complex) x[i] = 1.0 p_expected[i] = 1 / F(A.H(A(F.H(x))))[i] p = precond.circulant_precond(mps, coord=coord) npt.assert_allclose(p, p_expected, atol=1e-1, rtol=1e-1)
def test_circulant_precond_cart(self): nc = 4 n = 10 shape = (nc, n) mps = sp.randn(shape, dtype=np.complex) mps /= np.linalg.norm(mps, axis=0, keepdims=True) weights = sp.randn([n]) >= 0 A = sp.linop.Multiply(shape, weights**0.5) * linop.Sense(mps) F = sp.linop.FFT([n]) p_expected = np.zeros(n, np.complex) for i in range(n): if weights[i]: x = np.zeros(n, np.complex) x[i] = 1.0 p_expected[i] = 1 / F(A.H(A(F.H(x))))[i] p = precond.circulant_precond(mps, weights=weights) npt.assert_allclose(p[weights == 1], p_expected[weights == 1])
def __init__(self, y, mps, lamda=0, weights=None, coord=None, device=sp.cpu_device, coil_batch_size=None, **kwargs): weights = _estimate_weights(y, weights, coord) if weights is not None: y = sp.to_device(y * weights**0.5, device=device) else: y = sp.to_device(y, device=device) A = linop.Sense(mps, coord=coord, weights=weights, coil_batch_size=coil_batch_size) super().__init__(A, y, lamda=lamda, **kwargs)
def __init__(self, y, mps, eps, weights=None, coord=None, device=sp.cpu_device, coil_batch_size=None, **kwargs): weights = _estimate_weights(y, weights, coord) if weights is not None: y = sp.to_device(y * weights**0.5, device=device) else: y = sp.to_device(y, device=device) A = linop.Sense(mps, coord=coord, weights=weights, coil_batch_size=coil_batch_size) proxg = sp.prox.L2Reg(A.ishape, 1) super().__init__(A, y, proxg, eps, **kwargs)