Exemple #1
0
def test_3d_mrisensenufft_adjoint_backward():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (11, 33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, 1) + im_size) + \
        1j*np.random.normal(size=(nslice, 1) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 3, klength)).to(dtype)

    smap_sz = (nslice, ncoil, 2) + im_size
    smap = torch.randn(*smap_sz).to(dtype)

    sensenufft_ob = MriSenseNufft(smap=smap, im_size=im_size)
    adjsensenufft_ob = AdjMriSenseNufft(smap=smap, im_size=im_size)

    y.requires_grad = True
    x = adjsensenufft_ob.forward(y, ktraj)

    ((x**2) / 2).sum().backward()
    y_grad = y.grad.clone().detach()

    y_hat = sensenufft_ob.forward(x.clone().detach(), ktraj)

    assert torch.norm(y_grad - y_hat) < norm_tol
def test_mrisensenufft_3d_coilpack_adjoint(
    params_2d, testing_tol, testing_dtype, device_list
):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_2d["im_size"]
    numpoints = params_2d["numpoints"]

    x = params_2d["x"]
    y = params_2d["y"]
    ktraj = params_2d["ktraj"]
    smap = params_2d["smap"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(
            smap=smap, im_size=im_size, numpoints=numpoints, coilpack=True
        ).to(dtype=dtype, device=device)
        adjsensenufft_ob = AdjMriSenseNufft(
            smap=smap, im_size=im_size, numpoints=numpoints, coilpack=True
        ).to(dtype=dtype, device=device)

        x_forw = sensenufft_ob(x, ktraj)
        y_back = adjsensenufft_ob(y, ktraj)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)

        assert torch.norm(inprod1 - inprod2) < norm_tol
def test_mrisensenufft_2d_adjoint():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, 1) + im_size) + \
        1j*np.random.normal(size=(nslice, 1) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 2, klength)).to(dtype)

    smap_sz = (nslice, ncoil, 2) + im_size
    smap = torch.randn(*smap_sz).to(dtype)

    sensenufft_ob = MriSenseNufft(smap=smap, im_size=im_size)
    adjsensenufft_ob = AdjMriSenseNufft(smap=smap, im_size=im_size)

    x_forw = sensenufft_ob(x, ktraj)
    y_back = adjsensenufft_ob(y, ktraj)

    inprod1 = inner_product(y, x_forw, dim=2)
    inprod2 = inner_product(y_back, x, dim=2)

    assert torch.norm(inprod1 - inprod2) < norm_tol
def test_mrisensenufft_3d_adjoint():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (11, 33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, 1) + im_size) + \
        1j*np.random.normal(size=(nslice, 1) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 3, klength)).to(dtype)

    smap_sz = (nslice, ncoil, 2) + im_size
    smap = torch.randn(*smap_sz).to(dtype)

    sensenufft_ob = MriSenseNufft(smap=smap, im_size=im_size)
    adjsensenufft_ob = AdjMriSenseNufft(smap=smap, im_size=im_size)

    real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
    interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat}

    x_forw = sensenufft_ob(x, ktraj, interp_mats)
    y_back = adjsensenufft_ob(y, ktraj, interp_mats)

    inprod1 = inner_product(y, x_forw, dim=2)
    inprod2 = inner_product(y_back, x, dim=2)

    assert torch.norm(inprod1 - inprod2) < norm_tol
