Exemple #1
0
def depth_to_rgb(image, colormap=None):
    """Convert soft probabilities to an RGB image.

    Parameters
    ----------
    image : (*batch, D, H, W)
        A (batch of) 3D image, with depth along the 'D' dimension.
    colormap : (D, 3) tensor or str, optional
        A colormap or the name of a matplotlib colormap.

    Returns
    -------
    image : (*batch, H, W, 3)
        A (batch of) RGB image.

    """

    *batch, depth, height, width = image.shape
    colormap = _get_colormap_depth(colormap, depth, image.dtype, image.device)

    image = utils.movedim(image, -3, -1)
    cimage = linalg.dot(image.unsqueeze(-2), colormap.T)
    cimage /= image.sum(-1, keepdim=True)
    cimage *= image.max(-1, keepdim=True).values

    return cimage.clamp_(0, 1)
Exemple #2
0
    def forward(self, q, k, v, **overload):
        """

        Parameters
        ----------
        q : (b, c, *spatial)
            Queries
        k : (b, c, *spatial)
            Keys
        v : (b, c, *spatial)
            Values

        Returns
        -------
        x : (b, c, *spatial)

        """
        kernel_size = overload.pop('kernel_size', self.kernel_size)
        stride = overload.pop('stride', self.kernel_size)
        padding = overload.pop('padding', self.padding)
        padding_mode = overload.pop('padding_mode', self.padding_mode)

        dim = q.dim() - 2
        if padding == 'auto':
            k = spatial.pad_same(dim, k, kernel_size, bound=padding_mode)
            v = spatial.pad_same(dim, v, kernel_size, bound=padding_mode)
        elif padding:
            padding = [0] * 2 + py.make_list(padding, dim)
            k = utils.pad(k, padding, side='both', mode=padding_mode)
            v = utils.pad(v, padding, side='both', mode=padding_mode)

        # compute weights by query/key dot product
        kernel_size = py.make_list(kernel_size, dim)
        k = utils.unfold(k, kernel_size, stride)
        k = k.reshape([*k.shape[:dim + 2], -1])
        k = utils.movedim(k, 1, -1)
        q = utils.movedim(q[..., None], 1, -1)
        k = math.softmax(linalg.dot(k, q), dim=-1)
        k = k[:, None]  # add back channel dimension

        # compute new values by weight/value dot product
        v = utils.unfold(v, kernel_size, stride)
        v = v.reshape([*v.shape[:dim + 2], -1])
        v = linalg.dot(k, v)

        return v
Exemple #3
0
def sample_prior_soft(prior, fixed, out=None):
    """ out = \sum_j prior[j] * fixed[j]
    prior : (*B, J, K)
    fixed : (*B, J, N)
    out : (*B, K, N)
    """
    prior = prior.transpose(-1, -2).unsqueeze(-2)  # [*B, K, 1, J]
    fixed = fixed.transpose(-1, -2).unsqueeze(-3)  # [*B, 1, N, J]
    out = linalg.dot(prior, fixed, out=out)
    return out
Exemple #4
0
def scatter_prior_soft(prior, fixed, z):
    """
    prior : (*B, J, K)
    fixed : (*B, J, N)
    z : (*B, K, N)
    """
    z = z.unsqueeze(-3)  # [*B, 1, K, N]
    fixed = fixed.unsqueeze(-2)  # [*B, J, 1, N]
    prior = linalg.dot(z, fixed, out=prior)
    return prior
