示例#1
0
def test_interp_autograd(shape, kdata_shape, is_complex):
    default_dtype = torch.get_default_dtype()
    torch.set_default_dtype(torch.double)
    torch.manual_seed(123)
    if is_complex:
        im_size = shape[2:]
    else:
        im_size = shape[2:-1]

    image = create_input_plus_noise(shape, is_complex)
    kdata = create_input_plus_noise(kdata_shape, is_complex)
    ktraj = create_ktraj(len(im_size), kdata_shape[2])

    forw_ob = tkbn.KbInterp(im_size=im_size, grid_size=im_size)
    adj_ob = tkbn.KbInterpAdjoint(im_size=im_size, grid_size=im_size)

    # test with sparse matrices
    spmat = tkbn.calc_tensor_spmatrix(
        ktraj,
        im_size,
        grid_size=im_size,
    )

    nufft_autograd_test(image, kdata, ktraj, forw_ob, adj_ob, spmat)

    torch.set_default_dtype(default_dtype)
示例#2
0
def test_interp_complex_real_match(shape, kdata_shape, is_complex):
    default_dtype = torch.get_default_dtype()
    torch.set_default_dtype(torch.double)
    torch.manual_seed(123)
    im_size = shape[2:]

    image = create_input_plus_noise(shape, is_complex)
    ktraj = create_ktraj(len(im_size), kdata_shape[2])

    forw_ob = tkbn.KbInterp(im_size=im_size, grid_size=im_size)

    kdata_complex = forw_ob(image, ktraj)
    kdata_real = torch.view_as_complex(
        forw_ob(torch.view_as_real(image), ktraj))

    assert torch.allclose(kdata_complex, kdata_real)

    # test with sparse matrices
    spmat = tkbn.calc_tensor_spmatrix(
        ktraj,
        im_size,
        grid_size=im_size,
    )

    kdata_complex = forw_ob(image, ktraj, spmat)
    kdata_real = torch.view_as_complex(
        forw_ob(torch.view_as_real(image), ktraj, spmat))

    assert torch.allclose(kdata_complex, kdata_real)

    torch.set_default_dtype(default_dtype)
示例#3
0
def test_interp_adjoint_gpu(shape, kdata_shape, is_complex):
    if not torch.cuda.is_available():
        pytest.skip()

    device = torch.device("cuda")

    default_dtype = torch.get_default_dtype()
    torch.set_default_dtype(torch.double)
    torch.manual_seed(123)
    if is_complex:
        im_size = shape[2:]
    else:
        im_size = shape[2:-1]

    image = create_input_plus_noise(shape, is_complex).to(device)
    kdata = create_input_plus_noise(kdata_shape, is_complex).to(device)
    ktraj = create_ktraj(len(im_size), kdata_shape[2]).to(device)

    forw_ob = tkbn.KbInterp(im_size=im_size, grid_size=im_size).to(device)
    adj_ob = tkbn.KbInterpAdjoint(im_size=im_size,
                                  grid_size=im_size).to(device)

    # test with sparse matrices
    spmat = tkbn.calc_tensor_spmatrix(
        ktraj,
        im_size,
        grid_size=im_size,
    )

    nufft_adjoint_test(image, kdata, ktraj, forw_ob, adj_ob, spmat)

    torch.set_default_dtype(default_dtype)
示例#4
0
def profile_torchkbnufft(
    image,
    ktraj,
    smap,
    im_size,
    grid_size,
    device,
    sparse_mats_flag=False,
    toep_flag=False,
):
    # run double precision for CPU, float for GPU
    # these seem to be present in reference implementations
    if device == torch.device("cpu"):
        complex_dtype = torch.complex128
        real_dtype = torch.double
        if toep_flag:
            num_nuffts = 20
        else:
            num_nuffts = 5
    else:
        complex_dtype = torch.complex64
        real_dtype = torch.float
        if toep_flag:
            num_nuffts = 50
        else:
            num_nuffts = 20
    cpudevice = torch.device("cpu")

    res = ""
    image = image.to(dtype=complex_dtype)
    ktraj = ktraj.to(dtype=real_dtype)
    smap = smap.to(dtype=complex_dtype)
    interp_mats = None

    forw_ob = tkbn.KbNufft(im_size=im_size,
                           grid_size=grid_size,
                           dtype=complex_dtype,
                           device=device)
    adj_ob = tkbn.KbNufftAdjoint(im_size=im_size,
                                 grid_size=grid_size,
                                 dtype=complex_dtype,
                                 device=device)

    # precompute toeplitz kernel if using toeplitz
    if toep_flag:
        kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size, grid_size=grid_size)
        toep_ob = tkbn.ToepNufft()

    # precompute the sparse interpolation matrices
    if sparse_mats_flag:
        interp_mats = tkbn.calc_tensor_spmatrix(ktraj,
                                                im_size,
                                                grid_size=grid_size)
        interp_mats = tuple([t.to(device) for t in interp_mats])
    if toep_flag:
        # warm-up computation
        for _ in range(num_nuffts):
            x = toep_ob(
                image.to(device=device),
                kernel.to(device=device),
                smaps=smap.to(device=device),
            ).to(cpudevice)
        # run the speed tests
        if device == torch.device("cuda"):
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.synchronize()
        start_time = time.perf_counter()
        for _ in range(num_nuffts):
            x = toep_ob(image.to(device=device),
                        kernel.to(device=device),
                        smaps=smap.to(device))
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
            max_mem = torch.cuda.max_memory_allocated()
            res += "GPU forward max memory: {} GB, ".format(max_mem / 1e9)
        end_time = time.perf_counter()
        avg_time = (end_time - start_time) / num_nuffts
        res += "toeplitz forward/backward average time: {}".format(avg_time)
    else:
        # warm-up computation
        for _ in range(num_nuffts):
            y = forw_ob(
                image.to(device=device),
                ktraj.to(device=device),
                interp_mats,
                smaps=smap.to(device),
            ).to(cpudevice)

        # run the forward speed tests
        if device == torch.device("cuda"):
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.synchronize()
        start_time = time.perf_counter()
        for _ in range(num_nuffts):
            y = forw_ob(
                image.to(device=device),
                ktraj.to(device=device),
                interp_mats,
                smaps=smap.to(device),
            )
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
            max_mem = torch.cuda.max_memory_allocated()
            res += "GPU forward max memory: {} GB, ".format(max_mem / 1e9)
        end_time = time.perf_counter()
        avg_time = (end_time - start_time) / num_nuffts
        res += "forward average time: {}, ".format(avg_time)

        # warm-up computation
        for _ in range(num_nuffts):
            x = adj_ob(y.to(device),
                       ktraj.to(device),
                       interp_mats,
                       smaps=smap.to(device))

        # run the adjoint speed tests
        if device == torch.device("cuda"):
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.synchronize()
        start_time = time.perf_counter()
        for _ in range(num_nuffts):
            x = adj_ob(y.to(device),
                       ktraj.to(device),
                       interp_mats,
                       smaps=smap.to(device))
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
            max_mem = torch.cuda.max_memory_allocated()
            res += "GPU adjoint max memory: {} GB, ".format(max_mem / 1e9)
        end_time = time.perf_counter()
        avg_time = (end_time - start_time) / num_nuffts
        res += "backward average time: {}".format(avg_time)

    print(res)