def test_shepp_logan_SenseRecon_with_comm(self): img, mps, ksp = self.shepp_logan_setup() lamda = 0 comm = sp.Communicator() ksp = ksp[comm.rank::comm.size] mps = mps[comm.rank::comm.size] img_rec = app.SenseRecon( ksp, mps, lamda, comm=comm, alg_name='ConjugateGradient', show_pbar=False).run() npt.assert_allclose(img, img_rec, atol=1e-3, rtol=1e-3) img_rec = app.SenseRecon( ksp, mps, lamda, alg_name='GradientMethod', show_pbar=False).run() npt.assert_allclose(img, img_rec, atol=1e-3, rtol=1e-3) img_rec = app.SenseRecon( ksp, mps, lamda, alg_name='PrimalDualHybridGradient', max_iter=1000, show_pbar=False).run() npt.assert_allclose(img, img_rec, atol=1e-3, rtol=1e-3)
def jsens_calib(ksp, coord, dcf, ishape, device = sp.Device(-1)): img_s = nft.nufft_adj([ksp],[coord],[dcf],device = device,ishape = ishape,id_channel =True) ksp = sp.fft(input=np.asarray(img_s[0]),axes=(1,2,3)) mps = mr.app.JsenseRecon(ksp, mps_ker_width=12, ksp_calib_width=32, lamda=0, device=device, comm=sp.Communicator(), max_iter=10, max_inner_iter=10).run() return mps
def test_sense_model_with_comm(self): img_shape = [16, 16] mps_shape = [8, 16, 16] comm = sp.Communicator() img = sp.randn(img_shape, dtype=np.complex) mps = sp.randn(mps_shape, dtype=np.complex) comm.allreduce(img) comm.allreduce(mps) ksp = sp.fft(img * mps, axes=[-1, -2]) A = linop.Sense(mps[comm.rank::comm.size], comm=comm) npt.assert_allclose(A.H(ksp[comm.rank::comm.size]), np.sum( sp.ifft(ksp, axes=[-1, -2]) * mps.conjugate(), 0))
def test_shepp_logan_SenseRecon_with_comm(self): img, mps, ksp = self.shepp_logan_setup() lamda = 0 comm = sp.Communicator() ksp = ksp[comm.rank::comm.size] mps = mps[comm.rank::comm.size] for solver in ['ConjugateGradient', 'GradientMethod', 'PrimalDualHybridGradient', 'ADMM']: with self.subTest(solver=solver): img_rec = app.SenseRecon( ksp, mps, lamda, comm=comm, solver=solver, show_pbar=False).run() npt.assert_allclose(img, img_rec, atol=1e-2, rtol=1e-2)
parser.add_argument('ksp_file', type=str) parser.add_argument('coord_file', type=str) parser.add_argument('dcf_file', type=str) parser.add_argument('mps_file', type=str) args = parser.parse_args() logging.basicConfig(level=logging.DEBUG) logging.info('Reading data.') ksp = np.load(args.ksp_file, mmap_mode='r') coord = np.load(args.coord_file) dcf = np.load(args.dcf_file) # Choose device comm = sp.Communicator() if args.multi_gpu: device = comm.rank else: device = args.device logging.info('Jsense Recon.') ksp = np.array_split(ksp, comm.size)[comm.rank] mps = mr.app.JsenseRecon(ksp, coord=coord, weights=dcf, mps_ker_width=args.mps_ker_width, ksp_calib_width=args.ksp_calib_width, lamda=args.lamda, device=device, comm=comm,