Exemplo n.º 1
0
def test_nufft_autograd(shape, kdata_shape, is_complex):
    default_dtype = torch.get_default_dtype()
    torch.set_default_dtype(torch.double)
    torch.manual_seed(123)
    if is_complex:
        im_size = shape[2:]
    else:
        im_size = shape[2:-1]

    image = create_input_plus_noise(shape, is_complex)
    kdata = create_input_plus_noise(kdata_shape, is_complex)
    ktraj = create_ktraj(len(im_size), kdata_shape[2])

    forw_ob = tkbn.KbNufft(im_size=im_size)
    adj_ob = tkbn.KbNufftAdjoint(im_size=im_size)

    # test with sparse matrices
    spmat = tkbn.calc_tensor_spmatrix(
        ktraj,
        im_size,
    )

    nufft_autograd_test(image, kdata, ktraj, forw_ob, adj_ob, spmat)

    torch.set_default_dtype(default_dtype)
Exemplo n.º 2
0
def calc_one_batch_toeplitz_kernel(
    omega: Tensor,
    im_size: Sequence[int],
    weights: Optional[Tensor] = None,
    norm: Optional[str] = None,
    grid_size: Optional[Sequence[int]] = None,
    numpoints: Union[int, Sequence[int]] = 6,
    n_shift: Optional[Sequence[int]] = None,
    table_oversamp: Union[int, Sequence[int]] = 2**10,
    kbwidth: float = 2.34,
    order: Union[float, Sequence[float]] = 0.0,
) -> Tensor:
    """See calc_toeplitz_kernel()."""
    device = omega.device
    normalized = True if norm == "ortho" else False

    adj_ob = tkbn.KbNufftAdjoint(
        im_size=im_size,
        grid_size=grid_size,
        numpoints=numpoints,
        n_shift=[0 for _ in range(omega.shape[0])],
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        dtype=omega.dtype,
        device=omega.device,
    )

    # if we don't have any weights, just use ones
    assert isinstance(adj_ob.table_0, Tensor)
    if weights is None:
        weights = torch.ones(omega.shape[-1],
                             dtype=adj_ob.table_0.dtype,
                             device=device)
        weights = weights.unsqueeze(0).unsqueeze(0)
    else:
        weights = weights.to(adj_ob.table_0)

    # apply adjoints to n-1 dimensions
    if omega.shape[0] > 1:
        kernel = adjoint_flip_and_concat(1, omega, weights, adj_ob, norm)
    else:
        kernel = adj_ob(weights, omega, norm=norm)

    # now that we have half the kernel
    # we can use Hermitian symmetry
    kernel = reflect_conj_concat(kernel, 2)

    # make sure kernel is Hermitian symmetric
    kernel = hermitify(kernel, 2)

    # put the kernel in fft space
    return fft_fn(kernel, omega.shape[0], normalized=normalized)[0, 0]
Exemplo n.º 3
0
def test_dcomp_run(shape, kdata_shape, is_complex):
    default_dtype = torch.get_default_dtype()
    torch.set_default_dtype(torch.double)
    torch.manual_seed(123)
    if is_complex:
        im_size = shape[2:]
    else:
        im_size = shape[2:-1]

    kdata = create_input_plus_noise(kdata_shape, is_complex)
    ktraj = create_ktraj(len(im_size), kdata_shape[2])

    adj_ob = tkbn.KbNufftAdjoint(im_size=im_size)
    dcomp = tkbn.calc_density_compensation_function(ktraj=ktraj, im_size=im_size)

    if not is_complex:
        dcomp = torch.view_as_real(dcomp)

    _ = adj_ob(kdata * dcomp, ktraj)

    torch.set_default_dtype(default_dtype)
