Example #1
0
def transform_pointset_dense(points, grid, type='grid', bound='dct2'):
    """Transform a pointset

    Points must already be expressed in "grid voxels" coordinates.

    Parameters
    ----------
    points : (n, dim) tensor
        Set of coordinates, in voxel space
    grid : (*spatial, dim) tensor
        Dense transformation or displacement grid, in voxel space
    type : {'grid', 'disp'}, defualt='grid'
        Transformation or displacement
    bound : str, default='dct2'
        Boundary conditions for out-of-bounds data

    Returns
    -------
    points : (n, dim) tensor
        Transformed coordinates

    """

    dim = grid.shape[-1]
    points = utils.unsqueeze(points, 0, dim)
    grid = utils.movedim(grid, -1, 0)[None]
    delta = spatial.grid_pull(grid, points, bound=bound, extrapolate=True)
    delta = utils.movedim(delta, 1, -1)
    if type == 'disp':
        points = points + delta
    else:
        points = delta
    points = utils.squeeze(points, -2, dim - 1).squeeze(0)
    return points
Example #2
0
def _jhj(jac, hess):
    """J*H*J', where H is symmetric and stored sparse"""

    # Matlab symbolic toolbox
    #
    # 2D:
    # out[00] = h00*j00^2 + h11*j01^2 + 2*h01*j00*j01
    # out[11] = h00*j10^2 + h11*j11^2 + 2*h01*j10*j11
    # out[01] = h00*j00*j10 + h11*j01*j11 + h01*(j01*j10 + j00*j11)
    #
    # 3D:
    # out[00] = h00*j00^2 + 2*h01*j00*j01 + 2*h02*j00*j02 + h11*j01^2 + 2*h12*j01*j02 + h22*j02^2
    # out[11] = h00*j10^2 + 2*h01*j10*j11 + 2*h02*j10*j12 + h11*j11^2 + 2*h12*j11*j12 + h22*j12^2
    # out[22] = h00*j20^2 + 2*h01*j20*j21 + 2*h02*j20*j22 + h11*j21^2 + 2*h12*j21*j22 + h22*j22^2
    # out[01] = j10*(h00*j00 + h01*j01 + h02*j02) + j11*(h01*j00 + h11*j01 + h12*j02) + j12*(h02*j00 + h12*j01 + h22*j02)
    # out[02] = j20*(h00*j00 + h01*j01 + h02*j02) + j21*(h01*j00 + h11*j01 + h12*j02) + j22*(h02*j00 + h12*j01 + h22*j02)
    # out[12] = j20*(h00*j10 + h01*j11 + h02*j12) + j21*(h01*j10 + h11*j11 + h12*j12) + j22*(h02*j10 + h12*j11 + h22*j12)

    dim = jac.shape[-1]
    hess = utils.movedim(hess, -1, 0)
    jac = utils.movedim(jac, [-2, -1], [0, 1])
    if dim == 1:
        out = _jhj1(jac, hess)
    elif dim == 2:
        out = _jhj2(jac, hess)
    elif dim == 3:
        out = _jhj3(jac, hess)
    out = utils.movedim(out, 0, -1)
    return out
Example #3
0
def bending_grid(grid, voxel_size=1, bound='dft', weights=None):
    """Precision matrix for the Bending energy of a deformation grid

    Parameters
    ----------
    grid : (..., *spatial, dim) tensor
    voxel_size : float or sequence[float], default=1
    bound : str, default='dft'
    weights : (..., *spatial) tensor, optional

    Returns
    -------
    field : (..., *spatial, dim) tensor

    """
    grid = torch.as_tensor(grid)
    backend = dict(dtype=grid.dtype, device=grid.device)
    dim = grid.shape[-1]
    voxel_size = core.utils.make_vector(voxel_size, dim, **backend)
    if (voxel_size != 1).any():
        grid = grid * voxel_size
    grid = movedim(grid, -1, -(dim + 1))
    grid = bending(grid,
                   weights=weights,
                   voxel_size=voxel_size,
                   bound=bound,
                   dim=dim)
    grid = movedim(grid, -(dim + 1), -1)
    return grid
