Exemple #1
0
def test_3d_interp_backward():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (11, 33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, ncoil) + im_size) + \
        1j*np.random.normal(size=(nslice, ncoil) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 3, klength)).to(dtype)

    kbinterp_ob = KbInterpForw(im_size=(5, 20, 25),
                               grid_size=im_size,
                               numpoints=(2, 4, 6))
    adjkbinterp_ob = KbInterpBack(im_size=(5, 20, 25),
                                  grid_size=im_size,
                                  numpoints=(2, 4, 6))

    x.requires_grad = True
    y = kbinterp_ob.forward(x, ktraj)

    ((y**2) / 2).sum().backward()
    x_grad = x.grad.clone().detach()

    x_hat = adjkbinterp_ob.forward(y.clone().detach(), ktraj)

    assert torch.norm(x_grad - x_hat) < norm_tol
Exemple #2
0
def test_2d_interp_adjoint_backward():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, ncoil) + im_size) + \
        1j*np.random.normal(size=(nslice, ncoil) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 2, klength)).to(dtype)

    kbinterp_ob = KbInterpForw(im_size=(20, 25),
                               grid_size=im_size,
                               numpoints=(4, 6))
    adjkbinterp_ob = KbInterpBack(im_size=(20, 25),
                                  grid_size=im_size,
                                  numpoints=(4, 6))

    y.requires_grad = True
    x = adjkbinterp_ob.forward(y, ktraj)

    ((x**2) / 2).sum().backward()
    y_grad = y.grad.clone().detach()

    y_hat = kbinterp_ob.forward(x.clone().detach(), ktraj)

    assert torch.norm(y_grad - y_hat) < norm_tol
def test_interp_2d_adjoint():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, ncoil) + im_size) + \
        1j*np.random.normal(size=(nslice, ncoil) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 2, klength)).to(dtype)

    kbinterp_ob = KbInterpForw(im_size=(20, 25),
                               grid_size=im_size,
                               numpoints=(4, 6))
    adjkbinterp_ob = KbInterpBack(im_size=(20, 25),
                                  grid_size=im_size,
                                  numpoints=(4, 6))

    x_forw = kbinterp_ob(x, ktraj)
    y_back = adjkbinterp_ob(y, ktraj)

    inprod1 = inner_product(y, x_forw, dim=2)
    inprod2 = inner_product(y_back, x, dim=2)

    assert torch.norm(inprod1 - inprod2) < norm_tol
def test_interp_3d_adjoint():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (11, 33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, ncoil) + im_size) + \
        1j*np.random.normal(size=(nslice, ncoil) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 3, klength)).to(dtype)

    kbinterp_ob = KbInterpForw(im_size=(5, 20, 25),
                               grid_size=im_size,
                               numpoints=(2, 4, 6))
    adjkbinterp_ob = KbInterpBack(im_size=(5, 20, 25),
                                  grid_size=im_size,
                                  numpoints=(2, 4, 6))

    real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
    interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat}

    x_forw = kbinterp_ob(x, ktraj, interp_mats)
    y_back = adjkbinterp_ob(y, ktraj, interp_mats)

    inprod1 = inner_product(y, x_forw, dim=2)
    inprod2 = inner_product(y_back, x, dim=2)

    assert torch.norm(inprod1 - inprod2) < norm_tol
def test_interp_2d_adjoint(params_2d, testing_tol, testing_dtype, device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    batch_size = params_2d["batch_size"]
    im_size = params_2d["im_size"]
    grid_size = params_2d["grid_size"]
    numpoints = params_2d["numpoints"]

    x = np.random.normal(size=(batch_size, 1) + grid_size) + 1j * np.random.normal(
        size=(batch_size, 1) + grid_size
    )
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2))
    y = params_2d["y"]
    ktraj = params_2d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbinterp_ob = KbInterpForw(
            im_size=im_size, grid_size=grid_size, numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjkbinterp_ob = KbInterpBack(
            im_size=im_size, grid_size=grid_size, numpoints=numpoints
        ).to(dtype=dtype, device=device)

        x_forw = kbinterp_ob(x, ktraj)
        y_back = adjkbinterp_ob(y, ktraj)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)

        assert torch.norm(inprod1 - inprod2) < norm_tol
