def test_nufft_2d_adjoint():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, ncoil) + im_size) + \
        1j*np.random.normal(size=(nslice, ncoil) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 2, klength)).to(dtype)

    kbnufft_ob = KbNufft(im_size=im_size, numpoints=(4, 6))
    adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(4, 6))

    real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
    interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat}

    x_forw = kbnufft_ob(x, ktraj, interp_mats)
    y_back = adjkbnufft_ob(y, ktraj, interp_mats)

    inprod1 = inner_product(y, x_forw, dim=2)
    inprod2 = inner_product(y_back, x, dim=2)

    assert torch.norm(inprod1 - inprod2) < norm_tol
def test_mrisensenufft_2d_adjoint():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, 1) + im_size) + \
        1j*np.random.normal(size=(nslice, 1) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 2, klength)).to(dtype)

    smap_sz = (nslice, ncoil, 2) + im_size
    smap = torch.randn(*smap_sz).to(dtype)

    sensenufft_ob = MriSenseNufft(smap=smap, im_size=im_size)
    adjsensenufft_ob = AdjMriSenseNufft(smap=smap, im_size=im_size)

    real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
    interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat}

    x_forw = sensenufft_ob(x, ktraj, interp_mats)
    y_back = adjsensenufft_ob(y, ktraj, interp_mats)

    inprod1 = inner_product(y, x_forw, dim=2)
    inprod2 = inner_product(y_back, x, dim=2)

    assert torch.norm(inprod1 - inprod2) < norm_tol
Example #3
0
def test_2d_kbnufft_backward():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, ncoil) + im_size) + \
        1j*np.random.normal(size=(nslice, ncoil) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 2, klength)).to(dtype)

    kbnufft_ob = KbNufft(im_size=im_size, numpoints=(4, 6))
    adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(4, 6))

    real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
    interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat}

    x.requires_grad = True
    y = kbnufft_ob.forward(x, ktraj, interp_mats)

    ((y**2) / 2).sum().backward()
    x_grad = x.grad.clone().detach()

    x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj, interp_mats)

    assert torch.norm(x_grad - x_hat) < norm_tol
