def transform_pointset_dense(points, grid, type='grid', bound='dct2'): """Transform a pointset Points must already be expressed in "grid voxels" coordinates. Parameters ---------- points : (n, dim) tensor Set of coordinates, in voxel space grid : (*spatial, dim) tensor Dense transformation or displacement grid, in voxel space type : {'grid', 'disp'}, defualt='grid' Transformation or displacement bound : str, default='dct2' Boundary conditions for out-of-bounds data Returns ------- points : (n, dim) tensor Transformed coordinates """ dim = grid.shape[-1] points = utils.unsqueeze(points, 0, dim) grid = utils.movedim(grid, -1, 0)[None] delta = spatial.grid_pull(grid, points, bound=bound, extrapolate=True) delta = utils.movedim(delta, 1, -1) if type == 'disp': points = points + delta else: points = delta points = utils.squeeze(points, -2, dim - 1).squeeze(0) return points
def _jhj(jac, hess): """J*H*J', where H is symmetric and stored sparse""" # Matlab symbolic toolbox # # 2D: # out[00] = h00*j00^2 + h11*j01^2 + 2*h01*j00*j01 # out[11] = h00*j10^2 + h11*j11^2 + 2*h01*j10*j11 # out[01] = h00*j00*j10 + h11*j01*j11 + h01*(j01*j10 + j00*j11) # # 3D: # out[00] = h00*j00^2 + 2*h01*j00*j01 + 2*h02*j00*j02 + h11*j01^2 + 2*h12*j01*j02 + h22*j02^2 # out[11] = h00*j10^2 + 2*h01*j10*j11 + 2*h02*j10*j12 + h11*j11^2 + 2*h12*j11*j12 + h22*j12^2 # out[22] = h00*j20^2 + 2*h01*j20*j21 + 2*h02*j20*j22 + h11*j21^2 + 2*h12*j21*j22 + h22*j22^2 # out[01] = j10*(h00*j00 + h01*j01 + h02*j02) + j11*(h01*j00 + h11*j01 + h12*j02) + j12*(h02*j00 + h12*j01 + h22*j02) # out[02] = j20*(h00*j00 + h01*j01 + h02*j02) + j21*(h01*j00 + h11*j01 + h12*j02) + j22*(h02*j00 + h12*j01 + h22*j02) # out[12] = j20*(h00*j10 + h01*j11 + h02*j12) + j21*(h01*j10 + h11*j11 + h12*j12) + j22*(h02*j10 + h12*j11 + h22*j12) dim = jac.shape[-1] hess = utils.movedim(hess, -1, 0) jac = utils.movedim(jac, [-2, -1], [0, 1]) if dim == 1: out = _jhj1(jac, hess) elif dim == 2: out = _jhj2(jac, hess) elif dim == 3: out = _jhj3(jac, hess) out = utils.movedim(out, 0, -1) return out
def bending_grid(grid, voxel_size=1, bound='dft', weights=None): """Precision matrix for the Bending energy of a deformation grid Parameters ---------- grid : (..., *spatial, dim) tensor voxel_size : float or sequence[float], default=1 bound : str, default='dft' weights : (..., *spatial) tensor, optional Returns ------- field : (..., *spatial, dim) tensor """ grid = torch.as_tensor(grid) backend = dict(dtype=grid.dtype, device=grid.device) dim = grid.shape[-1] voxel_size = core.utils.make_vector(voxel_size, dim, **backend) if (voxel_size != 1).any(): grid = grid * voxel_size grid = movedim(grid, -1, -(dim + 1)) grid = bending(grid, weights=weights, voxel_size=voxel_size, bound=bound, dim=dim) grid = movedim(grid, -(dim + 1), -1) return grid
def modulate_prior(M, G): if G is None: return M M = utils.movedim(M, 0, -1) M = M * G M /= M.sum(-1, keepdim=True) M = utils.movedim(M, -1, 0) return M
def _inv(A): A = utils.movedim(A, [-2, -1], [0, 1]) if len(A) == 3: A = _inv3(A) elif len(A) == 2: A = _inv2(A) else: raise NotImplementedError A = utils.movedim(A, [0, 1], [-2, -1]) return A
def _deform_1d(img, disp, grad=False): img = utils.movedim(img, 0, -2) disp = disp.unsqueeze(-1) disp = spatial.add_identity_grid(disp) wrp = spatial.grid_pull(img, disp, bound=BND, extrapolate=True) wrp = utils.movedim(wrp, -2, 0) if not grad: return wrp, None grd = spatial.grid_grad(img, disp, bound=BND, extrapolate=True) grd = utils.movedim(grd.squeeze(-1), -2, 0) return wrp, grd
def topup_apply(pos, neg, vel, dim=-1, model='smalldef', modulation=True): """Apply a topup correction Parameters ---------- pos : ([C], *spatial) tensor Images with positive readout polarity neg : ([C], *spatial) tensor Images with negative readout polarity vel : (*spatial) tensor 1D displacement or velocity field dim : int, default=-1 Readout dimension model : {'smalldef', 'svf'}, default='smalldef' Deformation model Returns ------- pos : ([C], *spatial) tensor Images with positive polarity, unwarped neg : ([C], *spatial) tensor Images with negative polarity, unwarped """ ndim = vel.dim() dim = (dim - ndim) if dim >= 0 else dim no_batch_pos = pos.dim() == ndim if no_batch_pos: pos = pos[None] pos = utils.movedim(pos, dim, -1) no_batch_neg = neg.dim() == ndim if no_batch_neg: neg = neg[None] neg = utils.movedim(neg, dim, -1) vel = utils.movedim(vel, dim, -1) phi, iphi, jac, ijac = _exp_1d(vel, model=model) pos, _ = _deform_1d(pos, phi) neg, _ = _deform_1d(neg, iphi) if modulation: pos *= jac neg *= ijac del phi, iphi, jac, ijac pos = utils.movedim(pos, -1, dim) neg = utils.movedim(neg, -1, dim) if no_batch_pos: pos = pos[0] if no_batch_neg: neg = neg[0] return pos, neg
def forward(self, source, target, source_affine=None, target_affine=None): """ Parameters ---------- source : (sX, sY, sZ) tensor or str target : (tX, tY, tZ) tensor or str source_affine : (4, 4) tensor, optional target_affine : (4, 4) tensor, optional Returns ------- warped : (tX, tY, tZ) tensor Source warped to target velocity : (vX, vY, vZ, 3) tensor Stationary velocity field affine : (4, 4) tensor, optional Affine of the velocity space """ if self.verbose: print('Preprocessing... ', end='', flush=True) source, source_affine, source_orig, source_affine_orig \ = self.load(source, source_affine) target, target_affine, target_orig, target_affine_orig \ = self.load(target, target_affine) source = spatial.reslice(source, source_affine, target_affine, target.shape) if self.verbose: print('done.', flush=True) print('Registering... ', end='', flush=True) source = source[None, None] target = target[None, None] warped, vel, grid = super().forward(source, target) if self.verbose: print('done.', flush=True) del source, target, warped vel = vel[0] grid = grid[0] grid -= spatial.identity_grid(grid.shape[:-1], dtype=grid.dtype, device=grid.device) right_affine = target_affine.inverse() @ target_affine_orig right_affine = spatial.affine_grid(right_affine, target_orig.shape) grid = spatial.grid_pull(utils.movedim(grid, -1, 0), right_affine, bound='nearest', extrapolate=True) grid = utils.movedim(grid, 0, -1).add_(right_affine) left_affine = source_affine_orig.inverse() @ target_affine grid = spatial.affine_matvec(left_affine, grid) warped = spatial.grid_pull(source_orig, grid) return warped, vel, target_affine
def grid_jacobian(grid, sample=None, bound='dft', voxel_size=1, type='grid', add_identity=True, extrapolate=True): """Compute the Jacobian of a transformation field Notes ----- .. If `add_identity` is True, we compute the Jacobian of the transformation field (identity + displacement), even if a displacement is provided, by adding ones to the diagonal. .. If `sample` is not used, this function uses central finite differences to estimate the Jacobian. .. If 'sample' is provided, `grid_grad` is used to sample derivatives. Parameters ---------- grid : (..., *spatial, dim) tensor Transformation or displacement field sample : (..., *spatial, dim) tensor, optional Coordinates to sample in the input grid. bound : str, default='dft' Boundary condition voxel_size : [sequence of] float, default=1 Voxel size type : {'grid', 'disp'}, default='grid' Whether the input is a transformation ('grid') or displacement ('disp') field. add_identity : bool, default=True Adds the identity to the Jacobian of the displacement, making it the jacobian of the transformation. extrapolate : bool, default=True Extrapolate out-of-boudn data (only useful is `sample` is used) Returns ------- jac : (..., *spatial, dim, dim) tensor Jacobian. In each matrix: jac[i, j] = d psi[i] / d xj """ grid = torch.as_tensor(grid) dim = grid.shape[-1] shape = grid.shape[-dim-1:-1] if type == 'grid': grid = grid - identity_grid(shape, **utils.backend(grid)) if sample is None: dims = list(range(-dim-1, -1)) jac = diff(grid, dim=dims, bound=bound, voxel_size=voxel_size, side='c') else: grid = utils.movedim(grid, -1, -dim-1) jac = grid_grad(grid, sample, bound=bound, extrapolate=extrapolate) jac = utils.movedim(jac, -dim-2, -2) if add_identity: torch.diagonal(jac, 0, -1, -2).add_(1) return jac
def smart_pull_grid(vel, grid, type='disp', *args, **kwargs): """Interpolate a velocity/grid/displacement field. Notes ----- Defaults differ from grid_pull: - bound -> dft - extrapolate -> True Parameters ---------- vel : ([batch], *spatial, ndim) tensor Velocity grid : ([batch], *spatial, ndim) tensor Transformation field kwargs : dict Options to ``grid_pull`` Returns ------- pulled_vel : ([batch], *spatial, ndim) tensor Velocity """ if grid is None or vel is None: return vel kwargs.setdefault('bound', 'dft') kwargs.setdefault('extrapolate', True) dim = vel.shape[-1] if type == 'grid': id = spatial.identity_grid(vel.shape[-dim - 1:-1], **utils.backend(vel)) vel = vel - id vel = utils.movedim(vel, -1, -dim - 1) vel_no_batch = vel.dim() == dim + 1 grid_no_batch = grid.dim() == dim + 1 if vel_no_batch: vel = vel[None] if grid_no_batch: grid = grid[None] vel = spatial.grid_pull(vel, grid, *args, **kwargs) vel = utils.movedim(vel, -dim - 1, -1) if vel_no_batch: vel = vel[0] if type == 'grid': id = spatial.identity_grid(vel.shape[-dim - 1:-1], **utils.backend(vel)) vel += id return vel
def load(files, is_label=False): """Load one multi-channel multi-file volume. Returns a (channels, *spatial) tensor """ dats = [] for file in files: if is_label: dat = io.volumes.load(file.fname, dtype=torch.int32, device=device) else: dat = io.volumes.loadf(file.fname, rand=True, dtype=torch.float32, device=device) dat = dat.reshape([*file.shape, file.channels]) dat = dat[..., file.subchannels] dat = utils.movedim(dat, -1, 0) dim = dat.dim() - 1 qt = utils.quantile(dat, (0.01, 0.95), dim=range(-dim, 0), keepdim=True) mn, mx = qt.unbind(-1) dat = dat.sub_(mn).div_(mx-mn) dats.append(dat) del dat dats = torch.cat(dats, dim=0) if is_label and len(dats) > 1: warn('Multi-channel label images are not accepted. ' 'Using only the first channel') dats = dats[:1] return dats
def _set_weights(module, conv_keys, f, prefix='unet'): # print(prefix) if isinstance(module, Conv): if conv_keys: key = conv_keys.pop(0) else: # we might have reached the final "feat 2 class" conv key = 'vxm_dense_flow' try: kernel = torch.as_tensor(f[key][key]['kernel:0'], **utils.backend(module.weight)) except: kernel = torch.as_tensor(f[key][key + '_1']['kernel:0'], **utils.backend(module.weight)) kernel = utils.movedim(kernel, [-1, -2], [0, 1]) module.weight.copy_(kernel) try: bias = torch.as_tensor(f[key][key]['bias:0'], **utils.backend(module.bias)) except: bias = torch.as_tensor(f[key][key + '_1']['bias:0'], **utils.backend(module.bias)) module.bias.copy_(bias) else: for name, child in module.named_children(): _set_weights(child, conv_keys, f, f'{prefix}.{name}')
def movedim(x, source, target): if isinstance(x, np.ndarray): return np.moveaxis(x, source, target) elif torch.is_tensor(x): return utils.movedim(x, source, target) else: # MappedArray? return x.movedim(source, target)
def load(self, fname, dtype=None, device=None): """Load a volume from disk Parameters ---------- fname : str dtype : torch.dtype, optional Returns ------- dat : (channels, *spatial) tensor """ dtype = dtype or self.dtype device = device or self.device if not dtype or dtype.is_floating_point: dat = io.loadf(fname, dtype=dtype, device=device) dat = self.rescale(dat) else: dat = io.load(fname, dtype=dtype, device=device) dat = dat.squeeze() dim = self.dim or dat.dim() dat = utils.unsqueeze(dat, -1, max(0, dim - dat.dim())) dat = dat.reshape([*dat.shape[:dim], -1]) dat = utils.movedim(dat, -1, 0) dat = self.to_shape(dat) return dat
def forward(self, x, v=None): dim = x.dim() - 2 if dim not in (2, 3): raise ValueError(f'{type(self).__name__} only implemented ' f'in 2D or 3D.') radii = self.radii.to(**utils.backend(x)) pradii = self.pradii.to(**utils.backend(x)).log() # compute log-likelihood (vessel | radius, x) loss = x.new_zeros([len(radii), *x.shape]) for i, (p, r) in enumerate(zip(pradii, radii)): # compute unsorted eigenvalues e = spatial.hessian_eig(x, r, dim=dim, sort=None) e = utils.movedim(e, -1, 0) if dim == 3: loss[i] = self.vesselness3d(e[0], e[1], e[2]) # compute (stable) log-sum-exp (== model evidence) loss = math.logsumexp(loss, dim=0) # weight by probability to be a vessel and return if v is None: v = x return -(loss * v).sum() / (v.sum() + 1e-3)
def forward(self, image, **overload): backend = utils.backend(image) sigma = overload.get('sigma', self.sigma) gfactor = overload.get('gfactor', self.gfactor) # sample sigma if sigma is None: sigma = self.default_sigma(*image.shape[:2], **backend) if callable(sigma): sigma = sigma(image.shape[:2]) sigma = torch.as_tensor(sigma, **backend) sigma = unsqueeze(sigma, -1, 2 - sigma.dim()) # sample gfactor if gfactor is True: gfactor = field.RandomMultiplicativeField() if callable(gfactor): gfactor = gfactor(image.shape) # sample noise zero = torch.tensor(0, **backend) noise = td.Normal(zero, sigma).sample(image.shape[2:]) noise = utils.movedim(noise, [-1, -2], [0, 1]) if torch.is_tensor(gfactor): noise *= gfactor image = image + noise return image
def depth_to_rgb(image, colormap=None): """Convert soft probabilities to an RGB image. Parameters ---------- image : (*batch, D, H, W) A (batch of) 3D image, with depth along the 'D' dimension. colormap : (D, 3) tensor or str, optional A colormap or the name of a matplotlib colormap. Returns ------- image : (*batch, H, W, 3) A (batch of) RGB image. """ *batch, depth, height, width = image.shape colormap = _get_colormap_depth(colormap, depth, image.dtype, image.device) image = utils.movedim(image, -3, -1) cimage = linalg.dot(image.unsqueeze(-2), colormap.T) cimage /= image.sum(-1, keepdim=True) cimage *= image.max(-1, keepdim=True).values return cimage.clamp_(0, 1)
def forward(self, x, v=None): dim = x.dim() - 2 if dim not in (2, 3): raise ValueError(f'{type(self).__name__} only implemented ' f'in 2D or 3D.') radii = self.radii.to(**utils.backend(x)) pradii = self.pradii.to(**utils.backend(x)).log() # compute joint log-likelihood `ln p(x, radius | v)` loss = x.new_zeros([len(radii), *x.shape]) for i, (p, r) in enumerate(zip(pradii, radii)): # compute unsorted eigenvalues e = spatial.hessian_eig(x, r, dim=dim, sort=None) # soft sort P = math.softsort(e.abs(), tau=self.tau_sort, descending=True) e = linalg.matvec(P, e) e = utils.movedim(e, -1, 0) # compute penalties loss[i] = -self.tau_large * e[1:].sum(0) # white ridges e = e.square().clamp_min_(1e-32).log() if dim == 3: loss[i] += self.tau_ratio1 * (e[1] - e[2]) # tubes loss[i] += self.tau_ratio0 * (e[1] - e[0]) # not plates loss[i] += p # radius prior # compute (stable) log-sum-exp (== model evidence `ln p(x | v)`) loss = math.logsumexp(loss, dim=0) # weight by probability to be a vessel and return `E_v[ln p(x | v)]` if v is None: v = x return -(loss * v).sum() / (v.sum() + 1e-3)
def forward(self, q, k, v, **overload): """ Parameters ---------- q : (b, c, *spatial) Queries k : (b, c, *spatial) Keys v : (b, c, *spatial) Values Returns ------- x : (b, c, *spatial) """ kernel_size = overload.pop('kernel_size', self.kernel_size) stride = overload.pop('stride', self.kernel_size) padding = overload.pop('padding', self.padding) padding_mode = overload.pop('padding_mode', self.padding_mode) dim = q.dim() - 2 if padding == 'auto': k = spatial.pad_same(dim, k, kernel_size, bound=padding_mode) v = spatial.pad_same(dim, v, kernel_size, bound=padding_mode) elif padding: padding = [0] * 2 + py.make_list(padding, dim) k = utils.pad(k, padding, side='both', mode=padding_mode) v = utils.pad(v, padding, side='both', mode=padding_mode) # compute weights by query/key dot product kernel_size = py.make_list(kernel_size, dim) k = utils.unfold(k, kernel_size, stride) k = k.reshape([*k.shape[:dim + 2], -1]) k = utils.movedim(k, 1, -1) q = utils.movedim(q[..., None], 1, -1) k = math.softmax(linalg.dot(k, q), dim=-1) k = k[:, None] # add back channel dimension # compute new values by weight/value dot product v = utils.unfold(v, kernel_size, stride) v = v.reshape([*v.shape[:dim + 2], -1]) v = linalg.dot(k, v) return v
def forward(self, x, **overload): """ Parameters ---------- x : (batch, in_channels, *spatial_in) Returns ------- y : (batch, out_channels, *spatial_out) """ x = utils.movedim(self.linear(utils.movedim(x, 1, -1)), -1, 1) q, k, v = torch.chunk(x, 3, dim=1) x = self.dot(q, k, v, **overload) return x
def _warp_image1(image, target, shape=None, affine=None, nonlin=None, backward=False, reslice=False): """Returns the warped image, with channel dimension last""" # build transform aff_right = target aff_left = spatial.affine_inv(image.affine) aff = None if affine: # exp = affine.iexp if backward else affine.exp exp = affine.exp aff = exp(recompute=False, cache_result=True) if backward: aff = spatial.affine_inv(aff) if nonlin: if affine: if affine.position[0].lower() in ('ms' if backward else 'fs'): aff_right = spatial.affine_matmul(aff, aff_right) if affine.position[0].lower() in ('fs' if backward else 'ms'): aff_left = spatial.affine_matmul(aff_left, aff) exp = nonlin.iexp if backward else nonlin.exp phi = exp(recompute=False, cache_result=True) aff_left = spatial.affine_matmul(aff_left, nonlin.affine) aff_right = spatial.affine_lmdiv(nonlin.affine, aff_right) if _almost_identity(aff_right) and nonlin.shape == shape: phi = nonlin.add_identity(phi) else: tmp = spatial.affine_grid(aff_right, shape) phi = regutils.smart_pull_grid(phi, tmp).add_(tmp) del tmp if not _almost_identity(aff_left): phi = spatial.affine_matvec(aff_left, phi) else: # no nonlin: single affine even if position == 'symmetric' if reslice: aff = spatial.affine_matmul(aff, aff_right) aff = spatial.affine_matmul(aff_left, aff) phi = spatial.affine_grid(aff, shape) else: phi = None # warp image if phi is not None: warped = image.pull(phi) else: warped = image.dat # write to disk if len(warped) == 1: warped = warped[0] else: warped = utils.movedim(warped, 0, -1) return warped
def smart_pull_jac(jac, grid, *args, **kwargs): """Interpolate a jacobian field. Notes ----- Defaults differ from grid_pull: - bound -> dft - extrapolate -> True Parameters ---------- jac : ([batch], *spatial_in, ndim, ndim) tensor Jacobian field grid : ([batch], *spatial_out, ndim) tensor Transformation field kwargs : dict Options to ``grid_pull`` Returns ------- pulled_jac : ([batch], *spatial_out, ndim) tensor Jacobian field """ if grid is None or jac is None: return jac kwargs.setdefault('bound', 'dft') kwargs.setdefault('extrapolate', True) dim = jac.shape[-1] jac = jac.reshape([*jac.shape[:-2], dim * dim]) # collapse matrix jac = utils.movedim(jac, -1, -dim - 1) jac_no_batch = jac.dim() == dim + 1 grid_no_batch = grid.dim() == dim + 1 if jac_no_batch: jac = jac[None] if grid_no_batch: grid = grid[None] jac = spatial.grid_pull(jac, grid, *args, **kwargs) jac = utils.movedim(jac, -dim - 1, -1) jac = jac.reshape([*jac.shape[:-1], dim, dim]) if jac_no_batch: jac = jac[0] return jac
def _pull_jac(jac, grid, **kwargs): """Interpolate a Jacobian field. Notes ----- Defaults differ from grid_pull: - bound -> dft - extrapolate -> True Parameters ---------- jac : ([batch], *spatial, ndim, ndim) tensor Jacobian matrix grid : ([batch], *spatial, ndim) tensor Transformation field kwargs : dict Options to ``grid_pull`` Returns ------- pulled_jac : ([batch], *spatial, ndim, ndim) tensor Velocity """ kwargs.setdefault('bound', 'dft') kwargs.setdefault('extrapolate', True) dim = grid.shape[-1] jac = jac.reshape([*jac.shape[:-2], -1]) jac = utils.movedim(jac, -1, -dim - 1) jac_no_batch = jac.dim() == dim + 1 grid_no_batch = grid.dim() == dim + 1 if jac_no_batch: jac = jac[None] if grid_no_batch: grid = grid[None] jac = grid_pull(jac, grid, **kwargs) jac = utils.movedim(jac, -dim - 1, -1) jac = jac.reshape([*jac.shape[:-1], dim, dim]) if jac_no_batch and grid_no_batch: jac = jac[0] return jac
def jg(jac, grad, dim=None): """Jacobian-gradient product: J*g Parameters ---------- jac : (..., K, *spatial, D) grad : (..., K, *spatial) Returns ------- new_grad : (..., *spatial, D) """ if grad is None: return None dim = dim or (grad.dim() - 1) grad = utils.movedim(grad, -dim - 1, -1) jac = utils.movedim(jac, -dim - 2, -1) grad = linalg.matvec(jac, grad) return grad
def do_apply(fnames, phi, jac): """Correct files with a given polarity""" for fname in fnames: dir, base, ext = py.fileparts(fname) ofname = options.output ofname = ofname.format(dir=dir or '.', sep=os.sep, base=base, ext=ext) if options.verbose: print(f'unwarp {fname} \n' f' -> {ofname}') f = io.map(fname) d = f.fdata(device=device) d = utils.movedim(d, readout, -1) d = _deform1d(d, phi) if jac is not None: d *= jac d = utils.movedim(d, -1, readout) io.savef(d, ofname, like=fname)
def smart_push_grid(vel, grid, *args, **kwargs): """Push a velocity/grid/displacement field. Notes ----- Defaults differ from grid_push: - bound -> dft - extrapolate -> True Parameters ---------- vel : ([batch], *spatial, ndim) tensor Velocity grid : ([batch], *spatial, ndim) tensor Transformation field kwargs : dict Options to ``grid_pull`` Returns ------- pulled_vel : ([batch], *spatial, ndim) tensor Velocity """ if grid is None or vel is None: return vel kwargs.setdefault('bound', 'dft') kwargs.setdefault('extrapolate', True) dim = vel.shape[-1] vel = utils.movedim(vel, -1, -dim - 1) vel_no_batch = vel.dim() == dim + 1 grid_no_batch = grid.dim() == dim + 1 if vel_no_batch: vel = vel[None] if grid_no_batch: grid = grid[None] vel = spatial.grid_push(vel, grid, *args, **kwargs) vel = utils.movedim(vel, -dim - 1, -1) if vel_no_batch and grid_no_batch: vel = vel[0] return vel
def pull_grid(gridin, grid, interpolation=1, bound='dft', extrapolate=True): """Sample a displacement field. Parameters ---------- gridin : (*inshape, dim) tensor grid : (*outshape, dim) tensor Returns ------- gridout : (*outshape, dim) tensor """ gridin = movedim(gridin, -1, 0)[None] grid = grid[None] gridout = grid_pull(gridin, grid, interpolation=interpolation, bound=bound, extrapolate=extrapolate) gridout = movedim(gridout[0], 0, -1) return gridout
def rotation(self): i = self.i j = self.j k = self.k r = self.r matrix = [ [1 - 2 * (j**2 + k**2), 2 * (i * j - k * r), 2 * (i * k + j * r)], [2 * (i * j + k * r), 1 - 2 * (i**2 + k**2), 2 * (j * k - i * r)], [2 * (i * k - j * r), 2 * (j * k + i * r), 1 - 2 * (i**2 + j**2)] ] matrix = utils.as_tensor(matrix) matrix = utils.movedim(matrix, [0, 1], [-2, -1]) return matrix
def forward(self, x, **overload): """ Parameters ---------- x : (batch, in_channels, *spatial_in) Returns ------- y : (batch, out_channels, *spatial_out) """ out = None for i, head in enumerate(self.heads): y = head(x, **overload) if out is None: out_shape = list(y.shape) out_shape[1] *= len(self.heads) out = y.new_empty(out_shape) out[:, i * y.shape[1]:(i + 1) * y.shape[1]] = y del y out = utils.movedim(self.linear(utils.movedim(out, 1, -1)), -1, 1) return out
def intensity_to_rgb(image, min=None, max=None, colormap='gray', n=256, eq=False): """Colormap an intensity image Parameters ---------- image : (*batch, H, W) tensor A (batch of) 2d image min : tensor_like, optional Minimum value. Should be broadcastable to batch. Default: min of image for each batch element. max : tensor_like, optional Maximum value. Should be broadcastable to batch. Default: max of image for each batch element. colormap : str or (K, 3) tensor, default='gray' A colormap or the name of a matplotlib colormap. n : int, default=256 Number of color levels to use. eq : bool or {'linear', 'quadratic', 'log', None}, default=None Apply histogram equalization. If 'quadratic' or 'log', the histogram of the transformed signal is equalized. Returns ------- rgb : (*batch, H, W, 3) tensor A (batch of) of RGB image. """ image = torch.as_tensor(image).detach() image = intensity_preproc(image, min=min, max=max, eq=eq) # map colormap = _get_colormap_intensity(colormap, n, image.dtype, image.device) shape = image.shape image = image.mul_(n - 1).clamp_(0, n - 1) image = image.reshape([1, -1, 1]) colormap = colormap.T.reshape([1, 3, -1]) image = spatial.grid_pull(colormap, image) image = image.reshape([3, *shape]) image = utils.movedim(image, 0, -1) return image