def test_interp_adjoint_gpu(shape, kdata_shape, is_complex): if not torch.cuda.is_available(): pytest.skip() device = torch.device("cuda") default_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.double) torch.manual_seed(123) if is_complex: im_size = shape[2:] else: im_size = shape[2:-1] image = create_input_plus_noise(shape, is_complex).to(device) kdata = create_input_plus_noise(kdata_shape, is_complex).to(device) ktraj = create_ktraj(len(im_size), kdata_shape[2]).to(device) forw_ob = tkbn.KbInterp(im_size=im_size, grid_size=im_size).to(device) adj_ob = tkbn.KbInterpAdjoint(im_size=im_size, grid_size=im_size).to(device) # test with sparse matrices spmat = tkbn.calc_tensor_spmatrix( ktraj, im_size, grid_size=im_size, ) nufft_adjoint_test(image, kdata, ktraj, forw_ob, adj_ob, spmat) torch.set_default_dtype(default_dtype)
def test_interp_autograd(shape, kdata_shape, is_complex): default_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.double) torch.manual_seed(123) if is_complex: im_size = shape[2:] else: im_size = shape[2:-1] image = create_input_plus_noise(shape, is_complex) kdata = create_input_plus_noise(kdata_shape, is_complex) ktraj = create_ktraj(len(im_size), kdata_shape[2]) forw_ob = tkbn.KbInterp(im_size=im_size, grid_size=im_size) adj_ob = tkbn.KbInterpAdjoint(im_size=im_size, grid_size=im_size) # test with sparse matrices spmat = tkbn.calc_tensor_spmatrix( ktraj, im_size, grid_size=im_size, ) nufft_autograd_test(image, kdata, ktraj, forw_ob, adj_ob, spmat) torch.set_default_dtype(default_dtype)
def test_interp_batches(shape, kdata_shape, is_complex): default_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.double) torch.manual_seed(123) if is_complex: im_size = shape[2:] else: im_size = shape[2:-1] image = create_input_plus_noise(shape, is_complex) kdata = create_input_plus_noise(kdata_shape, is_complex) ktraj = ( torch.rand(size=(shape[0], len(im_size), kdata_shape[2])) * 2 * np.pi - np.pi) forw_ob = tkbn.KbInterp(im_size=im_size, grid_size=im_size) adj_ob = tkbn.KbInterpAdjoint(im_size=im_size, grid_size=im_size) forloop_test_forw = [] for image_it, ktraj_it in zip(image, ktraj): forloop_test_forw.append(forw_ob(image_it.unsqueeze(0), ktraj_it)) batched_test_forw = forw_ob(image, ktraj) assert torch.allclose(torch.cat(forloop_test_forw), batched_test_forw) forloop_test_adj = [] for data_it, ktraj_it in zip(kdata, ktraj): forloop_test_adj.append(adj_ob(data_it.unsqueeze(0), ktraj_it)) batched_test_adj = adj_ob(kdata, ktraj) assert torch.allclose(torch.cat(forloop_test_adj), batched_test_adj) torch.set_default_dtype(default_dtype)