Exemplo n.º 4
0
def test_toeplitz_nufft_accuracy(shape, kdata_shape, is_complex):
    norm_diff_tol = 1e-4  # toeplitz is only approximate
    default_dtype = torch.get_default_dtype()
    torch.set_default_dtype(torch.double)
    torch.manual_seed(123)
    if is_complex:
        im_size = shape[2:]
    else:
        im_size = shape[2:-1]
    im_shape = [s for s in shape]
    im_shape[1] = 1

    image = create_input_plus_noise(im_shape, is_complex)
    smaps = create_input_plus_noise(shape, is_complex)
    ktraj = create_ktraj(len(im_size), kdata_shape[2])

    forw_ob = tkbn.KbNufft(im_size=im_size)
    adj_ob = tkbn.KbNufftAdjoint(im_size=im_size)
    toep_ob = tkbn.ToepNufft()

    kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size, norm="ortho")

    fbn = adj_ob(
        forw_ob(image, ktraj, smaps=smaps, norm="ortho"),
        ktraj,
        smaps=smaps,
        norm="ortho",
    )
    fbt = toep_ob(image, kernel, smaps=smaps, norm="ortho")

    if is_complex:
        fbn = torch.view_as_real(fbn)
        fbt = torch.view_as_real(fbt)

    norm_diff = torch.norm(fbn - fbt) / torch.norm(fbn)

    assert norm_diff < norm_diff_tol

    torch.set_default_dtype(default_dtype)
Exemplo n.º 5
0
    def radial(self,
               out_ksp,
               full_ksp,
               under_ksp,
               metadict,
               device="cpu"):  #TODO: switch to cuda
        # om = torch.from_numpy(metadict['om'].transpose()).to(torch.float).to(device)
        # invom = torch.from_numpy(metadict['invom'].transpose()).to(torch.float).to(device)
        # fullom = torch.from_numpy(metadict['fullom'].transpose()).to(torch.float).to(device)
        # # dcf = torch.from_numpy(metadict['dcf'].squeeze())
        # dcfFullRes = torch.from_numpy(metadict['dcfFullRes'].squeeze()).to(torch.float).to(device)
        baseresolution = out_ksp.shape[0] * 2
        Nd = (baseresolution, baseresolution)
        imsize = out_ksp.shape[:2]

        nufft_ob = tkbn.KbNufft(
            im_size=imsize,
            grid_size=Nd,
        ).to(torch.complex64).to(device)
        adjnufft_ob = tkbn.KbNufftAdjoint(
            im_size=imsize,
            grid_size=Nd,
        ).to(torch.complex64).to(device)

        # intrp_ob = tkbn.KbInterp(
        #     im_size=imsize,
        #     grid_size=Nd,
        # ).to(torch.complex64).to(device)

        out_img = ifftNc(data=out_ksp, dim=(0, 1), norm="ortho").to(device)
        full_img = ifftNc(data=full_ksp, dim=(0, 1), norm="ortho").to(device)

        if len(out_img.shape) == 3:
            out_img = torch.permute(out_img, dims=(2, 0, 1)).unsqueeze(1)
            full_img = torch.permute(full_img, dims=(2, 0, 1)).unsqueeze(1)
        else:
            out_img = out_img.unsqueeze(0).unsqueeze(0)
            full_img = full_img.unsqueeze(0).unsqueeze(0)

        # out_img = torch.permute(out_ksp, dims=(2,0,1)).unsqueeze(1).to(device)
        # full_img = torch.permute(full_ksp, dims=(2,0,1)).unsqueeze(1).to(device)

        spokelength = full_img.shape[-1] * 2
        grid_size = (spokelength, spokelength)
        nspokes = 512

        ga = np.deg2rad(180 / ((1 + np.sqrt(5)) / 2))
        kx = np.zeros(shape=(spokelength, nspokes))
        ky = np.zeros(shape=(spokelength, nspokes))
        ky[:, 0] = np.linspace(-np.pi, np.pi, spokelength)
        for i in range(1, nspokes):
            kx[:, i] = np.cos(ga) * kx[:, i - 1] - np.sin(ga) * ky[:, i - 1]
            ky[:, i] = np.sin(ga) * kx[:, i - 1] + np.cos(ga) * ky[:, i - 1]

        ky = np.transpose(ky)
        kx = np.transpose(kx)

        fullom = torch.from_numpy(
            np.stack((ky.flatten(), kx.flatten()),
                     axis=0)).to(torch.float).to(device)
        om = fullom[:, :30720]
        invom = fullom[:, 30720:]
        dcfFullRes = tkbn.calc_density_compensation_function(
            ktraj=fullom, im_size=imsize).to(device)

        yUnder = nufft_ob(full_img, om, norm="ortho")
        yMissing = nufft_ob(out_img, invom, norm="ortho")
        # yUnder = intrp_ob(full_img, om)
        # yMissing = intrp_ob(out_img, invom)
        yCorrected = torch.concat((yUnder, yMissing), dim=-1)
        yCorrected = dcfFullRes * yCorrected
        out_corrected_img = adjnufft_ob(yCorrected, fullom,
                                        norm="ortho").squeeze()

        out_corrected_img = torch.abs(out_corrected_img)
        out_corrected_img = (out_corrected_img - out_corrected_img.min()) / \
            (out_corrected_img.max() - out_corrected_img.min())

        if len(out_corrected_img.shape) == 3:
            out_corrected_img = torch.permute(out_corrected_img,
                                              dims=(1, 2, 0))

        out_corrected_ksp = fftNc(data=out_corrected_img,
                                  dim=(0, 1),
                                  norm="ortho").cpu()
        return out_corrected_ksp