Example #4
0
def modulate_prior(M, G):
    if G is None:
        return M
    M = utils.movedim(M, 0, -1)
    M = M * G
    M /= M.sum(-1, keepdim=True)
    M = utils.movedim(M, -1, 0)
    return M
Example #5
0
def _inv(A):
    A = utils.movedim(A, [-2, -1], [0, 1])
    if len(A) == 3:
        A = _inv3(A)
    elif len(A) == 2:
        A = _inv2(A)
    else:
        raise NotImplementedError
    A = utils.movedim(A, [0, 1], [-2, -1])
    return A
Example #6
0
def _deform_1d(img, disp, grad=False):
    img = utils.movedim(img, 0, -2)
    disp = disp.unsqueeze(-1)
    disp = spatial.add_identity_grid(disp)
    wrp = spatial.grid_pull(img, disp, bound=BND, extrapolate=True)
    wrp = utils.movedim(wrp, -2, 0)
    if not grad:
        return wrp, None
    grd = spatial.grid_grad(img, disp, bound=BND, extrapolate=True)
    grd = utils.movedim(grd.squeeze(-1), -2, 0)
    return wrp, grd
Example #7
0
def topup_apply(pos, neg, vel, dim=-1, model='smalldef', modulation=True):
    """Apply a topup correction

    Parameters
    ----------
    pos : ([C], *spatial) tensor
        Images with positive readout polarity
    neg : ([C], *spatial) tensor
        Images with negative readout polarity
    vel : (*spatial) tensor
        1D displacement or velocity field
    dim : int, default=-1
        Readout dimension
    model : {'smalldef', 'svf'}, default='smalldef'
        Deformation model

    Returns
    -------
    pos : ([C], *spatial) tensor
        Images with positive polarity, unwarped
    neg : ([C], *spatial) tensor
        Images with negative polarity, unwarped

    """
    ndim = vel.dim()
    dim = (dim - ndim) if dim >= 0 else dim

    no_batch_pos = pos.dim() == ndim
    if no_batch_pos:
        pos = pos[None]
    pos = utils.movedim(pos, dim, -1)
    no_batch_neg = neg.dim() == ndim
    if no_batch_neg:
        neg = neg[None]
    neg = utils.movedim(neg, dim, -1)
    vel = utils.movedim(vel, dim, -1)

    phi, iphi, jac, ijac = _exp_1d(vel, model=model)
    pos, _ = _deform_1d(pos, phi)
    neg, _ = _deform_1d(neg, iphi)
    if modulation:
        pos *= jac
        neg *= ijac
    del phi, iphi, jac, ijac

    pos = utils.movedim(pos, -1, dim)
    neg = utils.movedim(neg, -1, dim)
    if no_batch_pos:
        pos = pos[0]
    if no_batch_neg:
        neg = neg[0]

    return pos, neg
Example #8
0
    def forward(self, source, target, source_affine=None, target_affine=None):
        """

        Parameters
        ----------
        source : (sX, sY, sZ) tensor or str
        target : (tX, tY, tZ) tensor or str
        source_affine : (4, 4) tensor, optional
        target_affine : (4, 4) tensor, optional

        Returns
        -------
        warped : (tX, tY, tZ) tensor
            Source warped to target
        velocity : (vX, vY, vZ, 3) tensor
            Stationary velocity field
        affine : (4, 4) tensor, optional
            Affine of the velocity space

        """
        if self.verbose:
            print('Preprocessing... ', end='', flush=True)
        source, source_affine, source_orig, source_affine_orig \
            = self.load(source, source_affine)
        target, target_affine, target_orig, target_affine_orig \
            = self.load(target, target_affine)
        source = spatial.reslice(source, source_affine, target_affine,
                                 target.shape)
        if self.verbose:
            print('done.', flush=True)
            print('Registering... ', end='', flush=True)
        source = source[None, None]
        target = target[None, None]
        warped, vel, grid = super().forward(source, target)
        if self.verbose:
            print('done.', flush=True)
        del source, target, warped
        vel = vel[0]
        grid = grid[0]
        grid -= spatial.identity_grid(grid.shape[:-1],
                                      dtype=grid.dtype,
                                      device=grid.device)
        right_affine = target_affine.inverse() @ target_affine_orig
        right_affine = spatial.affine_grid(right_affine, target_orig.shape)
        grid = spatial.grid_pull(utils.movedim(grid, -1, 0),
                                 right_affine,
                                 bound='nearest',
                                 extrapolate=True)
        grid = utils.movedim(grid, 0, -1).add_(right_affine)
        left_affine = source_affine_orig.inverse() @ target_affine
        grid = spatial.affine_matvec(left_affine, grid)
        warped = spatial.grid_pull(source_orig, grid)
        return warped, vel, target_affine
