def test_nufft_2d_adjoint(): dtype = torch.double nslice = 2 ncoil = 4 im_size = (33, 24) klength = 112 x = np.random.normal(size=(nslice, ncoil) + im_size) + \ 1j*np.random.normal(size=(nslice, ncoil) + im_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype) y = np.random.normal(size=(nslice, ncoil, klength)) + \ 1j*np.random.normal(size=(nslice, ncoil, klength)) y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype) ktraj = torch.randn(*(nslice, 2, klength)).to(dtype) kbnufft_ob = KbNufft(im_size=im_size, numpoints=(4, 6)) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(4, 6)) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob) interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat} x_forw = kbnufft_ob(x, ktraj, interp_mats) y_back = adjkbnufft_ob(y, ktraj, interp_mats) inprod1 = inner_product(y, x_forw, dim=2) inprod2 = inner_product(y_back, x, dim=2) assert torch.norm(inprod1 - inprod2) < norm_tol
def test_mrisensenufft_2d_adjoint(): dtype = torch.double nslice = 2 ncoil = 4 im_size = (33, 24) klength = 112 x = np.random.normal(size=(nslice, 1) + im_size) + \ 1j*np.random.normal(size=(nslice, 1) + im_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype) y = np.random.normal(size=(nslice, ncoil, klength)) + \ 1j*np.random.normal(size=(nslice, ncoil, klength)) y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype) ktraj = torch.randn(*(nslice, 2, klength)).to(dtype) smap_sz = (nslice, ncoil, 2) + im_size smap = torch.randn(*smap_sz).to(dtype) sensenufft_ob = MriSenseNufft(smap=smap, im_size=im_size) adjsensenufft_ob = AdjMriSenseNufft(smap=smap, im_size=im_size) real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob) interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat} x_forw = sensenufft_ob(x, ktraj, interp_mats) y_back = adjsensenufft_ob(y, ktraj, interp_mats) inprod1 = inner_product(y, x_forw, dim=2) inprod2 = inner_product(y_back, x, dim=2) assert torch.norm(inprod1 - inprod2) < norm_tol
def test_2d_kbnufft_backward(): dtype = torch.double nslice = 2 ncoil = 4 im_size = (33, 24) klength = 112 x = np.random.normal(size=(nslice, ncoil) + im_size) + \ 1j*np.random.normal(size=(nslice, ncoil) + im_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype) y = np.random.normal(size=(nslice, ncoil, klength)) + \ 1j*np.random.normal(size=(nslice, ncoil, klength)) y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype) ktraj = torch.randn(*(nslice, 2, klength)).to(dtype) kbnufft_ob = KbNufft(im_size=im_size, numpoints=(4, 6)) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(4, 6)) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob) interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat} x.requires_grad = True y = kbnufft_ob.forward(x, ktraj, interp_mats) ((y**2) / 2).sum().backward() x_grad = x.grad.clone().detach() x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj, interp_mats) assert torch.norm(x_grad - x_hat) < norm_tol
def test_2d_interp_adjoint_backward(params_2d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol batch_size = params_2d["batch_size"] im_size = params_2d["im_size"] grid_size = params_2d["grid_size"] numpoints = params_2d["numpoints"] x = np.random.normal( size=(batch_size, 1) + grid_size) + 1j * np.random.normal(size=(batch_size, 1) + grid_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)) y = params_2d["y"] ktraj = params_2d["ktraj"] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbinterp_ob = KbInterpForw(im_size=im_size, grid_size=grid_size, numpoints=numpoints).to(dtype=dtype, device=device) adjkbinterp_ob = KbInterpBack(im_size=im_size, grid_size=grid_size, numpoints=numpoints).to(dtype=dtype, device=device) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob) interp_mats = { "real_interp_mats": real_mat, "imag_interp_mats": imag_mat } y.requires_grad = True x = adjkbinterp_ob.forward(y, ktraj, interp_mats) ((x**2) / 2).sum().backward() y_grad = y.grad.clone().detach() y_hat = kbinterp_ob.forward(x.clone().detach(), ktraj, interp_mats) assert torch.norm(y_grad - y_hat) < norm_tol
def test_3d_mrisensenufft_coilpack_adjoint_backward(params_2d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol im_size = params_2d["im_size"] numpoints = params_2d["numpoints"] x = params_2d["x"] y = params_2d["y"] ktraj = params_2d["ktraj"] smap = params_2d["smap"] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) sensenufft_ob = MriSenseNufft(smap=smap, im_size=im_size, numpoints=numpoints, coilpack=True).to(dtype=dtype, device=device) adjsensenufft_ob = AdjMriSenseNufft(smap=smap, im_size=im_size, numpoints=numpoints, coilpack=True).to(dtype=dtype, device=device) real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob) interp_mats = { "real_interp_mats": real_mat, "imag_interp_mats": imag_mat } y.requires_grad = True x = adjsensenufft_ob.forward(y, ktraj, interp_mats) ((x**2) / 2).sum().backward() y_grad = y.grad.clone().detach() y_hat = sensenufft_ob.forward(x.clone().detach(), ktraj, interp_mats) assert torch.norm(y_grad - y_hat) < norm_tol
def test_3d_mrisensenufft_coilpack_backward(params_2d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol im_size = params_2d['im_size'] numpoints = params_2d['numpoints'] x = params_2d['x'] y = params_2d['y'] ktraj = params_2d['ktraj'] smap = params_2d['smap'] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) sensenufft_ob = MriSenseNufft(smap=smap, im_size=im_size, numpoints=numpoints, coilpack=True).to(dtype=dtype, device=device) adjsensenufft_ob = AdjMriSenseNufft(smap=smap, im_size=im_size, numpoints=numpoints, coilpack=True).to(dtype=dtype, device=device) real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob) interp_mats = { 'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat } x.requires_grad = True y = sensenufft_ob.forward(x, ktraj, interp_mats) ((y**2) / 2).sum().backward() x_grad = x.grad.clone().detach() x_hat = adjsensenufft_ob.forward(y.clone().detach(), ktraj, interp_mats) assert torch.norm(x_grad - x_hat) < norm_tol
def test_3d_interp_backward(params_3d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol batch_size = params_3d['batch_size'] im_size = params_3d['im_size'] grid_size = params_3d['grid_size'] numpoints = params_3d['numpoints'] x = np.random.normal(size=(batch_size, 1) + grid_size) + \ 1j*np.random.normal(size=(batch_size, 1) + grid_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)) y = params_3d['y'] ktraj = params_3d['ktraj'] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbinterp_ob = KbInterpForw(im_size=im_size, grid_size=grid_size, numpoints=numpoints).to(dtype=dtype, device=device) adjkbinterp_ob = KbInterpBack(im_size=im_size, grid_size=grid_size, numpoints=numpoints).to(dtype=dtype, device=device) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob) interp_mats = { 'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat } x.requires_grad = True y = kbinterp_ob.forward(x, ktraj, interp_mats) ((y**2) / 2).sum().backward() x_grad = x.grad.clone().detach() x_hat = adjkbinterp_ob.forward(y.clone().detach(), ktraj, interp_mats) assert torch.norm(x_grad - x_hat) < norm_tol
def test_interp_2d_adjoint(params_2d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol batch_size = params_2d["batch_size"] im_size = params_2d["im_size"] grid_size = params_2d["grid_size"] numpoints = params_2d["numpoints"] x = np.random.normal( size=(batch_size, 1) + grid_size) + 1j * np.random.normal(size=(batch_size, 1) + grid_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)) y = params_2d["y"] ktraj = params_2d["ktraj"] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbinterp_ob = KbInterpForw(im_size=im_size, grid_size=grid_size, numpoints=numpoints).to(dtype=dtype, device=device) adjkbinterp_ob = KbInterpBack(im_size=im_size, grid_size=grid_size, numpoints=numpoints).to(dtype=dtype, device=device) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob) interp_mats = { "real_interp_mats": real_mat, "imag_interp_mats": imag_mat } x_forw = kbinterp_ob(x, ktraj, interp_mats) y_back = adjkbinterp_ob(y, ktraj, interp_mats) inprod1 = inner_product(y, x_forw, dim=2) inprod2 = inner_product(y_back, x, dim=2) assert torch.norm(inprod1 - inprod2) < norm_tol
def test_mrisensenufft_3d_coilpack_adjoint(params_2d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol im_size = params_2d["im_size"] numpoints = params_2d["numpoints"] x = params_2d["x"] y = params_2d["y"] ktraj = params_2d["ktraj"] smap = params_2d["smap"] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) sensenufft_ob = MriSenseNufft(smap=smap, im_size=im_size, numpoints=numpoints, coilpack=True).to(dtype=dtype, device=device) adjsensenufft_ob = AdjMriSenseNufft(smap=smap, im_size=im_size, numpoints=numpoints, coilpack=True).to(dtype=dtype, device=device) real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob) interp_mats = { "real_interp_mats": real_mat, "imag_interp_mats": imag_mat } x_forw = sensenufft_ob(x, ktraj, interp_mats) y_back = adjsensenufft_ob(y, ktraj, interp_mats) inprod1 = inner_product(y, x_forw, dim=2) inprod2 = inner_product(y_back, x, dim=2) assert torch.norm(inprod1 - inprod2) < norm_tol
def test_mrisensenufft_2d_adjoint(params_2d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol im_size = params_2d['im_size'] numpoints = params_2d['numpoints'] x = params_2d['x'] y = params_2d['y'] ktraj = params_2d['ktraj'] smap = params_2d['smap'] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) sensenufft_ob = MriSenseNufft(smap=smap, im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) adjsensenufft_ob = AdjMriSenseNufft(smap=smap, im_size=im_size, numpoints=numpoints).to( dtype=dtype, device=device) real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob) interp_mats = { 'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat } x_forw = sensenufft_ob(x, ktraj, interp_mats) y_back = adjsensenufft_ob(y, ktraj, interp_mats) inprod1 = inner_product(y, x_forw, dim=2) inprod2 = inner_product(y_back, x, dim=2) assert torch.norm(inprod1 - inprod2) < norm_tol
def test_3d_kbnufft_backward(params_3d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol im_size = params_3d["im_size"] numpoints = params_3d["numpoints"] x = params_3d["x"] y = params_3d["y"] ktraj = params_3d["ktraj"] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbnufft_ob = KbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob) interp_mats = { "real_interp_mats": real_mat, "imag_interp_mats": imag_mat } x.requires_grad = True y = kbnufft_ob.forward(x, ktraj, interp_mats) ((y**2) / 2).sum().backward() x_grad = x.grad.clone().detach() x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj, interp_mats) assert torch.norm(x_grad - x_hat) < norm_tol
def test_3d_kbnufft_adjoint_backward(params_3d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol im_size = params_3d['im_size'] numpoints = params_3d['numpoints'] x = params_3d['x'] y = params_3d['y'] ktraj = params_3d['ktraj'] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbnufft_ob = KbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob) interp_mats = { 'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat } y.requires_grad = True x = adjkbnufft_ob.forward(y, ktraj, interp_mats) ((x**2) / 2).sum().backward() y_grad = y.grad.clone().detach() y_hat = kbnufft_ob.forward(x.clone().detach(), ktraj, interp_mats) assert torch.norm(y_grad - y_hat) < norm_tol
def test_2d_interp_adjoint_backward(): dtype = torch.double nslice = 2 ncoil = 4 im_size = (33, 24) klength = 112 x = np.random.normal(size=(nslice, ncoil) + im_size) + \ 1j*np.random.normal(size=(nslice, ncoil) + im_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype) y = np.random.normal(size=(nslice, ncoil, klength)) + \ 1j*np.random.normal(size=(nslice, ncoil, klength)) y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype) ktraj = torch.randn(*(nslice, 2, klength)).to(dtype) kbinterp_ob = KbInterpForw(im_size=(20, 25), grid_size=im_size, numpoints=(4, 6)) adjkbinterp_ob = KbInterpBack(im_size=(20, 25), grid_size=im_size, numpoints=(4, 6)) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob) interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat} y.requires_grad = True x = adjkbinterp_ob.forward(y, ktraj, interp_mats) ((x**2) / 2).sum().backward() y_grad = y.grad.clone().detach() y_hat = kbinterp_ob.forward(x.clone().detach(), ktraj, interp_mats) assert torch.norm(y_grad - y_hat) < norm_tol
def test_nufft_2d_adjoint(params_2d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol im_size = params_2d["im_size"] numpoints = params_2d["numpoints"] x = params_2d["x"] y = params_2d["y"] ktraj = params_2d["ktraj"] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbnufft_ob = KbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob) interp_mats = { "real_interp_mats": real_mat, "imag_interp_mats": imag_mat } x_forw = kbnufft_ob(x, ktraj, interp_mats) y_back = adjkbnufft_ob(y, ktraj, interp_mats) inprod1 = inner_product(y, x_forw, dim=2) inprod2 = inner_product(y_back, x, dim=2) assert torch.norm(inprod1 - inprod2) < norm_tol
def test_3d_mrisensenufft_adjoint_backward(): dtype = torch.double nslice = 2 ncoil = 4 im_size = (11, 33, 24) klength = 112 x = np.random.normal(size=(nslice, 1) + im_size) + \ 1j*np.random.normal(size=(nslice, 1) + im_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype) y = np.random.normal(size=(nslice, ncoil, klength)) + \ 1j*np.random.normal(size=(nslice, ncoil, klength)) y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype) ktraj = torch.randn(*(nslice, 3, klength)).to(dtype) smap_sz = (nslice, ncoil, 2) + im_size smap = torch.randn(*smap_sz).to(dtype) sensenufft_ob = MriSenseNufft(smap=smap, im_size=im_size) adjsensenufft_ob = AdjMriSenseNufft(smap=smap, im_size=im_size) real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob) interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat} y.requires_grad = True x = adjsensenufft_ob.forward(y, ktraj, interp_mats) ((x**2) / 2).sum().backward() y_grad = y.grad.clone().detach() y_hat = sensenufft_ob.forward(x.clone().detach(), ktraj, interp_mats) assert torch.norm(y_grad - y_hat) < norm_tol
def profile_torchkbnufft(image, ktraj, smap, im_size, device, sparse_mats_flag=False, use_toep=False): # run double precision for CPU, float for GPU # these seem to be present in reference implementations if device == torch.device("cpu"): dtype = torch.double if use_toep: num_nuffts = 20 else: num_nuffts = 5 else: dtype = torch.float if use_toep: num_nuffts = 50 else: num_nuffts = 20 cpudevice = torch.device("cpu") image = image.to(dtype=dtype) ktraj = ktraj.to(dtype=dtype) smap = smap.to(dtype=dtype) kbsense_ob = MriSenseNufft(smap=smap, im_size=im_size).to(dtype=dtype, device=device) adjkbsense_ob = AdjMriSenseNufft(smap=smap, im_size=im_size).to(dtype=dtype, device=device) adjkbnufft_ob = AdjKbNufft(im_size=im_size).to(dtype=dtype, device=device) # precompute toeplitz kernel if using toeplitz if use_toep: print("using toeplitz for forward/backward") kern = calc_toep_kernel(adjkbsense_ob, ktraj) toep_ob = ToepSenseNufft(smap=smap).to(dtype=dtype, device=device) # precompute the sparse interpolation matrices if sparse_mats_flag: print("using sparse interpolation matrices") real_mat, imag_mat = precomp_sparse_mats(ktraj, adjkbnufft_ob) interp_mats = { "real_interp_mats": real_mat, "imag_interp_mats": imag_mat } else: print("not using sparse interpolation matrices") interp_mats = None if use_toep: # warm-up computation for _ in range(num_nuffts): x = toep_ob(image.to(device=device), kern.to(device=device)).to(cpudevice) # run the speed tests if device == torch.device("cuda"): torch.cuda.reset_max_memory_allocated() torch.cuda.synchronize() start_time = time.perf_counter() for _ in range(num_nuffts): x = toep_ob(image.to(device=device), kern.to(device=device)) if device == torch.device("cuda"): torch.cuda.synchronize() max_mem = torch.cuda.max_memory_allocated() print("GPU forward max memory: {} GB".format(max_mem / 1e9)) end_time = time.perf_counter() avg_time = (end_time - start_time) / num_nuffts print("toeplitz forward/backward average time: {}".format(avg_time)) else: # warm-up computation for _ in range(num_nuffts): y = kbsense_ob(image.to(device=device), ktraj.to(device=device), interp_mats).to(cpudevice) # run the forward speed tests if device == torch.device("cuda"): torch.cuda.reset_max_memory_allocated() torch.cuda.synchronize() start_time = time.perf_counter() for _ in range(num_nuffts): y = kbsense_ob(image.to(device=device), ktraj.to(device=device), interp_mats) if device == torch.device("cuda"): torch.cuda.synchronize() max_mem = torch.cuda.max_memory_allocated() print("GPU forward max memory: {} GB".format(max_mem / 1e9)) end_time = time.perf_counter() avg_time = (end_time - start_time) / num_nuffts print("forward average time: {}".format(avg_time)) # warm-up computation for _ in range(num_nuffts): x = adjkbsense_ob(y.to(device), ktraj.to(device), interp_mats) # run the adjoint speed tests if device == torch.device("cuda"): torch.cuda.reset_max_memory_allocated() torch.cuda.synchronize() start_time = time.perf_counter() for _ in range(num_nuffts): x = adjkbsense_ob(y.to(device), ktraj.to(device), interp_mats) if device == torch.device("cuda"): torch.cuda.synchronize() max_mem = torch.cuda.max_memory_allocated() print("GPU adjoint max memory: {} GB".format(max_mem / 1e9)) end_time = time.perf_counter() avg_time = (end_time - start_time) / num_nuffts print("backward average time: {}".format(avg_time))