Esempio n. 1
0
    def forward(self, x, **overload):
        """

        Parameters
        ----------
        x : tensor
            Input tensor with shape (batch, channel, *spatial)
        overload : dict
            All parameters defined at build time can be overridden
            at call time.

        Returns
        -------
        g : tensor
            Finite differences with shape
            (batch, channel, *spatial, len(dim), len(side))

            If `dim` or `side` are scalars, not lists, their respective
            dimension is dropped in the output tensor.
            E.g., if `side='c'`, the output shape is
            (batch, channel, *spatial, len(dim))

        """
        order = overload.get('order', self.order)
        side = make_list(overload.get('side', self.side))
        drop_side_dim = not isinstance(side, (tuple, list))
        side = make_list(side)
        dim = overload.get('dim', self.dim)
        dim = list(range(2, x.dim())) if dim is None else dim
        drop_dim_dim = not isinstance(dim, (tuple, list))
        dim = make_list(dim)
        nb_dim = len(dim)
        voxel_size = overload.get('voxel_size', self.voxel_size)
        voxel_size = make_list(voxel_size, nb_dim)
        bound = make_list(overload.get('bound', self.bound), nb_dim)

        diffs = []
        for d, vx, bnd in zip(dim, voxel_size, bound):
            sides = []
            for s in side:
                grad = diff1d(x,
                              order=order,
                              dim=d,
                              voxel_size=vx,
                              side=s,
                              bound=bnd)
                sides.append(grad)
            sides = torch.stack(sides, dim=-1)
            diffs.append(sides)
        diffs = torch.stack(diffs, dim=-2)

        if drop_dim_dim:
            diffs = slice_tensor(diffs, 0, dim=-2)
        if drop_side_dim:
            diffs = slice_tensor(diffs, 0, dim=-1)
        return diffs
Esempio n. 2
0
def _softmax_fwd(input, dim=-1, implicit=False):
    """ SoftMax (safe).

    Parameters
    ----------
    input : torch.tensor
        Tensor with values.
    dim : int, default=-1
        Dimension to take softmax, defaults to last dimensions.
    implicit : bool or (bool, bool), default=False
        The first value relates to the input tensor and the second
        relates to the output tensor.
        - implicit[0] == True assumes that an additional (hidden) channel
          with value zero exists.
        - implicit[1] == True drops the last class from the
          softmaxed tensor.

    Returns
    -------
    Z : torch.tensor
        Soft-maxed tensor with values.

    """

    implicit_in, implicit_out = py.make_list(implicit, 2)

    maxval, _ = torch.max(input, dim=dim, keepdim=True)
    if implicit_in:
        maxval.clamp_min_(0)  # don't forget the class full of zeros

    input = input.clone().sub_(maxval).exp_()
    sumval = torch.sum(input,
                       dim=dim,
                       keepdim=True,
                       out=maxval if not implicit_in else None)
    if implicit_in:
        sumval += maxval.neg().exp()  # don't forget the class full of zeros
    input *= sumval.reciprocal_()

    if implicit_in and not implicit_out:
        background = input.sum(dim, keepdim=True).neg_().add_(1)
        input = torch.cat((input, background), dim=dim)
    elif implicit_out and not implicit_in:
        input = utils.slice_tensor(input, slice(-1), dim)

    return input