Example #9
0
def grid_jacobian(grid, sample=None, bound='dft', voxel_size=1, type='grid',
                  add_identity=True, extrapolate=True):
    """Compute the Jacobian of a transformation field

    Notes
    -----
    .. If `add_identity` is True, we compute the Jacobian
       of the transformation field (identity + displacement), even if
       a displacement is provided, by adding ones to the diagonal.
    .. If `sample` is not used, this function uses central finite
       differences to estimate the Jacobian.
    .. If 'sample' is provided, `grid_grad` is used to sample derivatives.

    Parameters
    ----------
    grid : (..., *spatial, dim) tensor
        Transformation or displacement field
    sample : (..., *spatial, dim) tensor, optional
        Coordinates to sample in the input grid.
    bound : str, default='dft'
        Boundary condition
    voxel_size : [sequence of] float, default=1
        Voxel size
    type : {'grid', 'disp'}, default='grid'
        Whether the input is a transformation ('grid') or displacement
        ('disp') field.
    add_identity : bool, default=True
        Adds the identity to the Jacobian of the displacement, making it
        the jacobian of the transformation.
    extrapolate : bool, default=True
        Extrapolate out-of-boudn data (only useful is `sample` is used)

    Returns
    -------
    jac : (..., *spatial, dim, dim) tensor
        Jacobian. In each matrix: jac[i, j] = d psi[i] / d xj

    """
    grid = torch.as_tensor(grid)
    dim = grid.shape[-1]
    shape = grid.shape[-dim-1:-1]
    if type == 'grid':
        grid = grid - identity_grid(shape, **utils.backend(grid))
    if sample is None:
        dims = list(range(-dim-1, -1))
        jac = diff(grid, dim=dims, bound=bound, voxel_size=voxel_size, side='c')
    else:
        grid = utils.movedim(grid, -1, -dim-1)
        jac = grid_grad(grid, sample, bound=bound, extrapolate=extrapolate)
        jac = utils.movedim(jac, -dim-2, -2)
    if add_identity:
        torch.diagonal(jac, 0, -1, -2).add_(1)
    return jac
Example #10
0
def smart_pull_grid(vel, grid, type='disp', *args, **kwargs):
    """Interpolate a velocity/grid/displacement field.

    Notes
    -----
    Defaults differ from grid_pull:
    - bound -> dft
    - extrapolate -> True

    Parameters
    ----------
    vel : ([batch], *spatial, ndim) tensor
        Velocity
    grid : ([batch], *spatial, ndim) tensor
        Transformation field
    kwargs : dict
        Options to ``grid_pull``

    Returns
    -------
    pulled_vel : ([batch], *spatial, ndim) tensor
        Velocity

    """
    if grid is None or vel is None:
        return vel
    kwargs.setdefault('bound', 'dft')
    kwargs.setdefault('extrapolate', True)
    dim = vel.shape[-1]
    if type == 'grid':
        id = spatial.identity_grid(vel.shape[-dim - 1:-1],
                                   **utils.backend(vel))
        vel = vel - id
    vel = utils.movedim(vel, -1, -dim - 1)
    vel_no_batch = vel.dim() == dim + 1
    grid_no_batch = grid.dim() == dim + 1
    if vel_no_batch:
        vel = vel[None]
    if grid_no_batch:
        grid = grid[None]
    vel = spatial.grid_pull(vel, grid, *args, **kwargs)
    vel = utils.movedim(vel, -dim - 1, -1)
    if vel_no_batch:
        vel = vel[0]
    if type == 'grid':
        id = spatial.identity_grid(vel.shape[-dim - 1:-1],
                                   **utils.backend(vel))
        vel += id
    return vel
