def test_2d_interp_adjoint_backward(): dtype = torch.double nslice = 2 ncoil = 4 im_size = (33, 24) klength = 112 x = np.random.normal(size=(nslice, ncoil) + im_size) + \ 1j*np.random.normal(size=(nslice, ncoil) + im_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype) y = np.random.normal(size=(nslice, ncoil, klength)) + \ 1j*np.random.normal(size=(nslice, ncoil, klength)) y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype) ktraj = torch.randn(*(nslice, 2, klength)).to(dtype) kbinterp_ob = KbInterpForw(im_size=(20, 25), grid_size=im_size, numpoints=(4, 6)) adjkbinterp_ob = KbInterpBack(im_size=(20, 25), grid_size=im_size, numpoints=(4, 6)) y.requires_grad = True x = adjkbinterp_ob.forward(y, ktraj) ((x**2) / 2).sum().backward() y_grad = y.grad.clone().detach() y_hat = kbinterp_ob.forward(x.clone().detach(), ktraj) assert torch.norm(y_grad - y_hat) < norm_tol
def test_interp_3d_adjoint(): dtype = torch.double nslice = 2 ncoil = 4 im_size = (11, 33, 24) klength = 112 y = np.random.normal(size=(nslice, ncoil, klength)) + \ 1j*np.random.normal(size=(nslice, ncoil, klength)) y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype) ktraj = torch.randn(*(nslice, 3, klength)).to(dtype) adjkbinterp_ob = KbInterpBack(im_size=(5, 20, 25), grid_size=im_size, numpoints=(2, 4, 6)) adjkbinterp_matadj_ob = KbInterpBack(im_size=(5, 20, 25), grid_size=im_size, numpoints=(2, 4, 6), matadj=True) x_normal = adjkbinterp_ob(y, ktraj) x_matadj = adjkbinterp_matadj_ob(y, ktraj) assert torch.norm(x_normal - x_matadj) < norm_tol
def test_3d_interp_backward(): dtype = torch.double nslice = 2 ncoil = 4 im_size = (11, 33, 24) klength = 112 x = np.random.normal(size=(nslice, ncoil) + im_size) + \ 1j*np.random.normal(size=(nslice, ncoil) + im_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype) y = np.random.normal(size=(nslice, ncoil, klength)) + \ 1j*np.random.normal(size=(nslice, ncoil, klength)) y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype) ktraj = torch.randn(*(nslice, 3, klength)).to(dtype) kbinterp_ob = KbInterpForw(im_size=(5, 20, 25), grid_size=im_size, numpoints=(2, 4, 6)) adjkbinterp_ob = KbInterpBack(im_size=(5, 20, 25), grid_size=im_size, numpoints=(2, 4, 6)) x.requires_grad = True y = kbinterp_ob.forward(x, ktraj) ((y**2) / 2).sum().backward() x_grad = x.grad.clone().detach() x_hat = adjkbinterp_ob.forward(y.clone().detach(), ktraj) assert torch.norm(x_grad - x_hat) < norm_tol
def test_interp_2d_adjoint(): dtype = torch.double nslice = 2 ncoil = 4 im_size = (33, 24) klength = 112 x = np.random.normal(size=(nslice, ncoil) + im_size) + \ 1j*np.random.normal(size=(nslice, ncoil) + im_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype) y = np.random.normal(size=(nslice, ncoil, klength)) + \ 1j*np.random.normal(size=(nslice, ncoil, klength)) y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype) ktraj = torch.randn(*(nslice, 2, klength)).to(dtype) kbinterp_ob = KbInterpForw(im_size=(20, 25), grid_size=im_size, numpoints=(4, 6)) adjkbinterp_ob = KbInterpBack(im_size=(20, 25), grid_size=im_size, numpoints=(4, 6)) x_forw = kbinterp_ob(x, ktraj) y_back = adjkbinterp_ob(y, ktraj) inprod1 = inner_product(y, x_forw, dim=2) inprod2 = inner_product(y_back, x, dim=2) assert torch.norm(inprod1 - inprod2) < norm_tol
def test_interp_3d_adjoint(): dtype = torch.double nslice = 2 ncoil = 4 im_size = (11, 33, 24) klength = 112 x = np.random.normal(size=(nslice, ncoil) + im_size) + \ 1j*np.random.normal(size=(nslice, ncoil) + im_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype) y = np.random.normal(size=(nslice, ncoil, klength)) + \ 1j*np.random.normal(size=(nslice, ncoil, klength)) y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype) ktraj = torch.randn(*(nslice, 3, klength)).to(dtype) kbinterp_ob = KbInterpForw(im_size=(5, 20, 25), grid_size=im_size, numpoints=(2, 4, 6)) adjkbinterp_ob = KbInterpBack(im_size=(5, 20, 25), grid_size=im_size, numpoints=(2, 4, 6)) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob) interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat} x_forw = kbinterp_ob(x, ktraj, interp_mats) y_back = adjkbinterp_ob(y, ktraj, interp_mats) inprod1 = inner_product(y, x_forw, dim=2) inprod2 = inner_product(y_back, x, dim=2) assert torch.norm(inprod1 - inprod2) < norm_tol
def test_interp_2d_adjoint(params_2d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol batch_size = params_2d["batch_size"] im_size = params_2d["im_size"] grid_size = params_2d["grid_size"] numpoints = params_2d["numpoints"] x = np.random.normal(size=(batch_size, 1) + grid_size) + 1j * np.random.normal( size=(batch_size, 1) + grid_size ) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)) y = params_2d["y"] ktraj = params_2d["ktraj"] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbinterp_ob = KbInterpForw( im_size=im_size, grid_size=grid_size, numpoints=numpoints ).to(dtype=dtype, device=device) adjkbinterp_ob = KbInterpBack( im_size=im_size, grid_size=grid_size, numpoints=numpoints ).to(dtype=dtype, device=device) x_forw = kbinterp_ob(x, ktraj) y_back = adjkbinterp_ob(y, ktraj) inprod1 = inner_product(y, x_forw, dim=2) inprod2 = inner_product(y_back, x, dim=2) assert torch.norm(inprod1 - inprod2) < norm_tol
def test_2d_interp_adjoint_backward(params_2d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol batch_size = params_2d["batch_size"] im_size = params_2d["im_size"] grid_size = params_2d["grid_size"] numpoints = params_2d["numpoints"] x = np.random.normal( size=(batch_size, 1) + grid_size) + 1j * np.random.normal(size=(batch_size, 1) + grid_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)) y = params_2d["y"] ktraj = params_2d["ktraj"] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbinterp_ob = KbInterpForw(im_size=im_size, grid_size=grid_size, numpoints=numpoints).to(dtype=dtype, device=device) adjkbinterp_ob = KbInterpBack(im_size=im_size, grid_size=grid_size, numpoints=numpoints).to(dtype=dtype, device=device) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob) interp_mats = { "real_interp_mats": real_mat, "imag_interp_mats": imag_mat } y.requires_grad = True x = adjkbinterp_ob.forward(y, ktraj, interp_mats) ((x**2) / 2).sum().backward() y_grad = y.grad.clone().detach() y_hat = kbinterp_ob.forward(x.clone().detach(), ktraj, interp_mats) assert torch.norm(y_grad - y_hat) < norm_tol
def test_3d_interp_backward(params_3d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol batch_size = params_3d['batch_size'] im_size = params_3d['im_size'] grid_size = params_3d['grid_size'] numpoints = params_3d['numpoints'] x = np.random.normal(size=(batch_size, 1) + grid_size) + \ 1j*np.random.normal(size=(batch_size, 1) + grid_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)) y = params_3d['y'] ktraj = params_3d['ktraj'] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbinterp_ob = KbInterpForw(im_size=im_size, grid_size=grid_size, numpoints=numpoints).to(dtype=dtype, device=device) adjkbinterp_ob = KbInterpBack(im_size=im_size, grid_size=grid_size, numpoints=numpoints).to(dtype=dtype, device=device) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob) interp_mats = { 'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat } x.requires_grad = True y = kbinterp_ob.forward(x, ktraj, interp_mats) ((y**2) / 2).sum().backward() x_grad = x.grad.clone().detach() x_hat = adjkbinterp_ob.forward(y.clone().detach(), ktraj, interp_mats) assert torch.norm(x_grad - x_hat) < norm_tol
def test_3d_interp_backward(params_3d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol batch_size = params_3d["batch_size"] im_size = params_3d["im_size"] grid_size = params_3d["grid_size"] numpoints = params_3d["numpoints"] x = np.random.normal( size=(batch_size, 1) + grid_size) + 1j * np.random.normal(size=(batch_size, 1) + grid_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)) y = params_3d["y"] ktraj = params_3d["ktraj"] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbinterp_ob = KbInterpForw(im_size=im_size, grid_size=grid_size, numpoints=numpoints).to(dtype=dtype, device=device) adjkbinterp_ob = KbInterpBack(im_size=im_size, grid_size=grid_size, numpoints=numpoints).to(dtype=dtype, device=device) x.requires_grad = True y = kbinterp_ob.forward(x, ktraj) ((y**2) / 2).sum().backward() x_grad = x.grad.clone().detach() x_hat = adjkbinterp_ob.forward(y.clone().detach(), ktraj) assert torch.norm(x_grad - x_hat) < norm_tol
def test_3d_interp_adjoint_backward(): dtype = torch.double nslice = 2 ncoil = 4 im_size = (11, 33, 24) klength = 112 x = np.random.normal(size=(nslice, ncoil) + im_size) + \ 1j*np.random.normal(size=(nslice, ncoil) + im_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype) y = np.random.normal(size=(nslice, ncoil, klength)) + \ 1j*np.random.normal(size=(nslice, ncoil, klength)) y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype) ktraj = torch.randn(*(nslice, 3, klength)).to(dtype) kbinterp_ob = KbInterpForw(im_size=(5, 20, 25), grid_size=im_size, numpoints=(2, 4, 6)) adjkbinterp_ob = KbInterpBack(im_size=(5, 20, 25), grid_size=im_size, numpoints=(2, 4, 6)) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob) interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat} y.requires_grad = True x = adjkbinterp_ob.forward(y, ktraj, interp_mats) ((x**2) / 2).sum().backward() y_grad = y.grad.clone().detach() y_hat = kbinterp_ob.forward(x.clone().detach(), ktraj, interp_mats) assert torch.norm(y_grad - y_hat) < norm_tol
def test_interp_3d_adjoint(params_3d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol batch_size = params_3d['batch_size'] im_size = params_3d['im_size'] grid_size = params_3d['grid_size'] numpoints = params_3d['numpoints'] x = np.random.normal(size=(batch_size, 1) + grid_size) + \ 1j*np.random.normal(size=(batch_size, 1) + grid_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)) y = params_3d['y'] ktraj = params_3d['ktraj'] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbinterp_ob = KbInterpForw(im_size=im_size, grid_size=grid_size, numpoints=numpoints).to(dtype=dtype, device=device) adjkbinterp_ob = KbInterpBack(im_size=im_size, grid_size=grid_size, numpoints=numpoints).to(dtype=dtype, device=device) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob) interp_mats = { 'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat } x_forw = kbinterp_ob(x, ktraj, interp_mats) y_back = adjkbinterp_ob(y, ktraj, interp_mats) inprod1 = inner_product(y, x_forw, dim=2) inprod2 = inner_product(y_back, x, dim=2) assert torch.norm(inprod1 - inprod2) < norm_tol
def test_kb_matching(testing_tol): norm_tol = testing_tol def check_tables(table1, table2): for ind, table in enumerate(table1): assert np.linalg.norm(table - table2[ind]) < norm_tol im_szs = [(256, 256), (10, 256, 256)] kbwidths = [2.34, 5] orders = [0, 2] for kbwidth in kbwidths: for order in orders: for im_sz in im_szs: smap = torch.randn(*((1,) + im_sz)) base_table = AdjKbNufft( im_sz, order=order, kbwidth=kbwidth).table cur_table = KbNufft(im_sz, order=order, kbwidth=kbwidth).table check_tables(base_table, cur_table) cur_table = KbInterpBack( im_sz, order=order, kbwidth=kbwidth).table check_tables(base_table, cur_table) cur_table = KbInterpForw( im_sz, order=order, kbwidth=kbwidth).table check_tables(base_table, cur_table) cur_table = MriSenseNufft( smap, im_sz, order=order, kbwidth=kbwidth).table check_tables(base_table, cur_table) cur_table = AdjMriSenseNufft( smap, im_sz, order=order, kbwidth=kbwidth).table check_tables(base_table, cur_table)
def test_2d_init_inputs(): # all object initializations have assertions # this should result in an error if any dimensions don't match # test 2d scalar inputs im_sz = (256, 256) smap = torch.randn(*((1,) + im_sz)) grid_sz = (512, 512) n_shift = (128, 128) numpoints = 6 table_oversamp = 2 ** 10 kbwidth = 2.34 order = 0 norm = "None" ob = KbInterpForw( im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, ) ob = KbInterpBack( im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, ) ob = KbNufft( im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm, ) ob = AdjKbNufft( im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm, ) ob = MriSenseNufft( smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm, ) ob = AdjMriSenseNufft( smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm, ) # test 2d tuple inputs im_sz = (256, 256) smap = torch.randn(*((1,) + im_sz)) grid_sz = (512, 512) n_shift = (128, 128) numpoints = (6, 6) table_oversamp = (2 ** 10, 2 ** 10) kbwidth = (2.34, 2.34) order = (0, 0) norm = "None" ob = KbInterpForw( im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, ) ob = KbInterpBack( im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, ) ob = KbNufft( im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm, ) ob = AdjKbNufft( im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm, ) ob = MriSenseNufft( smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm, ) ob = AdjMriSenseNufft( smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm, )