Ejemplo n.º 1
0
def unwrap(phase, dim=None, bound='dct2', max_iter=0, tol=1e-5):
    """Laplacian unwrapping of the phase

    Parameters
    ----------
    phase : tensor
        Wrapped phase, in radian
    dim : int, default=phase.dim()
        Number of spatial dimensions
    max_iter : int, default=0
        Maximum number of unwrapping iterations.
        If 0, return the Laplacian filtered phase, which is not exactly
        equal to the input phase modulo 2 pi.
    tol : float, default=1e-5
        Tolerance for early stopping


    Returns
    -------
    unwrapped : tensor

    References
    ----------
    .. "Fast phase unwrapping algorithm for interferometric applications"
       Marvin A. Schofield and Yimei Zhu
       Optics Letters (2003)

    """
    # TODO: would be nice to use DCT/DST rather than padding once they
    #       are available in PyTorch.

    dim = dim or phase.dim()
    dims = list(range(-dim, 0))
    shape = bigshape = phase.shape[-dim:]

    if bound not in ('dct', 'circulant'):
        phase = utils.pad(phase, [d//2 for d in shape], side='both', mode=bound)
        bigshape = phase.shape[-dim:]

    freq = _laplacian_freq(bigshape, **utils.backend(phase))
    phase = fft.ifftshift(phase, dim=dims)
    twopi = 2 * pymath.pi

    if max_iter == 0:
        phase = _laplacian_filter(phase, freq, dims)
    else:
        for n_iter in range(1, max_iter+1):
            filtered_phase = _laplacian_filter(phase, freq, dims)
            filtered_phase.sub_(phase).div_(twopi).round_().mul_(twopi)
            phase += filtered_phase

            if n_iter < max_iter and filtered_phase.mean() < tol:
                break

    phase = fft.fftshift(phase, dim=dims)

    if bound not in ('dct', 'circulant'):
        slicer = [slice(d//2, d+d//2) for d in shape]
        phase = phase[(Ellipsis, *slicer)]
    return phase
Ejemplo n.º 2
0
def _laplacian_freq(shape, **backend):
    """
    Compute Fourier squared frequency on the lattice and its inverse.
    """
    dim = len(shape)
    shape = torch.as_tensor(shape, **backend)
    g = spatial.identity_grid(shape, **backend)
    g -= shape // 2
    g /= shape
    g = g.square_().sum(-1)
    if fft._torch_has_old_fft:
        g = g.unsqueeze(-1)
    g = fft.ifftshift(g, dim=list(range(dim)))
    ig = g.reciprocal()
    ig[(0,) * dim] = 0
    return g, ig
Ejemplo n.º 3
0
def greens(shape,
           absolute=_default_absolute,
           membrane=_default_membrane,
           bending=_default_bending,
           lame=_default_lame,
           factor=1,
           voxel_size=1,
           dtype=None,
           device=None):
    """Generate the Greens function of a regulariser in Fourier space.

    Parameters
    ----------
    shape : tuple[int]
        Output shape
    absolute : float, default=0.0001
        Penalty on absolute values
    membrane : float, default=0.001
        Penalty on membrane energy
    bending : float, default=0.2
        Penalty on bending energy
    lame : float or (float, float), default=(0.05, 0.2)
        Penalty on linear-elastic energy
    voxel_size : [sequence of[ float, default=1
        Voxel size
    dtype : torch.dtype, optional
    device : torch.device, optional

    Returns
    -------
    greens : (*shape, [dim, dim]) tensor

    """
    # Authors
    # -------
    # .. John Ashburner <*****@*****.**> : original Matlab code
    # .. Yael Balbastre <*****@*****.**> : Python port
    #
    # License
    # -------
    # The original Matlab code is (C) 2012-2019 WCHN / John Ashburner
    # and was distributed as part of [SPM](https://www.fil.ion.ucl.ac.uk/spm)
    # under the GNU General Public Licence (version >= 2).

    backend = dict(dtype=dtype, device=device)
    shape = py.make_tuple(shape)
    dim = len(shape)
    lame1, lame2 = py.make_list(lame, 2)
    if not absolute:
        absolute = max(absolute, max(membrane, bending, lame1, lame2) * 1e-3)
    prm = dict(
        absolute=absolute,
        membrane=membrane,
        bending=bending,
        # factor=factor,
        voxel_size=voxel_size,
        bound='dft')
    if lame1 or lame2:
        prm['lame'] = lame

    # allocate
    if lame1 or lame2:
        kernel = torch.zeros([*shape, dim, dim], **backend)
    else:
        kernel = torch.zeros([*shape, 1, 1], **backend)

    # only use center to generate kernel
    if bending:
        subkernel = kernel[tuple(slice(s // 2 - 2, s // 2 + 3) for s in shape)]
        subsize = 5
    else:
        subkernel = kernel[tuple(slice(s // 2 - 1, s // 2 + 2) for s in shape)]
        subsize = 3

    # generate kernel
    if lame1 or lame2:
        for d in range(dim):
            center = (subsize // 2, ) * dim + (d, d)
            subkernel[center] = 1
            subkernel[..., :, d] = regulariser_grid(subkernel[..., :, d],
                                                    **prm)
    else:
        center = (subsize // 2, ) * dim
        subkernel[center] = 1
        subkernel[..., 0, 0] = regulariser(subkernel[None, ..., 0, 0],
                                           **prm,
                                           dim=dim)[0]

    kernel = fft.ifftshift(kernel, dim=range(dim))

    # fourier transform
    #   symmetric kernel -> real coefficients

    dtype = kernel.dtype
    kernel = kernel.double()
    kernel = fft.real(fft.fftn(kernel, dim=list(range(dim)), real=True))

    # if utils.torch_version('>=', (1, 8)):
    #     kernel = utils.movedim(kernel, [-2, -1], [0, 1])
    #     kernel = torch.fft.fftn(kernel, dim=list(range(-dim, 0))).real
    #     if callable(kernel):
    #         kernel = kernel()
    #     kernel = utils.movedim(kernel, [0, 1], [-2, -1])
    # else:
    #     kernel = utils.movedim(kernel, [-2, -1], [0, 1])
    #     if torch.backends.mkl.is_available:
    #         # use rfft
    #         kernel = torch.rfft(kernel, dim, onesided=False)
    #     else:
    #         zero = kernel.new_zeros([]).expand(kernel.shape)
    #         kernel = torch.stack([kernel, zero], dim=-1)
    #         kernel = torch.fft(kernel, dim)
    #     kernel = kernel[..., 0]  # should be real
    #     kernel = utils.movedim(kernel, [0, 1], [-2, -1])

    kernel = kernel.to(dtype=dtype)

    if lame1 or lame2:
        kernel = _inv(kernel)  #kernel.inverse()
    else:
        kernel = kernel[..., 0, 0].reciprocal_()
    return kernel
Ejemplo n.º 4
0
def mrfield_greens(shape,
                   absolute=0,
                   membrane=0,
                   bending=0,
                   factor=1,
                   voxel_size=1,
                   dtype=None,
                   device=None):
    """Generate the Greens function of a regulariser in Fourier space.

    Parameters
    ----------
    shape : tuple[int]
        Output shape
    absolute : float, default=0.0001
        Penalty on absolute values
    membrane : float, default=0.001
        Penalty on membrane energy
    bending : float, default=0.2
        Penalty on bending energy
    voxel_size : [sequence of[ float, default=1
        Voxel size
    dtype : torch.dtype, optional
    device : torch.device, optional

    Returns
    -------
    greens : (*shape, [dim, dim]) tensor

    """
    # Adapted from the geodesic shooting code

    backend = dict(dtype=dtype, device=device)
    shape = py.make_tuple(shape)
    dim = len(shape)
    if not absolute:
        # we need some regularization to invert
        absolute = max(absolute, max(membrane, bending) * 1e-6)
    prm = dict(absolute=absolute,
               membrane=membrane,
               bending=bending,
               factor=factor,
               voxel_size=voxel_size,
               bound='dft')

    # allocate
    kernel = torch.zeros(shape, **backend)

    # only use center to generate kernel
    if bending:
        subkernel = kernel[tuple(slice(s // 2 - 2, s // 2 + 3) for s in shape)]
        subsize = 5
    else:
        subkernel = kernel[tuple(slice(s // 2 - 1, s // 2 + 2) for s in shape)]
        subsize = 3

    # generate kernel
    center = (subsize // 2, ) * dim
    subkernel[center] = 1
    subkernel[...] = regulariser(subkernel, **prm, dim=dim)

    kernel = ifftshift(kernel, dim=range(dim))

    # fourier transform
    #   symmetric kernel -> real coefficients

    if utils.torch_version('>=', (1, 8)):
        kernel = torch.fft.fftn(kernel, dim=dim).real()
    else:
        if torch.backends.mkl.is_available:
            # use rfft
            kernel = torch.rfft(kernel, dim, onesided=False)
        else:
            zero = kernel.new_zeros([]).expand(kernel.shape)
            kernel = torch.stack([kernel, zero], dim=-1)
            kernel = torch.fft(kernel, dim)
        kernel = kernel[..., 0]  # should be real

    kernel = kernel.reciprocal_()
    return kernel
Ejemplo n.º 5
0
def mrfield_greens2(shape, zdim=-1, voxel_size=1, dtype=None, device=None):
    """Semi-analytical second derivative of the Greens kernel.

    This function implements exactly the solution from Jenkinson et al.
    (Same as in the FSL source code), with the assumption that
    no gradients are played and the main field is constant and has
    no orthogonal components (Bz = B0, Bx = By = 0).

    The Greens kernel and its second derivatives are derived analytically
    and integrated numerically over a voxel.

    The returned tensor has already been Fourier transformed and could
    be cached if multiple field simulations with the same lattice size
    must be performed in a row.

    Parameters
    ----------
    shape : sequence of int
        Lattice shape
    zdim : int, defualt=-1
        Dimension of the main magnetic field
    voxel_size : [sequence of] int
        Voxel size
    dtype : torch.dtype, optional
    device : torch.device, optional

    Returns
    -------
    kernel : (*shape) tensor
        Fourier transform of the (second derivatives of the) Greens kernel.

    """
    import itertools

    def atan(num, den):
        return torch.where(den.abs() > 1e-8, torch.atan_(num / den),
                           torch.atan2(num, den))

    dim = len(py.make_list(shape))
    g0 = identity_grid(shape, dtype=dtype, device=device)
    voxel_size = utils.make_vector(voxel_size, dim,
                                   dtype=torch.double).tolist()

    if dim == 3:
        if zdim in (-1, 2):
            odims = [-3, -2]
        elif zdim in (-2, 1):
            odims = [-3, -1]
        elif zdim in (-3, 0):
            odims = [-2, -1]
    else:
        raise NotImplementedError

    def make_shifted(shift):
        g = g0.clone()
        for g1, s, v, t in zip(g.unbind(-1), shape, voxel_size, shift):
            g1 -= s // 2  # make center voxel zero
            g1 += t  # apply shift
            g1 *= v  # convert to mm
        return g

    g = 0
    for shift in itertools.product([-0.5, 0.5], repeat=dim):
        g1 = make_shifted(shift)
        if dim == 3:
            r = g1.square().sum(-1).sqrt_()
            g1 = atan(g1[..., odims[0]] * g1[..., odims[1]], g1[..., zdim] * r)
        else:
            raise NotImplementedError
        if py.prod(shift) < 0:
            g -= g1
        else:
            g += g1

    g /= 4. * constants.pi
    g = ifftshift(g, range(dim))  # move center voxel to first voxel

    # fourier transform
    #   symmetric kernel -> real coefficients

    if utils.torch_version('>=', (1, 8)):
        g = torch.fft.fftn(g, dim=dim).real()
    else:
        if torch.backends.mkl.is_available:
            # use rfft
            g = torch.rfft(g, dim, onesided=False)
        else:
            zero = g.new_zeros([]).expand(g.shape)
            g = torch.stack([g, zero], dim=-1)
            g = torch.fft(g, dim)
        g = g[..., 0]  # should be real
    return g