Example #11
0
 def load(files, is_label=False):
     """Load one multi-channel multi-file volume.
     Returns a (channels, *spatial) tensor
     """
     dats = []
     for file in files:
         if is_label:
             dat = io.volumes.load(file.fname,
                                   dtype=torch.int32, device=device)
         else:
             dat = io.volumes.loadf(file.fname, rand=True,
                                    dtype=torch.float32, device=device)
         dat = dat.reshape([*file.shape, file.channels])
         dat = dat[..., file.subchannels]
         dat = utils.movedim(dat, -1, 0)
         dim = dat.dim() - 1
         qt = utils.quantile(dat, (0.01, 0.95), dim=range(-dim, 0), keepdim=True)
         mn, mx = qt.unbind(-1)
         dat = dat.sub_(mn).div_(mx-mn)
         dats.append(dat)
         del dat
     dats = torch.cat(dats, dim=0)
     if is_label and len(dats) > 1:
         warn('Multi-channel label images are not accepted. '
              'Using only the first channel')
         dats = dats[:1]
     return dats
Example #12
0
 def _set_weights(module, conv_keys, f, prefix='unet'):
     # print(prefix)
     if isinstance(module, Conv):
         if conv_keys:
             key = conv_keys.pop(0)
         else:
             # we might have reached the final "feat 2 class" conv
             key = 'vxm_dense_flow'
         try:
             kernel = torch.as_tensor(f[key][key]['kernel:0'],
                                      **utils.backend(module.weight))
         except:
             kernel = torch.as_tensor(f[key][key + '_1']['kernel:0'],
                                      **utils.backend(module.weight))
         kernel = utils.movedim(kernel, [-1, -2], [0, 1])
         module.weight.copy_(kernel)
         try:
             bias = torch.as_tensor(f[key][key]['bias:0'],
                                    **utils.backend(module.bias))
         except:
             bias = torch.as_tensor(f[key][key + '_1']['bias:0'],
                                    **utils.backend(module.bias))
         module.bias.copy_(bias)
     else:
         for name, child in module.named_children():
             _set_weights(child, conv_keys, f, f'{prefix}.{name}')
Example #13
0
def movedim(x, source, target):
    if isinstance(x, np.ndarray):
        return np.moveaxis(x, source, target)
    elif torch.is_tensor(x):
        return utils.movedim(x, source, target)
    else:  # MappedArray?
        return x.movedim(source, target)
Example #14
0
    def load(self, fname, dtype=None, device=None):
        """Load a volume from disk

        Parameters
        ----------
        fname : str
        dtype : torch.dtype, optional

        Returns
        -------
        dat : (channels, *spatial) tensor

        """
        dtype = dtype or self.dtype
        device = device or self.device
        if not dtype or dtype.is_floating_point:
            dat = io.loadf(fname, dtype=dtype, device=device)
            dat = self.rescale(dat)
        else:
            dat = io.load(fname, dtype=dtype, device=device)
        dat = dat.squeeze()
        dim = self.dim or dat.dim()
        dat = utils.unsqueeze(dat, -1, max(0, dim - dat.dim()))
        dat = dat.reshape([*dat.shape[:dim], -1])
        dat = utils.movedim(dat, -1, 0)
        dat = self.to_shape(dat)
        return dat
Example #15
0
    def forward(self, x, v=None):

        dim = x.dim() - 2
        if dim not in (2, 3):
            raise ValueError(f'{type(self).__name__} only implemented '
                             f'in 2D or 3D.')

        radii = self.radii.to(**utils.backend(x))
        pradii = self.pradii.to(**utils.backend(x)).log()

        # compute log-likelihood (vessel | radius, x)
        loss = x.new_zeros([len(radii), *x.shape])
        for i, (p, r) in enumerate(zip(pradii, radii)):
            # compute unsorted eigenvalues
            e = spatial.hessian_eig(x, r, dim=dim, sort=None)
            e = utils.movedim(e, -1, 0)
            if dim == 3:
                loss[i] = self.vesselness3d(e[0], e[1], e[2])

        # compute (stable) log-sum-exp (== model evidence)
        loss = math.logsumexp(loss, dim=0)

        # weight by probability to be a vessel and return
        if v is None:
            v = x
        return -(loss * v).sum() / (v.sum() + 1e-3)