Exemplo n.º 6
0
def calc_toeplitz_kernel(
    omega: Tensor,
    im_size: Sequence[int],
    weights: Optional[Tensor] = None,
    norm: Optional[str] = None,
    grid_size: Optional[Sequence[int]] = None,
    numpoints: Union[int, Sequence[int]] = 6,
    n_shift: Optional[Sequence[int]] = None,
    table_oversamp: Union[int, Sequence[int]] = 2 ** 10,
    kbwidth: float = 2.34,
    order: Union[float, Sequence[float]] = 0.0,
) -> Tensor:
    r"""Calculates an FFT kernel for Toeplitz embedding.

    The kernel is calculated using a adjoint NUFFT object. If the adjoint
    applies :math:`A'`, then this script calculates :math:`D` where
    :math:`F'DF \approx A'WA`, where :math:`F` is a DFT matrix and :math:`W` is
    a set of non-Cartesian k-space weights. :math:`D` can then be used to
    approximate :math:`A'WA` without any interpolation operations.

    For details on Toeplitz embedding, see
    `Efficient numerical methods in non-uniform sampling theory
    (Feichtinger et al.)
    <https://link.springer.com/article/10.1007/s002110050101>`_.

    This function has optional parameters for initializing a NUFFT object. See
    :py:class:`~torchkbnufft.KbNufftAdjoint` for details.

    Note:

        This function is intended to be used in conjunction with
        :py:class:`~torchkbnufft.ToepNufft` for forward operations.

    * :attr:`omega` should be of size ``(len(im_size), klength)``,
      where ``klength`` is the length of the k-space trajectory.

    Args:
        omega: k-space trajectory (in radians/voxel).
        im_size: Size of image with length being the number of dimensions.
        weights: Non-Cartesian k-space weights (e.g., density compensation).
            Default: ``torch.ones(omega.shape[1])``
        norm: Whether to apply normalization with the FFT operation. Options
            are ``"ortho"`` or ``None``.
        grid_size: Size of grid to use for interpolation, typically 1.25 to 2
            times ``im_size``. Default: ``2 * im_size``
        numpoints: Number of neighbors to use for interpolation in each
            dimension.
        n_shift: Size for fftshift. Default: ``im_size // 2``.
        table_oversamp: Table oversampling factor.
        kbwidth: Size of Kaiser-Bessel kernel.
        order: Order of Kaiser-Bessel kernel.

    Returns:
        The FFT kernel for approximating the forward/adjoint operation.

    Examples:

        >>> image = torch.randn(1, 1, 8, 8) + 1j * torch.randn(1, 1, 8, 8)
        >>> omega = torch.rand(2, 12) * 2 * np.pi - np.pi
        >>> toep_ob = tkbn.ToepNufft()
        >>> kernel = tkbn.calc_toeplitz_kernel(omega, im_size=(8, 8))
        >>> image = toep_ob(image, kernel)
    """
    device = omega.device
    normalized = True if norm == "ortho" else False

    adj_ob = tkbn.KbNufftAdjoint(
        im_size=im_size,
        grid_size=grid_size,
        numpoints=numpoints,
        n_shift=[0 for _ in range(omega.shape[0])],
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        dtype=omega.dtype,
        device=omega.device,
    )

    # if we don't have any weights, just use ones
    assert isinstance(adj_ob.table_0, Tensor)
    if weights is None:
        weights = torch.ones(omega.shape[-1], dtype=adj_ob.table_0.dtype, device=device)
        weights = weights.unsqueeze(0).unsqueeze(0)
    else:
        weights = weights.to(adj_ob.table_0)

    # apply adjoints to n-1 dimensions
    if omega.shape[0] > 1:
        kernel = adjoint_flip_and_concat(1, omega, weights, adj_ob, norm)
    else:
        kernel = adj_ob(weights, omega, norm=norm)

    # now that we have half the kernel
    # we can use Hermitian symmetry
    kernel = reflect_conj_concat(kernel, 2)

    # make sure kernel is Hermitian symmetric
    kernel = hermitify(kernel, 2)

    # put the kernel in fft space
    return fft_fn(kernel, omega.shape[0], normalized=normalized)
