Exemple #1
0
def test_interp_adjoint_gpu(shape, kdata_shape, is_complex):
    if not torch.cuda.is_available():
        pytest.skip()

    device = torch.device("cuda")

    default_dtype = torch.get_default_dtype()
    torch.set_default_dtype(torch.double)
    torch.manual_seed(123)
    if is_complex:
        im_size = shape[2:]
    else:
        im_size = shape[2:-1]

    image = create_input_plus_noise(shape, is_complex).to(device)
    kdata = create_input_plus_noise(kdata_shape, is_complex).to(device)
    ktraj = create_ktraj(len(im_size), kdata_shape[2]).to(device)

    forw_ob = tkbn.KbInterp(im_size=im_size, grid_size=im_size).to(device)
    adj_ob = tkbn.KbInterpAdjoint(im_size=im_size,
                                  grid_size=im_size).to(device)

    # test with sparse matrices
    spmat = tkbn.calc_tensor_spmatrix(
        ktraj,
        im_size,
        grid_size=im_size,
    )

    nufft_adjoint_test(image, kdata, ktraj, forw_ob, adj_ob, spmat)

    torch.set_default_dtype(default_dtype)
Exemple #2
0
def test_interp_autograd(shape, kdata_shape, is_complex):
    default_dtype = torch.get_default_dtype()
    torch.set_default_dtype(torch.double)
    torch.manual_seed(123)
    if is_complex:
        im_size = shape[2:]
    else:
        im_size = shape[2:-1]

    image = create_input_plus_noise(shape, is_complex)
    kdata = create_input_plus_noise(kdata_shape, is_complex)
    ktraj = create_ktraj(len(im_size), kdata_shape[2])

    forw_ob = tkbn.KbInterp(im_size=im_size, grid_size=im_size)
    adj_ob = tkbn.KbInterpAdjoint(im_size=im_size, grid_size=im_size)

    # test with sparse matrices
    spmat = tkbn.calc_tensor_spmatrix(
        ktraj,
        im_size,
        grid_size=im_size,
    )

    nufft_autograd_test(image, kdata, ktraj, forw_ob, adj_ob, spmat)

    torch.set_default_dtype(default_dtype)
Exemple #3
0
def test_interp_batches(shape, kdata_shape, is_complex):
    default_dtype = torch.get_default_dtype()
    torch.set_default_dtype(torch.double)
    torch.manual_seed(123)
    if is_complex:
        im_size = shape[2:]
    else:
        im_size = shape[2:-1]

    image = create_input_plus_noise(shape, is_complex)
    kdata = create_input_plus_noise(kdata_shape, is_complex)
    ktraj = (
        torch.rand(size=(shape[0], len(im_size), kdata_shape[2])) * 2 * np.pi -
        np.pi)

    forw_ob = tkbn.KbInterp(im_size=im_size, grid_size=im_size)
    adj_ob = tkbn.KbInterpAdjoint(im_size=im_size, grid_size=im_size)

    forloop_test_forw = []
    for image_it, ktraj_it in zip(image, ktraj):
        forloop_test_forw.append(forw_ob(image_it.unsqueeze(0), ktraj_it))

    batched_test_forw = forw_ob(image, ktraj)

    assert torch.allclose(torch.cat(forloop_test_forw), batched_test_forw)

    forloop_test_adj = []
    for data_it, ktraj_it in zip(kdata, ktraj):
        forloop_test_adj.append(adj_ob(data_it.unsqueeze(0), ktraj_it))

    batched_test_adj = adj_ob(kdata, ktraj)

    assert torch.allclose(torch.cat(forloop_test_adj), batched_test_adj)

    torch.set_default_dtype(default_dtype)