def test_2d_kbnufft_backward(params_2d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol im_size = params_2d["im_size"] numpoints = params_2d["numpoints"] x = params_2d["x"] y = params_2d["y"] ktraj = params_2d["ktraj"] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbnufft_ob = KbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) x.requires_grad = True y = kbnufft_ob.forward(x, ktraj) ((y**2) / 2).sum().backward() x_grad = x.grad.clone().detach() x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj) assert torch.norm(x_grad - x_hat) < norm_tol
def test_2d_kbnufft_backward(): dtype = torch.double nslice = 2 ncoil = 4 im_size = (33, 24) klength = 112 x = np.random.normal(size=(nslice, ncoil) + im_size) + \ 1j*np.random.normal(size=(nslice, ncoil) + im_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype) y = np.random.normal(size=(nslice, ncoil, klength)) + \ 1j*np.random.normal(size=(nslice, ncoil, klength)) y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype) ktraj = torch.randn(*(nslice, 2, klength)).to(dtype) kbnufft_ob = KbNufft(im_size=im_size, numpoints=(4, 6)) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(4, 6)) x.requires_grad = True y = kbnufft_ob.forward(x, ktraj) ((y**2) / 2).sum().backward() x_grad = x.grad.clone().detach() x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj) assert torch.norm(x_grad - x_hat) < norm_tol
def test_3d_kbnufft_adjoint_backward(): dtype = torch.double nslice = 2 ncoil = 4 im_size = (11, 33, 24) klength = 112 x = np.random.normal(size=(nslice, ncoil) + im_size) + \ 1j*np.random.normal(size=(nslice, ncoil) + im_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype) y = np.random.normal(size=(nslice, ncoil, klength)) + \ 1j*np.random.normal(size=(nslice, ncoil, klength)) y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype) ktraj = torch.randn(*(nslice, 3, klength)).to(dtype) kbnufft_ob = KbNufft(im_size=im_size, numpoints=(2, 4, 6)) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(2, 4, 6)) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob) interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat} y.requires_grad = True x = adjkbnufft_ob.forward(y, ktraj, interp_mats) ((x**2) / 2).sum().backward() y_grad = y.grad.clone().detach() y_hat = kbnufft_ob.forward(x.clone().detach(), ktraj, interp_mats) assert torch.norm(y_grad - y_hat) < norm_tol
def test_3d_kbnufft_adjoint_backward(params_3d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol im_size = params_3d["im_size"] numpoints = params_3d["numpoints"] x = params_3d["x"] y = params_3d["y"] ktraj = params_3d["ktraj"] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbnufft_ob = KbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob) interp_mats = { "real_interp_mats": real_mat, "imag_interp_mats": imag_mat } y.requires_grad = True x = adjkbnufft_ob.forward(y, ktraj, interp_mats) ((x**2) / 2).sum().backward() y_grad = y.grad.clone().detach() y_hat = kbnufft_ob.forward(x.clone().detach(), ktraj, interp_mats) assert torch.norm(y_grad - y_hat) < norm_tol
def test_3d_kbnufft_backward(params_3d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol im_size = params_3d['im_size'] numpoints = params_3d['numpoints'] x = params_3d['x'] y = params_3d['y'] ktraj = params_3d['ktraj'] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbnufft_ob = KbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob) interp_mats = { 'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat } x.requires_grad = True y = kbnufft_ob.forward(x, ktraj, interp_mats) ((y**2) / 2).sum().backward() x_grad = x.grad.clone().detach() x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj, interp_mats) assert torch.norm(x_grad - x_hat) < norm_tol