Exemplo n.º 7
0
def profile_torchkbnufft(
    image,
    ktraj,
    smap,
    im_size,
    grid_size,
    device,
    sparse_mats_flag=False,
    toep_flag=False,
):
    # run double precision for CPU, float for GPU
    # these seem to be present in reference implementations
    if device == torch.device("cpu"):
        complex_dtype = torch.complex128
        real_dtype = torch.double
        if toep_flag:
            num_nuffts = 20
        else:
            num_nuffts = 5
    else:
        complex_dtype = torch.complex64
        real_dtype = torch.float
        if toep_flag:
            num_nuffts = 50
        else:
            num_nuffts = 20
    cpudevice = torch.device("cpu")

    res = ""
    image = image.to(dtype=complex_dtype)
    ktraj = ktraj.to(dtype=real_dtype)
    smap = smap.to(dtype=complex_dtype)
    interp_mats = None

    forw_ob = tkbn.KbNufft(im_size=im_size,
                           grid_size=grid_size,
                           dtype=complex_dtype,
                           device=device)
    adj_ob = tkbn.KbNufftAdjoint(im_size=im_size,
                                 grid_size=grid_size,
                                 dtype=complex_dtype,
                                 device=device)

    # precompute toeplitz kernel if using toeplitz
    if toep_flag:
        kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size, grid_size=grid_size)
        toep_ob = tkbn.ToepNufft()

    # precompute the sparse interpolation matrices
    if sparse_mats_flag:
        interp_mats = tkbn.calc_tensor_spmatrix(ktraj,
                                                im_size,
                                                grid_size=grid_size)
        interp_mats = tuple([t.to(device) for t in interp_mats])
    if toep_flag:
        # warm-up computation
        for _ in range(num_nuffts):
            x = toep_ob(
                image.to(device=device),
                kernel.to(device=device),
                smaps=smap.to(device=device),
            ).to(cpudevice)
        # run the speed tests
        if device == torch.device("cuda"):
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.synchronize()
        start_time = time.perf_counter()
        for _ in range(num_nuffts):
            x = toep_ob(image.to(device=device),
                        kernel.to(device=device),
                        smaps=smap.to(device))
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
            max_mem = torch.cuda.max_memory_allocated()
            res += "GPU forward max memory: {} GB, ".format(max_mem / 1e9)
        end_time = time.perf_counter()
        avg_time = (end_time - start_time) / num_nuffts
        res += "toeplitz forward/backward average time: {}".format(avg_time)
    else:
        # warm-up computation
        for _ in range(num_nuffts):
            y = forw_ob(
                image.to(device=device),
                ktraj.to(device=device),
                interp_mats,
                smaps=smap.to(device),
            ).to(cpudevice)

        # run the forward speed tests
        if device == torch.device("cuda"):
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.synchronize()
        start_time = time.perf_counter()
        for _ in range(num_nuffts):
            y = forw_ob(
                image.to(device=device),
                ktraj.to(device=device),
                interp_mats,
                smaps=smap.to(device),
            )
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
            max_mem = torch.cuda.max_memory_allocated()
            res += "GPU forward max memory: {} GB, ".format(max_mem / 1e9)
        end_time = time.perf_counter()
        avg_time = (end_time - start_time) / num_nuffts
        res += "forward average time: {}, ".format(avg_time)

        # warm-up computation
        for _ in range(num_nuffts):
            x = adj_ob(y.to(device),
                       ktraj.to(device),
                       interp_mats,
                       smaps=smap.to(device))

        # run the adjoint speed tests
        if device == torch.device("cuda"):
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.synchronize()
        start_time = time.perf_counter()
        for _ in range(num_nuffts):
            x = adj_ob(y.to(device),
                       ktraj.to(device),
                       interp_mats,
                       smaps=smap.to(device))
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
            max_mem = torch.cuda.max_memory_allocated()
            res += "GPU adjoint max memory: {} GB, ".format(max_mem / 1e9)
        end_time = time.perf_counter()
        avg_time = (end_time - start_time) / num_nuffts
        res += "backward average time: {}".format(avg_time)

    print(res)