Esempio n. 3
0
def _softmax_bwd(output, output_grad, dim=-1, implicit=False):
    """ SoftMax backward pass

    Parameters
    ----------
    output : tensor
        Output of the forward softmax.
    output_grad : tensor
        Gradient with respect to the output of the forward pass
    dim : int, default=-1
        Dimension to take softmax, defaults to last dimensions.
    implicit : bool or (bool, bool), default=False
        The first value relates to the input tensor and the second
        relates to the output tensor.
        - implicit[0] == True assumes that an additional (hidden) channel
          with value zero exists.
        - implicit[1] == True drops the last class from the
          softmaxed tensor.

    Returns
    -------
    grad : tensor
        Gradient with respect to the input of the forward pass

    """
    implicit = py.make_list(implicit, 2)
    add_dim = implicit[1] and not implicit[0]
    drop_dim = implicit[0] and not implicit[1]

    grad = output_grad.clone()
    del output_grad
    grad *= output
    gradsum = grad.sum(dim=dim, keepdim=True)
    grad -= gradsum * output
    if add_dim:
        grad_background = output.sum(dim=dim, keepdim=True).neg().add(1)
        grad_background.mul_(gradsum).neg_()
        grad = torch.cat((grad, grad_background), dim=dim)
    elif drop_dim:
        grad = utils.slice_tensor(grad, slice(-1), dim)

    return grad
Esempio n. 4
0
def get_slice(image, dim=-1, index=None):
    """Extract a 2d slice from a 3d volume

    Parameters
    ----------
    image : (..., *shape3) tensor
        A (batch of) 3d volume
    dim : int, default=-1
        Index of the spatial dimension to slice
    index : int, default=shape//2
        Coordinate (in voxel) of the slice to extract

    Returns
    -------
    slice : (..., *shape2) tensor
        A (batch of) 2d slice

    """
    image = torch.as_tensor(image)
    if index is None:
        index = image.shape[dim] // 2
    return utils.slice_tensor(image, index, dim=dim)
Esempio n. 5
0
def resize_grid(grid,
                factor=None,
                shape=None,
                type='grid',
                affine=None,
                *args,
                **kwargs):
    """Resize a displacement grid by a factor.

    The displacement grid is resized *and* rescaled, so that
    displacements are expressed in the new voxel referential.

    Notes
    -----
    .. A least one of `factor` and `shape` must be specified.
    .. If `anchor in ('centers', 'edges')`, and both `factor` and `shape`
       are specified, `factor` is discarded.
    .. If `anchor in ('first', 'last')`, `factor` must be provided even
       if `shape` is specified.
    .. Because of rounding, it is in general not assured that
       `resize(resize(x, f), 1/f)` returns a tensor with the same shape as x.

    Parameters
    ----------
    grid : (batch, ..., ndim) tensor
        Grid to resize
    factor : float or list[float], optional
        Resizing factor
        * > 1 : larger image <-> smaller voxels
        * < 1 : smaller image <-> larger voxels
    shape : (ndim,) sequence[int], optional
        Output shape
    type : {'grid', 'displacement'}, default='grid'
        Grid type:
        * 'grid' correspond to dense grids of coordinates.
        * 'displacement' correspond to dense grid of relative displacements.
        Both types are not rescaled in the same way.
    affine : (batch, ndim[+1], ndim+1), optional
        Orientation matrix of the input grid.
        If provided, the orientation matrix of the resized image is
        returned as well.
    anchor : {'centers', 'edges', 'first', 'last'}, default='centers'
        * In cases 'c' and 'e', the volume shape is multiplied by the
          zoom factor (and eventually truncated), and two anchor points
          are used to determine the voxel size.
        * In cases 'f' and 'l', a single anchor point is used so that
          the voxel size is exactly divided by the zoom factor.
          This case with an integer factor corresponds to subslicing
          the volume (e.g., `vol[::f, ::f, ::f]`).
        * A list of anchors (one per dimension) can also be provided.
    **kwargs
        Parameters of `grid_pull`.

    Returns
    -------
    resized : (batch, ..., ndim) tensor
        Resized grid.
    affine : (batch, ndim[+1], ndim+1) tensor, optional
        Orientation matrix

    """
    # resize grid
    kwargs['_return_trf'] = True
    grid = utils.last2channel(grid)
    outputs = resize(grid, factor, shape, affine, *args, **kwargs)
    if affine is not None:
        grid, affine, (scales, shifts) = outputs
    else:
        grid, (scales, shifts) = outputs
    grid = utils.channel2last(grid)

    # rescale each component
    # scales and shifts map resized coordinates to original coordinates:
    #   original = scale * resized + shift
    # here we want to transform original coordinates into resized ones:
    #   resized = (original - shift) / scale
    grids = []
    for d, (scl, shft) in enumerate(zip(scales, shifts)):
        grid1 = utils.slice_tensor(grid, d, dim=-1)
        if type[0].lower() == 'g':
            grid1 = grid1 - shft
        grid1 = grid1 / scl
        grids.append(grid1)
    grid = torch.stack(grids, -1)

    # return
    if affine is not None:
        return grid, affine
    else:
        return grid
