Example #1
0
def test_2d_kbnufft_backward(params_2d, testing_tol, testing_dtype,
                             device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_2d["im_size"]
    numpoints = params_2d["numpoints"]

    x = params_2d["x"]
    y = params_2d["y"]
    ktraj = params_2d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(im_size=im_size,
                             numpoints=numpoints).to(dtype=dtype,
                                                     device=device)
        adjkbnufft_ob = AdjKbNufft(im_size=im_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)

        x.requires_grad = True
        y = kbnufft_ob.forward(x, ktraj)

        ((y**2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()

        x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj)

        assert torch.norm(x_grad - x_hat) < norm_tol
Example #2
0
def test_2d_kbnufft_backward():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, ncoil) + im_size) + \
        1j*np.random.normal(size=(nslice, ncoil) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 2, klength)).to(dtype)

    kbnufft_ob = KbNufft(im_size=im_size, numpoints=(4, 6))
    adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(4, 6))

    x.requires_grad = True
    y = kbnufft_ob.forward(x, ktraj)

    ((y**2) / 2).sum().backward()
    x_grad = x.grad.clone().detach()

    x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj)

    assert torch.norm(x_grad - x_hat) < norm_tol
Example #3
0
def test_3d_kbnufft_adjoint_backward():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (11, 33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, ncoil) + im_size) + \
        1j*np.random.normal(size=(nslice, ncoil) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 3, klength)).to(dtype)

    kbnufft_ob = KbNufft(im_size=im_size, numpoints=(2, 4, 6))
    adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(2, 4, 6))

    real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
    interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat}

    y.requires_grad = True
    x = adjkbnufft_ob.forward(y, ktraj, interp_mats)

    ((x**2) / 2).sum().backward()
    y_grad = y.grad.clone().detach()

    y_hat = kbnufft_ob.forward(x.clone().detach(), ktraj, interp_mats)

    assert torch.norm(y_grad - y_hat) < norm_tol
def test_3d_kbnufft_adjoint_backward(params_3d, testing_tol, testing_dtype,
                                     device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_3d["im_size"]
    numpoints = params_3d["numpoints"]

    x = params_3d["x"]
    y = params_3d["y"]
    ktraj = params_3d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(im_size=im_size,
                             numpoints=numpoints).to(dtype=dtype,
                                                     device=device)
        adjkbnufft_ob = AdjKbNufft(im_size=im_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
        interp_mats = {
            "real_interp_mats": real_mat,
            "imag_interp_mats": imag_mat
        }

        y.requires_grad = True
        x = adjkbnufft_ob.forward(y, ktraj, interp_mats)

        ((x**2) / 2).sum().backward()
        y_grad = y.grad.clone().detach()

        y_hat = kbnufft_ob.forward(x.clone().detach(), ktraj, interp_mats)

        assert torch.norm(y_grad - y_hat) < norm_tol
def test_3d_kbnufft_backward(params_3d, testing_tol, testing_dtype,
                             device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_3d['im_size']
    numpoints = params_3d['numpoints']

    x = params_3d['x']
    y = params_3d['y']
    ktraj = params_3d['ktraj']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(im_size=im_size,
                             numpoints=numpoints).to(dtype=dtype,
                                                     device=device)
        adjkbnufft_ob = AdjKbNufft(im_size=im_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x.requires_grad = True
        y = kbnufft_ob.forward(x, ktraj, interp_mats)

        ((y**2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()

        x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj, interp_mats)

        assert torch.norm(x_grad - x_hat) < norm_tol