Exemple #5
0
def min_dist(x, s, max_iter=2**16, tol=1e-6, steps=100):
    """Compute the minimum distance from a (set of) point(s) to a curve.

    Parameters
    ----------
    x : (..., dim) tensor
        Coordinates
    s : BSplineCurve
        Parameterized curve

    Returns
    -------
    t : (...) tensor
        Coordinate of the closest point
    d : (...) tensor
        Minimum distance between each point and the curve

    """
    # initialize using a discrete search
    all_t = torch.linspace(0, 1, steps, **utils.backend(x))
    t = x.new_zeros(x.shape[:-1])
    d = x.new_empty(x.shape[:-1]).fill_(float('inf'))
    for t1 in all_t:
        x1 = s.eval_position(t1)
        d1 = x1 - x
        d1 = d1.square_().sum(-1).sqrt_()
        t = torch.where(d1 < d, t1, t)
        d = torch.min(d, d1)

    # Fine tune using Gauss-Newton optimization
    nll = d.square_().sum()
    # d = s.eval_position(t).sub_(x)
    for n_iter in range(max_iter):
        # compute the distance between x and s(t) + gradients
        d, g = s.eval_grad_position(t)
        d.sub_(x)
        g = linalg.dot(g, d)
        h = linalg.dot(g, g)
        h.add_(1e-3)
        g.div_(h)

        # Perform GN step (with line search)
        # TODO: I could get rid of the line search
        armijo = 1
        t0 = t.clone()
        nll0 = nll
        success = False
        for n_ls in range(12):
            t = torch.sub(t0, g, alpha=armijo, out=t)
            t.clamp_(0, 1)
            d = s.eval_position(t).sub_(x)
            nll = d.square().sum(dtype=torch.double)
            if nll < nll0:
                success = True
                break
            armijo /= 2
        if not success:
            t = t0
            break

        # print(n_iter, nll.item(), (nll0 - nll)/t.numel())
        if (nll0 - nll) < tol * t.numel():
            break

    d = s.eval_position(t).sub_(x)
    d = d.square_().sum(-1).sqrt_()
    return t, d
Exemple #6
0
def emmi_hard(moving,
              fixed,
              dim=None,
              prior=None,
              fwhm=None,
              max_iter=32,
              weights=None,
              grad=True,
              hess=True,
              return_prior=False):
    # ------------------------------------------------------------------
    #           PREPARATION
    # ------------------------------------------------------------------
    tiny = 1e-16
    dim = dim or (moving.dim() - 1)
    moving, fixed, weights, prior, shape = emmi_prepare(
        moving, fixed, weights, prior, dim)

    *batch, J, K = prior.shape
    Nb = len(batch)
    N = moving.shape[-1]
    Nm = weights.sum()

    # ------------------------------------------------------------------
    #           EM LOOP
    # ------------------------------------------------------------------
    ll = -float('inf')
    z = moving.new_empty([*batch, K, N])
    prior0 = torch.empty_like(prior)
    for n_iter in range(max_iter):
        ll_prev = ll
        # --------------------------------------------------------------
        # E-step
        # ------
        # estimate responsibilities of each moving cluster for each
        # fixed voxel using Bayes' rule:
        #   p(z[n] == k | x[n] == j[n]) ∝ p(x[n] == j[n] | z[n] == k) p(z[n] == k)
        #
        # . j[n] is the discretized fixed image
        # . p(z[n] == k) is the moving template
        # . p(x[n] == j[n] | z[n] == k) is the conditional prior evaluated at (j[n], k)
        # --------------------------------------------------------------
        z = sample_prior(prior, fixed, z)
        z *= moving

        # --------------------------------------------------------------
        # compute log-likelihood (log_sum of the posterior)
        # ll = Σ_n log p(x[n] == j[n])
        #    = Σ_n log Σ_k p(x[n] == j[n] | z[n] == k)  p(z[n] == k)
        #    = Σ_n log Σ_k p(z[n] == k | x[n] == j) + constant{\z}
        # --------------------------------------------------------------
        ll = z.sum(-2, keepdim=True)
        ll = add_tiny_(ll, Nb)
        z /= ll
        ll = ll.log_().mul_(weights).sum([-1, -2], dtype=torch.double)

        z *= weights

        # --------------------------------------------------------------
        # M-step
        # ------
        # estimate joint prior by maximizing Q = E_{Z;H,mu}[ln p(X, Z; H)]
        # => H_jk = p(x == j, z == k) ∝ Σ_n p(z[n] == k | x[n] == j) 𝛿(x[n] == j)
        # --------------------------------------------------------------
        prior0 = scatter_prior(prior0, fixed, z)
        prior.copy_(prior0).add_(tiny)
        # make it a joint distribution
        prior /= add_tiny_(prior.sum(dim=[-1, -2], keepdim=True), Nb)

        if fwhm:
            # smooth "prior" for the prior
            prior = prior.transpose(-1, -2)
            prior = spatial.smooth(prior,
                                   dim=1,
                                   basis=0,
                                   fwhm=fwhm,
                                   bound='replicate')
            prior = prior.transpose(-1, -2)

        # prior /= prior.sum(dim=[-1, -2], keepdim=True)
        # MI-like normalization
        prior /= add_tiny_(
            prior.sum(dim=-1, keepdim=True) * prior.sum(dim=-2, keepdim=True),
            Nb)
        if ll - ll_prev < 1e-5 * Nm:
            break

    # compute mutual information (times number of observations)
    # > prior contains p(x,y)/(p(x) p(y))
    # > prior0 contains N * p(x,y)
    # >> 1/N \sum_{j,k} prior0[j,k] * log(prior[j,k])
    #    = \sum_{x,y} p(x,y) * (log p(x,y) - log p(x) - log p(y)
    #    = \sum_{x,y} p(x,y) * log p(x,y)
    #       - \sum_{xy} p(x,y) log p(x)
    #       - \sum_{xy} p(x,y) log p(y)
    #    = \sum_{x,y} p(x,y) * log p(x,y)
    #       - \sum_{x} p(x) log p(x)
    #       - \sum_{y} p(y) log p(y)
    #    = -H[x,y] + H[x] + H[y]
    #    = MI[x, y]
    ll = -(prior0 * add_tiny_(prior, Nb).log()).sum() / Nm
    out = [ll]

    # ------------------------------------------------------------------
    #           GRADIENTS
    # ------------------------------------------------------------------
    # compute gradients
    # Keeping only terms that depend on y, the mutual information is H[y]-H[x,y]
    # The objective function is \sum_n E[y_n]
    # > ll = \sum_n log p(x[n] == j[n], h)
    #      = \sum_n log \sum_k p(x[n] == j[n] | z[n] == k, h) p(z[n] == k)
    if grad or hess:

        g = sample_prior(prior, fixed)
        norm = linalg.dot(g.transpose(-1, -2), moving.transpose(-1, -2))
        norm = add_tiny_(norm, Nb).unsqueeze(-2).reciprocal_()
        g *= norm
        if hess:
            h = sym_outer(g, -2)

        if grad:
            g *= weights
            g /= -Nm
            # g.neg_()
            g = g.reshape([*g.shape[:-1], *shape])
            out.append(g)
        if hess:
            h *= weights
            h /= Nm
            h = h.reshape([*h.shape[:-1], *shape])
            out.append(h)

    if return_prior:
        out.append(prior)

    return out[0] if len(out) == 1 else tuple(out)
