def responsibilities(image, means, precisions, proportions): # aliases x = image m = means A = precisions p = proportions nb_dim = image.dim() - 2 del image, means, precisions, proportions # voxel-wise term x = channel2last(x).unsqueeze(-2) # [B, ..., 1, C] p = unsqueeze(p, dim=1, ndim=nb_dim) # [B, ones, K] m = unsqueeze(m, dim=1, ndim=nb_dim) # [B, ones, K, C] A = unsqueeze(A, dim=1, ndim=nb_dim) # [B, ones, K, C, C] x = x - m z = matvec(A, x) z = (z * x).sum(dim=-1) # [B, ..., K] z = -0.5 * z # constant term twopi = torch.as_tensor(2 * pi, dtype=A.dtype, device=A.device) nrm = torch.logdet(A) - A.shape[-1] * twopi.log() nrm = 0.5 * nrm + p.log() z = z + nrm # softmax z = last2channel(z) logz = torch.nn.functional.log_softmax(z, dim=1) z = torch.nn.functional.softmax(z, dim=1) return z, logz
def nll(image, resp, means, precisions): # aliases x = image z = resp m = means A = precisions nb_dim = image.dim() - 2 del image, resp, means, precisions x = channel2last(x).unsqueeze(-2) # [B, ..., 1, C] z = channel2last(z) # [B, ..., K] m = unsqueeze(m, dim=1, ndim=nb_dim) # [B, ones, K, C] A = unsqueeze(A, dim=1, ndim=nb_dim) # [B, ones, K, C, C] x = x - m loss = matvec(A, x) loss = (loss * x).sum(dim=-1) # [B, ..., K] loss = (loss * z).sum(dim=-1) # [B, ...] loss = loss * 0.5 return loss
def forward(self, prior, **overload): """ Parameters ---------- prior : (batch, channel, *shape) Prior probabilities or log-odds of the Categorical distribution overload : dict All parameters defined at buildtime can be overridden at call time Returns ------- sample : (batch, 1, *shape) """ # read arguments shape = overload.get('shape', self.shape) logits = overload.get('logits', self.logits) implicit = overload.get('implicit', self.implicit) # call prior in case it is a random parameter prior = prior() if callable(prior) else torch.as_tensor(prior) # repeat prior if shape provided if shape is not None: if prior.dim() != 2: raise ValueError('Expected tensor with shape (batch, channel) ' 'but got {}'.format(prior.shape)) prior = expand(prior, [*prior.shape, *shape], side='right') # add implicit class if implicit: shape = list(prior.shape) shape[1] = 1 zero = torch.zeros(shape, dtype=prior.dtype, device=prior.device) prior = torch.cat((prior, zero), dim=1) # reshape in 2d batch, channel, *shape = prior.shape prior = channel2last(prior) # make class dimension last kwargs = dict() kwargs['logits' if logits else 'probs'] = prior # sample sample = torch.distributions.Categorical(**kwargs).sample() sample = sample.reshape([batch, 1, *shape]) return sample
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch size overload : dict All parameters defined at build time can be overridden at call time Returns ------- vel : (batch, *shape, dim) tensor Velocity field """ # get arguments opt = { 'channel': overload.get('dim', self.field.channel), 'shape': overload.get('shape', self.field.shape), 'amplitude': overload.get('amplitude', self.field.amplitude), 'fwhm': overload.get('fwhm', self.field.fwhm), 'dtype': overload.get('dtype', self.field.dtype), 'device': overload.get('device', self.field.device), } # preprocess amplitude # > RandomField broadcast amplitude to (channel, *shape), with # padding from the left, which means that a 1d amplitude would # be broadcasted to (1, ..., dim) instead of (dim, ..., 1) # > We therefore reshape amplitude to avoid left-side padding def preprocess(a): a = torch.as_tensor(a) a = unsqueeze(a, dim=-1, ndim=opt['channel'] + 1 - a.dim()) return a amplitude = opt['amplitude'] if callable(amplitude): amplitude_fn = amplitude amplitude = lambda *args, **kwargs: preprocess( amplitude_fn(*args, **kwargs)) else: amplitude = preprocess(amplitude) opt['amplitude'] = amplitude return channel2last(self.field(batch, **opt))
def forward(self, source, target, *, _loss=None, _metric=None): """ Parameters ---------- source : tensor (batch, channel, *spatial) Source/moving image target : tensor (batch, channel, *spatial) Target/fixed image _loss : dict, optional If provided, all registered losses are computed and appended. _metric : dict, optional If provided, all registered metrics are computed and appended. Returns ------- deformed_source : tensor (batch, channel, *spatial) Deformed source image affine_prm : tensor (batch,, *nb_prm) affine Lie/Classic parameters """ # sanity checks check.dim(self.dim, source, target) check.shape(target, source, dims=[0], broadcast_ok=True) check.shape(target, source, dims=range(2, self.dim + 2)) # chain operations source_and_target = torch.cat((source, target), dim=1) dense = channel2last(self.unet(source_and_target)) affprm = self.dense2prm(dense) affine = self.exp(affprm.double()).to(dense.dtype) grid = self.grid(affine, shape=target.shape[2:]) deformed_source = self.pull(source, grid) # compute loss and metrics self.compute(_loss, _metric, image=[deformed_source, target], affine=[affprm], dense=[dense]) return deformed_source, affprm, dense
def _pull_vel(vel, grid, *args, **kwargs): """Interpolate a velocity/grid/displacement field. Parameters ---------- vel : (batch, ..., ndim) tensor Velocity grid : (batch, ..., ndim) tensor Transformation field opt : dict Options to ``grid_pull`` Returns ------- pulled_vel : (batch, ..., ndim) tensor Velocity """ return channel2last(grid_pull(last2channel(vel), grid, *args, **kwargs))
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch size Other Parameters ---------------- shape : sequence[int], optional device : torch.device, optional dtype : torch.dtype, optional Returns ------- vel : (batch, *shape, dim) tensor Velocity field """ overload['channel'] = len(overload.get('shape', self.field.shape)) return utils.channel2last(self.field(batch, **overload))
def resize_grid(grid, factor=None, shape=None, type='grid', affine=None, *args, **kwargs): """Resize a displacement grid by a factor. The displacement grid is resized *and* rescaled, so that displacements are expressed in the new voxel referential. Notes ----- .. A least one of `factor` and `shape` must be specified. .. If `anchor in ('centers', 'edges')`, and both `factor` and `shape` are specified, `factor` is discarded. .. If `anchor in ('first', 'last')`, `factor` must be provided even if `shape` is specified. .. Because of rounding, it is in general not assured that `resize(resize(x, f), 1/f)` returns a tensor with the same shape as x. Parameters ---------- grid : (batch, ..., ndim) tensor Grid to resize factor : float or list[float], optional Resizing factor * > 1 : larger image <-> smaller voxels * < 1 : smaller image <-> larger voxels shape : (ndim,) sequence[int], optional Output shape type : {'grid', 'displacement'}, default='grid' Grid type: * 'grid' correspond to dense grids of coordinates. * 'displacement' correspond to dense grid of relative displacements. Both types are not rescaled in the same way. affine : (batch, ndim[+1], ndim+1), optional Orientation matrix of the input grid. If provided, the orientation matrix of the resized image is returned as well. anchor : {'centers', 'edges', 'first', 'last'}, default='centers' * In cases 'c' and 'e', the volume shape is multiplied by the zoom factor (and eventually truncated), and two anchor points are used to determine the voxel size. * In cases 'f' and 'l', a single anchor point is used so that the voxel size is exactly divided by the zoom factor. This case with an integer factor corresponds to subslicing the volume (e.g., `vol[::f, ::f, ::f]`). * A list of anchors (one per dimension) can also be provided. **kwargs Parameters of `grid_pull`. Returns ------- resized : (batch, ..., ndim) tensor Resized grid. affine : (batch, ndim[+1], ndim+1) tensor, optional Orientation matrix """ # resize grid kwargs['_return_trf'] = True grid = utils.last2channel(grid) outputs = resize(grid, factor, shape, affine, *args, **kwargs) if affine is not None: grid, affine, (scales, shifts) = outputs else: grid, (scales, shifts) = outputs grid = utils.channel2last(grid) # rescale each component # scales and shifts map resized coordinates to original coordinates: # original = scale * resized + shift # here we want to transform original coordinates into resized ones: # resized = (original - shift) / scale grids = [] for d, (scl, shft) in enumerate(zip(scales, shifts)): grid1 = utils.slice_tensor(grid, d, dim=-1) if type[0].lower() == 'g': grid1 = grid1 - shft grid1 = grid1 / scl grids.append(grid1) grid = torch.stack(grids, -1) # return if affine is not None: return grid, affine else: return grid
def compose(*args, interpolation='linear', bound='dft'): """Compose multiple spatial deformations (affine matrices or flow fields). """ # TODO: # . add shape/dim argument to generate (if needed) an identity field # at the end of the chain. # . possibility to provide fields that have an orientation matrix? # (or keep it the responsibility of the user?) # . For higher order (> 1) interpolation: convert to spline coeficients. def ismatrix(x): """Check that a tensor is a matrix (ndim == 2).""" x = torch.as_tensor(x) shape = torch.as_tensor(x.shape) return shape.numel() == 2 # Pre-pass: check dimensionality dim = None last_affine = False at_least_one_field = False for arg in args: if ismatrix(arg): last_affine = True dim1 = arg.shape[1] else: last_affine = False at_least_one_field = True dim1 = arg.dim() - 2 if dim is not None and dim != dim1: raise ValueError("All deformations should have the same " "dimensionality (2D/3D).") elif dim is None: dim = dim1 if at_least_one_field and last_affine: raise ValueError("The last deformation cannot be an affine matrix. " "Use affine_field to transform it first.") # First pass: compose all sequential affine matrices args1 = [] last_affine = None for arg in args: if ismatrix(arg): if last_affine is None: last_affine = _make_square(arg) else: last_affine = last_affine.matmul(_make_square(arg)) else: if last_affine is not None: args1.append(last_affine) last_affine = None args1.append(arg) if not at_least_one_field: return last_affine # Second pass: perform all possible "field x matrix" compositions args2 = [] last_affine = None for arg in args1: if ismatrix(arg): last_affine = arg else: if last_affine is not None: new_field = arg.matmul( last_affine[:dim, :dim].transpose(0, 1)) \ + last_affine[:dim, dim].reshape((1,)*(dim+1) + (dim,)) args2.append(new_field) else: args2.append(arg) if last_affine is not None: args2.append(last_affine) # Third pass: compose all flow fields field = args2[-1] for arg in args2[-2::-1]: # args2[-2:0:-1] arg = arg - identity_grid(arg.shape[1:-1], arg.dtype, arg.device) arg = utils.last2channel(arg) field = field + utils.channel2last( grid_pull(arg, field, interpolation, bound)) # /!\ (TODO) The very first field (the first one being interpolated) # potentially contains a multiplication with an affine matrix (i.e., # it might not be expressed in voxels). This affine transformation should # be removed prior to subtracting the identity, and added back at the end. # However, I don't know how to 'guess' this matrix. # # After further though, I think we can find the matrix that minimizes in # the least-square sense (F*M-I), where F is NbVox*D and contains the # deformation field, I is NbVox*D and contains the identity field # (expressed in voxels) and M is the inverse of the unknown matrix. # This problem has a closed form solution: (F'*F)\(F'*I). # For better stability, We could encode M in gl(D), the Lie # algebra of invertible matrices, and use gauss-newton to optimise # the problem. # # Below is a tentative implementatin of the linear version # > Needs F'F to be invertible and well-conditioned # # For the last field, we factor out a possible affine transformation # arg = args2[0] # shape = arg.shape # N = shape[0] # Batch size # D = shape[-1] # Dimension # V = torch.as_tensor(shape[1:-1]).prod() # Nb of voxels # Id = identity(arg.shape[-2:0:-1], arg.dtype, arg.device).reshape(V, D) # arg = arg.reshape(N, V, D) # Field as a matrix # one = torch.ones((N, V, 1), dtype=arg.dtype, device=arg.device) # arg = cat((arg, one), 2) # Id = cat((Id, one)) # AA = arg.transpose(1, 2).bmm(arg) # LHS of linear system # AI = arg.transpose(1, 2).bmm(arg) # RHS of linear system # M, _ = torch.solve(AI, AA) # Solution # arg = arg.bmm(M) - Id # Closest displacement # arg = arg[..., :-1].reshape(shape) # arg = utils.last2channel(arg) # field = grid_pull(arg, field, interpolation, bound) # Interpolate # field = field + channel2grid(grid_pull(arg, field, interpolation, bound)) # shape = field.shape # V = torch.as_tensor(shape[1:-1]).prod() # field = field.reshape(N, V, D) # one = torch.ones((N, V, 1), dtype=field.dtype, device=field.device) # field, _ = torch.solve(field.transpose(1, 2), M.transpose(1, 2)) # field = field.transpose(1, 2)[..., :-1].reshape(shape) return field
def forward(self, source, target, source_seg=None, target_seg=None, *, _loss=None, _metric=None): """ Parameters ---------- source : tensor (batch, channel, *spatial) Source/moving image target : tensor (batch, channel, *spatial) Target/fixed image _loss : dict, optional If provided, all registered losses are computed and appended. _metric : dict, optional If provided, all registered metrics are computed and appended. Returns ------- deformed_source : tensor (batch, channel, *spatial) Deformed source image affine_prm : tensor (batch,, *spatial, len(spatial)) affine Lie parameters """ # sanity checks check.dim(self.dim, source, target, source_seg, target_seg) check.shape(target, source, dims=[0], broadcast_ok=True) check.shape(target, source, dims=range(2, self.dim + 2)) check.shape(target_seg, source_seg, dims=[0], broadcast_ok=True) check.shape(target_seg, source_seg, dims=range(2, self.dim + 2)) # chain operations source_and_target = torch.cat((source, target), dim=1) # generate affine affine_prm = self.cnn(source_and_target) affine_prm = affine_prm.reshape(affine_prm.shape[:2]) # generate velocity velocity = self.unet(source_and_target) velocity = channel2last(velocity) # generate deformation grid grid = self.exp(velocity, affine_prm) # deform deformed_source = self.pull(source, grid) if source_seg is not None: if source_seg.shape[2:] != source.shape[2:]: grid = spatial.resize_grid(grid, shape=source_seg.shape[2:]) deformed_source_seg = self.pull(source_seg, grid) else: deformed_source_seg = None # compute loss and metrics self.compute(_loss, _metric, image=[deformed_source, target], velocity=[velocity], segmentation=[deformed_source_seg, target_seg], affine=[affine_prm]) if deformed_source_seg is None: return deformed_source, velocity, affine_prm else: return deformed_source, deformed_source_seg, velocity, affine_prm