def test_2d_interp_adjoint_backward(params_2d, testing_tol, testing_dtype,
                                    device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    batch_size = params_2d["batch_size"]
    im_size = params_2d["im_size"]
    grid_size = params_2d["grid_size"]
    numpoints = params_2d["numpoints"]

    x = np.random.normal(
        size=(batch_size, 1) +
        grid_size) + 1j * np.random.normal(size=(batch_size, 1) + grid_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2))
    y = params_2d["y"]
    ktraj = params_2d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbinterp_ob = KbInterpForw(im_size=im_size,
                                   grid_size=grid_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)
        adjkbinterp_ob = KbInterpBack(im_size=im_size,
                                      grid_size=grid_size,
                                      numpoints=numpoints).to(dtype=dtype,
                                                              device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
        interp_mats = {
            "real_interp_mats": real_mat,
            "imag_interp_mats": imag_mat
        }

        y.requires_grad = True
        x = adjkbinterp_ob.forward(y, ktraj, interp_mats)

        ((x**2) / 2).sum().backward()
        y_grad = y.grad.clone().detach()

        y_hat = kbinterp_ob.forward(x.clone().detach(), ktraj, interp_mats)

        assert torch.norm(y_grad - y_hat) < norm_tol
def test_3d_interp_backward(params_3d, testing_tol, testing_dtype,
                            device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    batch_size = params_3d['batch_size']
    im_size = params_3d['im_size']
    grid_size = params_3d['grid_size']
    numpoints = params_3d['numpoints']

    x = np.random.normal(size=(batch_size, 1) + grid_size) + \
        1j*np.random.normal(size=(batch_size, 1) + grid_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2))
    y = params_3d['y']
    ktraj = params_3d['ktraj']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbinterp_ob = KbInterpForw(im_size=im_size,
                                   grid_size=grid_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)
        adjkbinterp_ob = KbInterpBack(im_size=im_size,
                                      grid_size=grid_size,
                                      numpoints=numpoints).to(dtype=dtype,
                                                              device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x.requires_grad = True
        y = kbinterp_ob.forward(x, ktraj, interp_mats)

        ((y**2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()

        x_hat = adjkbinterp_ob.forward(y.clone().detach(), ktraj, interp_mats)

        assert torch.norm(x_grad - x_hat) < norm_tol
Exemple #8
0
def test_3d_interp_backward(params_3d, testing_tol, testing_dtype,
                            device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    batch_size = params_3d["batch_size"]
    im_size = params_3d["im_size"]
    grid_size = params_3d["grid_size"]
    numpoints = params_3d["numpoints"]

    x = np.random.normal(
        size=(batch_size, 1) +
        grid_size) + 1j * np.random.normal(size=(batch_size, 1) + grid_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2))
    y = params_3d["y"]
    ktraj = params_3d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbinterp_ob = KbInterpForw(im_size=im_size,
                                   grid_size=grid_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)
        adjkbinterp_ob = KbInterpBack(im_size=im_size,
                                      grid_size=grid_size,
                                      numpoints=numpoints).to(dtype=dtype,
                                                              device=device)

        x.requires_grad = True
        y = kbinterp_ob.forward(x, ktraj)

        ((y**2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()

        x_hat = adjkbinterp_ob.forward(y.clone().detach(), ktraj)

        assert torch.norm(x_grad - x_hat) < norm_tol
Exemple #9
0
def test_3d_interp_adjoint_backward():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (11, 33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, ncoil) + im_size) + \
        1j*np.random.normal(size=(nslice, ncoil) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 3, klength)).to(dtype)

    kbinterp_ob = KbInterpForw(im_size=(5, 20, 25),
                               grid_size=im_size,
                               numpoints=(2, 4, 6))
    adjkbinterp_ob = KbInterpBack(im_size=(5, 20, 25),
                                  grid_size=im_size,
                                  numpoints=(2, 4, 6))

    real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
    interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat}

    y.requires_grad = True
    x = adjkbinterp_ob.forward(y, ktraj, interp_mats)

    ((x**2) / 2).sum().backward()
    y_grad = y.grad.clone().detach()

    y_hat = kbinterp_ob.forward(x.clone().detach(), ktraj, interp_mats)

    assert torch.norm(y_grad - y_hat) < norm_tol
def test_interp_3d_adjoint(params_3d, testing_tol, testing_dtype, device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    batch_size = params_3d['batch_size']
    im_size = params_3d['im_size']
    grid_size = params_3d['grid_size']
    numpoints = params_3d['numpoints']

    x = np.random.normal(size=(batch_size, 1) + grid_size) + \
        1j*np.random.normal(size=(batch_size, 1) + grid_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2))
    y = params_3d['y']
    ktraj = params_3d['ktraj']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbinterp_ob = KbInterpForw(im_size=im_size,
                                   grid_size=grid_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)
        adjkbinterp_ob = KbInterpBack(im_size=im_size,
                                      grid_size=grid_size,
                                      numpoints=numpoints).to(dtype=dtype,
                                                              device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x_forw = kbinterp_ob(x, ktraj, interp_mats)
        y_back = adjkbinterp_ob(y, ktraj, interp_mats)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)

        assert torch.norm(inprod1 - inprod2) < norm_tol
Exemple #11
0
def test_kb_matching(testing_tol):
    norm_tol = testing_tol

    def check_tables(table1, table2):
        for ind, table in enumerate(table1):
            assert np.linalg.norm(table - table2[ind]) < norm_tol

    im_szs = [(256, 256), (10, 256, 256)]

    kbwidths = [2.34, 5]
    orders = [0, 2]

    for kbwidth in kbwidths:
        for order in orders:
            for im_sz in im_szs:
                smap = torch.randn(*((1,) + im_sz))

                base_table = AdjKbNufft(
                    im_sz, order=order, kbwidth=kbwidth).table

                cur_table = KbNufft(im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = KbInterpBack(
                    im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = KbInterpForw(
                    im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = MriSenseNufft(
                    smap, im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = AdjMriSenseNufft(
                    smap, im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)
def test_2d_init_inputs():
    # all object initializations have assertions
    # this should result in an error if any dimensions don't match

    # test 2d scalar inputs
    im_sz = (256, 256)
    smap = torch.randn(*((1,) + im_sz))
    grid_sz = (512, 512)
    n_shift = (128, 128)
    numpoints = 6
    table_oversamp = 2 ** 10
    kbwidth = 2.34
    order = 0
    norm = "None"

    ob = KbInterpForw(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
    )
    ob = KbInterpBack(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
    )

    ob = KbNufft(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )
    ob = AdjKbNufft(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )

    ob = MriSenseNufft(
        smap=smap,
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )
    ob = AdjMriSenseNufft(
        smap=smap,
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )

    # test 2d tuple inputs
    im_sz = (256, 256)
    smap = torch.randn(*((1,) + im_sz))
    grid_sz = (512, 512)
    n_shift = (128, 128)
    numpoints = (6, 6)
    table_oversamp = (2 ** 10, 2 ** 10)
    kbwidth = (2.34, 2.34)
    order = (0, 0)
    norm = "None"

    ob = KbInterpForw(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
    )
    ob = KbInterpBack(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
    )

    ob = KbNufft(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )
    ob = AdjKbNufft(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )

    ob = MriSenseNufft(
        smap=smap,
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )
    ob = AdjMriSenseNufft(
        smap=smap,
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )