def test_interp_autograd(shape, kdata_shape, is_complex): default_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.double) torch.manual_seed(123) if is_complex: im_size = shape[2:] else: im_size = shape[2:-1] image = create_input_plus_noise(shape, is_complex) kdata = create_input_plus_noise(kdata_shape, is_complex) ktraj = create_ktraj(len(im_size), kdata_shape[2]) forw_ob = tkbn.KbInterp(im_size=im_size, grid_size=im_size) adj_ob = tkbn.KbInterpAdjoint(im_size=im_size, grid_size=im_size) # test with sparse matrices spmat = tkbn.calc_tensor_spmatrix( ktraj, im_size, grid_size=im_size, ) nufft_autograd_test(image, kdata, ktraj, forw_ob, adj_ob, spmat) torch.set_default_dtype(default_dtype)
def test_interp_complex_real_match(shape, kdata_shape, is_complex): default_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.double) torch.manual_seed(123) im_size = shape[2:] image = create_input_plus_noise(shape, is_complex) ktraj = create_ktraj(len(im_size), kdata_shape[2]) forw_ob = tkbn.KbInterp(im_size=im_size, grid_size=im_size) kdata_complex = forw_ob(image, ktraj) kdata_real = torch.view_as_complex( forw_ob(torch.view_as_real(image), ktraj)) assert torch.allclose(kdata_complex, kdata_real) # test with sparse matrices spmat = tkbn.calc_tensor_spmatrix( ktraj, im_size, grid_size=im_size, ) kdata_complex = forw_ob(image, ktraj, spmat) kdata_real = torch.view_as_complex( forw_ob(torch.view_as_real(image), ktraj, spmat)) assert torch.allclose(kdata_complex, kdata_real) torch.set_default_dtype(default_dtype)
def test_interp_adjoint_gpu(shape, kdata_shape, is_complex): if not torch.cuda.is_available(): pytest.skip() device = torch.device("cuda") default_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.double) torch.manual_seed(123) if is_complex: im_size = shape[2:] else: im_size = shape[2:-1] image = create_input_plus_noise(shape, is_complex).to(device) kdata = create_input_plus_noise(kdata_shape, is_complex).to(device) ktraj = create_ktraj(len(im_size), kdata_shape[2]).to(device) forw_ob = tkbn.KbInterp(im_size=im_size, grid_size=im_size).to(device) adj_ob = tkbn.KbInterpAdjoint(im_size=im_size, grid_size=im_size).to(device) # test with sparse matrices spmat = tkbn.calc_tensor_spmatrix( ktraj, im_size, grid_size=im_size, ) nufft_adjoint_test(image, kdata, ktraj, forw_ob, adj_ob, spmat) torch.set_default_dtype(default_dtype)
def profile_torchkbnufft( image, ktraj, smap, im_size, grid_size, device, sparse_mats_flag=False, toep_flag=False, ): # run double precision for CPU, float for GPU # these seem to be present in reference implementations if device == torch.device("cpu"): complex_dtype = torch.complex128 real_dtype = torch.double if toep_flag: num_nuffts = 20 else: num_nuffts = 5 else: complex_dtype = torch.complex64 real_dtype = torch.float if toep_flag: num_nuffts = 50 else: num_nuffts = 20 cpudevice = torch.device("cpu") res = "" image = image.to(dtype=complex_dtype) ktraj = ktraj.to(dtype=real_dtype) smap = smap.to(dtype=complex_dtype) interp_mats = None forw_ob = tkbn.KbNufft(im_size=im_size, grid_size=grid_size, dtype=complex_dtype, device=device) adj_ob = tkbn.KbNufftAdjoint(im_size=im_size, grid_size=grid_size, dtype=complex_dtype, device=device) # precompute toeplitz kernel if using toeplitz if toep_flag: kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size, grid_size=grid_size) toep_ob = tkbn.ToepNufft() # precompute the sparse interpolation matrices if sparse_mats_flag: interp_mats = tkbn.calc_tensor_spmatrix(ktraj, im_size, grid_size=grid_size) interp_mats = tuple([t.to(device) for t in interp_mats]) if toep_flag: # warm-up computation for _ in range(num_nuffts): x = toep_ob( image.to(device=device), kernel.to(device=device), smaps=smap.to(device=device), ).to(cpudevice) # run the speed tests if device == torch.device("cuda"): torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() start_time = time.perf_counter() for _ in range(num_nuffts): x = toep_ob(image.to(device=device), kernel.to(device=device), smaps=smap.to(device)) if device == torch.device("cuda"): torch.cuda.synchronize() max_mem = torch.cuda.max_memory_allocated() res += "GPU forward max memory: {} GB, ".format(max_mem / 1e9) end_time = time.perf_counter() avg_time = (end_time - start_time) / num_nuffts res += "toeplitz forward/backward average time: {}".format(avg_time) else: # warm-up computation for _ in range(num_nuffts): y = forw_ob( image.to(device=device), ktraj.to(device=device), interp_mats, smaps=smap.to(device), ).to(cpudevice) # run the forward speed tests if device == torch.device("cuda"): torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() start_time = time.perf_counter() for _ in range(num_nuffts): y = forw_ob( image.to(device=device), ktraj.to(device=device), interp_mats, smaps=smap.to(device), ) if device == torch.device("cuda"): torch.cuda.synchronize() max_mem = torch.cuda.max_memory_allocated() res += "GPU forward max memory: {} GB, ".format(max_mem / 1e9) end_time = time.perf_counter() avg_time = (end_time - start_time) / num_nuffts res += "forward average time: {}, ".format(avg_time) # warm-up computation for _ in range(num_nuffts): x = adj_ob(y.to(device), ktraj.to(device), interp_mats, smaps=smap.to(device)) # run the adjoint speed tests if device == torch.device("cuda"): torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() start_time = time.perf_counter() for _ in range(num_nuffts): x = adj_ob(y.to(device), ktraj.to(device), interp_mats, smaps=smap.to(device)) if device == torch.device("cuda"): torch.cuda.synchronize() max_mem = torch.cuda.max_memory_allocated() res += "GPU adjoint max memory: {} GB, ".format(max_mem / 1e9) end_time = time.perf_counter() avg_time = (end_time - start_time) / num_nuffts res += "backward average time: {}".format(avg_time) print(res)