コード例 #1
0
def test_2d_kbnufft_backward():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, ncoil) + im_size) + \
        1j*np.random.normal(size=(nslice, ncoil) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 2, klength)).to(dtype)

    kbnufft_ob = KbNufft(im_size=im_size, numpoints=(4, 6))
    adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(4, 6))

    x.requires_grad = True
    y = kbnufft_ob.forward(x, ktraj)

    ((y**2) / 2).sum().backward()
    x_grad = x.grad.clone().detach()

    x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj)

    assert torch.norm(x_grad - x_hat) < norm_tol
コード例 #2
0
def test_2d_kbnufft_backward(params_2d, testing_tol, testing_dtype,
                             device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_2d["im_size"]
    numpoints = params_2d["numpoints"]

    x = params_2d["x"]
    y = params_2d["y"]
    ktraj = params_2d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(im_size=im_size,
                             numpoints=numpoints).to(dtype=dtype,
                                                     device=device)
        adjkbnufft_ob = AdjKbNufft(im_size=im_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)

        x.requires_grad = True
        y = kbnufft_ob.forward(x, ktraj)

        ((y**2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()

        x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj)

        assert torch.norm(x_grad - x_hat) < norm_tol
コード例 #3
0
def test_3d_kbnufft_adjoint_backward():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (11, 33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, ncoil) + im_size) + \
        1j*np.random.normal(size=(nslice, ncoil) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 3, klength)).to(dtype)

    kbnufft_ob = KbNufft(im_size=im_size, numpoints=(2, 4, 6))
    adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(2, 4, 6))

    real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
    interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat}

    y.requires_grad = True
    x = adjkbnufft_ob.forward(y, ktraj, interp_mats)

    ((x**2) / 2).sum().backward()
    y_grad = y.grad.clone().detach()

    y_hat = kbnufft_ob.forward(x.clone().detach(), ktraj, interp_mats)

    assert torch.norm(y_grad - y_hat) < norm_tol
コード例 #4
0
def test_nufft_3d_adjoint():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (11, 33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, ncoil) + im_size) + \
        1j*np.random.normal(size=(nslice, ncoil) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 3, klength)).to(dtype)

    kbnufft_ob = KbNufft(im_size=im_size, numpoints=(2, 4, 6))
    adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(2, 4, 6))

    real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
    interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat}

    x_forw = kbnufft_ob(x, ktraj, interp_mats)
    y_back = adjkbnufft_ob(y, ktraj, interp_mats)

    inprod1 = inner_product(y, x_forw, dim=2)
    inprod2 = inner_product(y_back, x, dim=2)

    assert torch.norm(inprod1 - inprod2) < norm_tol
コード例 #5
0
def setup():
    spokelength = 400
    grid_size = (spokelength, spokelength)
    nspokes = 10

    ga = np.deg2rad(180 / ((1 + np.sqrt(5)) / 2))
    kx = np.zeros(shape=(spokelength, nspokes))
    ky = np.zeros(shape=(spokelength, nspokes))
    ky[:, 0] = np.linspace(-np.pi, np.pi, spokelength)
    for i in range(1, nspokes):
        kx[:, i] = np.cos(ga) * kx[:, i - 1] - np.sin(ga) * ky[:, i - 1]
        ky[:, i] = np.sin(ga) * kx[:, i - 1] + np.cos(ga) * ky[:, i - 1]

    ky = np.transpose(ky)
    kx = np.transpose(kx)

    ktraj = np.stack((ky.flatten(), kx.flatten()), axis=0)
    im_size = (200, 200)
    nufft_ob = KbNufftModule(im_size=im_size,
                             grid_size=grid_size,
                             norm='ortho')
    torch_forward = KbNufft(im_size=im_size, grid_size=grid_size, norm='ortho')
    torch_backward = AdjKbNufft(im_size=im_size,
                                grid_size=grid_size,
                                norm='ortho')
    return ktraj, nufft_ob, torch_forward, torch_backward
コード例 #6
0
def test_nufft_2d_adjoint(params_2d, testing_tol, testing_dtype, device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_2d["im_size"]
    numpoints = params_2d["numpoints"]

    x = params_2d["x"]
    y = params_2d["y"]
    ktraj = params_2d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(im_size=im_size, numpoints=numpoints).to(
            dtype=dtype, device=device
        )
        adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=numpoints).to(
            dtype=dtype, device=device
        )

        x_forw = kbnufft_ob(x, ktraj)
        y_back = adjkbnufft_ob(y, ktraj)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)

        assert torch.norm(inprod1 - inprod2) < norm_tol