Example #16
0
    def forward(self, image, **overload):
        backend = utils.backend(image)
        sigma = overload.get('sigma', self.sigma)
        gfactor = overload.get('gfactor', self.gfactor)

        # sample sigma
        if sigma is None:
            sigma = self.default_sigma(*image.shape[:2], **backend)
        if callable(sigma):
            sigma = sigma(image.shape[:2])
        sigma = torch.as_tensor(sigma, **backend)
        sigma = unsqueeze(sigma, -1, 2 - sigma.dim())

        # sample gfactor
        if gfactor is True:
            gfactor = field.RandomMultiplicativeField()
        if callable(gfactor):
            gfactor = gfactor(image.shape)

        # sample noise
        zero = torch.tensor(0, **backend)
        noise = td.Normal(zero, sigma).sample(image.shape[2:])
        noise = utils.movedim(noise, [-1, -2], [0, 1])

        if torch.is_tensor(gfactor):
            noise *= gfactor

        image = image + noise
        return image
Example #17
0
def depth_to_rgb(image, colormap=None):
    """Convert soft probabilities to an RGB image.

    Parameters
    ----------
    image : (*batch, D, H, W)
        A (batch of) 3D image, with depth along the 'D' dimension.
    colormap : (D, 3) tensor or str, optional
        A colormap or the name of a matplotlib colormap.

    Returns
    -------
    image : (*batch, H, W, 3)
        A (batch of) RGB image.

    """

    *batch, depth, height, width = image.shape
    colormap = _get_colormap_depth(colormap, depth, image.dtype, image.device)

    image = utils.movedim(image, -3, -1)
    cimage = linalg.dot(image.unsqueeze(-2), colormap.T)
    cimage /= image.sum(-1, keepdim=True)
    cimage *= image.max(-1, keepdim=True).values

    return cimage.clamp_(0, 1)
Example #18
0
    def forward(self, x, v=None):

        dim = x.dim() - 2
        if dim not in (2, 3):
            raise ValueError(f'{type(self).__name__} only implemented '
                             f'in 2D or 3D.')

        radii = self.radii.to(**utils.backend(x))
        pradii = self.pradii.to(**utils.backend(x)).log()

        # compute joint log-likelihood `ln p(x, radius | v)`
        loss = x.new_zeros([len(radii), *x.shape])
        for i, (p, r) in enumerate(zip(pradii, radii)):
            # compute unsorted eigenvalues
            e = spatial.hessian_eig(x, r, dim=dim, sort=None)
            # soft sort
            P = math.softsort(e.abs(), tau=self.tau_sort, descending=True)
            e = linalg.matvec(P, e)
            e = utils.movedim(e, -1, 0)
            # compute penalties
            loss[i] = -self.tau_large * e[1:].sum(0)  # white ridges
            e = e.square().clamp_min_(1e-32).log()
            if dim == 3:
                loss[i] += self.tau_ratio1 * (e[1] - e[2])  # tubes
            loss[i] += self.tau_ratio0 * (e[1] - e[0])  # not plates
            loss[i] += p  # radius prior

        # compute (stable) log-sum-exp (== model evidence `ln p(x | v)`)
        loss = math.logsumexp(loss, dim=0)

        # weight by probability to be a vessel and return `E_v[ln p(x | v)]`
        if v is None:
            v = x
        return -(loss * v).sum() / (v.sum() + 1e-3)