Esempio n. 6
0
def softmax_lse(input, dim=-1, lse=False, weights=None, implicit=False):
    """ SoftMax (safe).

    Parameters
    ----------
    input : torch.tensor
        Tensor with values.
    dim : int, default=-1
        Dimension to take softmax, defaults to last dimensions.
    lse : bool, default=False
        Compute log-sum-exp as well.
    weights : torch.tensor, optional:
        Observation weights (only used in the log-sum-exp).
    implicit : bool or (bool, bool), default=False
        The first value relates to the input tensor and the second
        relates to the output tensor.
        - implicit[0] == True assumes that an additional (hidden) channel
          with value zero exists.
        - implicit[1] == True drops the last class from the
          softmaxed tensor.

    Returns
    -------
    Z : torch.tensor
        Soft-maxed tensor with values.

    """
    def sumto(x, *a, out=None, **k):
        if out is None or x.requires_grad:
            return torch.sum(x, *a, **k)
        else:
            return torch.sum(x, *a, **k, out=out)

    implicit_in, implicit_out = py.make_list(implicit, 2)

    maxval, _ = torch.max(input, dim=dim, keepdim=True)
    if implicit_in:
        maxval.clamp_min_(0)  # don't forget the class full of zeros

    input = (input - maxval).exp()
    sumval = sumto(input,
                   dim=dim,
                   keepdim=True,
                   out=maxval if not lse else None)
    if implicit_in:
        sumval += maxval.neg().exp()  # don't forget the class full of zeros
    input = input / sumval

    if lse:
        # Compute log-sum-exp
        #   maxval = max(logit)
        #   lse = maxval + log[sum(exp(logit - maxval))]
        # If implicit
        #   maxval = max(max(logit),0)
        #   lse = maxval + log[sum(exp(logit - maxval)) + exp(-maxval)]
        sumval = sumval.log()
        maxval += sumval
        if weights is not None:
            maxval = maxval * weights
        maxval = maxval.sum(dtype=torch.float64)
    else:
        maxval = None

    if implicit_in and not implicit_out:
        background = input.sum(dim, keepdim=True).neg().add(1)
        input = torch.cat((input, background), dim=dim)
    elif implicit_out and not implicit_in:
        input = utils.slice_tensor(input, slice(-1), dim)

    if lse:
        return input, maxval
    else:
        return input