def test_3d_mrisensenufft_coilpack_backward(params_2d, testing_tol,
                                            testing_dtype, device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_2d['im_size']
    numpoints = params_2d['numpoints']

    x = params_2d['x']
    y = params_2d['y']
    ktraj = params_2d['ktraj']
    smap = params_2d['smap']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(smap=smap,
                                      im_size=im_size,
                                      numpoints=numpoints,
                                      coilpack=True).to(dtype=dtype,
                                                        device=device)
        adjsensenufft_ob = AdjMriSenseNufft(smap=smap,
                                            im_size=im_size,
                                            numpoints=numpoints,
                                            coilpack=True).to(dtype=dtype,
                                                              device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x.requires_grad = True
        y = sensenufft_ob.forward(x, ktraj, interp_mats)

        ((y**2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()

        x_hat = adjsensenufft_ob.forward(y.clone().detach(), ktraj,
                                         interp_mats)

        assert torch.norm(x_grad - x_hat) < norm_tol
def test_3d_mrisensenufft_coilpack_adjoint_backward(params_2d, testing_tol,
                                                    testing_dtype,
                                                    device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_2d["im_size"]
    numpoints = params_2d["numpoints"]

    x = params_2d["x"]
    y = params_2d["y"]
    ktraj = params_2d["ktraj"]
    smap = params_2d["smap"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(smap=smap,
                                      im_size=im_size,
                                      numpoints=numpoints,
                                      coilpack=True).to(dtype=dtype,
                                                        device=device)
        adjsensenufft_ob = AdjMriSenseNufft(smap=smap,
                                            im_size=im_size,
                                            numpoints=numpoints,
                                            coilpack=True).to(dtype=dtype,
                                                              device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
        interp_mats = {
            "real_interp_mats": real_mat,
            "imag_interp_mats": imag_mat
        }

        y.requires_grad = True
        x = adjsensenufft_ob.forward(y, ktraj, interp_mats)

        ((x**2) / 2).sum().backward()
        y_grad = y.grad.clone().detach()

        y_hat = sensenufft_ob.forward(x.clone().detach(), ktraj, interp_mats)

        assert torch.norm(y_grad - y_hat) < norm_tol
Exemple #7
0
def test_3d_mrisensenufft_coilpack_backward(params_2d, testing_tol,
                                            testing_dtype, device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_2d["im_size"]
    numpoints = params_2d["numpoints"]

    x = params_2d["x"]
    y = params_2d["y"]
    ktraj = params_2d["ktraj"]
    smap = params_2d["smap"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(smap=smap,
                                      im_size=im_size,
                                      numpoints=numpoints,
                                      coilpack=True).to(dtype=dtype,
                                                        device=device)
        adjsensenufft_ob = AdjMriSenseNufft(smap=smap,
                                            im_size=im_size,
                                            numpoints=numpoints,
                                            coilpack=True).to(dtype=dtype,
                                                              device=device)

        x.requires_grad = True
        y = sensenufft_ob.forward(x, ktraj)

        ((y**2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()

        x_hat = adjsensenufft_ob.forward(y.clone().detach(), ktraj)

        assert torch.norm(x_grad - x_hat) < norm_tol
Exemple #8
0
def test_3d_mrisensenufft_adjoint_backward(params_3d, testing_tol,
                                           testing_dtype, device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_3d["im_size"]
    numpoints = params_3d["numpoints"]

    x = params_3d["x"]
    y = params_3d["y"]
    ktraj = params_3d["ktraj"]
    smap = params_3d["smap"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(smap=smap,
                                      im_size=im_size,
                                      numpoints=numpoints).to(dtype=dtype,
                                                              device=device)
        adjsensenufft_ob = AdjMriSenseNufft(smap=smap,
                                            im_size=im_size,
                                            numpoints=numpoints).to(
                                                dtype=dtype, device=device)

        y.requires_grad = True
        x = adjsensenufft_ob.forward(y, ktraj)

        ((x**2) / 2).sum().backward()
        y_grad = y.grad.clone().detach()

        y_hat = sensenufft_ob.forward(x.clone().detach(), ktraj)

        assert torch.norm(y_grad - y_hat) < norm_tol
Exemple #9
0
def test_3d_mrisensenufft_backward():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (11, 33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, 1) + im_size) + \
        1j*np.random.normal(size=(nslice, 1) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 3, klength)).to(dtype)

    smap_sz = (nslice, ncoil, 2) + im_size
    smap = torch.randn(*smap_sz).to(dtype)

    sensenufft_ob = MriSenseNufft(smap=smap, im_size=im_size)
    adjsensenufft_ob = AdjMriSenseNufft(smap=smap, im_size=im_size)

    real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
    interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat}

    x.requires_grad = True
    y = sensenufft_ob.forward(x, ktraj, interp_mats)

    ((y**2) / 2).sum().backward()
    x_grad = x.grad.clone().detach()

    x_hat = adjsensenufft_ob.forward(y.clone().detach(), ktraj, interp_mats)

    assert torch.norm(x_grad - x_hat) < norm_tol
def test_mrisensenufft_3d_coilpack_adjoint(params_2d, testing_tol,
                                           testing_dtype, device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_2d['im_size']
    numpoints = params_2d['numpoints']

    x = params_2d['x']
    y = params_2d['y']
    ktraj = params_2d['ktraj']
    smap = params_2d['smap']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(smap=smap,
                                      im_size=im_size,
                                      numpoints=numpoints,
                                      coilpack=True).to(dtype=dtype,
                                                        device=device)
        adjsensenufft_ob = AdjMriSenseNufft(smap=smap,
                                            im_size=im_size,
                                            numpoints=numpoints,
                                            coilpack=True).to(dtype=dtype,
                                                              device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x_forw = sensenufft_ob(x, ktraj, interp_mats)
        y_back = adjsensenufft_ob(y, ktraj, interp_mats)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)

        assert torch.norm(inprod1 - inprod2) < norm_tol
def test_mrisensenufft_3d_adjoint(params_3d, testing_tol, testing_dtype,
                                  device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_3d["im_size"]
    numpoints = params_3d["numpoints"]

    x = params_3d["x"]
    y = params_3d["y"]
    ktraj = params_3d["ktraj"]
    smap = params_3d["smap"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(smap=smap,
                                      im_size=im_size,
                                      numpoints=numpoints).to(dtype=dtype,
                                                              device=device)
        adjsensenufft_ob = AdjMriSenseNufft(smap=smap,
                                            im_size=im_size,
                                            numpoints=numpoints).to(
                                                dtype=dtype, device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
        interp_mats = {
            "real_interp_mats": real_mat,
            "imag_interp_mats": imag_mat
        }

        x_forw = sensenufft_ob(x, ktraj, interp_mats)
        y_back = adjsensenufft_ob(y, ktraj, interp_mats)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)

        assert torch.norm(inprod1 - inprod2) < norm_tol
Exemple #12
0
def test_kb_matching(testing_tol):
    norm_tol = testing_tol

    def check_tables(table1, table2):
        for ind, table in enumerate(table1):
            assert np.linalg.norm(table - table2[ind]) < norm_tol

    im_szs = [(256, 256), (10, 256, 256)]

    kbwidths = [2.34, 5]
    orders = [0, 2]

    for kbwidth in kbwidths:
        for order in orders:
            for im_sz in im_szs:
                smap = torch.randn(*((1,) + im_sz))

                base_table = AdjKbNufft(
                    im_sz, order=order, kbwidth=kbwidth).table

                cur_table = KbNufft(im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = KbInterpBack(
                    im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = KbInterpForw(
                    im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = MriSenseNufft(
                    smap, im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = AdjMriSenseNufft(
                    smap, im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)
Exemple #13
0
def test_toeplitz_mrisensenufft_2d(params_2d, testing_tol, testing_dtype, device_list):
    dtype = testing_dtype
    # this tolerance looks really bad, but toep struggles with random traj
    # for radial it's more like 1e-06
    norm_tol = 1e-3

    im_size = params_2d["im_size"]
    numpoints = params_2d["numpoints"]

    x = params_2d["x"]
    y = params_2d["y"]
    ktraj = params_2d["ktraj"]
    smap = params_2d["smap"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(
            smap=smap, im_size=im_size, numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjsensenufft_ob = AdjMriSenseNufft(
            smap=smap, im_size=im_size, numpoints=numpoints
        ).to(dtype=dtype, device=device)
        toep_ob = ToepSenseNufft(smap=smap).to(dtype=dtype, device=device)

        kern = calc_toep_kernel(adjsensenufft_ob, ktraj)

        normal_forw = adjsensenufft_ob(sensenufft_ob(x, ktraj), ktraj)

        toep_forw = toep_ob(x, kern)

        diff = torch.norm(normal_forw - toep_forw) / torch.norm(normal_forw)

        assert diff < norm_tol
def test_2d_init_inputs():
    # all object initializations have assertions
    # this should result in an error if any dimensions don't match

    # test 2d scalar inputs
    im_sz = (256, 256)
    smap = torch.randn(*((1,) + im_sz))
    grid_sz = (512, 512)
    n_shift = (128, 128)
    numpoints = 6
    table_oversamp = 2 ** 10
    kbwidth = 2.34
    order = 0
    norm = "None"

    ob = KbInterpForw(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
    )
    ob = KbInterpBack(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
    )

    ob = KbNufft(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )
    ob = AdjKbNufft(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )

    ob = MriSenseNufft(
        smap=smap,
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )
    ob = AdjMriSenseNufft(
        smap=smap,
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )

    # test 2d tuple inputs
    im_sz = (256, 256)
    smap = torch.randn(*((1,) + im_sz))
    grid_sz = (512, 512)
    n_shift = (128, 128)
    numpoints = (6, 6)
    table_oversamp = (2 ** 10, 2 ** 10)
    kbwidth = (2.34, 2.34)
    order = (0, 0)
    norm = "None"

    ob = KbInterpForw(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
    )
    ob = KbInterpBack(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
    )

    ob = KbNufft(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )
    ob = AdjKbNufft(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )

    ob = MriSenseNufft(
        smap=smap,
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )
    ob = AdjMriSenseNufft(
        smap=smap,
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )
Exemple #15
0
def profile_torchkbnufft(image,
                         ktraj,
                         smap,
                         im_size,
                         device,
                         sparse_mats_flag=False,
                         use_toep=False):
    # run double precision for CPU, float for GPU
    # these seem to be present in reference implementations
    if device == torch.device("cpu"):
        dtype = torch.double
        if use_toep:
            num_nuffts = 20
        else:
            num_nuffts = 5
    else:
        dtype = torch.float
        if use_toep:
            num_nuffts = 50
        else:
            num_nuffts = 20
    cpudevice = torch.device("cpu")

    image = image.to(dtype=dtype)
    ktraj = ktraj.to(dtype=dtype)
    smap = smap.to(dtype=dtype)

    kbsense_ob = MriSenseNufft(smap=smap, im_size=im_size).to(dtype=dtype,
                                                              device=device)
    adjkbsense_ob = AdjMriSenseNufft(smap=smap,
                                     im_size=im_size).to(dtype=dtype,
                                                         device=device)

    adjkbnufft_ob = AdjKbNufft(im_size=im_size).to(dtype=dtype, device=device)

    # precompute toeplitz kernel if using toeplitz
    if use_toep:
        print("using toeplitz for forward/backward")
        kern = calc_toep_kernel(adjkbsense_ob, ktraj)
        toep_ob = ToepSenseNufft(smap=smap).to(dtype=dtype, device=device)

    # precompute the sparse interpolation matrices
    if sparse_mats_flag:
        print("using sparse interpolation matrices")
        real_mat, imag_mat = precomp_sparse_mats(ktraj, adjkbnufft_ob)
        interp_mats = {
            "real_interp_mats": real_mat,
            "imag_interp_mats": imag_mat
        }
    else:
        print("not using sparse interpolation matrices")
        interp_mats = None

    if use_toep:
        # warm-up computation
        for _ in range(num_nuffts):
            x = toep_ob(image.to(device=device),
                        kern.to(device=device)).to(cpudevice)
        # run the speed tests
        if device == torch.device("cuda"):
            torch.cuda.reset_max_memory_allocated()
            torch.cuda.synchronize()
        start_time = time.perf_counter()
        for _ in range(num_nuffts):
            x = toep_ob(image.to(device=device), kern.to(device=device))
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
            max_mem = torch.cuda.max_memory_allocated()
            print("GPU forward max memory: {} GB".format(max_mem / 1e9))
        end_time = time.perf_counter()
        avg_time = (end_time - start_time) / num_nuffts
        print("toeplitz forward/backward average time: {}".format(avg_time))
    else:
        # warm-up computation
        for _ in range(num_nuffts):
            y = kbsense_ob(image.to(device=device), ktraj.to(device=device),
                           interp_mats).to(cpudevice)

        # run the forward speed tests
        if device == torch.device("cuda"):
            torch.cuda.reset_max_memory_allocated()
            torch.cuda.synchronize()
        start_time = time.perf_counter()
        for _ in range(num_nuffts):
            y = kbsense_ob(image.to(device=device), ktraj.to(device=device),
                           interp_mats)
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
            max_mem = torch.cuda.max_memory_allocated()
            print("GPU forward max memory: {} GB".format(max_mem / 1e9))
        end_time = time.perf_counter()
        avg_time = (end_time - start_time) / num_nuffts
        print("forward average time: {}".format(avg_time))

        # warm-up computation
        for _ in range(num_nuffts):
            x = adjkbsense_ob(y.to(device), ktraj.to(device), interp_mats)

        # run the adjoint speed tests
        if device == torch.device("cuda"):
            torch.cuda.reset_max_memory_allocated()
            torch.cuda.synchronize()
        start_time = time.perf_counter()
        for _ in range(num_nuffts):
            x = adjkbsense_ob(y.to(device), ktraj.to(device), interp_mats)
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
            max_mem = torch.cuda.max_memory_allocated()
            print("GPU adjoint max memory: {} GB".format(max_mem / 1e9))
        end_time = time.perf_counter()
        avg_time = (end_time - start_time) / num_nuffts
        print("backward average time: {}".format(avg_time))
def main():
    dtype = torch.double
    spokelength = 512
    targ_size = (int(spokelength/2), int(spokelength/2))
    nspokes = 405

    image = shepp_logan_phantom().astype(np.complex)
    im_size = image.shape
    grid_size = tuple(2 * np.array(im_size))

    # convert the phantom to a tensor and unsqueeze coil and batch dimension
    image = np.stack((np.real(image), np.imag(image)))
    image = torch.tensor(image).to(dtype).unsqueeze(0).unsqueeze(0)

    # create k-space trajectory
    ga = np.deg2rad(180 / ((1 + np.sqrt(5)) / 2))
    kx = np.zeros(shape=(spokelength, nspokes))
    ky = np.zeros(shape=(spokelength, nspokes))
    kmax = np.pi * ((spokelength/2) / im_size[0])
    ky[:, 0] = np.linspace(-kmax, kmax, spokelength)
    for i in range(1, nspokes):
        kx[:, i] = np.cos(ga) * kx[:, i - 1] - np.sin(ga) * ky[:, i - 1]
        ky[:, i] = np.sin(ga) * kx[:, i - 1] + np.cos(ga) * ky[:, i - 1]

    ky = np.transpose(ky)
    kx = np.transpose(kx)

    ktraj = np.stack((ky.flatten(), kx.flatten()), axis=0)
    ktraj = torch.tensor(ktraj).to(dtype).unsqueeze(0)

    # sensitivity maps
    ncoil = 8
    smap = np.absolute(np.stack(mrisensesim(
        im_size, coil_width=64))).astype(np.complex)
    smap = np.stack((np.real(smap), np.imag(smap)), axis=1)
    smap = torch.tensor(smap).to(dtype).unsqueeze(0)

    # operators
    sensenufft_ob = MriSenseNufft(
        smap=smap, im_size=im_size, grid_size=grid_size).to(dtype)

    kdata = sensenufft_ob(image, ktraj)

    kdata = np.squeeze(kdata.numpy())
    kdata = np.reshape(kdata[:, 0] + 1j*kdata[:, 1],
                       (ncoil, nspokes, spokelength))

    ktraj = np.squeeze(ktraj.numpy())
    ktraj = ktraj / np.max(ktraj) * np.pi
    ktraj = np.reshape(ktraj, (2, nspokes, spokelength))

    smap = np.squeeze(smap.numpy())
    smap = smap[:, 0] + 1j*smap[:, 1]
    smap_new = []
    for coilind in range(smap.shape[0]):
        smap_new.append(
            resize(np.real(smap[coilind]), targ_size) +
            1j*resize(np.imag(smap[coilind]), targ_size)
        )
    smap_new = np.array(smap_new)

    data = {
        'kdata': kdata,
        'ktraj': ktraj,
        'smap': smap_new
    }

    sio.savemat('demo_data.mat', data)
    def __init__(self, im_size, ktraj, dcomp, csmap, norm='ortho'):
        super(Dyn2DRadEncObj, self).__init__()
        """
		inputs:
		- im_size	... the shape of the object to be imaged
		- ktraj 	... the k-space tractories  (torch tensor)
		- csmap 	... the coil-sensivity maps (torch tensor)
		- dcomp 	... the density compensation (torch tensor)
		
		N.B. the shapes for the tensors are the following:
			
			image   (1,1,2,Nx,Ny,Nt),
			ktraj 	(1,2,Nrad,Nt),
			csmap   (1,Nc,2,Nx,Ny),
			dcomp 	(1,1,1,Nrad,Nt),
			kdata 	(1,Nc,2,Nrad,Nt),
			
		where 
		- Nx,Ny,Nt = im_size
		- Nrad = 2*Ny*n_spokes
		- n_spokes = the number of spokes per 2D frame
		- Nc = the number of receiver coils

		"""
        dtype = torch.float

        Nx, Ny, Nt = im_size
        self.im_size = im_size
        self.spokelength = im_size[1] * 2
        self.Nrad = ktraj.shape[2]
        self.ktraj_tensor = ktraj
        self.dcomp_tensor = dcomp

        #parameters for contstructing the operators
        spokelength = im_size[1] * 2
        grid_size = (spokelength, spokelength)

        #single-coil/multi-coil NUFFTs
        if csmap is not None:
            self.NUFFT = MriSenseNufft(smap=csmap,
                                       im_size=im_size[:2],
                                       grid_size=grid_size,
                                       norm=norm).to(dtype)
            self.AdjNUFFT = AdjMriSenseNufft(smap=csmap,
                                             im_size=im_size[:2],
                                             grid_size=grid_size,
                                             norm=norm).to(dtype)
            self.ToepNUFFT = ToepSenseNufft(csmap).to(dtype)
        else:
            self.NUFFT = KbNufft(im_size=im_size[:2],
                                 grid_size=grid_size,
                                 norm=norm).to(dtype)
            self.AdjNUFFT = AdjKbNufft(im_size=im_size[:2],
                                       grid_size=grid_size,
                                       norm=norm).to(dtype)
            self.ToepNUFFT = ToepNufft().to(dtype)

        #calculate Toeplitz kernels
        # for E^{\dagger} \circ E
        self.AdagA_toep_kernel_list = [
            calc_toep_kernel(self.AdjNUFFT,
                             ktraj[..., kt],
                             weights=dcomp[..., kt]) for kt in range(Nt)
        ]

        # for E^H \circ E
        self.AHA_toep_kernel_list = [
            calc_toep_kernel(self.AdjNUFFT, ktraj[..., kt]) for kt in range(Nt)
        ]