Example #19
0
    def forward(self, q, k, v, **overload):
        """

        Parameters
        ----------
        q : (b, c, *spatial)
            Queries
        k : (b, c, *spatial)
            Keys
        v : (b, c, *spatial)
            Values

        Returns
        -------
        x : (b, c, *spatial)

        """
        kernel_size = overload.pop('kernel_size', self.kernel_size)
        stride = overload.pop('stride', self.kernel_size)
        padding = overload.pop('padding', self.padding)
        padding_mode = overload.pop('padding_mode', self.padding_mode)

        dim = q.dim() - 2
        if padding == 'auto':
            k = spatial.pad_same(dim, k, kernel_size, bound=padding_mode)
            v = spatial.pad_same(dim, v, kernel_size, bound=padding_mode)
        elif padding:
            padding = [0] * 2 + py.make_list(padding, dim)
            k = utils.pad(k, padding, side='both', mode=padding_mode)
            v = utils.pad(v, padding, side='both', mode=padding_mode)

        # compute weights by query/key dot product
        kernel_size = py.make_list(kernel_size, dim)
        k = utils.unfold(k, kernel_size, stride)
        k = k.reshape([*k.shape[:dim + 2], -1])
        k = utils.movedim(k, 1, -1)
        q = utils.movedim(q[..., None], 1, -1)
        k = math.softmax(linalg.dot(k, q), dim=-1)
        k = k[:, None]  # add back channel dimension

        # compute new values by weight/value dot product
        v = utils.unfold(v, kernel_size, stride)
        v = v.reshape([*v.shape[:dim + 2], -1])
        v = linalg.dot(k, v)

        return v
Example #20
0
    def forward(self, x, **overload):
        """

        Parameters
        ----------
        x : (batch, in_channels, *spatial_in)

        Returns
        -------
        y : (batch, out_channels, *spatial_out)

        """

        x = utils.movedim(self.linear(utils.movedim(x, 1, -1)), -1, 1)
        q, k, v = torch.chunk(x, 3, dim=1)
        x = self.dot(q, k, v, **overload)
        return x
Example #21
0
def _warp_image1(image,
                 target,
                 shape=None,
                 affine=None,
                 nonlin=None,
                 backward=False,
                 reslice=False):
    """Returns the warped image, with channel dimension last"""
    # build transform
    aff_right = target
    aff_left = spatial.affine_inv(image.affine)
    aff = None
    if affine:
        # exp = affine.iexp if backward else affine.exp
        exp = affine.exp
        aff = exp(recompute=False, cache_result=True)
        if backward:
            aff = spatial.affine_inv(aff)
    if nonlin:
        if affine:
            if affine.position[0].lower() in ('ms' if backward else 'fs'):
                aff_right = spatial.affine_matmul(aff, aff_right)
            if affine.position[0].lower() in ('fs' if backward else 'ms'):
                aff_left = spatial.affine_matmul(aff_left, aff)
        exp = nonlin.iexp if backward else nonlin.exp
        phi = exp(recompute=False, cache_result=True)
        aff_left = spatial.affine_matmul(aff_left, nonlin.affine)
        aff_right = spatial.affine_lmdiv(nonlin.affine, aff_right)
        if _almost_identity(aff_right) and nonlin.shape == shape:
            phi = nonlin.add_identity(phi)
        else:
            tmp = spatial.affine_grid(aff_right, shape)
            phi = regutils.smart_pull_grid(phi, tmp).add_(tmp)
            del tmp
        if not _almost_identity(aff_left):
            phi = spatial.affine_matvec(aff_left, phi)
    else:
        # no nonlin: single affine even if position == 'symmetric'
        if reslice:
            aff = spatial.affine_matmul(aff, aff_right)
            aff = spatial.affine_matmul(aff_left, aff)
            phi = spatial.affine_grid(aff, shape)
        else:
            phi = None

    # warp image
    if phi is not None:
        warped = image.pull(phi)
    else:
        warped = image.dat

    # write to disk
    if len(warped) == 1:
        warped = warped[0]
    else:
        warped = utils.movedim(warped, 0, -1)
    return warped
