Example #1
0
def dist_map(shape, dtype=None, device=None):
    """Return the squared distance between all pairs in a FOV.

    Parameters
    ----------
    shape : sequence[int]
    dtype : optional
    device : optional

    Returns
    -------
    dist : (prod(shape), proD(shape) tensor
        Squared distance map

    """
    backend = dict(dtype=dtype, device=device)
    shape = py.make_tuple(shape)
    dim = len(shape)
    g = spatial.identity_grid(shape, **backend)
    g = g.reshape([-1, dim])
    g = (g[:, None, :] - g[None, :, :]).square_().sum(-1)
    return g
Example #2
0
def susceptibility_phantom(shape, radius=None, dtype=None, device=None):
    """Generate a circle/sphere susceptibility phantom

    Parameters
    ----------
    shape : sequence of int
    radius : default=shape/4
    dtype : optional
    backend : optional

    Returns
    -------
    f : (*shape) tensor[bool]
        susceptibility delta map

    """

    shape = py.make_tuple(shape)
    radius = radius or (min(shape) / 4.)
    f = identity_grid(shape, dtype=dtype, device=device)
    for comp, s in zip(f.unbind(-1), shape):
        comp -= s / 2
    f = f.square().sum(-1).sqrt() <= radius
    return f
Example #3
0
def greens(shape,
           absolute=_default_absolute,
           membrane=_default_membrane,
           bending=_default_bending,
           lame=_default_lame,
           factor=1,
           voxel_size=1,
           dtype=None,
           device=None):
    """Generate the Greens function of a regulariser in Fourier space.

    Parameters
    ----------
    shape : tuple[int]
        Output shape
    absolute : float, default=0.0001
        Penalty on absolute values
    membrane : float, default=0.001
        Penalty on membrane energy
    bending : float, default=0.2
        Penalty on bending energy
    lame : float or (float, float), default=(0.05, 0.2)
        Penalty on linear-elastic energy
    voxel_size : [sequence of[ float, default=1
        Voxel size
    dtype : torch.dtype, optional
    device : torch.device, optional

    Returns
    -------
    greens : (*shape, [dim, dim]) tensor

    """
    # Authors
    # -------
    # .. John Ashburner <*****@*****.**> : original Matlab code
    # .. Yael Balbastre <*****@*****.**> : Python port
    #
    # License
    # -------
    # The original Matlab code is (C) 2012-2019 WCHN / John Ashburner
    # and was distributed as part of [SPM](https://www.fil.ion.ucl.ac.uk/spm)
    # under the GNU General Public Licence (version >= 2).

    backend = dict(dtype=dtype, device=device)
    shape = py.make_tuple(shape)
    dim = len(shape)
    lame1, lame2 = py.make_list(lame, 2)
    if not absolute:
        absolute = max(absolute, max(membrane, bending, lame1, lame2) * 1e-3)
    prm = dict(
        absolute=absolute,
        membrane=membrane,
        bending=bending,
        # factor=factor,
        voxel_size=voxel_size,
        bound='dft')
    if lame1 or lame2:
        prm['lame'] = lame

    # allocate
    if lame1 or lame2:
        kernel = torch.zeros([*shape, dim, dim], **backend)
    else:
        kernel = torch.zeros([*shape, 1, 1], **backend)

    # only use center to generate kernel
    if bending:
        subkernel = kernel[tuple(slice(s // 2 - 2, s // 2 + 3) for s in shape)]
        subsize = 5
    else:
        subkernel = kernel[tuple(slice(s // 2 - 1, s // 2 + 2) for s in shape)]
        subsize = 3

    # generate kernel
    if lame1 or lame2:
        for d in range(dim):
            center = (subsize // 2, ) * dim + (d, d)
            subkernel[center] = 1
            subkernel[..., :, d] = regulariser_grid(subkernel[..., :, d],
                                                    **prm)
    else:
        center = (subsize // 2, ) * dim
        subkernel[center] = 1
        subkernel[..., 0, 0] = regulariser(subkernel[None, ..., 0, 0],
                                           **prm,
                                           dim=dim)[0]

    kernel = fft.ifftshift(kernel, dim=range(dim))

    # fourier transform
    #   symmetric kernel -> real coefficients

    dtype = kernel.dtype
    kernel = kernel.double()
    kernel = fft.real(fft.fftn(kernel, dim=list(range(dim)), real=True))

    # if utils.torch_version('>=', (1, 8)):
    #     kernel = utils.movedim(kernel, [-2, -1], [0, 1])
    #     kernel = torch.fft.fftn(kernel, dim=list(range(-dim, 0))).real
    #     if callable(kernel):
    #         kernel = kernel()
    #     kernel = utils.movedim(kernel, [0, 1], [-2, -1])
    # else:
    #     kernel = utils.movedim(kernel, [-2, -1], [0, 1])
    #     if torch.backends.mkl.is_available:
    #         # use rfft
    #         kernel = torch.rfft(kernel, dim, onesided=False)
    #     else:
    #         zero = kernel.new_zeros([]).expand(kernel.shape)
    #         kernel = torch.stack([kernel, zero], dim=-1)
    #         kernel = torch.fft(kernel, dim)
    #     kernel = kernel[..., 0]  # should be real
    #     kernel = utils.movedim(kernel, [0, 1], [-2, -1])

    kernel = kernel.to(dtype=dtype)

    if lame1 or lame2:
        kernel = _inv(kernel)  #kernel.inverse()
    else:
        kernel = kernel[..., 0, 0].reciprocal_()
    return kernel
Example #4
0
def empirical_cov(series,
                  nb_dim=1,
                  dim=None,
                  subtract_mean=True,
                  flatten=False,
                  keepdim=False,
                  return_mean=False):
    """Compute an empirical covariance

    Parameters
    ----------
    series : (..., *dims) tensor_like
        Sample series
    nb_dim : int, default=1
        Number of spatial dimensions.
    dim : [sequence of] int, default=None
        Dimensions that are reduced when computing the covariance.
        If None: all but the last `nb_dim`.
    subtract_mean : bool, default=True
        Subtract empirical mean before computing the covariance.
    flatten : bool, default=False
        If True, flatten the 'covariance' dimensions.
    keepdim : bool, default=False
        Keep reduced dimensions.

    Returns
    -------
    cov : (..., *dims, *dims) or (..., prod(dims), prod(dims)) tensor
        Covariance.
    mean : (..., *dims) or (..., prod(dims)) tensor, if `return_mean`
        Mean.

    """

    # Convert to tensor
    series = torch.as_tensor(series)
    prespatial = series.shape[:-nb_dim]
    spatial = series.shape[-nb_dim:]

    if dim is None:
        dim = range(series.dim() - nb_dim)
    dim = py.make_tuple(dim)
    dim = [series.dim() + d if d < 0 else d for d in dim]

    reduced = [prespatial[d] for d in dim]
    batch = [
        prespatial[d] for d in range(series.dim() - nb_dim) if d not in dim
    ]

    # Subtract mean
    if subtract_mean:
        mean = series.mean(dim=dim, keepdim=True)
        series = series - mean

    # Compute empirical covariance.
    series = series.reshape([*series.shape[:-nb_dim], -1])
    series = utils.movedim(series, dim, -2)
    series = series.reshape([*batch, -1, series.shape[-1]])
    n_reduced = series.shape[-2]
    n_vox = series.shape[-1]
    # (*batch, reduced, spatial)

    # Torch's matmul just uses too much memory
    # We don't expect to have more than about 100 time frames,
    # so it is better to unroll the loop in python.
    # cov = torch.matmul(series.transpose(-1, -2), series)
    cov = None
    buf = series.new_empty([*batch, n_vox, n_vox])
    for i in range(n_reduced):
        buf = torch.mul(series.transpose(-1, -2)[..., :, i, None],
                        series[..., i, None, :],
                        out=buf)
        if cov is None:
            cov = buf.clone()
        else:
            cov += buf
    cov /= py.prod(reduced)

    if keepdim:
        outshape = [1 if d in dim else s for d, s in enumerate(prespatial)]
    else:
        outshape = list(batch)
    if flatten:
        outshape_mean = outshape + [py.prod(spatial)]
        outshape += [py.prod(spatial)] * 2
    else:
        outshape_mean = outshape + list(spatial)
        outshape += list(spatial) * 2

    cov = cov.reshape(outshape)
    if return_mean:
        mean = mean.reshape(outshape_mean)
        return cov, mean
    return cov
Example #5
0
def mrfield_greens(shape,
                   absolute=0,
                   membrane=0,
                   bending=0,
                   factor=1,
                   voxel_size=1,
                   dtype=None,
                   device=None):
    """Generate the Greens function of a regulariser in Fourier space.

    Parameters
    ----------
    shape : tuple[int]
        Output shape
    absolute : float, default=0.0001
        Penalty on absolute values
    membrane : float, default=0.001
        Penalty on membrane energy
    bending : float, default=0.2
        Penalty on bending energy
    voxel_size : [sequence of[ float, default=1
        Voxel size
    dtype : torch.dtype, optional
    device : torch.device, optional

    Returns
    -------
    greens : (*shape, [dim, dim]) tensor

    """
    # Adapted from the geodesic shooting code

    backend = dict(dtype=dtype, device=device)
    shape = py.make_tuple(shape)
    dim = len(shape)
    if not absolute:
        # we need some regularization to invert
        absolute = max(absolute, max(membrane, bending) * 1e-6)
    prm = dict(absolute=absolute,
               membrane=membrane,
               bending=bending,
               factor=factor,
               voxel_size=voxel_size,
               bound='dft')

    # allocate
    kernel = torch.zeros(shape, **backend)

    # only use center to generate kernel
    if bending:
        subkernel = kernel[tuple(slice(s // 2 - 2, s // 2 + 3) for s in shape)]
        subsize = 5
    else:
        subkernel = kernel[tuple(slice(s // 2 - 1, s // 2 + 2) for s in shape)]
        subsize = 3

    # generate kernel
    center = (subsize // 2, ) * dim
    subkernel[center] = 1
    subkernel[...] = regulariser(subkernel, **prm, dim=dim)

    kernel = ifftshift(kernel, dim=range(dim))

    # fourier transform
    #   symmetric kernel -> real coefficients

    if utils.torch_version('>=', (1, 8)):
        kernel = torch.fft.fftn(kernel, dim=dim).real()
    else:
        if torch.backends.mkl.is_available:
            # use rfft
            kernel = torch.rfft(kernel, dim, onesided=False)
        else:
            zero = kernel.new_zeros([]).expand(kernel.shape)
            kernel = torch.stack([kernel, zero], dim=-1)
            kernel = torch.fft(kernel, dim)
        kernel = kernel[..., 0]  # should be real

    kernel = kernel.reciprocal_()
    return kernel