Esempio n. 7
0
def dice_nolog(moving, fixed, dim=None, grad=True, hess=True, mask=None,
               add_background=False, weighted=False):
    """Dice loss for optimisation-based registration.

    Parameters
    ----------
    moving : (..., K, *spatial) tensor
        Moving image of probabilities (post-softmax).
        The background class should be omitted.
    fixed : (..., K, *spatial) tensor
        Fixed image of probabilities.
    dim : int, default=`fixed.dim() - 1`
        Number of spatial dimensions.
    grad : bool, default=True
        Compute and return gradient
    hess : bool, default=True
        Compute and return Hessian
    mask : (..., *spatial) tensor, optional
        Mask of voxels to include in the loss (all by default)
    add_background : bool, default=False
        Include the Dice of the (implicit) background class in the loss.
    weighted : bool or tensor, default=False
        Weights for each class. If True, weight by positive rate.

    Returns
    -------
    ll : () tensor
        Negative log-likelihood
    g : (..., K, *spatial) tensor, optional
        Gradient with respect to the moving image.
    h : (..., K, *spatial) tensor, optional
        Hessian with respect to the moving image.

    """
    fixed, moving = utils.to_max_backend(fixed, moving)
    dim = dim or (fixed.dim() - 1)
    nc = moving.shape[-dim-1]                               # nb classes - bck
    fixed = utils.slice_tensor(fixed, slice(nc), -dim-1)    # remove bkg class
    if mask is not None:
        mask = mask.to(moving.device)
        nvox = mask.sum(range(-dim-1), keepdim=True)
    else:
        nvox = py.prod(fixed.shape[-dim:])

    @torch.jit.script
    def rescale(x, dim_channel: int, add_background: bool = False):
        """Ensure that a tensor is in [0, 1]"""
        x = x.clamp_min(0)
        x = x / x.sum(dim_channel, keepdim=True).clamp_min_(1)
        if add_background:
            x = torch.stack([x, 1 - x.sum(dim_channel, keepdim=True)], dim_channel)
        return x

    moving = rescale(moving, -dim-1, add_background)
    fixed = rescale(fixed, -dim-1, add_background)
    if mask is not None:
        moving *= mask
        fixed *= mask

    if weighted is True:
            weighted = fixed.sum(list(range(-dim, 0)), keepdim=True).div_(nvox)
    elif weighted is not False:
        weighted = torch.as_tensor(weighted, **utils.backend(moving))
        for _ in range(dim):
            weighted = weighted.unsqueeze(-1)
    else:
        weighted = None

    @torch.jit.script
    def loss_components(moving, fixed, dim: int, weighted: Optional[Tensor] = None):
        """Compute the (negative) DiceLoss, (positive) Dice and union"""
        dims = [d for d in range(-dim, 0)]
        overlap = (moving * fixed).sum(dims, keepdim=True)
        union = (moving + fixed).sum(dims, keepdim=True)
        union += 1e-5
        dice = 2 * overlap / union
        if weighted is not None:
            ll = 1 - weighted * dice
        else:
            ll = 1 - dice
        ll = ll.sum()
        return ll, dice, union

    ll, dice, union = loss_components(moving, fixed, dim, weighted)
    out = [ll]

    # gradient
    if grad:
        @torch.jit.script
        def do_grad(dice, fixed, union):
            return (dice - 2 * fixed) / union
        g = do_grad(dice, fixed, union)
        if weighted is not None:
            g *= weighted
        if add_background:
            g_last = utils.slice_tensor(g, slice(-1, None), -dim-1)
            g = utils.slice_tensor(g, slice(-1), -dim-1)
            g -= g_last
        if mask is not None:
            g *= mask
        out.append(g)

    # hessian
    if hess:
        @torch.jit.script
        def do_hess(dice, fixed, union, nvox, dim: int):
            dims = [d for d in range(-dim, 0)]
            positive_rate = fixed.sum(dims, keepdim=True) / nvox
            h = (dice - fixed - positive_rate).abs()
            h = 2 * nvox * h / union.square()
            return h
        nvox = torch.as_tensor(nvox, device=moving.device)
        h = do_hess(dice, fixed, union, nvox, dim)
        if weighted is not None:
            h *= weighted
        if add_background:
            h_foreground = utils.slice_tensor(h, slice(-1), -dim-1)
            h = utils.slice_tensor(h, slice(-1, None), -dim-1)  # h background
            hshape = list(h.shape)
            hshape[-dim-1] = nc*(nc+1)//2
            h = h.expand(hshape).clone()
            diag = utils.slice_tensor(h, range(nc), -dim-1)
            diag += h_foreground
        if mask is not None:
            h *= mask
        out.append(h)

    return tuple(out) if len(out) > 1 else out[0]