Example #22
0
def smart_pull_jac(jac, grid, *args, **kwargs):
    """Interpolate a jacobian field.

    Notes
    -----
    Defaults differ from grid_pull:
    - bound -> dft
    - extrapolate -> True

    Parameters
    ----------
    jac : ([batch], *spatial_in, ndim, ndim) tensor
        Jacobian field
    grid : ([batch], *spatial_out, ndim) tensor
        Transformation field
    kwargs : dict
        Options to ``grid_pull``

    Returns
    -------
    pulled_jac : ([batch], *spatial_out, ndim) tensor
        Jacobian field

    """
    if grid is None or jac is None:
        return jac
    kwargs.setdefault('bound', 'dft')
    kwargs.setdefault('extrapolate', True)
    dim = jac.shape[-1]
    jac = jac.reshape([*jac.shape[:-2], dim * dim])  # collapse matrix
    jac = utils.movedim(jac, -1, -dim - 1)
    jac_no_batch = jac.dim() == dim + 1
    grid_no_batch = grid.dim() == dim + 1
    if jac_no_batch:
        jac = jac[None]
    if grid_no_batch:
        grid = grid[None]
    jac = spatial.grid_pull(jac, grid, *args, **kwargs)
    jac = utils.movedim(jac, -dim - 1, -1)
    jac = jac.reshape([*jac.shape[:-1], dim, dim])
    if jac_no_batch:
        jac = jac[0]
    return jac
Example #23
0
def _pull_jac(jac, grid, **kwargs):
    """Interpolate a Jacobian field.

    Notes
    -----
    Defaults differ from grid_pull:
    - bound -> dft
    - extrapolate -> True

    Parameters
    ----------
    jac : ([batch], *spatial, ndim, ndim) tensor
        Jacobian matrix
    grid : ([batch], *spatial, ndim) tensor
        Transformation field
    kwargs : dict
        Options to ``grid_pull``

    Returns
    -------
    pulled_jac : ([batch], *spatial, ndim, ndim) tensor
        Velocity

    """
    kwargs.setdefault('bound', 'dft')
    kwargs.setdefault('extrapolate', True)
    dim = grid.shape[-1]

    jac = jac.reshape([*jac.shape[:-2], -1])
    jac = utils.movedim(jac, -1, -dim - 1)
    jac_no_batch = jac.dim() == dim + 1
    grid_no_batch = grid.dim() == dim + 1
    if jac_no_batch:
        jac = jac[None]
    if grid_no_batch:
        grid = grid[None]
    jac = grid_pull(jac, grid, **kwargs)
    jac = utils.movedim(jac, -dim - 1, -1)
    jac = jac.reshape([*jac.shape[:-1], dim, dim])
    if jac_no_batch and grid_no_batch:
        jac = jac[0]
    return jac
Example #24
0
def jg(jac, grad, dim=None):
    """Jacobian-gradient product: J*g

    Parameters
    ----------
    jac : (..., K, *spatial, D)
    grad : (..., K, *spatial)

    Returns
    -------
    new_grad : (..., *spatial, D)

    """
    if grad is None:
        return None
    dim = dim or (grad.dim() - 1)
    grad = utils.movedim(grad, -dim - 1, -1)
    jac = utils.movedim(jac, -dim - 2, -1)
    grad = linalg.matvec(jac, grad)
    return grad
Example #25
0
    def do_apply(fnames, phi, jac):
        """Correct files with a given polarity"""
        for fname in fnames:
            dir, base, ext = py.fileparts(fname)
            ofname = options.output
            ofname = ofname.format(dir=dir or '.', sep=os.sep, base=base,
                                   ext=ext)
            if options.verbose:
                print(f'unwarp {fname} \n'
                      f'    -> {ofname}')

            f = io.map(fname)
            d = f.fdata(device=device)
            d = utils.movedim(d, readout, -1)
            d = _deform1d(d, phi)
            if jac is not None:
                d *= jac
            d = utils.movedim(d, -1, readout)

            io.savef(d, ofname, like=fname)