コード例 #7
0
def test_toeplitz_nufft_3d(params_3d, testing_tol, testing_dtype, device_list):
    dtype = testing_dtype
    # this tolerance looks really bad, but toep struggles with random traj
    # for radial it's more like 1e-06
    norm_tol = 1e-1

    im_size = params_3d["im_size"]
    numpoints = params_3d["numpoints"]

    x = params_3d["x"]
    y = params_3d["y"]
    ktraj = params_3d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(im_size=im_size, numpoints=numpoints).to(
            dtype=dtype, device=device
        )
        adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=numpoints).to(
            dtype=dtype, device=device
        )
        toep_ob = ToepNufft()

        kern = calc_toep_kernel(adjkbnufft_ob, ktraj)

        normal_forw = adjkbnufft_ob(kbnufft_ob(x, ktraj), ktraj)

        toep_forw = toep_ob(x, kern)

        diff = torch.norm(normal_forw - toep_forw) / torch.norm(normal_forw)

        assert diff < norm_tol
コード例 #8
0
def test_nufft_2d_adjoint():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, ncoil) + im_size) + \
        1j*np.random.normal(size=(nslice, ncoil) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 2, klength)).to(dtype)

    kbnufft_ob = KbNufft(im_size=im_size, numpoints=(4, 6))
    adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(4, 6))

    x_forw = kbnufft_ob(x, ktraj)
    y_back = adjkbnufft_ob(y, ktraj)

    inprod1 = inner_product(y, x_forw, dim=2)
    inprod2 = inner_product(y_back, x, dim=2)

    assert torch.norm(inprod1 - inprod2) < norm_tol
コード例 #9
0
def test_toepnufft_2d_adjoint(params_2d, testing_tol, testing_dtype, device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_2d["im_size"]
    numpoints = params_2d["numpoints"]

    x = params_2d["x"]
    y = torch.randn(x.shape)
    ktraj = params_2d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=numpoints).to(
            dtype=dtype, device=device
        )
        toep_ob = ToepNufft().to(dtype=dtype, device=device)

        kern = calc_toep_kernel(adjkbnufft_ob, ktraj)

        x_forw = toep_ob(x, kern)
        y_back = toep_ob(y, kern)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)

        assert torch.norm(inprod1 - inprod2) < norm_tol
