def forward(self, x, **overload): """ Parameters ---------- x : tensor Input tensor with shape (batch, channel, *spatial) overload : dict All parameters defined at build time can be overridden at call time. Returns ------- g : tensor Finite differences with shape (batch, channel, *spatial, len(dim), len(side)) If `dim` or `side` are scalars, not lists, their respective dimension is dropped in the output tensor. E.g., if `side='c'`, the output shape is (batch, channel, *spatial, len(dim)) """ order = overload.get('order', self.order) side = make_list(overload.get('side', self.side)) drop_side_dim = not isinstance(side, (tuple, list)) side = make_list(side) dim = overload.get('dim', self.dim) dim = list(range(2, x.dim())) if dim is None else dim drop_dim_dim = not isinstance(dim, (tuple, list)) dim = make_list(dim) nb_dim = len(dim) voxel_size = overload.get('voxel_size', self.voxel_size) voxel_size = make_list(voxel_size, nb_dim) bound = make_list(overload.get('bound', self.bound), nb_dim) diffs = [] for d, vx, bnd in zip(dim, voxel_size, bound): sides = [] for s in side: grad = diff1d(x, order=order, dim=d, voxel_size=vx, side=s, bound=bnd) sides.append(grad) sides = torch.stack(sides, dim=-1) diffs.append(sides) diffs = torch.stack(diffs, dim=-2) if drop_dim_dim: diffs = slice_tensor(diffs, 0, dim=-2) if drop_side_dim: diffs = slice_tensor(diffs, 0, dim=-1) return diffs
def _softmax_fwd(input, dim=-1, implicit=False): """ SoftMax (safe). Parameters ---------- input : torch.tensor Tensor with values. dim : int, default=-1 Dimension to take softmax, defaults to last dimensions. implicit : bool or (bool, bool), default=False The first value relates to the input tensor and the second relates to the output tensor. - implicit[0] == True assumes that an additional (hidden) channel with value zero exists. - implicit[1] == True drops the last class from the softmaxed tensor. Returns ------- Z : torch.tensor Soft-maxed tensor with values. """ implicit_in, implicit_out = py.make_list(implicit, 2) maxval, _ = torch.max(input, dim=dim, keepdim=True) if implicit_in: maxval.clamp_min_(0) # don't forget the class full of zeros input = input.clone().sub_(maxval).exp_() sumval = torch.sum(input, dim=dim, keepdim=True, out=maxval if not implicit_in else None) if implicit_in: sumval += maxval.neg().exp() # don't forget the class full of zeros input *= sumval.reciprocal_() if implicit_in and not implicit_out: background = input.sum(dim, keepdim=True).neg_().add_(1) input = torch.cat((input, background), dim=dim) elif implicit_out and not implicit_in: input = utils.slice_tensor(input, slice(-1), dim) return input
def _softmax_bwd(output, output_grad, dim=-1, implicit=False): """ SoftMax backward pass Parameters ---------- output : tensor Output of the forward softmax. output_grad : tensor Gradient with respect to the output of the forward pass dim : int, default=-1 Dimension to take softmax, defaults to last dimensions. implicit : bool or (bool, bool), default=False The first value relates to the input tensor and the second relates to the output tensor. - implicit[0] == True assumes that an additional (hidden) channel with value zero exists. - implicit[1] == True drops the last class from the softmaxed tensor. Returns ------- grad : tensor Gradient with respect to the input of the forward pass """ implicit = py.make_list(implicit, 2) add_dim = implicit[1] and not implicit[0] drop_dim = implicit[0] and not implicit[1] grad = output_grad.clone() del output_grad grad *= output gradsum = grad.sum(dim=dim, keepdim=True) grad -= gradsum * output if add_dim: grad_background = output.sum(dim=dim, keepdim=True).neg().add(1) grad_background.mul_(gradsum).neg_() grad = torch.cat((grad, grad_background), dim=dim) elif drop_dim: grad = utils.slice_tensor(grad, slice(-1), dim) return grad
def get_slice(image, dim=-1, index=None): """Extract a 2d slice from a 3d volume Parameters ---------- image : (..., *shape3) tensor A (batch of) 3d volume dim : int, default=-1 Index of the spatial dimension to slice index : int, default=shape//2 Coordinate (in voxel) of the slice to extract Returns ------- slice : (..., *shape2) tensor A (batch of) 2d slice """ image = torch.as_tensor(image) if index is None: index = image.shape[dim] // 2 return utils.slice_tensor(image, index, dim=dim)
def resize_grid(grid, factor=None, shape=None, type='grid', affine=None, *args, **kwargs): """Resize a displacement grid by a factor. The displacement grid is resized *and* rescaled, so that displacements are expressed in the new voxel referential. Notes ----- .. A least one of `factor` and `shape` must be specified. .. If `anchor in ('centers', 'edges')`, and both `factor` and `shape` are specified, `factor` is discarded. .. If `anchor in ('first', 'last')`, `factor` must be provided even if `shape` is specified. .. Because of rounding, it is in general not assured that `resize(resize(x, f), 1/f)` returns a tensor with the same shape as x. Parameters ---------- grid : (batch, ..., ndim) tensor Grid to resize factor : float or list[float], optional Resizing factor * > 1 : larger image <-> smaller voxels * < 1 : smaller image <-> larger voxels shape : (ndim,) sequence[int], optional Output shape type : {'grid', 'displacement'}, default='grid' Grid type: * 'grid' correspond to dense grids of coordinates. * 'displacement' correspond to dense grid of relative displacements. Both types are not rescaled in the same way. affine : (batch, ndim[+1], ndim+1), optional Orientation matrix of the input grid. If provided, the orientation matrix of the resized image is returned as well. anchor : {'centers', 'edges', 'first', 'last'}, default='centers' * In cases 'c' and 'e', the volume shape is multiplied by the zoom factor (and eventually truncated), and two anchor points are used to determine the voxel size. * In cases 'f' and 'l', a single anchor point is used so that the voxel size is exactly divided by the zoom factor. This case with an integer factor corresponds to subslicing the volume (e.g., `vol[::f, ::f, ::f]`). * A list of anchors (one per dimension) can also be provided. **kwargs Parameters of `grid_pull`. Returns ------- resized : (batch, ..., ndim) tensor Resized grid. affine : (batch, ndim[+1], ndim+1) tensor, optional Orientation matrix """ # resize grid kwargs['_return_trf'] = True grid = utils.last2channel(grid) outputs = resize(grid, factor, shape, affine, *args, **kwargs) if affine is not None: grid, affine, (scales, shifts) = outputs else: grid, (scales, shifts) = outputs grid = utils.channel2last(grid) # rescale each component # scales and shifts map resized coordinates to original coordinates: # original = scale * resized + shift # here we want to transform original coordinates into resized ones: # resized = (original - shift) / scale grids = [] for d, (scl, shft) in enumerate(zip(scales, shifts)): grid1 = utils.slice_tensor(grid, d, dim=-1) if type[0].lower() == 'g': grid1 = grid1 - shft grid1 = grid1 / scl grids.append(grid1) grid = torch.stack(grids, -1) # return if affine is not None: return grid, affine else: return grid
def softmax_lse(input, dim=-1, lse=False, weights=None, implicit=False): """ SoftMax (safe). Parameters ---------- input : torch.tensor Tensor with values. dim : int, default=-1 Dimension to take softmax, defaults to last dimensions. lse : bool, default=False Compute log-sum-exp as well. weights : torch.tensor, optional: Observation weights (only used in the log-sum-exp). implicit : bool or (bool, bool), default=False The first value relates to the input tensor and the second relates to the output tensor. - implicit[0] == True assumes that an additional (hidden) channel with value zero exists. - implicit[1] == True drops the last class from the softmaxed tensor. Returns ------- Z : torch.tensor Soft-maxed tensor with values. """ def sumto(x, *a, out=None, **k): if out is None or x.requires_grad: return torch.sum(x, *a, **k) else: return torch.sum(x, *a, **k, out=out) implicit_in, implicit_out = py.make_list(implicit, 2) maxval, _ = torch.max(input, dim=dim, keepdim=True) if implicit_in: maxval.clamp_min_(0) # don't forget the class full of zeros input = (input - maxval).exp() sumval = sumto(input, dim=dim, keepdim=True, out=maxval if not lse else None) if implicit_in: sumval += maxval.neg().exp() # don't forget the class full of zeros input = input / sumval if lse: # Compute log-sum-exp # maxval = max(logit) # lse = maxval + log[sum(exp(logit - maxval))] # If implicit # maxval = max(max(logit),0) # lse = maxval + log[sum(exp(logit - maxval)) + exp(-maxval)] sumval = sumval.log() maxval += sumval if weights is not None: maxval = maxval * weights maxval = maxval.sum(dtype=torch.float64) else: maxval = None if implicit_in and not implicit_out: background = input.sum(dim, keepdim=True).neg().add(1) input = torch.cat((input, background), dim=dim) elif implicit_out and not implicit_in: input = utils.slice_tensor(input, slice(-1), dim) if lse: return input, maxval else: return input
def dice_nolog(moving, fixed, dim=None, grad=True, hess=True, mask=None, add_background=False, weighted=False): """Dice loss for optimisation-based registration. Parameters ---------- moving : (..., K, *spatial) tensor Moving image of probabilities (post-softmax). The background class should be omitted. fixed : (..., K, *spatial) tensor Fixed image of probabilities. dim : int, default=`fixed.dim() - 1` Number of spatial dimensions. grad : bool, default=True Compute and return gradient hess : bool, default=True Compute and return Hessian mask : (..., *spatial) tensor, optional Mask of voxels to include in the loss (all by default) add_background : bool, default=False Include the Dice of the (implicit) background class in the loss. weighted : bool or tensor, default=False Weights for each class. If True, weight by positive rate. Returns ------- ll : () tensor Negative log-likelihood g : (..., K, *spatial) tensor, optional Gradient with respect to the moving image. h : (..., K, *spatial) tensor, optional Hessian with respect to the moving image. """ fixed, moving = utils.to_max_backend(fixed, moving) dim = dim or (fixed.dim() - 1) nc = moving.shape[-dim-1] # nb classes - bck fixed = utils.slice_tensor(fixed, slice(nc), -dim-1) # remove bkg class if mask is not None: mask = mask.to(moving.device) nvox = mask.sum(range(-dim-1), keepdim=True) else: nvox = py.prod(fixed.shape[-dim:]) @torch.jit.script def rescale(x, dim_channel: int, add_background: bool = False): """Ensure that a tensor is in [0, 1]""" x = x.clamp_min(0) x = x / x.sum(dim_channel, keepdim=True).clamp_min_(1) if add_background: x = torch.stack([x, 1 - x.sum(dim_channel, keepdim=True)], dim_channel) return x moving = rescale(moving, -dim-1, add_background) fixed = rescale(fixed, -dim-1, add_background) if mask is not None: moving *= mask fixed *= mask if weighted is True: weighted = fixed.sum(list(range(-dim, 0)), keepdim=True).div_(nvox) elif weighted is not False: weighted = torch.as_tensor(weighted, **utils.backend(moving)) for _ in range(dim): weighted = weighted.unsqueeze(-1) else: weighted = None @torch.jit.script def loss_components(moving, fixed, dim: int, weighted: Optional[Tensor] = None): """Compute the (negative) DiceLoss, (positive) Dice and union""" dims = [d for d in range(-dim, 0)] overlap = (moving * fixed).sum(dims, keepdim=True) union = (moving + fixed).sum(dims, keepdim=True) union += 1e-5 dice = 2 * overlap / union if weighted is not None: ll = 1 - weighted * dice else: ll = 1 - dice ll = ll.sum() return ll, dice, union ll, dice, union = loss_components(moving, fixed, dim, weighted) out = [ll] # gradient if grad: @torch.jit.script def do_grad(dice, fixed, union): return (dice - 2 * fixed) / union g = do_grad(dice, fixed, union) if weighted is not None: g *= weighted if add_background: g_last = utils.slice_tensor(g, slice(-1, None), -dim-1) g = utils.slice_tensor(g, slice(-1), -dim-1) g -= g_last if mask is not None: g *= mask out.append(g) # hessian if hess: @torch.jit.script def do_hess(dice, fixed, union, nvox, dim: int): dims = [d for d in range(-dim, 0)] positive_rate = fixed.sum(dims, keepdim=True) / nvox h = (dice - fixed - positive_rate).abs() h = 2 * nvox * h / union.square() return h nvox = torch.as_tensor(nvox, device=moving.device) h = do_hess(dice, fixed, union, nvox, dim) if weighted is not None: h *= weighted if add_background: h_foreground = utils.slice_tensor(h, slice(-1), -dim-1) h = utils.slice_tensor(h, slice(-1, None), -dim-1) # h background hshape = list(h.shape) hshape[-dim-1] = nc*(nc+1)//2 h = h.expand(hshape).clone() diag = utils.slice_tensor(h, range(nc), -dim-1) diag += h_foreground if mask is not None: h *= mask out.append(h) return tuple(out) if len(out) > 1 else out[0]