Example #26
0
def smart_push_grid(vel, grid, *args, **kwargs):
    """Push a velocity/grid/displacement field.

    Notes
    -----
    Defaults differ from grid_push:
    - bound -> dft
    - extrapolate -> True

    Parameters
    ----------
    vel : ([batch], *spatial, ndim) tensor
        Velocity
    grid : ([batch], *spatial, ndim) tensor
        Transformation field
    kwargs : dict
        Options to ``grid_pull``

    Returns
    -------
    pulled_vel : ([batch], *spatial, ndim) tensor
        Velocity

    """
    if grid is None or vel is None:
        return vel
    kwargs.setdefault('bound', 'dft')
    kwargs.setdefault('extrapolate', True)
    dim = vel.shape[-1]
    vel = utils.movedim(vel, -1, -dim - 1)
    vel_no_batch = vel.dim() == dim + 1
    grid_no_batch = grid.dim() == dim + 1
    if vel_no_batch:
        vel = vel[None]
    if grid_no_batch:
        grid = grid[None]
    vel = spatial.grid_push(vel, grid, *args, **kwargs)
    vel = utils.movedim(vel, -dim - 1, -1)
    if vel_no_batch and grid_no_batch:
        vel = vel[0]
    return vel
Example #27
0
def pull_grid(gridin, grid, interpolation=1, bound='dft', extrapolate=True):
    """Sample a displacement field.

    Parameters
    ----------
    gridin : (*inshape, dim) tensor
    grid : (*outshape, dim) tensor

    Returns
    -------
    gridout : (*outshape, dim) tensor

    """
    gridin = movedim(gridin, -1, 0)[None]
    grid = grid[None]
    gridout = grid_pull(gridin,
                        grid,
                        interpolation=interpolation,
                        bound=bound,
                        extrapolate=extrapolate)
    gridout = movedim(gridout[0], 0, -1)
    return gridout
Example #28
0
 def rotation(self):
     i = self.i
     j = self.j
     k = self.k
     r = self.r
     matrix = [
         [1 - 2 * (j**2 + k**2), 2 * (i * j - k * r), 2 * (i * k + j * r)],
         [2 * (i * j + k * r), 1 - 2 * (i**2 + k**2), 2 * (j * k - i * r)],
         [2 * (i * k - j * r), 2 * (j * k + i * r), 1 - 2 * (i**2 + j**2)]
     ]
     matrix = utils.as_tensor(matrix)
     matrix = utils.movedim(matrix, [0, 1], [-2, -1])
     return matrix
Example #29
0
    def forward(self, x, **overload):
        """

        Parameters
        ----------
        x : (batch, in_channels, *spatial_in)

        Returns
        -------
        y : (batch, out_channels, *spatial_out)

        """

        out = None
        for i, head in enumerate(self.heads):
            y = head(x, **overload)
            if out is None:
                out_shape = list(y.shape)
                out_shape[1] *= len(self.heads)
                out = y.new_empty(out_shape)
            out[:, i * y.shape[1]:(i + 1) * y.shape[1]] = y
            del y
        out = utils.movedim(self.linear(utils.movedim(out, 1, -1)), -1, 1)
        return out
Example #30
0
def intensity_to_rgb(image,
                     min=None,
                     max=None,
                     colormap='gray',
                     n=256,
                     eq=False):
    """Colormap an intensity image

    Parameters
    ----------
    image : (*batch, H, W) tensor
        A (batch of) 2d image
    min : tensor_like, optional
        Minimum value. Should be broadcastable to batch.
        Default: min of image for each batch element.
    max : tensor_like, optional
        Maximum value. Should be broadcastable to batch.
        Default: max of image for each batch element.
    colormap : str or (K, 3) tensor, default='gray'
        A colormap or the name of a matplotlib colormap.
    n : int, default=256
        Number of color levels to use.
    eq : bool or {'linear', 'quadratic', 'log', None}, default=None
        Apply histogram equalization.
        If 'quadratic' or 'log', the histogram of the transformed signal
        is equalized.

    Returns
    -------
    rgb : (*batch, H, W, 3) tensor
        A (batch of) of RGB image.

    """
    image = torch.as_tensor(image).detach()
    image = intensity_preproc(image, min=min, max=max, eq=eq)

    # map
    colormap = _get_colormap_intensity(colormap, n, image.dtype, image.device)
    shape = image.shape
    image = image.mul_(n - 1).clamp_(0, n - 1)
    image = image.reshape([1, -1, 1])
    colormap = colormap.T.reshape([1, 3, -1])
    image = spatial.grid_pull(colormap, image)
    image = image.reshape([3, *shape])
    image = utils.movedim(image, 0, -1)

    return image