Exemple #7
0
def emmi_soft(moving,
              fixed,
              dim=None,
              prior=None,
              fwhm=None,
              max_iter=32,
              weights=None,
              grad=True,
              hess=True,
              return_prior=False):
    # ------------------------------------------------------------------
    #           PREPARATION
    # ------------------------------------------------------------------
    tiny = 1e-16
    dim = dim or (moving.dim() - 1)
    moving, fixed, weights, prior, shape = emmi_prepare(
        moving, fixed, weights, prior, dim)

    *batch, J, K = prior.shape
    Nb = len(batch)
    N = moving.shape[-1]
    Nm = weights.sum()

    # ------------------------------------------------------------------
    #           EM LOOP
    # ------------------------------------------------------------------
    ll = -float('inf')
    z = moving.new_empty([*batch, K, N])
    prior0 = torch.empty_like(prior)
    for n_iter in range(max_iter):
        ll_prev = ll
        # --------------------------------------------------------------
        # E-step
        # --------------------------------------------------------------
        z = sample_prior(prior.log(), fixed, z)
        z += moving.log()
        z, ll = math.softmax_lse(z, -2, lse=True, weights=weights)

        # --------------------------------------------------------------
        # M-step
        # ------
        # estimate joint prior by maximizing Q = E_{Z;H,mu}[ln p(X, Z; H)]
        # => H_jk = p(x == j, z == k) ∝ Σ_n p(z[n] == k | x[n] == j) 𝛿(x[n] == j)
        # --------------------------------------------------------------
        z *= weights
        prior0 = scatter_prior(prior0, fixed, z)
        prior.copy_(prior0).add_(tiny)
        # make it a joint distribution
        prior /= add_tiny_(prior.sum(dim=[-1, -2], keepdim=True), Nb)
        if fwhm:
            # smooth "prior" for the prior
            prior = prior.transpose(-1, -2)
            prior = spatial.smooth(prior,
                                   dim=1,
                                   basis=0,
                                   fwhm=fwhm,
                                   bound='replicate')
            prior = prior.transpose(-1, -2)
        # MI-like normalization
        prior /= prior.sum(dim=-1, keepdim=True) * prior.sum(dim=-2,
                                                             keepdim=True)
        if ll - ll_prev < 1e-5 * Nm:
            break

    # compute mutual information (times number of observations)
    # > prior contains p(x,y)/(p(x) p(y))
    # > prior0 contains N * p(x,y)
    # >> 1/N Σ_{j,k} prior0[j,k] * log(prior[j,k])
    #    = Σ_{x,y} p(x,y) * (log p(x,y) - log p(x) - log p(y)
    #    = Σ_{x,y} p(x,y) * log p(x,y)
    #       - Σ_{xy} p(x,y) log p(x)
    #       - Σ_{xy} p(x,y) log p(y)
    #    = Σ_{x,y} p(x,y) * log p(x,y)
    #       - Σ_{x} p(x) log p(x)
    #       - Σ_{y} p(y) log p(y)
    #    = -H[x,y] + H[x] + H[y]
    #    = MI[x, y]
    ll = -(prior0 * prior.log()).sum() / Nm
    out = [ll]

    # ------------------------------------------------------------------
    #           GRADIENTS
    # ------------------------------------------------------------------
    # compute gradients
    # Keeping only terms that depend on y, the mutual information is H[y]-H[x,y]
    # The objective function is \sum_n E[y_n]
    # > ll = Σ_n log p(x[n] == j[n], h)
    #      = Σ_nj \sum_j q(x[n] == j) log \sum_k p(x[n] == j | z[n] == k, h) p(z[n] == k)
    if grad or hess:

        norm = linalg.dot(
            prior.transpose(-1, -2).unsqueeze(-1),
            moving.transpose(-1, -2).unsqueeze(-3))
        norm = norm.add_(tiny).reciprocal_()
        g = sample_prior(prior, fixed * norm)

        if hess:
            norm = norm.square_().mul_(fixed).unsqueeze(-1)
            h = moving.new_zeros([*g.shape[:-2], K * (K + 1) // 2, N])
            for j in range(J):
                h[..., :K, :] += prior[..., j, :K, None].square() * norm
                c = K
                for k in range(K):
                    for kk in range(k + 1, K):
                        h[..., c, :] += (prior[..., j, k, None] *
                                         prior[..., j, kk, None] * norm)
                        c += 1

        if grad:
            g *= weights
            g.neg_()
            g = g.reshape([*g.shape[:-1], *shape])
            out.append(g)
        if hess:
            h *= weights
            h = h.reshape([*h.shape[:-1], *shape])
            out.append(h)

    if return_prior:
        out.append(prior)

    return out[0] if len(out) == 1 else tuple(out)
Exemple #8
0
    def forward(self, predicted, reference, mask=None):
        """

        Parameters
        ----------
        predicted : (batch, nb_class[-1], *spatial) tensor
            Predicted classes.
        reference : (batch, nb_class[-1]|1, *spatial) tensor
            Reference classes (or their expectation).
                * If `reference` has a floating point data type (`half`,
                  `float`, `double`) it is assumed to hold one-hot or
                  soft labels, and its channel dimension should be
                  `nb_class` or `nb_class - 1`.
                * If `reference` has an integer or boolean data type, it is
                  assumed to hold hard labels and its channel dimension
                  should be 1. Eventually, `one_hot_map` is used to map
                  one-hot labels to hard labels.
        mask : (nb_batch, 1, *spatial) tensor, optional
            Loss mask

        Returns
        -------
        loss : scalar or tensor
            The output shape depends on the type of reduction used.
            If 'mean' or 'sum', this function returns a scalar.

        """
        logit = self.logit
        implicit = self.implicit
        weighted = self.weighted
        exclude_background = self.exclude_background

        predicted = torch.as_tensor(predicted)
        reference = torch.as_tensor(reference, device=predicted.device)
        backend = dict(dtype=predicted.dtype, device=predicted.device)
        dim = predicted.dim() - 2

        # if only one predicted class -> must be implicit
        implicit = implicit or (predicted.shape[1] == 1)

        # take softmax if needed
        predicted = get_prob_explicit(predicted,
                                      logit=logit,
                                      implicit=implicit)

        nb_classes = predicted.shape[1]
        spatial_dims = list(range(2, predicted.dim()))

        # prepare weights
        if not torch.is_tensor(weighted) and not weighted:
            weighted = False
        if not isinstance(weighted, bool):
            weighted = utils.make_vector(weighted, nb_classes, **backend)[None]

        # preprocess reference
        if reference.dtype.is_floating_point:
            # one-hot labels
            reference = reference.to(predicted.dtype)
            implicit_ref = reference.shape[1] == nb_classes - 1
            reference = get_prob_explicit(reference, implicit=implicit_ref)
            if reference.shape[1] != nb_classes:
                raise ValueError('Number of classes not consistent. '
                                 'Expected {} or {} but got {}.'.format(
                                     nb_classes, nb_classes - 1,
                                     reference.shape[1]))

            if exclude_background:
                predicted = predicted[:, 1:]
                reference = reference[:, 1:]
            if mask is not None:
                predicted = predicted * mask
                reference = reference * mask
            predicted = predicted.reshape([*predicted.shape[:-dim], -1])
            reference = reference.reshape([*reference.shape[:-dim], -1])
            inter = linalg.dot(predicted, reference)
            sumpred = predicted.sum(-1)
            sumref = reference.sum(-1)
            union = sumpred + sumref
            # inter = math.nansum(predicted * reference, dim=spatial_dims)
            # union = math.nansum(predicted + reference, dim=spatial_dims)
            loss = -2 * inter / union.clamp_min_(1e-5)
            del inter, union
            if weighted is not False:
                if weighted is True:
                    # weights = math.nansum(reference, dim=spatial_dims)
                    weights = sumref / sumref.sum(dim=1, keepdim=True)
                else:
                    weights = weighted
                loss = loss * weights

        else:
            # hard labels
            loss = []
            weights = []
            first_index = 1 if exclude_background else 0
            for index in range(first_index, nb_classes):
                pred1 = predicted[:, None, index, ...]
                ref1 = reference == index
                if mask is not None:
                    pred1 = pred1 * mask
                    ref1 = ref1 * mask

                inter = math.sum(pred1 * ref1, dim=spatial_dims)
                union = math.sum(pred1 + ref1, dim=spatial_dims)
                loss1 = -2 * inter / union.clamp_min_(1e-5)
                del inter, union
                if weighted is not False:
                    if weighted is True:
                        weight1 = ref1.sum()
                    else:
                        weight1 = float(weighted[index])
                    loss1 = loss1 * weight1
                    weights.append(weight1)
                loss.append(loss1)

            loss = torch.cat(loss, dim=1)
            if weighted is True:
                weights = sum(weights)
                loss = loss / weights

        loss += 1
        return super().forward(loss)
Exemple #9
0
def is_inside(points, vertices, faces=None):
    """Test if a point is inside a polygon/surface.

    The polygon or surface *must* be closed.

    Parameters
    ----------
    points : (..., dim) tensor
        Coordinates of points to test
    vertices : (nv, dim) tensor
        Vertex coordinates
    faces : (nf, dim) tensor[int]
        Faces are encoded by the indices of its vertices.
        By default, assume that vertices are ordered and define a closed curve

    Returns
    -------
    check : (...) tensor[bool]

    """
    # This function uses a ray-tracing technique:
    #
    #   A half-line is started in each point. If it crosses an even
    #   number of faces, it is inside the shape. If it crosses an even
    #   number of faces, it is not.
    #
    #   In practice, we loop through faces (as we expect there are much
    #   less vertices than voxels) and compute intersection points between
    #   all lines and each face in a batched fashion. We only want to
    #   send these rays in one direction, so we keep aside points whose
    #   intersection have a positive coordinate along the ray.

    points = torch.as_tensor(points)
    vertices = torch.as_tensor(vertices)
    if faces is None:
        faces = [(i, i + 1) for i in range(len(vertices) - 1)]
        faces += [(len(vertices) - 1, 0)]
        faces = utils.as_tensor(faces, dtype=torch.long)

    points, vertices = utils.to_max_dtype(points, vertices)
    points, vertices, faces = utils.to_max_device(points, vertices, faces)
    backend = utils.backend(points)
    batch = points.shape[:-1]
    dim = points.shape[-1]
    eps = constants.eps(points.dtype)
    cross = points.new_zeros(batch, dtype=torch.long)

    ray = torch.randn(dim, **backend)

    for face in faces:
        face = vertices[face]

        # compute normal vector
        origin = face[0]
        if dim == 3:
            u = face[1] - face[0]
            v = face[2] - face[0]
            norm = torch.stack([
                u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2],
                u[0] * v[1] - u[1] * v[0]
            ])
        else:
            assert dim == 2
            u = face[1] - face[0]
            norm = torch.stack([-u[1], u[0]])

        # check co-linearity between face and ray
        colinear = linalg.dot(ray,
                              norm).abs() / (ray.norm() * norm.norm()) < eps
        if colinear:
            continue

        # compute intersection between ray and plane
        #   plane: <norm, x - origin> = 0
        #   line: x = p + t*u
        #   => <norm, p + t*u - origin> = 0
        intersection = linalg.dot(norm, points - origin)
        intersection /= linalg.dot(norm, ray)
        halfmask = intersection >= 0  # we only want to shoot in one direction
        intersection = intersection[halfmask]
        halfpoints = points[halfmask]
        intersection = intersection[..., None] * (-ray)
        intersection += halfpoints

        # check if the intersection is inside the face
        #   first, we project it onto a frame of dimension `dim-1`
        #   defined by (origin, (u, v))
        intersection -= origin
        if dim == 3:
            interu = linalg.dot(intersection, u)
            interv = linalg.dot(intersection, v)
            intersection = (interu >= 0) & (interv > 0) & (interu + interv < 1)
        else:
            intersection = linalg.dot(intersection, u)
            intersection /= u.norm().square_()
            intersection = (intersection >= 0) & (intersection < 1)

        cross[halfmask] += intersection

    # check that the number of crossings is even
    cross = cross.bitwise_and_(1).bool()
    return cross
Exemple #10
0
def derivatives_intensity(moving, fixed, prior, weights=None):
    """

    Parameters
    ----------
    moving : (B, K, N) tensor
    fixed : (B, J|1, N) tensor
    prior : (B, J, K) tensor
    weights : (B, 1, N) tensor, optional

    Returns
    -------
    grad : (B, K, *spatial) tensor, if `grad`
    hess : (B, K, *spatial) tensor, if `hess`

    """
    # ------------------------------------------------------------------
    #       PREPARATION
    # ------------------------------------------------------------------
    moving, fixed, weights = spatial_prepare(moving, fixed, weights)
    B, K, J, spatial = spatial_shapes(moving, fixed, prior)
    N = spatial.numel()

    # Flatten
    moving = moving.reshape([*moving.shape[:2], -1])
    fixed = fixed.reshape([*fixed.shape[:2], -1])
    weights = weights.reshape([*weights.shape[:2], -1])

    # compute gradients
    # Keeping only terms that depend on y, the mutual information is H[y]-H[x,y]
    # The objective function is \sum_n E[y_n]
    # > ll = \sum_n log p(x[n] == j[n] ; H, mu)
    #      = \sum_n log \sum_k p(x[n] == j[n] | z[n] == k; H) p(z[n] == k; mu)
    g = moving.new_zeros([B, K, N])
    K2 = K * (K + 1) // 2
    h = moving.new_zeros([B, K2, N])

    # ------------------------------------------------------------------
    #       VERSION 1: DISCRETE LABELS
    # ------------------------------------------------------------------
    if not fixed.dtype.is_floating_point:
        sample_prior(prior, fixed, g)
        norm = linalg.dot(t(g), t(moving)).unsqueeze(1)
        norm = norm.add_(tiny).reciprocal_()
        g *= norm

        torch.mul(g[:, :K], g[:, :K], out=h[:, :K])
        c = K
        for k in range(K):
            for kk in range(k + 1, K):
                torch.mul(g[:, k], g[:, kk], out=h[:, c])
                c += 1

    # ------------------------------------------------------------------
    #       VERSION 2: SOFT LABELS
    # ------------------------------------------------------------------
    else:
        for j in range(J):
            norm = 0
            tmp = torch.zeros_like(g)
            for k in range(K):
                prior1 = prior[:, j, k, None]
                norm += prior1 * moving[:, k, :]
                tmp[:, k, :] = prior1
            tmp /= norm.add_(tiny)
            g += tmp * fixed[:, j, None, :]

            h[:, :K, :] += tmp.square() * fixed[:, j, None, :]
            c = K
            for k in range(K):
                for kk in range(k + 1, K):
                    h[:, c, :] += tmp[:, k, :] * tmp[:, kk, :] \
                                    * fixed[:, j, :]
                    c += 1

    g *= weights
    g.neg_()
    g = g.reshape([B, K, *spatial])
    h *= weights
    h = h.reshape([B, K2, *spatial])

    return g, h
Exemple #11
0
    def __call__(self, vel, grad=False, hess=False, in_line_search=False):

        vel = vel[0]

        phi, iphi, jac, ijac = _exp_1d(vel, model=self.model)

        pos, grad_pos = _deform_1d(self.pos, phi, grad=grad)
        del phi

        neg, grad_neg = _deform_1d(self.neg, iphi, grad=grad)
        del iphi

        if self.modulation:
            pos *= jac
            neg *= ijac
            if self.model == 'svf':
                grad_pos *= jac
                grad_neg *= ijac

        g = ig = h = ih = None
        state = self.loss.get_state()
        if grad and hess:
            ll, g, h = self.loss.loss_grad_hess(pos, neg, mask=self.mask)
            ill, ig, ih = self.loss.loss_grad_hess(neg, pos, mask=self.mask)
        elif grad:
            ll, g = self.loss.loss_grad(pos, neg, mask=self.mask)
            ill, ig = self.loss.loss_grad(neg, pos, mask=self.mask)
        else:
            ll = self.loss.loss(pos, neg, mask=self.mask)
            ill = self.loss.loss(neg, pos, mask=self.mask)
        if in_line_search:
            self.loss.set_state(state)

        ll += ill
        ll /= 2
        if grad:

            if self.modulation:
                pos /= jac
                neg /= ijac

            # move channel channels to the end so that we can use `dot`
            g = utils.movedim(g, 0, -1)
            ig = utils.movedim(ig, 0, -1)
            pos = utils.movedim(pos, 0, -1)
            neg = utils.movedim(neg, 0, -1)
            grad_pos = utils.movedim(grad_pos, 0, -1)
            grad_neg = utils.movedim(grad_neg, 0, -1)

            g0, ig0 = g, ig

            g = linalg.dot(g0, grad_pos)
            if self.modulation:
                g = g.mul_(jac)
                g0 = linalg.dot(g0, pos)
                if self.model == 'svf':
                    g0.mul_(jac)
                g += _div_1d(g0)
            ig = linalg.dot(ig0, grad_neg)
            if self.modulation:
                ig = ig.mul_(ijac)
                ig0 = linalg.dot(ig0, neg)
                if self.model == 'svf':
                    ig0.mul_(ijac)
                ig += _div_1d(ig0)
            del g0, ig0

            if hess:
                h = utils.movedim(h, 0, -1)
                ih = utils.movedim(ih, 0, -1)
                h0, ih0 = h, ih

                grad_pos.square_()
                grad_neg.square_()

                h = linalg.dot(grad_pos, h0)
                if self.modulation:
                    jac.square_()
                    h = h.mul_(jac)
                    h0 = linalg.dot(pos, h0).square_()
                    if self.model == 'svf':
                        h0.mul_(jac)
                    h += _div_1d(_div_1d(h0))
                ih = linalg.dot(grad_neg, ih0)
                if self.modulation:
                    ijac.square_()
                    ih = ih.mul_(ijac)
                    ih0 = linalg.dot(neg, ih0).square_()
                    if self.model == 'svf':
                        ih0.mul_(ijac).mul_(ijac)
                    ih += _div_1d(_div_1d(ih0))
                del h0, ih0

            if self.model == 'svf':
                g, h = spatial.exp1d_backward(vel, g, h, bound=BND)
                ig, ih = spatial.exp1d_backward(-vel, ig, ih, bound=BND)

            g = g.sub_(ig).div_(2)
            g = g[None]
            if hess:
                h = h.add_(ih).div_(2)
                h = h[None]

        del ig, ih, grad_pos, grad_neg, jac, ijac
        vel = vel[None]

        # add regularization term
        vgrad = self.reg(vel)
        llv = 0.5 * vel.flatten().dot(vgrad.flatten())
        if grad:
            g += vgrad
        del vgrad

        # print objective
        if self.verbose and (self.verbose > 1 or not in_line_search):
            ll_prev = self.ll
            if in_line_search:
                line = '(search) | '
            else:
                line = '(topup)  | '
            line += f'{self.n_iter:03d} | {ll.item():12.6g} + {llv.item():12.6g} = {ll.item() + llv.item():12.6g}'
            if not in_line_search:
                self.ll = ll.item() + llv.item()
                self.n_iter += 1
                gain = (ll_prev - self.ll)
                line += f' | {gain:12.6g}'
            print(line, end='\r')

        ll += llv
        out = [ll]
        if grad:
            out.append(g)
        if hess:
            out.append(h)
        return tuple(out) if len(out) > 1 else out[0]
Exemple #12
0
def min_dist(x, s, max_iter=2**16, tol=1e-6, steps=100):
    """Compute the minimum distance from a (set of) point(s) to a curve.

    Parameters
    ----------
    x : (..., dim) tensor
        Coordinates
    s : BSplineCurve
        Parameterized curve

    Returns
    -------
    t : (...) tensor
        Coordinate of the closest point
    d : (...) tensor
        Minimum distance between each point and the curve

    """
    # initialize using a discrete search
    all_t = torch.linspace(0, 1, steps, **utils.backend(x))
    t = x.new_zeros(x.shape[:-1])
    d = x.new_empty(x.shape[:-1]).fill_(float('inf'))
    for t1 in all_t:
        x1 = s.eval_position(t1)
        d1 = x1 - x
        d1 = d1.square_().sum(-1).sqrt_()
        t = torch.where(d1 < d, t1, t)
        d = torch.min(d, d1)

    # Fine tune using Gauss-Newton optimization
    nll = d.square_().sum(-1)
    # d = s.eval_position(t).sub_(x)
    # print(f'{0:03d} {nll.sum().item():12.6g}')
    for n_iter in range(1, max_iter+1):
        # compute the distance between x and s(t) + gradients
        d, g = s.eval_grad_position(t)
        d.sub_(x)
        g = linalg.dot(g, d)
        h = linalg.dot(g, g)
        h.add_(1e-3)
        g.div_(h)

        # Perform GN step (with line search)
        # TODO: I could get rid of the line search
        t0 = t.clone()
        nll0: torch.Tensor = nll
        armijo = torch.full_like(t, 1024)
        success = torch.zeros_like(t, dtype=torch.bool)
        for n_ls in range(12):
            # t = torch.sub(t0, g, alpha=armijo, out=t)
            t = torch.where(success, t, t0 - armijo * g)
            t.clamp_(0, 1)
            d = s.eval_position(t).sub_(x)
            nll = d.square().sum(-1)
            success = success.logical_or_(nll < nll0)
            if success.all():
                break
            armijo = torch.where(success, armijo, armijo/2)
        t = torch.where(success, t, t0)
        if not success.any():
            break

        # print(f'{n_iter:03d} '
        #       f'{nll.sum().item():12.6g} '
        #       f'{(nll0 - nll).sum().item()/t.numel():6.3g} '
        #       f'{armijo.min():6.3g} {armijo.max():6.3g}')
        if (nll0 - nll).sum() < tol * t.numel():
            break

    d = s.eval_position(t).sub_(x)
    d = d.square_().sum(-1).sqrt_()
    return t, d