def test_3d_kbnufft_backward(params_3d, testing_tol, testing_dtype,
                             device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_3d['im_size']
    numpoints = params_3d['numpoints']

    x = params_3d['x']
    y = params_3d['y']
    ktraj = params_3d['ktraj']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(im_size=im_size,
                             numpoints=numpoints).to(dtype=dtype,
                                                     device=device)
        adjkbnufft_ob = AdjKbNufft(im_size=im_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x.requires_grad = True
        y = kbnufft_ob.forward(x, ktraj, interp_mats)

        ((y**2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()

        x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj, interp_mats)

        assert torch.norm(x_grad - x_hat) < norm_tol
コード例 #11
0
def test_3d_kbnufft_adjoint_backward(params_3d, testing_tol, testing_dtype,
                                     device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_3d["im_size"]
    numpoints = params_3d["numpoints"]

    x = params_3d["x"]
    y = params_3d["y"]
    ktraj = params_3d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(im_size=im_size,
                             numpoints=numpoints).to(dtype=dtype,
                                                     device=device)
        adjkbnufft_ob = AdjKbNufft(im_size=im_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
        interp_mats = {
            "real_interp_mats": real_mat,
            "imag_interp_mats": imag_mat
        }

        y.requires_grad = True
        x = adjkbnufft_ob.forward(y, ktraj, interp_mats)

        ((x**2) / 2).sum().backward()
        y_grad = y.grad.clone().detach()

        y_hat = kbnufft_ob.forward(x.clone().detach(), ktraj, interp_mats)

        assert torch.norm(y_grad - y_hat) < norm_tol
コード例 #12
0
def test_nufft_2d_matadj():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (33, 24)
    klength = 112

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 2, klength)).to(dtype)

    adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(4, 6))
    adjkbnufft_matadj_ob = AdjKbNufft(im_size=im_size,
                                      numpoints=(4, 6),
                                      matadj=True)

    x_normal = adjkbnufft_ob(y, ktraj)
    x_matadj = adjkbnufft_matadj_ob(y, ktraj)

    assert torch.norm(x_normal - x_matadj) / torch.norm(x_normal) < norm_tol
コード例 #13
0
    def _nufft(self, freq_domain_data, iflag=1, eps=1E-7, inv_transfo=False):
        """
        rotate coordinates and perform nufft
        :param freq_domain_data:
        :param iflag/eps: see finufftpy doc
        :param eps: precision of nufft
        :return: nufft of freq_domain_data after applying self.rotations
        """

        new_grid_coords = torch.from_numpy(
            np.asarray(self._rotate_coordinates(
                inv_transfo=inv_transfo)[0])).unsqueeze(0)
        adj_nufft = AdjKbNufft(im_size=freq_domain_data.shape[:-1],
                               n_shift=(0, 0, 0))
        freq_domain_data = freq_domain_data.permute(3, 0, 1, 2).view(
            (1, 1, 2, -1))
        if self.cuda:
            adj_nufft = adj_nufft.cuda()
            new_grid_coords = new_grid_coords.cuda()
        im_out = adj_nufft(freq_domain_data, new_grid_coords)
        im_out = torch.stack(torch.unbind(im_out.squeeze(), 0), -1)
        del adj_nufft, new_grid_coords
        return im_out
コード例 #14
0
def test_kb_matching(testing_tol):
    norm_tol = testing_tol

    def check_tables(table1, table2):
        for ind, table in enumerate(table1):
            assert np.linalg.norm(table - table2[ind]) < norm_tol

    im_szs = [(256, 256), (10, 256, 256)]

    kbwidths = [2.34, 5]
    orders = [0, 2]

    for kbwidth in kbwidths:
        for order in orders:
            for im_sz in im_szs:
                smap = torch.randn(*((1,) + im_sz))

                base_table = AdjKbNufft(
                    im_sz, order=order, kbwidth=kbwidth).table

                cur_table = KbNufft(im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = KbInterpBack(
                    im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = KbInterpForw(
                    im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = MriSenseNufft(
                    smap, im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = AdjMriSenseNufft(
                    smap, im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)
コード例 #15
0
def test_nufft_3d_adjoint(params_3d, testing_tol, testing_dtype, device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_3d["im_size"]
    numpoints = params_3d["numpoints"]

    x = params_3d["x"]
    y = params_3d["y"]
    ktraj = params_3d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(im_size=im_size,
                             numpoints=numpoints).to(dtype=dtype,
                                                     device=device)
        adjkbnufft_ob = AdjKbNufft(im_size=im_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
        interp_mats = {
            "real_interp_mats": real_mat,
            "imag_interp_mats": imag_mat
        }

        x_forw = kbnufft_ob(x, ktraj, interp_mats)
        y_back = adjkbnufft_ob(y, ktraj, interp_mats)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)

        assert torch.norm(inprod1 - inprod2) < norm_tol
コード例 #16
0
def test_2d_init_inputs():
    # all object initializations have assertions
    # this should result in an error if any dimensions don't match

    # test 2d scalar inputs
    im_sz = (256, 256)
    smap = torch.randn(*((1,) + im_sz))
    grid_sz = (512, 512)
    n_shift = (128, 128)
    numpoints = 6
    table_oversamp = 2 ** 10
    kbwidth = 2.34
    order = 0
    norm = "None"

    ob = KbInterpForw(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
    )
    ob = KbInterpBack(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
    )

    ob = KbNufft(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )
    ob = AdjKbNufft(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )

    ob = MriSenseNufft(
        smap=smap,
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )
    ob = AdjMriSenseNufft(
        smap=smap,
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )

    # test 2d tuple inputs
    im_sz = (256, 256)
    smap = torch.randn(*((1,) + im_sz))
    grid_sz = (512, 512)
    n_shift = (128, 128)
    numpoints = (6, 6)
    table_oversamp = (2 ** 10, 2 ** 10)
    kbwidth = (2.34, 2.34)
    order = (0, 0)
    norm = "None"

    ob = KbInterpForw(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
    )
    ob = KbInterpBack(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
    )

    ob = KbNufft(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )
    ob = AdjKbNufft(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )

    ob = MriSenseNufft(
        smap=smap,
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )
    ob = AdjMriSenseNufft(
        smap=smap,
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )
コード例 #17
0
    #     uv_wavelengths[:, 0] / (1.0 / (2.0 * grid.pixel_scales[0] * units.arcsec.to(units.rad))) * np.pi
    # ]).T
    #
    # uv_wavelengths_temp = np.stack(
    #     (
    #         np.transpose(uv_wavelengths_temp[:, 0]).flatten(),
    #         np.transpose(uv_wavelengths_temp[:, 1]).flatten()
    #     ),
    #     axis=0
    # )
    #
    # ktraj = torch.tensor(uv_wavelengths_temp).to(dtype).unsqueeze(0)


    nufft_ob = KbNufft(im_size=im_size, grid_size=grid_size, norm='ortho').to(dtype)
    adjnufft_ob = AdjKbNufft(im_size=im_size, grid_size=grid_size, norm='ortho').to(dtype)

    # calculate k-space data
    #print(image.shape, ktraj.shape);exit()
    kdata = nufft_ob(_image, ktraj)
    #print(kdata[0, 0, 0, :])
    #exit()

    # real_visibilities__from__kbnufft_transformer = kdata[0, 0, 0, :]
    # real_visibilities__from__kbnufft_transformer *= shift.real
    # imag_visibilities__from__kbnufft_transformer = kdata[0, 0, 1, :]
    # imag_visibilities__from__kbnufft_transformer *= shift.imag

    # image_blurry = adjnufft_ob(kdata, ktraj)
    #
    # # show the images
コード例 #18
0
    def __init__(self, im_size, ktraj, dcomp, csmap, norm='ortho'):
        super(Dyn2DRadEncObj, self).__init__()
        """
		inputs:
		- im_size	... the shape of the object to be imaged
		- ktraj 	... the k-space tractories  (torch tensor)
		- csmap 	... the coil-sensivity maps (torch tensor)
		- dcomp 	... the density compensation (torch tensor)
		
		N.B. the shapes for the tensors are the following:
			
			image   (1,1,2,Nx,Ny,Nt),
			ktraj 	(1,2,Nrad,Nt),
			csmap   (1,Nc,2,Nx,Ny),
			dcomp 	(1,1,1,Nrad,Nt),
			kdata 	(1,Nc,2,Nrad,Nt),
			
		where 
		- Nx,Ny,Nt = im_size
		- Nrad = 2*Ny*n_spokes
		- n_spokes = the number of spokes per 2D frame
		- Nc = the number of receiver coils

		"""
        dtype = torch.float

        Nx, Ny, Nt = im_size
        self.im_size = im_size
        self.spokelength = im_size[1] * 2
        self.Nrad = ktraj.shape[2]
        self.ktraj_tensor = ktraj
        self.dcomp_tensor = dcomp

        #parameters for contstructing the operators
        spokelength = im_size[1] * 2
        grid_size = (spokelength, spokelength)

        #single-coil/multi-coil NUFFTs
        if csmap is not None:
            self.NUFFT = MriSenseNufft(smap=csmap,
                                       im_size=im_size[:2],
                                       grid_size=grid_size,
                                       norm=norm).to(dtype)
            self.AdjNUFFT = AdjMriSenseNufft(smap=csmap,
                                             im_size=im_size[:2],
                                             grid_size=grid_size,
                                             norm=norm).to(dtype)
            self.ToepNUFFT = ToepSenseNufft(csmap).to(dtype)
        else:
            self.NUFFT = KbNufft(im_size=im_size[:2],
                                 grid_size=grid_size,
                                 norm=norm).to(dtype)
            self.AdjNUFFT = AdjKbNufft(im_size=im_size[:2],
                                       grid_size=grid_size,
                                       norm=norm).to(dtype)
            self.ToepNUFFT = ToepNufft().to(dtype)

        #calculate Toeplitz kernels
        # for E^{\dagger} \circ E
        self.AdagA_toep_kernel_list = [
            calc_toep_kernel(self.AdjNUFFT,
                             ktraj[..., kt],
                             weights=dcomp[..., kt]) for kt in range(Nt)
        ]

        # for E^H \circ E
        self.AHA_toep_kernel_list = [
            calc_toep_kernel(self.AdjNUFFT, ktraj[..., kt]) for kt in range(Nt)
        ]
コード例 #19
0
def profile_torchkbnufft(image,
                         ktraj,
                         smap,
                         im_size,
                         device,
                         sparse_mats_flag=False,
                         use_toep=False):
    # run double precision for CPU, float for GPU
    # these seem to be present in reference implementations
    if device == torch.device("cpu"):
        dtype = torch.double
        if use_toep:
            num_nuffts = 20
        else:
            num_nuffts = 5
    else:
        dtype = torch.float
        if use_toep:
            num_nuffts = 50
        else:
            num_nuffts = 20
    cpudevice = torch.device("cpu")

    image = image.to(dtype=dtype)
    ktraj = ktraj.to(dtype=dtype)
    smap = smap.to(dtype=dtype)

    kbsense_ob = MriSenseNufft(smap=smap, im_size=im_size).to(dtype=dtype,
                                                              device=device)
    adjkbsense_ob = AdjMriSenseNufft(smap=smap,
                                     im_size=im_size).to(dtype=dtype,
                                                         device=device)

    adjkbnufft_ob = AdjKbNufft(im_size=im_size).to(dtype=dtype, device=device)

    # precompute toeplitz kernel if using toeplitz
    if use_toep:
        print("using toeplitz for forward/backward")
        kern = calc_toep_kernel(adjkbsense_ob, ktraj)
        toep_ob = ToepSenseNufft(smap=smap).to(dtype=dtype, device=device)

    # precompute the sparse interpolation matrices
    if sparse_mats_flag:
        print("using sparse interpolation matrices")
        real_mat, imag_mat = precomp_sparse_mats(ktraj, adjkbnufft_ob)
        interp_mats = {
            "real_interp_mats": real_mat,
            "imag_interp_mats": imag_mat
        }
    else:
        print("not using sparse interpolation matrices")
        interp_mats = None

    if use_toep:
        # warm-up computation
        for _ in range(num_nuffts):
            x = toep_ob(image.to(device=device),
                        kern.to(device=device)).to(cpudevice)
        # run the speed tests
        if device == torch.device("cuda"):
            torch.cuda.reset_max_memory_allocated()
            torch.cuda.synchronize()
        start_time = time.perf_counter()
        for _ in range(num_nuffts):
            x = toep_ob(image.to(device=device), kern.to(device=device))
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
            max_mem = torch.cuda.max_memory_allocated()
            print("GPU forward max memory: {} GB".format(max_mem / 1e9))
        end_time = time.perf_counter()
        avg_time = (end_time - start_time) / num_nuffts
        print("toeplitz forward/backward average time: {}".format(avg_time))
    else:
        # warm-up computation
        for _ in range(num_nuffts):
            y = kbsense_ob(image.to(device=device), ktraj.to(device=device),
                           interp_mats).to(cpudevice)

        # run the forward speed tests
        if device == torch.device("cuda"):
            torch.cuda.reset_max_memory_allocated()
            torch.cuda.synchronize()
        start_time = time.perf_counter()
        for _ in range(num_nuffts):
            y = kbsense_ob(image.to(device=device), ktraj.to(device=device),
                           interp_mats)
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
            max_mem = torch.cuda.max_memory_allocated()
            print("GPU forward max memory: {} GB".format(max_mem / 1e9))
        end_time = time.perf_counter()
        avg_time = (end_time - start_time) / num_nuffts
        print("forward average time: {}".format(avg_time))

        # warm-up computation
        for _ in range(num_nuffts):
            x = adjkbsense_ob(y.to(device), ktraj.to(device), interp_mats)

        # run the adjoint speed tests
        if device == torch.device("cuda"):
            torch.cuda.reset_max_memory_allocated()
            torch.cuda.synchronize()
        start_time = time.perf_counter()
        for _ in range(num_nuffts):
            x = adjkbsense_ob(y.to(device), ktraj.to(device), interp_mats)
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
            max_mem = torch.cuda.max_memory_allocated()
            print("GPU adjoint max memory: {} GB".format(max_mem / 1e9))
        end_time = time.perf_counter()
        avg_time = (end_time - start_time) / num_nuffts
        print("backward average time: {}".format(avg_time))