def responsibilities(image, means, precisions, proportions): # aliases x = image m = means A = precisions p = proportions nb_dim = image.dim() - 2 del image, means, precisions, proportions # voxel-wise term x = channel2last(x).unsqueeze(-2) # [B, ..., 1, C] p = unsqueeze(p, dim=1, ndim=nb_dim) # [B, ones, K] m = unsqueeze(m, dim=1, ndim=nb_dim) # [B, ones, K, C] A = unsqueeze(A, dim=1, ndim=nb_dim) # [B, ones, K, C, C] x = x - m z = matvec(A, x) z = (z * x).sum(dim=-1) # [B, ..., K] z = -0.5 * z # constant term twopi = torch.as_tensor(2 * pi, dtype=A.dtype, device=A.device) nrm = torch.logdet(A) - A.shape[-1] * twopi.log() nrm = 0.5 * nrm + p.log() z = z + nrm # softmax z = last2channel(z) logz = torch.nn.functional.log_softmax(z, dim=1) z = torch.nn.functional.softmax(z, dim=1) return z, logz
def affine_grid(mat, shape): """Create a dense transformation grid from an affine matrix. Parameters ---------- mat : (..., D[+1], D[+1]) tensor Affine matrix (or matrices). shape : (D,) sequence[int] Shape of the grid, with length D. Returns ------- grid : (..., *shape, D) tensor Dense transformation grid """ mat = torch.as_tensor(mat) shape = list(shape) nb_dim = mat.shape[-1] - 1 if nb_dim != len(shape): raise ValueError('Dimension of the affine matrix ({}) and shape ({}) ' 'are not the same.'.format(nb_dim, len(shape))) if mat.shape[-2] not in (nb_dim, nb_dim + 1): raise ValueError( 'First argument should be matrces of shape ' '(..., {0}, {1}) or (..., {1], {1}) but got {2}.'.format( nb_dim, nb_dim + 1, mat.shape)) batch_shape = mat.shape[:-2] grid = identity_grid(shape, mat.dtype, mat.device) grid = utils.unsqueeze(grid, dim=0, ndim=len(batch_shape)) mat = utils.unsqueeze(mat, dim=-3, ndim=nb_dim) lin = mat[..., :nb_dim, :nb_dim] off = mat[..., :nb_dim, -1] grid = linalg.matvec(lin, grid) + off return grid
def forward(self, x): backend = utils.backend(x) # compute intensity bounds vmin = self.vmin if vmin is None: vmin = x.reshape([*x.shape[:2], -1]).min(dim=-1).values vmax = self.vmax if vmax is None: vmax = x.reshape([*x.shape[:2], -1]).max(dim=-1).values vmin = torch.as_tensor(vmin, **backend).expand(x.shape[:2]) vmin = unsqueeze(vmin, -1, x.dim() - vmin.dim()) vmax = torch.as_tensor(vmax, **backend).expand(x.shape[:2]) vmax = unsqueeze(vmax, -1, x.dim() - vmax.dim()) # sample factor factor_exp = utils.make_vector(self.factor_exp, x.shape[1], **backend) factor_scale = utils.make_vector(self.factor_scale, x.shape[1], **backend) factor = self.factor(factor_exp, factor_scale) factor = factor.sample([len(x)]) factor = unsqueeze(factor, -1, x.dim() - 2) # apply correction x = (x - vmin) / (vmax - vmin) x = x.pow(factor) x = x * (vmax - vmin) + vmin return x
def get_bins(x, min, max, nbins): """Compute the histogram bins.""" # TODO: It's suboptimal to have bin centers fall at the # min and max. Better to shift them slightly inside. if mask is not None: # we set masked values to nan so that we can exclude them when # computing min/max val_nan = torch.as_tensor(nan, **backend) x = torch.where(mask, val_nan, x) min_fn = nanmin max_fn = nanmax else: min_fn = lambda *a, **k: torch.min(*a, **k).values max_fn = lambda *a, **k: torch.max(*a, **k).values min = min_fn(x, dim=-1) if min is None else min min = torch.as_tensor(min, **backend) min = unsqueeze(min, dim=2, ndim=4 - min.dim()) # -> shape = [B, C, 1, 1] max = max_fn(x, dim=-1) if max is None else max max = torch.as_tensor(max, **backend) max = unsqueeze(max, dim=2, ndim=4 - max.dim()) # -> shape = [B, C, 1, 1] bins = torch.linspace(0, 1, nbins, **backend) bins = unsqueeze(bins, dim=0, ndim=3) # -> [1, 1, 1, nb_bins] bins = min + bins * (max - min) # -> [B, C, 1, nb_bins] binwidth = (max - min) / (nbins - 1) # -> [B, C, 1, 1] return bins, binwidth
def get_log_confusion(confusion, nb_classes_pred, nb_classes_ref, dim, **backend): """Return a well formed (log) confusion matrix""" if confusion is None: confusion = torch.eye(nb_classes_pred, nb_classes_ref, **backend).exp() confusion = utils.unsqueeze(confusion, -1, dim) # spatial shape if confusion.dim() < dim + 3: confusion = utils.unsqueeze(confusion, 0, 1) # batch shape confusion = confusion / confusion.sum(dim=[-1, -2], keepdim=True) confusion = confusion.clamp(min=1e-7, max=1 - 1e-7).logit() return confusion
def _build_kernel(dim, **backend): kernel = torch.as_tensor([0.75, 1., 0.75], **backend) normk = kernel for d in range(1, dim): normk = normk.unsqueeze(-1) normk = normk * kernel normk = normk.sum() normk = normk ** (1/dim) kernel /= normk kernels = [] for d in range(dim): kernel1 = kernel kernel1 = utils.unsqueeze(kernel1, 0, d) kernel1 = utils.unsqueeze(kernel1, -1, dim-1-d) kernels.append(kernel1) return kernels
def forward(self, image, **overload): backend = utils.backend(image) sigma = overload.get('sigma', self.sigma) gfactor = overload.get('gfactor', self.gfactor) # sample sigma if sigma is None: sigma = self.default_sigma(*image.shape[:2], **backend) if callable(sigma): sigma = sigma(image.shape[:2]) sigma = torch.as_tensor(sigma, **backend) sigma = unsqueeze(sigma, -1, 2 - sigma.dim()) # sample gfactor if gfactor is True: gfactor = field.RandomMultiplicativeField() if callable(gfactor): gfactor = gfactor(image.shape) # sample noise zero = torch.tensor(0, **backend) noise = td.Normal(zero, sigma).sample(image.shape[2:]) noise = utils.movedim(noise, [-1, -2], [0, 1]) if torch.is_tensor(gfactor): noise *= gfactor image = image + noise return image
def load(self, fname, dtype=None, device=None): """Load a volume from disk Parameters ---------- fname : str dtype : torch.dtype, optional Returns ------- dat : (channels, *spatial) tensor """ dtype = dtype or self.dtype device = device or self.device if not dtype or dtype.is_floating_point: dat = io.loadf(fname, dtype=dtype, device=device) dat = self.rescale(dat) else: dat = io.load(fname, dtype=dtype, device=device) dat = dat.squeeze() dim = self.dim or dat.dim() dat = utils.unsqueeze(dat, -1, max(0, dim - dat.dim())) dat = dat.reshape([*dat.shape[:dim], -1]) dat = utils.movedim(dat, -1, 0) dat = self.to_shape(dat) return dat
def transform_pointset_dense(points, grid, type='grid', bound='dct2'): """Transform a pointset Points must already be expressed in "grid voxels" coordinates. Parameters ---------- points : (n, dim) tensor Set of coordinates, in voxel space grid : (*spatial, dim) tensor Dense transformation or displacement grid, in voxel space type : {'grid', 'disp'}, defualt='grid' Transformation or displacement bound : str, default='dct2' Boundary conditions for out-of-bounds data Returns ------- points : (n, dim) tensor Transformed coordinates """ dim = grid.shape[-1] points = utils.unsqueeze(points, 0, dim) grid = utils.movedim(grid, -1, 0)[None] delta = spatial.grid_pull(grid, points, bound=bound, extrapolate=True) delta = utils.movedim(delta, 1, -1) if type == 'disp': points = points + delta else: points = delta points = utils.squeeze(points, -2, dim - 1).squeeze(0) return points
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch size Other Parameters ---------------- shape : sequence[int], optional channel : int, optional device : torch.device, optional dtype : torch.dtype, optional Returns ------- field : (batch, channel, *shape) tensor Generated random field """ # get arguments shape = overload.get('shape', self.shape) channel = overload.get('channel', self.channel) dtype = overload.get('dtype', self.dtype) device = overload.get('device', self.device) backend = dict(dtype=dtype, device=device) # device/dtype nb_dim = len(shape) mean = utils.make_vector(self.mean, channel, **backend) amplitude = utils.make_vector(self.amplitude, channel, **backend) fwhm = utils.make_vector(self.fwhm, nb_dim, **backend) # sample spline coefficients nodes = [(s/f).ceil().int().item() for s, f in zip(shape, fwhm)] sample = torch.randn([batch, channel, *nodes], **backend) sample *= utils.unsqueeze(amplitude, -1, nb_dim) sample = spatial.resize(sample, shape=shape, interpolation=self.basis, bound='dct2', prefilter=False) sample += utils.unsqueeze(mean, -1, nb_dim) return sample
def nll(image, resp, means, precisions): # aliases x = image z = resp m = means A = precisions nb_dim = image.dim() - 2 del image, resp, means, precisions x = channel2last(x).unsqueeze(-2) # [B, ..., 1, C] z = channel2last(z) # [B, ..., K] m = unsqueeze(m, dim=1, ndim=nb_dim) # [B, ones, K, C] A = unsqueeze(A, dim=1, ndim=nb_dim) # [B, ones, K, C, C] x = x - m loss = matvec(A, x) loss = (loss * x).sum(dim=-1) # [B, ..., K] loss = (loss * z).sum(dim=-1) # [B, ...] loss = loss * 0.5 return loss
def forward(self, image, **overload): factor = overload.get('factor', self.factor) if factor is None: factor = self.default_factor(len(image), **utils.backend(image)) if callable(factor): factor = factor(image.shape[0]) factor = torch.as_tensor(factor, **utils.backend(image)) factor = unsqueeze(factor, -1, image.dim() - factor.dim()) image = self.op(image, factor) return image
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch shape. Other Parameters ---------------- shape : sequence[int], optional device : torch.device, optional dtype : torch.dtype, optional Returns ------- grid : (batch, *shape, 3) tensor Resampling grid """ shape = overload.get('shape', self.grid.velocity.field.shape) dtype = overload.get('dtype', self.grid.velocity.field.dtype) device = overload.get('device', self.grid.velocity.field.device) backend = dict(dtype=dtype, device=device) if self.grid.velocity.field.amplitude == 0: grid = identity_grid(shape, **backend) else: grid = self.grid(batch, shape=shape, **backend) dtype = grid.dtype device = grid.device backend = dict(dtype=dtype, device=device) shape = grid.shape[1:-1] dim = len(shape) aff = self.affine(batch, dim=dim, **backend) # shift center of rotation aff_shift = torch.cat(( torch.eye(dim, **backend), torch.as_tensor(shape, **backend)[:, None].sub_(1).div_(-2)), dim=1) aff_shift = as_euclidean(aff_shift) aff = affine_matmul(aff, aff_shift) aff = affine_lmdiv(aff_shift, aff) # compose aff = utils.unsqueeze(aff, dim=-3, ndim=dim) lin = aff[..., :dim, :dim] off = aff[..., :dim, -1] grid = linalg.matvec(lin, grid) + off return grid
def exp(self, velocity, affine=None, displacement=False): """Generate a deformation grid from tangent parameters. Parameters ---------- velocity : (batch, *spatial, nb_dim) Stationary velocity field affine : (batch, nb_prm) Affine parameters displacement : bool, default=False Return a displacement field (voxel to shift) rather than a transformation field (voxel to voxel). Returns ------- grid : (batch, *spatial, nb_dim) Deformation grid (transformation or displacment). """ info = {'dtype': velocity.dtype, 'device': velocity.device} # generate grid shape = velocity.shape[1:-1] velocity_small = self.resize(velocity, type='displacement') grid = self.velexp(velocity_small) grid = self.resize(grid, shape=shape, type='grid') if affine is not None: # exponentiate affine_prm = affine affine = [] for prm in affine_prm: affine.append(self.affexp(prm)) affine = torch.stack(affine, dim=0) # shift center of rotation affine_shift = torch.cat( (torch.eye(self.dim, **info), -torch.as_tensor(shape, **info)[:, None] / 2), dim=1) affine = spatial.affine_matmul(affine, affine_shift) affine = spatial.affine_lmdiv(affine_shift, affine) # compose affine = unsqueeze(affine, dim=-3, ndim=self.dim) lin = affine[..., :self.dim, :self.dim] off = affine[..., :self.dim, -1] grid = matvec(lin, grid) + off if displacement: grid = grid - spatial.identity_grid(grid.shape[1:-1], **info) return grid
def kl(resp, log_resp, proportions): # aliases z = resp logz = log_resp p = proportions nb_dim = resp.dim() - 2 del resp, log_resp, proportions p = unsqueeze(p, dim=-1, ndim=nb_dim) # [B, K, ones] loss = z * (logz - p.log()) # [B, K, ...] loss = loss.sum(dim=1) # [B, ...] return loss
def _process_weights(weighted, dim, nb_classes, **backend): weighted_channelwise = False if weighted is not False: weighted = torch.as_tensor(weighted, **backend) if weighted.dim() == 1: weighted = utils.unsqueeze(weighted, -1, dim) if weighted.numel() == nb_classes: weighted_channelwise = True weighted = weighted.flatten() else: weighted = None return weighted, weighted_channelwise
def irls_tukey_reweight(moving, fixed, lam=1, c=4.685, joint=False, dim=None, mask=None): """Update iteratively reweighted least-squares weights for Tukey's biweight Parameters ---------- moving : ([B], K, *spatial) tensor Moving image fixed : ([B], K, *spatial) tensor Fixed image lam : float or ([B], K|1, [*spatial]) tensor_like Equivalent to Gaussian noise precision (used to standardize the residuals) c : float, default=4.685 Tukey's threshold. Approximately equal to a number of standard deviations above which the loss is capped. dim : int, default=`fixed.dim() - 1` Number of spatial dimensions Returns ------- weights : (..., K|1, *spatial) tensor IRLS weights """ if lam is None: lam = 1 c = c * c fixed, moving, lam = utils.to_max_backend(fixed, moving, lam) if mask is not None: mask = mask.to(fixed.device) dim = dim or (fixed.dim() - 1) if lam.dim() <= 2: if lam.dim() == 0: lam = lam.flatten() lam = utils.unsqueeze(lam, -1, dim) # pad spatial dimensions weights = (moving - fixed).square_().mul_(lam) if mask is not None: weights = weights.mul_(mask) if joint: weights = weights.sum(dim=-dim - 1, keepdims=True) zeromsk = weights > c weights = weights.div_(-c).add_(1).square() weights[zeromsk].zero_() return weights
def irls_laplace_reweight(moving, fixed, lam=1, joint=False, eps=1e-5, dim=None, mask=None): """Update iteratively reweighted least-squares weights for l1 Parameters ---------- moving : ([B], K, *spatial) tensor Moving image fixed : ([B], K, *spatial) tensor Fixed image lam : float or ([B], K|1, [*spatial]) tensor_like Inverse-squared scale parameter of the Laplace distribution. (equivalent to Gaussian noise precision) dim : int, default=`fixed.dim() - 1` Number of spatial dimensions Returns ------- weights : (..., K|1, *spatial) tensor IRLS weights """ if lam is None: lam = 1 fixed, moving, lam = utils.to_max_backend(fixed, moving, lam) if mask is not None: mask = mask.to(fixed.device) dim = dim or (fixed.dim() - 1) if lam.dim() <= 2: if lam.dim() == 0: lam = lam.flatten() lam = utils.unsqueeze(lam, -1, dim) # pad spatial dimensions weights = (moving - fixed).square_().mul_(lam) if mask is not None: weights = weights.mul_(mask) if joint: weights = weights.sum(dim=-dim - 1, keepdims=True) weights = weights.sqrt_().clamp_min_(eps).reciprocal_() if mask is not None: weights = weights.masked_fill_(mask == 0, 0) return weights
def roi_closing(label, radius=10, dim=None): """Performs a multi-label morphological closing. Parameters ---------- label : (..., *spatial) tensor[int] Volume of labels. radius : float, default=1 Radius of the structuring element (in voxels) dim : int, default=label.dim() Number of spatial dimensions Returns ------- closed_label : tensor[int] """ from scipy.ndimage import distance_transform_edt, binary_closing dim = dim or label.dim() closest_label = torch.zeros_like(label) closest_dist = label.new_full(label.shape, float('inf'), dtype=torch.float) dist = torch.empty_like(closest_dist) for l in label.unique(): if l == 0: continue if label.dim() == dim: dist = torch.as_tensor(distance_transform_edt(label != l)) elif label.dim() == dim + 1: for z in range(len(dist)): dist[z] = torch.as_tensor( distance_transform_edt(label[z] != l)) else: raise NotImplementedError closest_label[dist < closest_dist] = l closest_dist = torch.min(closest_dist, dist) struct = spatial.identity_grid([2 * radius + 1] * dim).sub_(radius) struct = struct.square().sum(-1).sqrt() <= radius struct = utils.unsqueeze(struct, 0, label.dim() - dim) mask = binary_closing(label > 0, struct) mask = torch.as_tensor(mask).bitwise_not_() closest_label[mask] = 0 return closest_label
def forward(self, x, output_padding=None, output_shape=None): """ Parameters ---------- x : (batch, channel, *in_spatial) tensor output_padding : [sequence of] int, default=self.output_padding output_shape : [sequence of] int, default=self.output_shape Returns ------- x : (batch, channel, *out_spatial) tensor """ dim = x.dim() - 2 offset = py.make_list(self.offset, dim) stride = py.make_list(self.stride, dim) new_shape = self.shape(x, output_padding=output_padding, output_shape=output_shape) y = x.new_zeros(new_shape) if self.fill: z = utils.unfold(y, stride) x = utils.unsqueeze(x, -1, dim) slicer = [slice(o, o+sz*st) for sz, st, o in zip(x.shape[2:], stride, offset)] slicer = [slice(None)]*2 + slicer subz = z[tuple(slicer)] slicer = [slice(mx) for mx in subz.shape[2:]] slicer = [slice(None)]*2 + slicer subz.copy_(x[tuple(slicer)]) else: slicer = [slice(o, None, s) for o, s in zip(offset, stride)] slicer = [slice(None)]*2 + slicer suby = y[tuple(slicer)] slicer = [slice(mx) for mx in suby.shape[2:]] slicer = [slice(None)]*2 + slicer suby.copy_(x[tuple(slicer)]) return y
def spconv(input, kernel, step=1, start=0, stop=None, inplace=False, bound='dct2', dim=None): """Convolution with a sparse kernel. Notes ----- .. This convolution does not support strides, padding, dilation. .. The output spatial shape is the same as the input spatial shape. .. The output batch shape is the same as the input batch shape. .. Data outside the field-of-view is extrapolated according to `bound` .. It is implemented as a linear combination of views into the input tensor and should therefore be relatively memory-efficient. Parameters ---------- input : (..., [channel_in], *spatial) tensor Input tensor, to convolve. kernel : ([channel_in, [channel_out]], *kernel_size) sparse tensor Convolution kernel. start : [sequence of] int, default=0 stop : [sequence of] int, default=None step : [sequence of] int, default=1 Equivalent to spconv(x)[start:stop:step] bound : [sequence of] str, default='dct2' Boundary condition (per spatial dimension). dim : int, default=kernel.dim() Number of spatial dimensions. Returns ------- output : (..., [channel_out or channel_in], *spatial) tensor * If the kernel shape is (channel_in, channel_out, *kernel_size), the output shape is (..., channel_out, *spatial) and cross-channel convolution happens: out[co] = \sum_{ci} conv(inp[ci], ker[ci, co]) * If the kernel_shape is (channel_in, *kernel_size), independent single-channel convolutions are applied to each channels:: out[c] = conv(inp[c], ker[c]) * If the kernel shape is (*kernel_size), the same convolution is applied to all input channels: out[c] = conv(inp[c], ker) """ # get kernel dimensions dim = dim or kernel.dim() if kernel.dim() == dim + 2: channel_in, channel_out, *kernel_size = kernel.shape elif kernel.dim() == dim + 1: channel_in, *kernel_size = kernel.shape channel_out = None elif kernel.dim() == dim: kernel_size = kernel.shape channel_in = channel_out = None else: raise ValueError('Incompatible kernel shape: too many dimensions') start = core.py.ensure_list(start or 0, dim) stop = core.py.ensure_list(stop, dim) step = core.py.ensure_list(step, dim) # check input dimensions added_dims = max(0, dim + 1 - input.dim()) input = unsqueeze(input, 0, added_dims) if channel_in is not None: if input.shape[-dim - 1] not in (1, channel_in): raise ValueError('Incompatible kernel shape: input channels') spatial_shape = input.shape[-dim:] batch_shape = input.shape[:-dim - 1] output_shape = tuple( [*batch_shape, channel_out or channel_in, *spatial_shape]) else: # add a fake channel dimension spatial_shape = input.shape[-dim:] batch_shape = input.shape[:-dim] input = input.reshape([*batch_shape, 1, *spatial_shape]) output_shape = input.shape output_spatial_shape = spatial_shape start = [ 0 if not str else str + sz if str < 0 else str for str, sz in zip(start, spatial_shape) ] stop = [ sz if stp is None else stp + sz if stp < 0 else stp for stp, sz in zip(stop, spatial_shape) ] stop = [stp - 1 for stp in stop ] # we use an inclusive stop in the rest of the code step = [st or 1 for st in step] if step: output_spatial_shape = [ int(pymath.floor((stp - str) / float(st) + 1)) for stp, st, str in zip(stop, step, start) ] output_shape = [*output_shape[:-dim], *output_spatial_shape] slicer = [ slice(str, stp + 1, st) for str, stp, st in zip(start, stop, step) ] slicer = tuple([Ellipsis, *slicer]) identity = input[slicer] assert identity.shape[-dim:] == tuple(output_shape[-dim:]), "oops" if inplace: output = identity identity = identity.clone() output.zero_() else: output = input.new_zeros(output_shape) # move channel + spatial dimensions to the front for d in range(dim + 1): # +1 for channel dim input = core.utils.fast_movedim(input, -1, 0) output = core.utils.fast_movedim(output, -1, 0) identity = core.utils.fast_movedim(identity, -1, 0) # prepare other stuff bound = core.py.ensure_list(bound, dim) bound = [getattr(_bounds, b, None) for b in bound] # shift = torch.as_tensor([int(pymath.floor(k/2)) for k in kernel_size], # dtype=torch.long, device=kernel.device) shift = [int(pymath.floor(k / 2)) for k in kernel_size] sides = list(itertools.product([True, False], repeat=dim)) # Numeric magic to (hopefully) avoid floating point inaccuracy subw0 = True if subw0: kernel, w0 = _split_kernel(kernel, dim) else: identity = None split_idx = _get_idx_split(kernel.dim(), dim) # loop across weights in the sparse kernel indices = kernel._indices().t().tolist() values = kernel._values() for idx, weight in zip(indices, values): # map input and output channels ci, co, idx = split_idx(idx) idx = [i - s for i, s in zip(idx, shift)] inp = input[ci] out = output[co] if identity is not None: idt = identity[co] else: idt = None # generate slicers (input_center_slice, input_side_slice, output_center_slice, output_side_slice, transfo_side) = \ _make_slicers(idx, start, stop, step, output_spatial_shape, spatial_shape, bound) # Iterate all combinations of in/out of bounds for side in sides: input_slicer = tuple( input_center_slice[d] if inside else input_side_slice[d] for d, inside in enumerate(side)) output_slicer = tuple( output_center_slice[d] if inside else output_side_slice[d] for d, inside in enumerate(side)) transfo = tuple(None if inside else transfo_side[d] for d, inside in enumerate(side)) if any(sl is None for sl in input_slicer): continue if any(sl is None for sl in output_slicer): continue _accumulate(out, inp, output_slicer, input_slicer, transfo, weight, idt=idt, diag=(ci == co)) # add weighted identity if subw0: w0 = core.utils.unsqueeze(w0, -1, output.dim() - 1) output.addcmul_(identity, w0) # move spatial dimensions to the back for d in range(dim + 1): output = core.utils.fast_movedim(output, 0, -1) # remove fake channels if channel_in is None: output = output.squeeze(len(batch_shape)) # remove added dimensions for _ in range(added_dims): output = output.squeeze(-dim - 1) return output
def forward(self, score, truth, mask=None): """ Parameters ---------- score : (nb_batch, nb_class, *spatial) tensor Pre-transformed score vector. truth : (nb_batch, nb_class[-1]|1, *spatial) tensor Observed classes (or their expectation). * If `obs` has a floating point data type (`half`, `float`, `double`) it is assumed to hold one-hot or soft labels, and its channel dimension should be `nb_class` or `nb_class - 1`. * If `obs` has an integer or boolean data type, it is assumed to hold hard labels and its channel dimension should be 1. mask : (nb_batch, 1, *spatial) tensor, optional Loss mask Returns ------- loss : scalar or tensor The output shape depends on the type of reduction used. If 'mean' or 'sum', this function returns a scalar. """ weighted = self.weighted score = torch.as_tensor(score) truth = torch.as_tensor(truth, device=score.device) nb_classes = score.shape[1] # (includes background) if truth.dtype.is_floating_point: # soft labels truth = truth.to(score.dtype) truth_implicit = truth.shape[1] == nb_classes - 1 truth = get_prob_explicit(truth, implicit=truth_implicit) if truth.shape[1] != nb_classes: raise ValueError('Number of classes not consistent. ' 'Expected {} or {} but got {}.'.format( nb_classes, nb_classes - 1, truth.shape[1])) loss = score * truth if weighted is True: weighted = _auto_weighted_soft(truth) if mask is not None: weighted = weighted * mask loss *= weighted elif weighted not in (None, False): dim = truth.dim() - 2 weighted = utils.make_vector(weighted, nb_classes, **utils.backend(loss)) weighted = utils.unsqueeze(weighted, -1, dim) if mask is not None: weighted = weighted * mask loss *= weighted elif mask is not None: loss *= mask else: # hard labels channelwise = True if weighted is True: channelwise = False weighted = _auto_weighted_hard(truth, nb_classes, **utils.backend(score)) elif weighted not in (None, False): weighted = utils.make_vector(weighted, **utils.backend(score)) else: weighted = None truth = truth.squeeze(1).long() # If weights are a list of length C (or none), use nll_loss if channelwise and isinstance(self.reduction, str) and mask is None: return F.nll_loss(score, truth, weighted, reduction=self.reduction or 'none').neg_() # Otherwise, use our own implementation else: if weighted is not None: score = score * weighted loss = score.gather(dim=1, index=truth) if mask is not None: mask.squeeze(1) loss *= mask if mask is not None and self.reduction == 'mean': return loss.sum() / mask.sum() return super().forward(loss)
def compute_grad(dat): med = dat.reshape([dat.shape[0], -1]).median(dim=-1).values med = utils.unsqueeze(med, -1, 3) dat /= 0.5*med dat = spatial.diff(dat, dim=[1, 2, 3]).square().sum(-1) return dat
def shim(fmap, max_order=2, mask=None, isocenter=None, dim=None, returns='corrected'): """Subtract a linear combination of spherical harmonics that minimize gradients Parameters ---------- fmap : (..., *spatial) tensor Field map max_order : int, default=2 Maximum order of the spherical harmonics mask : tensor, optional Mask of voxels to include (typically brain mask) isocenter : [sequence of] float, default=shape/2 Coordinate of isocenter, in voxels dim : int, default=fmap.dim() Number of spatial dimensions returns : combination of {'corrected', 'correction', 'parameters'}, default='corrected' Components to return Returns ------- corrected : (..., *spatial) tensor, if 'corrected' in `returns` Corrected field map (with spherical harmonics subtracted) correction : (..., *spatial) tensor, if 'correction' in `returns` Linear combination of spherical harmonics. parameters : (..., k) tensor, if 'parameters' in `returns` Parameters of the linear combination """ fmap = torch.as_tensor(fmap) dim = dim or fmap.dim() shape = fmap.shape[-dim:] batch = fmap.shape[:-dim] backend = utils.backend(fmap) dims = list(range(-dim, 0)) if mask is not None: mask = ~mask # make it a mask of background voxels # compute gradients gmap = diff(fmap, dim=dims, side='f', bound='dct2') if mask is not None: gmap[..., mask, :] = 0 gmap = gmap.reshape([*batch, -1]) # compute basis of spherical harmonics basis = [] for i in range(1, max_order + 1): b = spherical_harmonics(shape, i, isocenter, **backend) b = utils.movedim(b, -1, 0) b = diff(b, dim=dims, side='f', bound='dct2') if mask is not None: b[..., mask, :] = 0 b = b.reshape([b.shape[0], *batch, -1]) basis.append(b) basis = torch.cat(basis, 0) basis = utils.movedim(basis, 0, -1) # (*batch, vox*dim, k) # solve system prm = linalg.lmdiv(basis, gmap[..., None], method='pinv')[..., 0] # > (*batch, k) # rebuild basis (without taking gradients) basis = [] for i in range(1, max_order + 1): b = spherical_harmonics(shape, i, isocenter, **backend) b = utils.movedim(b, -1, 0) b = b.reshape([b.shape[0], *batch, *shape]) basis.append(b) basis = torch.cat(basis, 0) basis = utils.movedim(basis, 0, -1) # (*batch, vox*dim, k) comb = linalg.matvec(basis.unsqueeze(-2), utils.unsqueeze(prm, -2, dim)) comb = comb[..., 0] fmap = fmap - comb returns = returns.split('+') out = [] for ret in returns: if ret == 'corrected': out.append(fmap) elif ret == 'correction': out.append(comb) elif ret[0] == 'p': out.append(prm) return out[0] if len(out) == 1 else tuple(out)
def zcorrect_exp_const(x, decay=None, sigma=None, lam=10, mask=None, max_iter=128, tol=1e-6, verbose=False, snr=5): """Correct the z signal decay in a SPIM image. The signal is modelled as: f(z) = s * exp(-b * z) + eps where z=0 is (arbitrarily) the middle slice, s is the intercept and b is the decay coefficient. Parameters ---------- x : (..., nz) tensor SPIM image with the z dimension last and the z=0 plane first decay : float, optional Initial guess for decay parameter. Default: educated guess. sigma : float, optional Noise standard deviation. Default: educated guess. lam : float or (float, float), default=10 Regularisation. max_iter : int, default=128 tol : float, default=1e-6 verbose : int or bool, default=False Returns ------- y : tensor Corrected image decay : float Decay parameters """ x = torch.as_tensor(x) if not x.dtype.is_floating_point: x = x.to(dtype=torch.get_default_dtype()) backend = utils.backend(x) shape = x.shape dim = x.dim() - 1 nz = shape[-1] b = decay x = utils.movedim(x, -1, 0).clone() if mask is None: mask = torch.isfinite(x) & (x > 0) else: mask = mask & (torch.isfinite(x) & (x > 0)) x[~mask] = 0 # decay educated guess: closed form from two values if b is None: z1 = 2 * nz // 5 z2 = 3 * nz // 5 x1 = x[z1] x1 = x1[x1 > 0].median() x2 = x[z2] x2 = x2[x2 > 0].median() z1 = float(z1) z2 = float(z2) b = (x2.log() - x1.log()) / (z1 - z2) y = x[(nz - 1) // 2] y = y[y > 0].median().log() b = b.item() if torch.is_tensor(b) else b y = y.item() print(f'init: y = {y}, b = {b}') # noise educated guess: assume SNR=5 at z=1/2 sigma = sigma or (y / snr) lam_y, lam_b = py.make_list(lam, 2) lam_y = lam_y**2 * sigma**2 lam_b = lam_b**2 * sigma**2 reg = lambda t: spatial.regulariser( t, membrane=1, dim=dim, factor=(lam_y, lam_b)) solve = lambda h, g: spatial.solve_field_fmg( h, g, membrane=1, dim=dim, factor=(lam_y, lam_b)) # init z = torch.arange(nz, **backend) - (nz - 1) / 2 z = utils.unsqueeze(z, -1, dim) theta = z.new_empty([2, *x.shape[1:]], **backend) logy = theta[0].fill_(y) b = theta[1].fill_(b) y = logy.exp() ll0 = (mask * y * (-b * z).exp_() - x).square_().sum() + (theta * reg(theta)).sum() ll1 = ll0 g = torch.zeros_like(theta) h = theta.new_zeros([3, *theta.shape[1:]]) for it in range(max_iter): # exponentiate y = torch.exp(logy, out=y) fit = (b * z).neg_().exp_().mul_(y).mul_(mask) res = fit - x # compute objective reg_theta = reg(theta) ll = res.square().sum() + (theta * reg_theta).sum() gain = (ll1 - ll) / ll0 if verbose: end = '\n' if verbose > 1 else '\r' print(f'{it:3d} | {ll:12.6g} | gain = {gain:12.6g}', end=end) if it > 0 and gain < tol: break ll1 = ll g[0] = (fit * res).sum(0) g[1] = -(fit * res * z).sum(0) h[0] = (fit * (fit + res.abs())).sum(0) h[1] = (fit * (fit + res.abs()) * (z * z)).sum(0) h[2] = -(z * fit * fit).sum(0) g += reg_theta theta -= solve(h, g) y = torch.exp(logy, out=y) x = x * (b * z).exp_() x = utils.movedim(x, 0, -1) x = x.reshape(shape) return y, b, x
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch size overload : dict Returns ------- field : (batch, channel, *shape) tensor Generated random field """ # get arguments shape = overload.get('shape', self.shape) mean = overload.get('mean', self.mean) amplitude = overload.get('amplitude', self.amplitude) fwhm = overload.get('fwhm', self.fwhm) channel = overload.get('channel', self.channel) basis = overload.get('basis', self.basis) dtype = overload.get('dtype', self.dtype) device = overload.get('device', self.device) # sample if parameters are callable mean = mean() if callable(mean) else mean amplitude = amplitude() if callable(amplitude) else amplitude fwhm = fwhm() if callable(fwhm) else fwhm # device/dtype mean = torch.as_tensor(mean, dtype=dtype, device=device) amplitude = torch.as_tensor(amplitude, dtype=dtype, device=device) fwhm = torch.as_tensor(fwhm, dtype=dtype, device=device) # reshape nb_dim = len(shape) full_shape = [batch, channel, *shape] mean = mean.expand(full_shape) amplitude = amplitude.expand(full_shape) fwhm = fwhm.expand([batch, channel, nb_dim]) conv = torch.nn.functional.conv1d if nb_dim == 1 else \ torch.nn.functional.conv2d if nb_dim == 2 else \ torch.nn.functional.conv3d if nb_dim == 3 else None # convert SE parameters to noise/kernel parameters sigma_se = fwhm / math.sqrt(8 * math.log(2)) sigma_se = unsqueeze(sigma_se.prod(dim=-1), dim=-1, ndim=nb_dim) amplitude = amplitude * (2 * pi)**(nb_dim / 4) * sigma_se.sqrt() fwhm = fwhm * math.sqrt(2) # smooth samples_b = [] for b in range(batch): samples_c = [] for c in range(channel): kernel = smooth('gauss', fwhm[b, c], basis=basis, device=device, dtype=dtype) # compute input shape pad_shape = [ shape[d] + kernel[d].shape[d + 2] - 1 for d in range(nb_dim) ] mean1 = ensure_shape(mean[b, c], pad_shape, mode='reflect2', side='both') amplitude1 = ensure_shape(amplitude[b, c], pad_shape, mode='reflect2', side='both') # generate sample sample = torch.distributions.Normal(mean1, amplitude1).sample() sample = sample[None, None, ...] # convolve for ker in kernel: sample = conv(sample, ker) samples_c.append(sample) samples_b.append(torch.cat(samples_c, dim=1)) sample = torch.cat(samples_b, dim=0) return sample
def mse(moving, fixed, lam=1, dim=None, grad=True, hess=True, mask=None): """Mean-squared error loss for optimisation-based registration. (A factor 1/2 is included, and the loss is averaged across voxels, but not across channels or batches) Parameters ---------- moving : ([B], K, *spatial) tensor Moving image fixed : ([B], K, *spatial) tensor Fixed image lam : float or ([B], K|1, [*spatial]) tensor_like Gaussian noise precision (or IRLS weights) dim : int, default=`fixed.dim() - 1` Number of spatial dimensions grad : bool, default=True Compute and return gradient hess : bool, default=True Compute and return Hessian Returns ------- ll : () tensor Negative log-likelihood g : (..., K, *spatial) tensor, optional Gradient with respect to the moving imaged h : (..., K, *spatial) tensor, optional (Diagonal) Hessian with respect to the moving image """ fixed, moving, lam = utils.to_max_backend(fixed, moving, lam) if mask is not None: mask = mask.to(fixed.device) dim = dim or (fixed.dim() - 1) if lam.dim() <= 2: if lam.dim() == 0: lam = lam.flatten() lam = utils.unsqueeze(lam, -1, dim) # pad spatial dimensions nvox = py.prod(fixed.shape[-dim:]) if moving.requires_grad: ll = moving - fixed if mask is not None: ll = ll.mul_(mask) ll = ll.square().mul_(lam).sum() / (2 * nvox) else: ll = moving - fixed if mask is not None: ll = ll.mul_(mask) ll = ll.square_().mul_(lam).sum() / (2 * nvox) out = [ll] if grad: g = moving - fixed if mask is not None: g = g.mul_(mask) g = g.mul_(lam).div_(nvox) out.append(g) if hess: h = lam / nvox if mask is not None: h = mask * h out.append(h) return tuple(out) if len(out) > 1 else out[0]
def forward(self, image, **overload): """ Parameters ---------- image : (batch, channel, *shape) tensor Input image overload : dict All parameters defined at build time can be overridden at call time Returns ------- warped : (batch, channel, *shape) tensor Deformed image grid : (batch, *shape, 3) tensor Resampling grid """ image = torch.as_tensor(image) dim = image.dim() - 2 batch, channel, *shape = image.shape info = {'dtype': image.dtype, 'device': image.device} # get arguments opt_grid = { 'dim': dim, 'shape': shape, 'amplitude': overload.get('vel_amplitude', self.grid.amplitude), 'fwhm': overload.get('vel_fwhm', self.grid.fwhm), 'bound': overload.get('vel_bound', self.grid.bound), 'interpolation': overload.get('interpolation', self.grid.interpolation), 'dtype': overload.get('dtype', self.grid.dtype), 'device': overload.get('device', self.grid.device), } opt_affine = { 'dim': dim, 'translation': overload.get('translation', self.affine.translation), 'rotation': overload.get('rotation', self.affine.rotation), 'zoom': overload.get('zoom', self.affine.zoom), 'shear': overload.get('shear', self.affine.shear), 'dtype': overload.get('dtype', self.affine.dtype), 'device': overload.get('device', self.affine.device), } opt_pull = { 'bound': overload.get('image_bound', self.pull.bound), 'interpolation': overload.get('interpolation', self.pull.interpolation), } grid = self.grid(batch, **opt_grid) aff = self.affine(batch, **opt_affine) # shift center of rotation aff_shift = torch.cat( (torch.eye(dim, **info), -torch.as_tensor(opt_grid['shape'], **info)[:, None] / 2), dim=1) aff = affine_matmul(aff, aff_shift) aff = affine_lmdiv(aff_shift, aff) # compose aff = unsqueeze(aff, dim=-3, ndim=dim) lin = aff[..., :dim, :dim] off = aff[..., :dim, -1] grid = matvec(lin, grid) + off # pull warped = self.pull(image, grid, **opt_pull) return warped, grid
def lcc(moving, fixed, dim=None, patch=20, stride=1, lam=1, mode='g', grad=True, hess=True, mask=None): """Local correlation coefficient (squared) This function implements a squared version of Cachier and Pennec's local correlation coefficient, so that anti-correlations are not penalized. Parameters ---------- moving : (..., K, *spatial) tensor Moving image with K channels. fixed : (..., K, *spatial) tensor Fixed image with K channels. dim : int, default=`fixed.dim() - 1` Number of spatial dimensions. patch : int, default=5 Patch size lam : float or ([B], K|1, [*spatial]) tensor_like, default=1 Precision of the NCC distribution grad : bool, default=True Compute and return gradient hess : bool, default=True Compute and return approximate Hessian Returns ------- ll : () tensor References ---------- ..[1] "3D Non-Rigid Registration by Gradient Descent on a Gaussian- Windowed Similarity Measure using Convolutions" Pascal Cachier, Xavier Pennec MMBIA (2000) """ if moving.requires_grad: sqrt_ = torch.sqrt div_ = torch.div else: sqrt_ = torch.sqrt_ div_ = lambda x, y: x.div_(y) fixed, moving, lam = utils.to_max_backend(fixed, moving, lam) dim = dim or (fixed.dim() - 1) shape = fixed.shape[-dim:] if mask is not None: mask = mask.to(**utils.backend(fixed)) else: mask = fixed.new_ones(fixed.shape[-dim:]) if lam.dim() <= 2: if lam.dim() == 0: lam = lam.flatten() lam = utils.unsqueeze(lam, -1, dim) patch = list(map(float, py.ensure_list(patch))) stride = py.ensure_list(stride) stride = [s or 0 for s in stride] fwd = lambda x: local_mean( x, patch, stride, dim=dim, mode=mode, mask=mask, cache=local_cache) bwd = lambda x: local_mean(x, patch, stride, dim=dim, mode=mode, mask=mask, backward=True, shape=shape, cache=local_cache) sumall = lambda x: x.sum(list(range(-dim, 0)), keepdim=True) # compute ncc within each patch mom0, moving_mean, fixed_mean, moving_std, fixed_std, corr = \ _suffstat(fwd, moving, fixed) mom0 = mom0.div_(sumall(mom0).clamp_min_(1e-5)).mul_(lam) moving_std = sqrt_(moving_std.addcmul_(moving_mean, moving_mean, value=-1)) fixed_std = sqrt_(fixed_std.addcmul_(fixed_mean, fixed_mean, value=-1)) moving_std.clamp_min_(1e-5) fixed_std.clamp_min_(1e-5) corr = div_( div_(corr.addcmul_(moving_mean, fixed_mean, value=-1), moving_std), fixed_std) corr2 = corr.square().neg_().add_(1).clamp_min_(1e-8) out = [] if grad or hess: h = (corr / moving_std).square_().mul_(mom0).div_(corr2) h = bwd(h) if grad: # g = G' * (corr.*(corr.*xmean./xstd - ymean./ystd)./xstd) # - x .* (G' * (corr./ xstd).^2) # + y .* (G' * (corr ./ (xstd.*ystd))) # g = -2 * g fixed_mean = fixed_mean.div_(fixed_std) moving_mean = moving_mean.div_(moving_std) g = fixed_mean.addcmul_(corr, moving_mean, value=-1) fixed_mean = moving_mean = None g = g.mul_(corr).div_(moving_std).mul_(mom0).div_(corr2) g = bwd(g) g = g.addcmul_(h, moving) g = g.addcmul_(bwd( corr.div_(moving_std).div_(fixed_std).mul_(mom0).div_(corr2)), fixed, value=-1) g = g.mul_(2) out.append(g) if hess: # h = 2 * (G' * (corr./ xstd).^2) h = h.mul_(2) out.append(h) # return stuff corr = corr2.log_().mul_(mom0) corr = corr.sum() out = [corr, *out] return tuple(out) if len(out) > 1 else out[0]
def preprocess(a): a = torch.as_tensor(a) a = unsqueeze(a, dim=-1, ndim=opt['channel'] + 1 - a.dim()) return a