def test_2d_interp_adjoint_backward(params_2d, testing_tol, testing_dtype,
                                    device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    batch_size = params_2d["batch_size"]
    im_size = params_2d["im_size"]
    grid_size = params_2d["grid_size"]
    numpoints = params_2d["numpoints"]

    x = np.random.normal(
        size=(batch_size, 1) +
        grid_size) + 1j * np.random.normal(size=(batch_size, 1) + grid_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2))
    y = params_2d["y"]
    ktraj = params_2d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbinterp_ob = KbInterpForw(im_size=im_size,
                                   grid_size=grid_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)
        adjkbinterp_ob = KbInterpBack(im_size=im_size,
                                      grid_size=grid_size,
                                      numpoints=numpoints).to(dtype=dtype,
                                                              device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
        interp_mats = {
            "real_interp_mats": real_mat,
            "imag_interp_mats": imag_mat
        }

        y.requires_grad = True
        x = adjkbinterp_ob.forward(y, ktraj, interp_mats)

        ((x**2) / 2).sum().backward()
        y_grad = y.grad.clone().detach()

        y_hat = kbinterp_ob.forward(x.clone().detach(), ktraj, interp_mats)

        assert torch.norm(y_grad - y_hat) < norm_tol
def test_3d_mrisensenufft_coilpack_adjoint_backward(params_2d, testing_tol,
                                                    testing_dtype,
                                                    device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_2d["im_size"]
    numpoints = params_2d["numpoints"]

    x = params_2d["x"]
    y = params_2d["y"]
    ktraj = params_2d["ktraj"]
    smap = params_2d["smap"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(smap=smap,
                                      im_size=im_size,
                                      numpoints=numpoints,
                                      coilpack=True).to(dtype=dtype,
                                                        device=device)
        adjsensenufft_ob = AdjMriSenseNufft(smap=smap,
                                            im_size=im_size,
                                            numpoints=numpoints,
                                            coilpack=True).to(dtype=dtype,
                                                              device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
        interp_mats = {
            "real_interp_mats": real_mat,
            "imag_interp_mats": imag_mat
        }

        y.requires_grad = True
        x = adjsensenufft_ob.forward(y, ktraj, interp_mats)

        ((x**2) / 2).sum().backward()
        y_grad = y.grad.clone().detach()

        y_hat = sensenufft_ob.forward(x.clone().detach(), ktraj, interp_mats)

        assert torch.norm(y_grad - y_hat) < norm_tol
def test_3d_mrisensenufft_coilpack_backward(params_2d, testing_tol,
                                            testing_dtype, device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_2d['im_size']
    numpoints = params_2d['numpoints']

    x = params_2d['x']
    y = params_2d['y']
    ktraj = params_2d['ktraj']
    smap = params_2d['smap']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(smap=smap,
                                      im_size=im_size,
                                      numpoints=numpoints,
                                      coilpack=True).to(dtype=dtype,
                                                        device=device)
        adjsensenufft_ob = AdjMriSenseNufft(smap=smap,
                                            im_size=im_size,
                                            numpoints=numpoints,
                                            coilpack=True).to(dtype=dtype,
                                                              device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x.requires_grad = True
        y = sensenufft_ob.forward(x, ktraj, interp_mats)

        ((y**2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()

        x_hat = adjsensenufft_ob.forward(y.clone().detach(), ktraj,
                                         interp_mats)

        assert torch.norm(x_grad - x_hat) < norm_tol
def test_3d_interp_backward(params_3d, testing_tol, testing_dtype,
                            device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    batch_size = params_3d['batch_size']
    im_size = params_3d['im_size']
    grid_size = params_3d['grid_size']
    numpoints = params_3d['numpoints']

    x = np.random.normal(size=(batch_size, 1) + grid_size) + \
        1j*np.random.normal(size=(batch_size, 1) + grid_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2))
    y = params_3d['y']
    ktraj = params_3d['ktraj']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbinterp_ob = KbInterpForw(im_size=im_size,
                                   grid_size=grid_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)
        adjkbinterp_ob = KbInterpBack(im_size=im_size,
                                      grid_size=grid_size,
                                      numpoints=numpoints).to(dtype=dtype,
                                                              device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x.requires_grad = True
        y = kbinterp_ob.forward(x, ktraj, interp_mats)

        ((y**2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()

        x_hat = adjkbinterp_ob.forward(y.clone().detach(), ktraj, interp_mats)

        assert torch.norm(x_grad - x_hat) < norm_tol
def test_interp_2d_adjoint(params_2d, testing_tol, testing_dtype, device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    batch_size = params_2d["batch_size"]
    im_size = params_2d["im_size"]
    grid_size = params_2d["grid_size"]
    numpoints = params_2d["numpoints"]

    x = np.random.normal(
        size=(batch_size, 1) +
        grid_size) + 1j * np.random.normal(size=(batch_size, 1) + grid_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2))
    y = params_2d["y"]
    ktraj = params_2d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbinterp_ob = KbInterpForw(im_size=im_size,
                                   grid_size=grid_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)
        adjkbinterp_ob = KbInterpBack(im_size=im_size,
                                      grid_size=grid_size,
                                      numpoints=numpoints).to(dtype=dtype,
                                                              device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
        interp_mats = {
            "real_interp_mats": real_mat,
            "imag_interp_mats": imag_mat
        }

        x_forw = kbinterp_ob(x, ktraj, interp_mats)
        y_back = adjkbinterp_ob(y, ktraj, interp_mats)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)

        assert torch.norm(inprod1 - inprod2) < norm_tol
def test_mrisensenufft_3d_coilpack_adjoint(params_2d, testing_tol,
                                           testing_dtype, device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_2d["im_size"]
    numpoints = params_2d["numpoints"]

    x = params_2d["x"]
    y = params_2d["y"]
    ktraj = params_2d["ktraj"]
    smap = params_2d["smap"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(smap=smap,
                                      im_size=im_size,
                                      numpoints=numpoints,
                                      coilpack=True).to(dtype=dtype,
                                                        device=device)
        adjsensenufft_ob = AdjMriSenseNufft(smap=smap,
                                            im_size=im_size,
                                            numpoints=numpoints,
                                            coilpack=True).to(dtype=dtype,
                                                              device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
        interp_mats = {
            "real_interp_mats": real_mat,
            "imag_interp_mats": imag_mat
        }

        x_forw = sensenufft_ob(x, ktraj, interp_mats)
        y_back = adjsensenufft_ob(y, ktraj, interp_mats)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)

        assert torch.norm(inprod1 - inprod2) < norm_tol
def test_mrisensenufft_2d_adjoint(params_2d, testing_tol, testing_dtype,
                                  device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_2d['im_size']
    numpoints = params_2d['numpoints']

    x = params_2d['x']
    y = params_2d['y']
    ktraj = params_2d['ktraj']
    smap = params_2d['smap']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(smap=smap,
                                      im_size=im_size,
                                      numpoints=numpoints).to(dtype=dtype,
                                                              device=device)
        adjsensenufft_ob = AdjMriSenseNufft(smap=smap,
                                            im_size=im_size,
                                            numpoints=numpoints).to(
                                                dtype=dtype, device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x_forw = sensenufft_ob(x, ktraj, interp_mats)
        y_back = adjsensenufft_ob(y, ktraj, interp_mats)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)

        assert torch.norm(inprod1 - inprod2) < norm_tol
def test_3d_kbnufft_backward(params_3d, testing_tol, testing_dtype,
                             device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_3d["im_size"]
    numpoints = params_3d["numpoints"]

    x = params_3d["x"]
    y = params_3d["y"]
    ktraj = params_3d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(im_size=im_size,
                             numpoints=numpoints).to(dtype=dtype,
                                                     device=device)
        adjkbnufft_ob = AdjKbNufft(im_size=im_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
        interp_mats = {
            "real_interp_mats": real_mat,
            "imag_interp_mats": imag_mat
        }

        x.requires_grad = True
        y = kbnufft_ob.forward(x, ktraj, interp_mats)

        ((y**2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()

        x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj, interp_mats)

        assert torch.norm(x_grad - x_hat) < norm_tol
def test_3d_kbnufft_adjoint_backward(params_3d, testing_tol, testing_dtype,
                                     device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_3d['im_size']
    numpoints = params_3d['numpoints']

    x = params_3d['x']
    y = params_3d['y']
    ktraj = params_3d['ktraj']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(im_size=im_size,
                             numpoints=numpoints).to(dtype=dtype,
                                                     device=device)
        adjkbnufft_ob = AdjKbNufft(im_size=im_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        y.requires_grad = True
        x = adjkbnufft_ob.forward(y, ktraj, interp_mats)

        ((x**2) / 2).sum().backward()
        y_grad = y.grad.clone().detach()

        y_hat = kbnufft_ob.forward(x.clone().detach(), ktraj, interp_mats)

        assert torch.norm(y_grad - y_hat) < norm_tol
Example #13
0
def test_2d_interp_adjoint_backward():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, ncoil) + im_size) + \
        1j*np.random.normal(size=(nslice, ncoil) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 2, klength)).to(dtype)

    kbinterp_ob = KbInterpForw(im_size=(20, 25),
                               grid_size=im_size,
                               numpoints=(4, 6))
    adjkbinterp_ob = KbInterpBack(im_size=(20, 25),
                                  grid_size=im_size,
                                  numpoints=(4, 6))

    real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
    interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat}

    y.requires_grad = True
    x = adjkbinterp_ob.forward(y, ktraj, interp_mats)

    ((x**2) / 2).sum().backward()
    y_grad = y.grad.clone().detach()

    y_hat = kbinterp_ob.forward(x.clone().detach(), ktraj, interp_mats)

    assert torch.norm(y_grad - y_hat) < norm_tol
def test_nufft_2d_adjoint(params_2d, testing_tol, testing_dtype, device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_2d["im_size"]
    numpoints = params_2d["numpoints"]

    x = params_2d["x"]
    y = params_2d["y"]
    ktraj = params_2d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(im_size=im_size,
                             numpoints=numpoints).to(dtype=dtype,
                                                     device=device)
        adjkbnufft_ob = AdjKbNufft(im_size=im_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
        interp_mats = {
            "real_interp_mats": real_mat,
            "imag_interp_mats": imag_mat
        }

        x_forw = kbnufft_ob(x, ktraj, interp_mats)
        y_back = adjkbnufft_ob(y, ktraj, interp_mats)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)

        assert torch.norm(inprod1 - inprod2) < norm_tol
Example #15
0
def test_3d_mrisensenufft_adjoint_backward():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (11, 33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, 1) + im_size) + \
        1j*np.random.normal(size=(nslice, 1) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 3, klength)).to(dtype)

    smap_sz = (nslice, ncoil, 2) + im_size
    smap = torch.randn(*smap_sz).to(dtype)

    sensenufft_ob = MriSenseNufft(smap=smap, im_size=im_size)
    adjsensenufft_ob = AdjMriSenseNufft(smap=smap, im_size=im_size)

    real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
    interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat}

    y.requires_grad = True
    x = adjsensenufft_ob.forward(y, ktraj, interp_mats)

    ((x**2) / 2).sum().backward()
    y_grad = y.grad.clone().detach()

    y_hat = sensenufft_ob.forward(x.clone().detach(), ktraj, interp_mats)

    assert torch.norm(y_grad - y_hat) < norm_tol
Example #16
0
def profile_torchkbnufft(image,
                         ktraj,
                         smap,
                         im_size,
                         device,
                         sparse_mats_flag=False,
                         use_toep=False):
    # run double precision for CPU, float for GPU
    # these seem to be present in reference implementations
    if device == torch.device("cpu"):
        dtype = torch.double
        if use_toep:
            num_nuffts = 20
        else:
            num_nuffts = 5
    else:
        dtype = torch.float
        if use_toep:
            num_nuffts = 50
        else:
            num_nuffts = 20
    cpudevice = torch.device("cpu")

    image = image.to(dtype=dtype)
    ktraj = ktraj.to(dtype=dtype)
    smap = smap.to(dtype=dtype)

    kbsense_ob = MriSenseNufft(smap=smap, im_size=im_size).to(dtype=dtype,
                                                              device=device)
    adjkbsense_ob = AdjMriSenseNufft(smap=smap,
                                     im_size=im_size).to(dtype=dtype,
                                                         device=device)

    adjkbnufft_ob = AdjKbNufft(im_size=im_size).to(dtype=dtype, device=device)

    # precompute toeplitz kernel if using toeplitz
    if use_toep:
        print("using toeplitz for forward/backward")
        kern = calc_toep_kernel(adjkbsense_ob, ktraj)
        toep_ob = ToepSenseNufft(smap=smap).to(dtype=dtype, device=device)

    # precompute the sparse interpolation matrices
    if sparse_mats_flag:
        print("using sparse interpolation matrices")
        real_mat, imag_mat = precomp_sparse_mats(ktraj, adjkbnufft_ob)
        interp_mats = {
            "real_interp_mats": real_mat,
            "imag_interp_mats": imag_mat
        }
    else:
        print("not using sparse interpolation matrices")
        interp_mats = None

    if use_toep:
        # warm-up computation
        for _ in range(num_nuffts):
            x = toep_ob(image.to(device=device),
                        kern.to(device=device)).to(cpudevice)
        # run the speed tests
        if device == torch.device("cuda"):
            torch.cuda.reset_max_memory_allocated()
            torch.cuda.synchronize()
        start_time = time.perf_counter()
        for _ in range(num_nuffts):
            x = toep_ob(image.to(device=device), kern.to(device=device))
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
            max_mem = torch.cuda.max_memory_allocated()
            print("GPU forward max memory: {} GB".format(max_mem / 1e9))
        end_time = time.perf_counter()
        avg_time = (end_time - start_time) / num_nuffts
        print("toeplitz forward/backward average time: {}".format(avg_time))
    else:
        # warm-up computation
        for _ in range(num_nuffts):
            y = kbsense_ob(image.to(device=device), ktraj.to(device=device),
                           interp_mats).to(cpudevice)

        # run the forward speed tests
        if device == torch.device("cuda"):
            torch.cuda.reset_max_memory_allocated()
            torch.cuda.synchronize()
        start_time = time.perf_counter()
        for _ in range(num_nuffts):
            y = kbsense_ob(image.to(device=device), ktraj.to(device=device),
                           interp_mats)
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
            max_mem = torch.cuda.max_memory_allocated()
            print("GPU forward max memory: {} GB".format(max_mem / 1e9))
        end_time = time.perf_counter()
        avg_time = (end_time - start_time) / num_nuffts
        print("forward average time: {}".format(avg_time))

        # warm-up computation
        for _ in range(num_nuffts):
            x = adjkbsense_ob(y.to(device), ktraj.to(device), interp_mats)

        # run the adjoint speed tests
        if device == torch.device("cuda"):
            torch.cuda.reset_max_memory_allocated()
            torch.cuda.synchronize()
        start_time = time.perf_counter()
        for _ in range(num_nuffts):
            x = adjkbsense_ob(y.to(device), ktraj.to(device), interp_mats)
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
            max_mem = torch.cuda.max_memory_allocated()
            print("GPU adjoint max memory: {} GB".format(max_mem / 1e9))
        end_time = time.perf_counter()
        avg_time = (end_time - start_time) / num_nuffts
        print("backward average time: {}".format(avg_time))