def responsibilities(image, means, precisions, proportions): # aliases x = image m = means A = precisions p = proportions nb_dim = image.dim() - 2 del image, means, precisions, proportions # voxel-wise term x = channel2last(x).unsqueeze(-2) # [B, ..., 1, C] p = unsqueeze(p, dim=1, ndim=nb_dim) # [B, ones, K] m = unsqueeze(m, dim=1, ndim=nb_dim) # [B, ones, K, C] A = unsqueeze(A, dim=1, ndim=nb_dim) # [B, ones, K, C, C] x = x - m z = matvec(A, x) z = (z * x).sum(dim=-1) # [B, ..., K] z = -0.5 * z # constant term twopi = torch.as_tensor(2 * pi, dtype=A.dtype, device=A.device) nrm = torch.logdet(A) - A.shape[-1] * twopi.log() nrm = 0.5 * nrm + p.log() z = z + nrm # softmax z = last2channel(z) logz = torch.nn.functional.log_softmax(z, dim=1) z = torch.nn.functional.softmax(z, dim=1) return z, logz
def _dense2prm_cnn(self, x): """CNN-based implementation of dense2prm""" x = last2channel(x) shape = x.shape[2:] grid = self._identity(x) x = torch.cat([x, grid], dim=1) prm = self.cnn(x) prm = prm.reshape(prm.shape[:2]) return self._std_prm(prm, shape)
def jacobian(warp, bound='circular'): """Compute the jacobian of a 'vox' warp. This function estimates the field of Jacobian matrices of a deformation field using central finite differences: (next-previous)/2. Note that for Neumann boundary conditions, symmetric padding is usuallly used (symmetry w.r.t. voxel edge), when computing Jacobian fields, reflection padding is more adapted (symmetry w.r.t. voxel centre), so that derivatives are zero at the edges of the FOV. Note that voxel sizes are not considered here. The flow field should be expressed in voxels and so will the Jacobian. Args: warp (torch.Tensor): flow field (N, W, H, D, 3). bound (str, optional): Boundary conditions. Defaults to 'circular'. Returns: jac (torch.Tensor): Field of Jacobian matrices (N, W, H, D, 3, 3). jac[:,:,:,:,i,j] contains the derivative of the i-th component of the deformation field with respect to the j-th axis. """ warp = torch.as_tensor(warp) shape = warp.size() dim = shape[-1] ker = kernels.imgrad(dim, device=warp.device, dtype=warp.dtype) ker = kernels.make_separable(ker, dim) warp = utils.last2channel(warp) if bound in ('circular', 'fft'): warp = utils.pad(warp, (1, ) * dim, mode='circular', side='both') pad = 0 elif bound in ('reflect1', 'dct1'): warp = utils.pad(warp, (1, ) * dim, mode='reflect1', side='both') pad = 0 elif bound in ('reflect2', 'dct2'): warp = utils.pad(warp, (1, ) * dim, mode='reflect2', side='both') pad = 0 elif bound in ('constant', 'zero', 'zeros'): pad = 1 else: raise ValueError('Unknown bound {}.'.format(bound)) if dim == 1: conv = _F.conv1d elif dim == 2: conv = _F.conv2d elif dim == 3: conv = _F.conv3d else: raise ValueError( 'Warps must be of dimension 1, 2 or 3. Got {}.'.format(dim)) jac = conv(warp, ker, padding=pad, groups=dim) jac = jac.reshape((shape[0], dim, dim) + shape[1:]) jac = jac.permute((0, ) + tuple(range(3, 3 + dim)) + (1, 2)) return jac
def _identity(x): """Build an identity grid with same shape/backend as a tensor. The grid is built such that coordinate zero is at the center of the FOV.""" shape = x.shape[2:] backend = dict(dtype=x.dtype, device=x.device) grid = spatial.identity_grid(shape, **backend) grid -= torch.as_tensor(shape, **backend) / 2. grid /= torch.as_tensor(shape, **backend) / 2. grid = last2channel(grid[None, ...]) return grid
def _pull_vel(vel, grid, *args, **kwargs): """Interpolate a velocity/grid/displacement field. Parameters ---------- vel : (batch, ..., ndim) tensor Velocity grid : (batch, ..., ndim) tensor Transformation field opt : dict Options to ``grid_pull`` Returns ------- pulled_vel : (batch, ..., ndim) tensor Velocity """ return channel2last(grid_pull(last2channel(vel), grid, *args, **kwargs))
def resize_grid(grid, factor=None, shape=None, type='grid', affine=None, *args, **kwargs): """Resize a displacement grid by a factor. The displacement grid is resized *and* rescaled, so that displacements are expressed in the new voxel referential. Notes ----- .. A least one of `factor` and `shape` must be specified. .. If `anchor in ('centers', 'edges')`, and both `factor` and `shape` are specified, `factor` is discarded. .. If `anchor in ('first', 'last')`, `factor` must be provided even if `shape` is specified. .. Because of rounding, it is in general not assured that `resize(resize(x, f), 1/f)` returns a tensor with the same shape as x. Parameters ---------- grid : (batch, ..., ndim) tensor Grid to resize factor : float or list[float], optional Resizing factor * > 1 : larger image <-> smaller voxels * < 1 : smaller image <-> larger voxels shape : (ndim,) sequence[int], optional Output shape type : {'grid', 'displacement'}, default='grid' Grid type: * 'grid' correspond to dense grids of coordinates. * 'displacement' correspond to dense grid of relative displacements. Both types are not rescaled in the same way. affine : (batch, ndim[+1], ndim+1), optional Orientation matrix of the input grid. If provided, the orientation matrix of the resized image is returned as well. anchor : {'centers', 'edges', 'first', 'last'}, default='centers' * In cases 'c' and 'e', the volume shape is multiplied by the zoom factor (and eventually truncated), and two anchor points are used to determine the voxel size. * In cases 'f' and 'l', a single anchor point is used so that the voxel size is exactly divided by the zoom factor. This case with an integer factor corresponds to subslicing the volume (e.g., `vol[::f, ::f, ::f]`). * A list of anchors (one per dimension) can also be provided. **kwargs Parameters of `grid_pull`. Returns ------- resized : (batch, ..., ndim) tensor Resized grid. affine : (batch, ndim[+1], ndim+1) tensor, optional Orientation matrix """ # resize grid kwargs['_return_trf'] = True grid = utils.last2channel(grid) outputs = resize(grid, factor, shape, affine, *args, **kwargs) if affine is not None: grid, affine, (scales, shifts) = outputs else: grid, (scales, shifts) = outputs grid = utils.channel2last(grid) # rescale each component # scales and shifts map resized coordinates to original coordinates: # original = scale * resized + shift # here we want to transform original coordinates into resized ones: # resized = (original - shift) / scale grids = [] for d, (scl, shft) in enumerate(zip(scales, shifts)): grid1 = utils.slice_tensor(grid, d, dim=-1) if type[0].lower() == 'g': grid1 = grid1 - shft grid1 = grid1 / scl grids.append(grid1) grid = torch.stack(grids, -1) # return if affine is not None: return grid, affine else: return grid
def compose(*args, interpolation='linear', bound='dft'): """Compose multiple spatial deformations (affine matrices or flow fields). """ # TODO: # . add shape/dim argument to generate (if needed) an identity field # at the end of the chain. # . possibility to provide fields that have an orientation matrix? # (or keep it the responsibility of the user?) # . For higher order (> 1) interpolation: convert to spline coeficients. def ismatrix(x): """Check that a tensor is a matrix (ndim == 2).""" x = torch.as_tensor(x) shape = torch.as_tensor(x.shape) return shape.numel() == 2 # Pre-pass: check dimensionality dim = None last_affine = False at_least_one_field = False for arg in args: if ismatrix(arg): last_affine = True dim1 = arg.shape[1] else: last_affine = False at_least_one_field = True dim1 = arg.dim() - 2 if dim is not None and dim != dim1: raise ValueError("All deformations should have the same " "dimensionality (2D/3D).") elif dim is None: dim = dim1 if at_least_one_field and last_affine: raise ValueError("The last deformation cannot be an affine matrix. " "Use affine_field to transform it first.") # First pass: compose all sequential affine matrices args1 = [] last_affine = None for arg in args: if ismatrix(arg): if last_affine is None: last_affine = _make_square(arg) else: last_affine = last_affine.matmul(_make_square(arg)) else: if last_affine is not None: args1.append(last_affine) last_affine = None args1.append(arg) if not at_least_one_field: return last_affine # Second pass: perform all possible "field x matrix" compositions args2 = [] last_affine = None for arg in args1: if ismatrix(arg): last_affine = arg else: if last_affine is not None: new_field = arg.matmul( last_affine[:dim, :dim].transpose(0, 1)) \ + last_affine[:dim, dim].reshape((1,)*(dim+1) + (dim,)) args2.append(new_field) else: args2.append(arg) if last_affine is not None: args2.append(last_affine) # Third pass: compose all flow fields field = args2[-1] for arg in args2[-2::-1]: # args2[-2:0:-1] arg = arg - identity_grid(arg.shape[1:-1], arg.dtype, arg.device) arg = utils.last2channel(arg) field = field + utils.channel2last( grid_pull(arg, field, interpolation, bound)) # /!\ (TODO) The very first field (the first one being interpolated) # potentially contains a multiplication with an affine matrix (i.e., # it might not be expressed in voxels). This affine transformation should # be removed prior to subtracting the identity, and added back at the end. # However, I don't know how to 'guess' this matrix. # # After further though, I think we can find the matrix that minimizes in # the least-square sense (F*M-I), where F is NbVox*D and contains the # deformation field, I is NbVox*D and contains the identity field # (expressed in voxels) and M is the inverse of the unknown matrix. # This problem has a closed form solution: (F'*F)\(F'*I). # For better stability, We could encode M in gl(D), the Lie # algebra of invertible matrices, and use gauss-newton to optimise # the problem. # # Below is a tentative implementatin of the linear version # > Needs F'F to be invertible and well-conditioned # # For the last field, we factor out a possible affine transformation # arg = args2[0] # shape = arg.shape # N = shape[0] # Batch size # D = shape[-1] # Dimension # V = torch.as_tensor(shape[1:-1]).prod() # Nb of voxels # Id = identity(arg.shape[-2:0:-1], arg.dtype, arg.device).reshape(V, D) # arg = arg.reshape(N, V, D) # Field as a matrix # one = torch.ones((N, V, 1), dtype=arg.dtype, device=arg.device) # arg = cat((arg, one), 2) # Id = cat((Id, one)) # AA = arg.transpose(1, 2).bmm(arg) # LHS of linear system # AI = arg.transpose(1, 2).bmm(arg) # RHS of linear system # M, _ = torch.solve(AI, AA) # Solution # arg = arg.bmm(M) - Id # Closest displacement # arg = arg[..., :-1].reshape(shape) # arg = utils.last2channel(arg) # field = grid_pull(arg, field, interpolation, bound) # Interpolate # field = field + channel2grid(grid_pull(arg, field, interpolation, bound)) # shape = field.shape # V = torch.as_tensor(shape[1:-1]).prod() # field = field.reshape(N, V, D) # one = torch.ones((N, V, 1), dtype=field.dtype, device=field.device) # field, _ = torch.solve(field.transpose(1, 2), M.transpose(1, 2)) # field = field.transpose(1, 2)[..., :-1].reshape(shape) return field