def depth_to_rgb(image, colormap=None): """Convert soft probabilities to an RGB image. Parameters ---------- image : (*batch, D, H, W) A (batch of) 3D image, with depth along the 'D' dimension. colormap : (D, 3) tensor or str, optional A colormap or the name of a matplotlib colormap. Returns ------- image : (*batch, H, W, 3) A (batch of) RGB image. """ *batch, depth, height, width = image.shape colormap = _get_colormap_depth(colormap, depth, image.dtype, image.device) image = utils.movedim(image, -3, -1) cimage = linalg.dot(image.unsqueeze(-2), colormap.T) cimage /= image.sum(-1, keepdim=True) cimage *= image.max(-1, keepdim=True).values return cimage.clamp_(0, 1)
def forward(self, q, k, v, **overload): """ Parameters ---------- q : (b, c, *spatial) Queries k : (b, c, *spatial) Keys v : (b, c, *spatial) Values Returns ------- x : (b, c, *spatial) """ kernel_size = overload.pop('kernel_size', self.kernel_size) stride = overload.pop('stride', self.kernel_size) padding = overload.pop('padding', self.padding) padding_mode = overload.pop('padding_mode', self.padding_mode) dim = q.dim() - 2 if padding == 'auto': k = spatial.pad_same(dim, k, kernel_size, bound=padding_mode) v = spatial.pad_same(dim, v, kernel_size, bound=padding_mode) elif padding: padding = [0] * 2 + py.make_list(padding, dim) k = utils.pad(k, padding, side='both', mode=padding_mode) v = utils.pad(v, padding, side='both', mode=padding_mode) # compute weights by query/key dot product kernel_size = py.make_list(kernel_size, dim) k = utils.unfold(k, kernel_size, stride) k = k.reshape([*k.shape[:dim + 2], -1]) k = utils.movedim(k, 1, -1) q = utils.movedim(q[..., None], 1, -1) k = math.softmax(linalg.dot(k, q), dim=-1) k = k[:, None] # add back channel dimension # compute new values by weight/value dot product v = utils.unfold(v, kernel_size, stride) v = v.reshape([*v.shape[:dim + 2], -1]) v = linalg.dot(k, v) return v
def sample_prior_soft(prior, fixed, out=None): """ out = \sum_j prior[j] * fixed[j] prior : (*B, J, K) fixed : (*B, J, N) out : (*B, K, N) """ prior = prior.transpose(-1, -2).unsqueeze(-2) # [*B, K, 1, J] fixed = fixed.transpose(-1, -2).unsqueeze(-3) # [*B, 1, N, J] out = linalg.dot(prior, fixed, out=out) return out
def scatter_prior_soft(prior, fixed, z): """ prior : (*B, J, K) fixed : (*B, J, N) z : (*B, K, N) """ z = z.unsqueeze(-3) # [*B, 1, K, N] fixed = fixed.unsqueeze(-2) # [*B, J, 1, N] prior = linalg.dot(z, fixed, out=prior) return prior
def min_dist(x, s, max_iter=2**16, tol=1e-6, steps=100): """Compute the minimum distance from a (set of) point(s) to a curve. Parameters ---------- x : (..., dim) tensor Coordinates s : BSplineCurve Parameterized curve Returns ------- t : (...) tensor Coordinate of the closest point d : (...) tensor Minimum distance between each point and the curve """ # initialize using a discrete search all_t = torch.linspace(0, 1, steps, **utils.backend(x)) t = x.new_zeros(x.shape[:-1]) d = x.new_empty(x.shape[:-1]).fill_(float('inf')) for t1 in all_t: x1 = s.eval_position(t1) d1 = x1 - x d1 = d1.square_().sum(-1).sqrt_() t = torch.where(d1 < d, t1, t) d = torch.min(d, d1) # Fine tune using Gauss-Newton optimization nll = d.square_().sum() # d = s.eval_position(t).sub_(x) for n_iter in range(max_iter): # compute the distance between x and s(t) + gradients d, g = s.eval_grad_position(t) d.sub_(x) g = linalg.dot(g, d) h = linalg.dot(g, g) h.add_(1e-3) g.div_(h) # Perform GN step (with line search) # TODO: I could get rid of the line search armijo = 1 t0 = t.clone() nll0 = nll success = False for n_ls in range(12): t = torch.sub(t0, g, alpha=armijo, out=t) t.clamp_(0, 1) d = s.eval_position(t).sub_(x) nll = d.square().sum(dtype=torch.double) if nll < nll0: success = True break armijo /= 2 if not success: t = t0 break # print(n_iter, nll.item(), (nll0 - nll)/t.numel()) if (nll0 - nll) < tol * t.numel(): break d = s.eval_position(t).sub_(x) d = d.square_().sum(-1).sqrt_() return t, d
def emmi_hard(moving, fixed, dim=None, prior=None, fwhm=None, max_iter=32, weights=None, grad=True, hess=True, return_prior=False): # ------------------------------------------------------------------ # PREPARATION # ------------------------------------------------------------------ tiny = 1e-16 dim = dim or (moving.dim() - 1) moving, fixed, weights, prior, shape = emmi_prepare( moving, fixed, weights, prior, dim) *batch, J, K = prior.shape Nb = len(batch) N = moving.shape[-1] Nm = weights.sum() # ------------------------------------------------------------------ # EM LOOP # ------------------------------------------------------------------ ll = -float('inf') z = moving.new_empty([*batch, K, N]) prior0 = torch.empty_like(prior) for n_iter in range(max_iter): ll_prev = ll # -------------------------------------------------------------- # E-step # ------ # estimate responsibilities of each moving cluster for each # fixed voxel using Bayes' rule: # p(z[n] == k | x[n] == j[n]) ∝ p(x[n] == j[n] | z[n] == k) p(z[n] == k) # # . j[n] is the discretized fixed image # . p(z[n] == k) is the moving template # . p(x[n] == j[n] | z[n] == k) is the conditional prior evaluated at (j[n], k) # -------------------------------------------------------------- z = sample_prior(prior, fixed, z) z *= moving # -------------------------------------------------------------- # compute log-likelihood (log_sum of the posterior) # ll = Σ_n log p(x[n] == j[n]) # = Σ_n log Σ_k p(x[n] == j[n] | z[n] == k) p(z[n] == k) # = Σ_n log Σ_k p(z[n] == k | x[n] == j) + constant{\z} # -------------------------------------------------------------- ll = z.sum(-2, keepdim=True) ll = add_tiny_(ll, Nb) z /= ll ll = ll.log_().mul_(weights).sum([-1, -2], dtype=torch.double) z *= weights # -------------------------------------------------------------- # M-step # ------ # estimate joint prior by maximizing Q = E_{Z;H,mu}[ln p(X, Z; H)] # => H_jk = p(x == j, z == k) ∝ Σ_n p(z[n] == k | x[n] == j) 𝛿(x[n] == j) # -------------------------------------------------------------- prior0 = scatter_prior(prior0, fixed, z) prior.copy_(prior0).add_(tiny) # make it a joint distribution prior /= add_tiny_(prior.sum(dim=[-1, -2], keepdim=True), Nb) if fwhm: # smooth "prior" for the prior prior = prior.transpose(-1, -2) prior = spatial.smooth(prior, dim=1, basis=0, fwhm=fwhm, bound='replicate') prior = prior.transpose(-1, -2) # prior /= prior.sum(dim=[-1, -2], keepdim=True) # MI-like normalization prior /= add_tiny_( prior.sum(dim=-1, keepdim=True) * prior.sum(dim=-2, keepdim=True), Nb) if ll - ll_prev < 1e-5 * Nm: break # compute mutual information (times number of observations) # > prior contains p(x,y)/(p(x) p(y)) # > prior0 contains N * p(x,y) # >> 1/N \sum_{j,k} prior0[j,k] * log(prior[j,k]) # = \sum_{x,y} p(x,y) * (log p(x,y) - log p(x) - log p(y) # = \sum_{x,y} p(x,y) * log p(x,y) # - \sum_{xy} p(x,y) log p(x) # - \sum_{xy} p(x,y) log p(y) # = \sum_{x,y} p(x,y) * log p(x,y) # - \sum_{x} p(x) log p(x) # - \sum_{y} p(y) log p(y) # = -H[x,y] + H[x] + H[y] # = MI[x, y] ll = -(prior0 * add_tiny_(prior, Nb).log()).sum() / Nm out = [ll] # ------------------------------------------------------------------ # GRADIENTS # ------------------------------------------------------------------ # compute gradients # Keeping only terms that depend on y, the mutual information is H[y]-H[x,y] # The objective function is \sum_n E[y_n] # > ll = \sum_n log p(x[n] == j[n], h) # = \sum_n log \sum_k p(x[n] == j[n] | z[n] == k, h) p(z[n] == k) if grad or hess: g = sample_prior(prior, fixed) norm = linalg.dot(g.transpose(-1, -2), moving.transpose(-1, -2)) norm = add_tiny_(norm, Nb).unsqueeze(-2).reciprocal_() g *= norm if hess: h = sym_outer(g, -2) if grad: g *= weights g /= -Nm # g.neg_() g = g.reshape([*g.shape[:-1], *shape]) out.append(g) if hess: h *= weights h /= Nm h = h.reshape([*h.shape[:-1], *shape]) out.append(h) if return_prior: out.append(prior) return out[0] if len(out) == 1 else tuple(out)
def emmi_soft(moving, fixed, dim=None, prior=None, fwhm=None, max_iter=32, weights=None, grad=True, hess=True, return_prior=False): # ------------------------------------------------------------------ # PREPARATION # ------------------------------------------------------------------ tiny = 1e-16 dim = dim or (moving.dim() - 1) moving, fixed, weights, prior, shape = emmi_prepare( moving, fixed, weights, prior, dim) *batch, J, K = prior.shape Nb = len(batch) N = moving.shape[-1] Nm = weights.sum() # ------------------------------------------------------------------ # EM LOOP # ------------------------------------------------------------------ ll = -float('inf') z = moving.new_empty([*batch, K, N]) prior0 = torch.empty_like(prior) for n_iter in range(max_iter): ll_prev = ll # -------------------------------------------------------------- # E-step # -------------------------------------------------------------- z = sample_prior(prior.log(), fixed, z) z += moving.log() z, ll = math.softmax_lse(z, -2, lse=True, weights=weights) # -------------------------------------------------------------- # M-step # ------ # estimate joint prior by maximizing Q = E_{Z;H,mu}[ln p(X, Z; H)] # => H_jk = p(x == j, z == k) ∝ Σ_n p(z[n] == k | x[n] == j) 𝛿(x[n] == j) # -------------------------------------------------------------- z *= weights prior0 = scatter_prior(prior0, fixed, z) prior.copy_(prior0).add_(tiny) # make it a joint distribution prior /= add_tiny_(prior.sum(dim=[-1, -2], keepdim=True), Nb) if fwhm: # smooth "prior" for the prior prior = prior.transpose(-1, -2) prior = spatial.smooth(prior, dim=1, basis=0, fwhm=fwhm, bound='replicate') prior = prior.transpose(-1, -2) # MI-like normalization prior /= prior.sum(dim=-1, keepdim=True) * prior.sum(dim=-2, keepdim=True) if ll - ll_prev < 1e-5 * Nm: break # compute mutual information (times number of observations) # > prior contains p(x,y)/(p(x) p(y)) # > prior0 contains N * p(x,y) # >> 1/N Σ_{j,k} prior0[j,k] * log(prior[j,k]) # = Σ_{x,y} p(x,y) * (log p(x,y) - log p(x) - log p(y) # = Σ_{x,y} p(x,y) * log p(x,y) # - Σ_{xy} p(x,y) log p(x) # - Σ_{xy} p(x,y) log p(y) # = Σ_{x,y} p(x,y) * log p(x,y) # - Σ_{x} p(x) log p(x) # - Σ_{y} p(y) log p(y) # = -H[x,y] + H[x] + H[y] # = MI[x, y] ll = -(prior0 * prior.log()).sum() / Nm out = [ll] # ------------------------------------------------------------------ # GRADIENTS # ------------------------------------------------------------------ # compute gradients # Keeping only terms that depend on y, the mutual information is H[y]-H[x,y] # The objective function is \sum_n E[y_n] # > ll = Σ_n log p(x[n] == j[n], h) # = Σ_nj \sum_j q(x[n] == j) log \sum_k p(x[n] == j | z[n] == k, h) p(z[n] == k) if grad or hess: norm = linalg.dot( prior.transpose(-1, -2).unsqueeze(-1), moving.transpose(-1, -2).unsqueeze(-3)) norm = norm.add_(tiny).reciprocal_() g = sample_prior(prior, fixed * norm) if hess: norm = norm.square_().mul_(fixed).unsqueeze(-1) h = moving.new_zeros([*g.shape[:-2], K * (K + 1) // 2, N]) for j in range(J): h[..., :K, :] += prior[..., j, :K, None].square() * norm c = K for k in range(K): for kk in range(k + 1, K): h[..., c, :] += (prior[..., j, k, None] * prior[..., j, kk, None] * norm) c += 1 if grad: g *= weights g.neg_() g = g.reshape([*g.shape[:-1], *shape]) out.append(g) if hess: h *= weights h = h.reshape([*h.shape[:-1], *shape]) out.append(h) if return_prior: out.append(prior) return out[0] if len(out) == 1 else tuple(out)
def forward(self, predicted, reference, mask=None): """ Parameters ---------- predicted : (batch, nb_class[-1], *spatial) tensor Predicted classes. reference : (batch, nb_class[-1]|1, *spatial) tensor Reference classes (or their expectation). * If `reference` has a floating point data type (`half`, `float`, `double`) it is assumed to hold one-hot or soft labels, and its channel dimension should be `nb_class` or `nb_class - 1`. * If `reference` has an integer or boolean data type, it is assumed to hold hard labels and its channel dimension should be 1. Eventually, `one_hot_map` is used to map one-hot labels to hard labels. mask : (nb_batch, 1, *spatial) tensor, optional Loss mask Returns ------- loss : scalar or tensor The output shape depends on the type of reduction used. If 'mean' or 'sum', this function returns a scalar. """ logit = self.logit implicit = self.implicit weighted = self.weighted exclude_background = self.exclude_background predicted = torch.as_tensor(predicted) reference = torch.as_tensor(reference, device=predicted.device) backend = dict(dtype=predicted.dtype, device=predicted.device) dim = predicted.dim() - 2 # if only one predicted class -> must be implicit implicit = implicit or (predicted.shape[1] == 1) # take softmax if needed predicted = get_prob_explicit(predicted, logit=logit, implicit=implicit) nb_classes = predicted.shape[1] spatial_dims = list(range(2, predicted.dim())) # prepare weights if not torch.is_tensor(weighted) and not weighted: weighted = False if not isinstance(weighted, bool): weighted = utils.make_vector(weighted, nb_classes, **backend)[None] # preprocess reference if reference.dtype.is_floating_point: # one-hot labels reference = reference.to(predicted.dtype) implicit_ref = reference.shape[1] == nb_classes - 1 reference = get_prob_explicit(reference, implicit=implicit_ref) if reference.shape[1] != nb_classes: raise ValueError('Number of classes not consistent. ' 'Expected {} or {} but got {}.'.format( nb_classes, nb_classes - 1, reference.shape[1])) if exclude_background: predicted = predicted[:, 1:] reference = reference[:, 1:] if mask is not None: predicted = predicted * mask reference = reference * mask predicted = predicted.reshape([*predicted.shape[:-dim], -1]) reference = reference.reshape([*reference.shape[:-dim], -1]) inter = linalg.dot(predicted, reference) sumpred = predicted.sum(-1) sumref = reference.sum(-1) union = sumpred + sumref # inter = math.nansum(predicted * reference, dim=spatial_dims) # union = math.nansum(predicted + reference, dim=spatial_dims) loss = -2 * inter / union.clamp_min_(1e-5) del inter, union if weighted is not False: if weighted is True: # weights = math.nansum(reference, dim=spatial_dims) weights = sumref / sumref.sum(dim=1, keepdim=True) else: weights = weighted loss = loss * weights else: # hard labels loss = [] weights = [] first_index = 1 if exclude_background else 0 for index in range(first_index, nb_classes): pred1 = predicted[:, None, index, ...] ref1 = reference == index if mask is not None: pred1 = pred1 * mask ref1 = ref1 * mask inter = math.sum(pred1 * ref1, dim=spatial_dims) union = math.sum(pred1 + ref1, dim=spatial_dims) loss1 = -2 * inter / union.clamp_min_(1e-5) del inter, union if weighted is not False: if weighted is True: weight1 = ref1.sum() else: weight1 = float(weighted[index]) loss1 = loss1 * weight1 weights.append(weight1) loss.append(loss1) loss = torch.cat(loss, dim=1) if weighted is True: weights = sum(weights) loss = loss / weights loss += 1 return super().forward(loss)
def is_inside(points, vertices, faces=None): """Test if a point is inside a polygon/surface. The polygon or surface *must* be closed. Parameters ---------- points : (..., dim) tensor Coordinates of points to test vertices : (nv, dim) tensor Vertex coordinates faces : (nf, dim) tensor[int] Faces are encoded by the indices of its vertices. By default, assume that vertices are ordered and define a closed curve Returns ------- check : (...) tensor[bool] """ # This function uses a ray-tracing technique: # # A half-line is started in each point. If it crosses an even # number of faces, it is inside the shape. If it crosses an even # number of faces, it is not. # # In practice, we loop through faces (as we expect there are much # less vertices than voxels) and compute intersection points between # all lines and each face in a batched fashion. We only want to # send these rays in one direction, so we keep aside points whose # intersection have a positive coordinate along the ray. points = torch.as_tensor(points) vertices = torch.as_tensor(vertices) if faces is None: faces = [(i, i + 1) for i in range(len(vertices) - 1)] faces += [(len(vertices) - 1, 0)] faces = utils.as_tensor(faces, dtype=torch.long) points, vertices = utils.to_max_dtype(points, vertices) points, vertices, faces = utils.to_max_device(points, vertices, faces) backend = utils.backend(points) batch = points.shape[:-1] dim = points.shape[-1] eps = constants.eps(points.dtype) cross = points.new_zeros(batch, dtype=torch.long) ray = torch.randn(dim, **backend) for face in faces: face = vertices[face] # compute normal vector origin = face[0] if dim == 3: u = face[1] - face[0] v = face[2] - face[0] norm = torch.stack([ u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2], u[0] * v[1] - u[1] * v[0] ]) else: assert dim == 2 u = face[1] - face[0] norm = torch.stack([-u[1], u[0]]) # check co-linearity between face and ray colinear = linalg.dot(ray, norm).abs() / (ray.norm() * norm.norm()) < eps if colinear: continue # compute intersection between ray and plane # plane: <norm, x - origin> = 0 # line: x = p + t*u # => <norm, p + t*u - origin> = 0 intersection = linalg.dot(norm, points - origin) intersection /= linalg.dot(norm, ray) halfmask = intersection >= 0 # we only want to shoot in one direction intersection = intersection[halfmask] halfpoints = points[halfmask] intersection = intersection[..., None] * (-ray) intersection += halfpoints # check if the intersection is inside the face # first, we project it onto a frame of dimension `dim-1` # defined by (origin, (u, v)) intersection -= origin if dim == 3: interu = linalg.dot(intersection, u) interv = linalg.dot(intersection, v) intersection = (interu >= 0) & (interv > 0) & (interu + interv < 1) else: intersection = linalg.dot(intersection, u) intersection /= u.norm().square_() intersection = (intersection >= 0) & (intersection < 1) cross[halfmask] += intersection # check that the number of crossings is even cross = cross.bitwise_and_(1).bool() return cross
def derivatives_intensity(moving, fixed, prior, weights=None): """ Parameters ---------- moving : (B, K, N) tensor fixed : (B, J|1, N) tensor prior : (B, J, K) tensor weights : (B, 1, N) tensor, optional Returns ------- grad : (B, K, *spatial) tensor, if `grad` hess : (B, K, *spatial) tensor, if `hess` """ # ------------------------------------------------------------------ # PREPARATION # ------------------------------------------------------------------ moving, fixed, weights = spatial_prepare(moving, fixed, weights) B, K, J, spatial = spatial_shapes(moving, fixed, prior) N = spatial.numel() # Flatten moving = moving.reshape([*moving.shape[:2], -1]) fixed = fixed.reshape([*fixed.shape[:2], -1]) weights = weights.reshape([*weights.shape[:2], -1]) # compute gradients # Keeping only terms that depend on y, the mutual information is H[y]-H[x,y] # The objective function is \sum_n E[y_n] # > ll = \sum_n log p(x[n] == j[n] ; H, mu) # = \sum_n log \sum_k p(x[n] == j[n] | z[n] == k; H) p(z[n] == k; mu) g = moving.new_zeros([B, K, N]) K2 = K * (K + 1) // 2 h = moving.new_zeros([B, K2, N]) # ------------------------------------------------------------------ # VERSION 1: DISCRETE LABELS # ------------------------------------------------------------------ if not fixed.dtype.is_floating_point: sample_prior(prior, fixed, g) norm = linalg.dot(t(g), t(moving)).unsqueeze(1) norm = norm.add_(tiny).reciprocal_() g *= norm torch.mul(g[:, :K], g[:, :K], out=h[:, :K]) c = K for k in range(K): for kk in range(k + 1, K): torch.mul(g[:, k], g[:, kk], out=h[:, c]) c += 1 # ------------------------------------------------------------------ # VERSION 2: SOFT LABELS # ------------------------------------------------------------------ else: for j in range(J): norm = 0 tmp = torch.zeros_like(g) for k in range(K): prior1 = prior[:, j, k, None] norm += prior1 * moving[:, k, :] tmp[:, k, :] = prior1 tmp /= norm.add_(tiny) g += tmp * fixed[:, j, None, :] h[:, :K, :] += tmp.square() * fixed[:, j, None, :] c = K for k in range(K): for kk in range(k + 1, K): h[:, c, :] += tmp[:, k, :] * tmp[:, kk, :] \ * fixed[:, j, :] c += 1 g *= weights g.neg_() g = g.reshape([B, K, *spatial]) h *= weights h = h.reshape([B, K2, *spatial]) return g, h
def __call__(self, vel, grad=False, hess=False, in_line_search=False): vel = vel[0] phi, iphi, jac, ijac = _exp_1d(vel, model=self.model) pos, grad_pos = _deform_1d(self.pos, phi, grad=grad) del phi neg, grad_neg = _deform_1d(self.neg, iphi, grad=grad) del iphi if self.modulation: pos *= jac neg *= ijac if self.model == 'svf': grad_pos *= jac grad_neg *= ijac g = ig = h = ih = None state = self.loss.get_state() if grad and hess: ll, g, h = self.loss.loss_grad_hess(pos, neg, mask=self.mask) ill, ig, ih = self.loss.loss_grad_hess(neg, pos, mask=self.mask) elif grad: ll, g = self.loss.loss_grad(pos, neg, mask=self.mask) ill, ig = self.loss.loss_grad(neg, pos, mask=self.mask) else: ll = self.loss.loss(pos, neg, mask=self.mask) ill = self.loss.loss(neg, pos, mask=self.mask) if in_line_search: self.loss.set_state(state) ll += ill ll /= 2 if grad: if self.modulation: pos /= jac neg /= ijac # move channel channels to the end so that we can use `dot` g = utils.movedim(g, 0, -1) ig = utils.movedim(ig, 0, -1) pos = utils.movedim(pos, 0, -1) neg = utils.movedim(neg, 0, -1) grad_pos = utils.movedim(grad_pos, 0, -1) grad_neg = utils.movedim(grad_neg, 0, -1) g0, ig0 = g, ig g = linalg.dot(g0, grad_pos) if self.modulation: g = g.mul_(jac) g0 = linalg.dot(g0, pos) if self.model == 'svf': g0.mul_(jac) g += _div_1d(g0) ig = linalg.dot(ig0, grad_neg) if self.modulation: ig = ig.mul_(ijac) ig0 = linalg.dot(ig0, neg) if self.model == 'svf': ig0.mul_(ijac) ig += _div_1d(ig0) del g0, ig0 if hess: h = utils.movedim(h, 0, -1) ih = utils.movedim(ih, 0, -1) h0, ih0 = h, ih grad_pos.square_() grad_neg.square_() h = linalg.dot(grad_pos, h0) if self.modulation: jac.square_() h = h.mul_(jac) h0 = linalg.dot(pos, h0).square_() if self.model == 'svf': h0.mul_(jac) h += _div_1d(_div_1d(h0)) ih = linalg.dot(grad_neg, ih0) if self.modulation: ijac.square_() ih = ih.mul_(ijac) ih0 = linalg.dot(neg, ih0).square_() if self.model == 'svf': ih0.mul_(ijac).mul_(ijac) ih += _div_1d(_div_1d(ih0)) del h0, ih0 if self.model == 'svf': g, h = spatial.exp1d_backward(vel, g, h, bound=BND) ig, ih = spatial.exp1d_backward(-vel, ig, ih, bound=BND) g = g.sub_(ig).div_(2) g = g[None] if hess: h = h.add_(ih).div_(2) h = h[None] del ig, ih, grad_pos, grad_neg, jac, ijac vel = vel[None] # add regularization term vgrad = self.reg(vel) llv = 0.5 * vel.flatten().dot(vgrad.flatten()) if grad: g += vgrad del vgrad # print objective if self.verbose and (self.verbose > 1 or not in_line_search): ll_prev = self.ll if in_line_search: line = '(search) | ' else: line = '(topup) | ' line += f'{self.n_iter:03d} | {ll.item():12.6g} + {llv.item():12.6g} = {ll.item() + llv.item():12.6g}' if not in_line_search: self.ll = ll.item() + llv.item() self.n_iter += 1 gain = (ll_prev - self.ll) line += f' | {gain:12.6g}' print(line, end='\r') ll += llv out = [ll] if grad: out.append(g) if hess: out.append(h) return tuple(out) if len(out) > 1 else out[0]
def min_dist(x, s, max_iter=2**16, tol=1e-6, steps=100): """Compute the minimum distance from a (set of) point(s) to a curve. Parameters ---------- x : (..., dim) tensor Coordinates s : BSplineCurve Parameterized curve Returns ------- t : (...) tensor Coordinate of the closest point d : (...) tensor Minimum distance between each point and the curve """ # initialize using a discrete search all_t = torch.linspace(0, 1, steps, **utils.backend(x)) t = x.new_zeros(x.shape[:-1]) d = x.new_empty(x.shape[:-1]).fill_(float('inf')) for t1 in all_t: x1 = s.eval_position(t1) d1 = x1 - x d1 = d1.square_().sum(-1).sqrt_() t = torch.where(d1 < d, t1, t) d = torch.min(d, d1) # Fine tune using Gauss-Newton optimization nll = d.square_().sum(-1) # d = s.eval_position(t).sub_(x) # print(f'{0:03d} {nll.sum().item():12.6g}') for n_iter in range(1, max_iter+1): # compute the distance between x and s(t) + gradients d, g = s.eval_grad_position(t) d.sub_(x) g = linalg.dot(g, d) h = linalg.dot(g, g) h.add_(1e-3) g.div_(h) # Perform GN step (with line search) # TODO: I could get rid of the line search t0 = t.clone() nll0: torch.Tensor = nll armijo = torch.full_like(t, 1024) success = torch.zeros_like(t, dtype=torch.bool) for n_ls in range(12): # t = torch.sub(t0, g, alpha=armijo, out=t) t = torch.where(success, t, t0 - armijo * g) t.clamp_(0, 1) d = s.eval_position(t).sub_(x) nll = d.square().sum(-1) success = success.logical_or_(nll < nll0) if success.all(): break armijo = torch.where(success, armijo, armijo/2) t = torch.where(success, t, t0) if not success.any(): break # print(f'{n_iter:03d} ' # f'{nll.sum().item():12.6g} ' # f'{(nll0 - nll).sum().item()/t.numel():6.3g} ' # f'{armijo.min():6.3g} {armijo.max():6.3g}') if (nll0 - nll).sum() < tol * t.numel(): break d = s.eval_position(t).sub_(x) d = d.square_().sum(-1).sqrt_() return t, d