Esempio n. 1
0
def smart_pull_grid(vel, grid, type='disp', *args, **kwargs):
    """Interpolate a velocity/grid/displacement field.

    Notes
    -----
    Defaults differ from grid_pull:
    - bound -> dft
    - extrapolate -> True

    Parameters
    ----------
    vel : ([batch], *spatial, ndim) tensor
        Velocity
    grid : ([batch], *spatial, ndim) tensor
        Transformation field
    kwargs : dict
        Options to ``grid_pull``

    Returns
    -------
    pulled_vel : ([batch], *spatial, ndim) tensor
        Velocity

    """
    if grid is None or vel is None:
        return vel
    kwargs.setdefault('bound', 'dft')
    kwargs.setdefault('extrapolate', True)
    dim = vel.shape[-1]
    if type == 'grid':
        id = spatial.identity_grid(vel.shape[-dim - 1:-1],
                                   **utils.backend(vel))
        vel = vel - id
    vel = utils.movedim(vel, -1, -dim - 1)
    vel_no_batch = vel.dim() == dim + 1
    grid_no_batch = grid.dim() == dim + 1
    if vel_no_batch:
        vel = vel[None]
    if grid_no_batch:
        grid = grid[None]
    vel = spatial.grid_pull(vel, grid, *args, **kwargs)
    vel = utils.movedim(vel, -dim - 1, -1)
    if vel_no_batch:
        vel = vel[0]
    if type == 'grid':
        id = spatial.identity_grid(vel.shape[-dim - 1:-1],
                                   **utils.backend(vel))
        vel += id
    return vel
def make_data(shape, device, dtype):
    id = identity_grid(shape, dtype=dtype, device=device)
    id = id[None, ...]  # add batch dimension
    disp = torch.rand(id.shape, device=device, dtype=dtype)
    grid = id + disp
    vol = torch.rand((1, 1) + shape, device=device, dtype=dtype)
    return vol, grid
Esempio n. 3
0
def smart_grad(tensor, grid, **opt):
    """Pull gradients iff grid is defined (+ add/remove batch dim).

    Parameters
    ----------
    tensor : (channels, *input_shape) tensor
        Input volume
    grid : (*output_shape, D) tensor or None
        Sampling grid

    Returns
    -------
    pulled : (channels, *output_shape) tensor
        Sampled volume

    """
    # if grid is None:
    #     opt.pop('extrapolate', None)
    #     opt.pop('interpolation', None)
    #     return spatial.diff(tensor, dim=3, **opt)
    if grid is None:
        grid = spatial.identity_grid(tensor.shape[-3:],
                                     dtype=tensor.dtype,
                                     device=tensor.device)
    out = spatial.grid_grad(tensor, grid, **opt)
    return out
Esempio n. 4
0
    def exp(self, velocity, displacement=False):
        """Generate a deformation grid from tangent parameters.

        Parameters
        ----------
        velocity : (batch, *spatial, nb_dim)
            Stationary velocity field
        displacement : bool, default=False
            Return a displacement field (voxel to shift) rather than
            a transformation field (voxel to voxel).

        Returns
        -------
        grid : (batch, *spatial, nb_dim)
            Deformation grid (transformation or displacement).

        """
        backend = dict(dtype=velocity.dtype, device=velocity.device)

        # generate grid
        shape = velocity.shape[1:-1]
        velocity_small = self.resize(velocity, type='displacement')
        grid = self.velexp(velocity_small)
        grid = self.resize(grid, shape=shape, type='grid')

        if displacement:
            grid = grid - spatial.identity_grid(grid.shape[1:-1], **backend)
        return grid
Esempio n. 5
0
 def exp(prm):
     disp = spatial.resize_grid(prm,
                                type='displacement',
                                shape=target.shape[2:],
                                interpolation=3,
                                bound='dft')
     grid = disp + spatial.identity_grid(target.shape[2:], **backend)
     return disp, grid
Esempio n. 6
0
def draw_curves(shape, s, mode='gaussian', tiny=0, **kwargs):
    """Draw multiple BSpline curves

    Parameters
    ----------
    shape : list[int]
    s : list[BSplineCurve]
    mode : {'binary', 'gaussian'}

    Returns
    -------
    x : (*shape) tensor
        Drawn curve
    lab : (*shape) tensor[int]
        Label of closest curve

    """
    s = list(s)
    x = identity_grid(shape, **utils.backend(s[0].waypoints))
    n = len(s)
    tiny = tiny / n
    l = x.new_zeros(shape, dtype=torch.long)
    if mode[0].lower() == 'b':
        s1 = s.pop(0)
        t, d = min_dist(x, s1, **kwargs)
        r = s1.eval_radius(t)
        c = d <= r
        l[c] = 1
        cnt = 1
        while s:
            cnt += 1
            s1 = s.pop(0)
            t, d = min_dist(x, s1, **kwargs)
            r = s1.eval_radius(t)
            c.bitwise_or_(d <= r)
            l[d <= r] = cnt
    else:
        s1 = s.pop(0)
        t, d = min_dist(x, s1, **kwargs)
        r = s1.eval_radius(t)
        c = dist_to_prob(d, r, tiny)
        l.fill_(1)
        cnt = 1
        p = c.clone()
        c = c.neg_().add_(1)
        while s:
            cnt += 1
            s1 = s.pop(0)
            t, d = min_dist(x, s1, **kwargs)
            r = s1.eval_radius(t)
            c1 = dist_to_prob(d, r, tiny)
            l[c1 > p] = cnt
            p = torch.maximum(c1, p)
            c.mul_(c1.neg_().add_(1))
        c = c.neg_().add_(1)
    return c, l
Esempio n. 7
0
def gauss_kernel(f, dim):
    s = f / math.sqrt(8. * math.log(2.)) + 1E-7
    shape = math.ceil(4 * s)
    shape = shape + (shape % 2 == 0)
    g = identity_grid([shape] * dim)
    g -= shape / 2
    g = g.square_().sum(-1)
    g *= (-0.5 / (s**2))
    g.exp_()
    g /= g.sum()
    return g
Esempio n. 8
0
 def _identity(x):
     """Build an identity grid with same shape/backend as a tensor.
     The grid is built such that coordinate zero is at the center of 
     the FOV."""
     shape = x.shape[2:]
     backend = dict(dtype=x.dtype, device=x.device)
     grid = spatial.identity_grid(shape, **backend)
     grid -= torch.as_tensor(shape, **backend) / 2.
     grid /= torch.as_tensor(shape, **backend) / 2.
     grid = last2channel(grid[None, ...])
     return grid
Esempio n. 9
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch shape.

        Other Parameters
        ----------------
        shape : sequence[int], optional
        device : torch.device, optional
        dtype : torch.dtype, optional

        Returns
        -------
        grid : (batch, *shape, 3) tensor
            Resampling grid

        """
        shape = overload.get('shape', self.grid.velocity.field.shape)
        dtype = overload.get('dtype', self.grid.velocity.field.dtype)
        device = overload.get('device', self.grid.velocity.field.device)
        backend = dict(dtype=dtype, device=device)

        if self.grid.velocity.field.amplitude == 0:
            grid = identity_grid(shape, **backend)
        else:
            grid = self.grid(batch, shape=shape, **backend)
        dtype = grid.dtype
        device = grid.device
        backend = dict(dtype=dtype, device=device)

        shape = grid.shape[1:-1]
        dim = len(shape)
        aff = self.affine(batch, dim=dim, **backend)

        # shift center of rotation
        aff_shift = torch.cat((
            torch.eye(dim, **backend),
            torch.as_tensor(shape, **backend)[:, None].sub_(1).div_(-2)),
            dim=1)
        aff_shift = as_euclidean(aff_shift)

        aff = affine_matmul(aff, aff_shift)
        aff = affine_lmdiv(aff_shift, aff)

        # compose
        aff = utils.unsqueeze(aff, dim=-3, ndim=dim)
        lin = aff[..., :dim, :dim]
        off = aff[..., :dim, -1]
        grid = linalg.matvec(lin, grid) + off

        return grid
Esempio n. 10
0
    def forward(self, source, target, source_affine=None, target_affine=None):
        """

        Parameters
        ----------
        source : (sX, sY, sZ) tensor or str
        target : (tX, tY, tZ) tensor or str
        source_affine : (4, 4) tensor, optional
        target_affine : (4, 4) tensor, optional

        Returns
        -------
        warped : (tX, tY, tZ) tensor
            Source warped to target
        velocity : (vX, vY, vZ, 3) tensor
            Stationary velocity field
        affine : (4, 4) tensor, optional
            Affine of the velocity space

        """
        if self.verbose:
            print('Preprocessing... ', end='', flush=True)
        source, source_affine, source_orig, source_affine_orig \
            = self.load(source, source_affine)
        target, target_affine, target_orig, target_affine_orig \
            = self.load(target, target_affine)
        source = spatial.reslice(source, source_affine, target_affine,
                                 target.shape)
        if self.verbose:
            print('done.', flush=True)
            print('Registering... ', end='', flush=True)
        source = source[None, None]
        target = target[None, None]
        warped, vel, grid = super().forward(source, target)
        if self.verbose:
            print('done.', flush=True)
        del source, target, warped
        vel = vel[0]
        grid = grid[0]
        grid -= spatial.identity_grid(grid.shape[:-1],
                                      dtype=grid.dtype,
                                      device=grid.device)
        right_affine = target_affine.inverse() @ target_affine_orig
        right_affine = spatial.affine_grid(right_affine, target_orig.shape)
        grid = spatial.grid_pull(utils.movedim(grid, -1, 0),
                                 right_affine,
                                 bound='nearest',
                                 extrapolate=True)
        grid = utils.movedim(grid, 0, -1).add_(right_affine)
        left_affine = source_affine_orig.inverse() @ target_affine
        grid = spatial.affine_matvec(left_affine, grid)
        warped = spatial.grid_pull(source_orig, grid)
        return warped, vel, target_affine
Esempio n. 11
0
    def exp(self, velocity, affine=None, displacement=False):
        """Generate a deformation grid from tangent parameters.

        Parameters
        ----------
        velocity : (batch, *spatial, nb_dim)
            Stationary velocity field
        affine : (batch, nb_prm)
            Affine parameters
        displacement : bool, default=False
            Return a displacement field (voxel to shift) rather than
            a transformation field (voxel to voxel).

        Returns
        -------
        grid : (batch, *spatial, nb_dim)
            Deformation grid (transformation or displacment).

        """
        info = {'dtype': velocity.dtype, 'device': velocity.device}

        # generate grid
        shape = velocity.shape[1:-1]
        velocity_small = self.resize(velocity, type='displacement')
        grid = self.velexp(velocity_small)
        grid = self.resize(grid, shape=shape, type='grid')

        if affine is not None:
            # exponentiate
            affine_prm = affine
            affine = []
            for prm in affine_prm:
                affine.append(self.affexp(prm))
            affine = torch.stack(affine, dim=0)

            # shift center of rotation
            affine_shift = torch.cat(
                (torch.eye(self.dim, **info),
                 -torch.as_tensor(shape, **info)[:, None] / 2),
                dim=1)
            affine = spatial.affine_matmul(affine, affine_shift)
            affine = spatial.affine_lmdiv(affine_shift, affine)

            # compose
            affine = unsqueeze(affine, dim=-3, ndim=self.dim)
            lin = affine[..., :self.dim, :self.dim]
            off = affine[..., :self.dim, -1]
            grid = matvec(lin, grid) + off

        if displacement:
            grid = grid - spatial.identity_grid(grid.shape[1:-1], **info)

        return grid
Esempio n. 12
0
def smalldef(disp):
    """Transform a displacement grid into a transformation grid

    Parameters
    ----------
    disp : (*shape, dim)

    Returns
    -------
    grid : (*shape, dim)

    """
    id = identity_grid(disp.shape[:-1], dtype=disp.dtype, device=disp.device)
    return disp + id
Esempio n. 13
0
def _draw_curves_inv(shape, s, tiny=0):
    """prod_k (1 - p_k)"""
    s = list(s)
    x = identity_grid(shape, **utils.backend(s[0].waypoints))
    s1 = s.pop(0)
    t, d = min_dist(x, s1)
    r = s1.eval_radius(t)
    c = dist_to_prob(d, r, tiny=tiny).neg_().add_(1)
    while s:
        s1 = s.pop(0)
        t, d = min_dist(x, s1)
        r = s1.eval_radius(t)
        c.mul_(dist_to_prob(d, r, tiny=tiny).neg_().add_(1))
    return c
Esempio n. 14
0
def _laplacian_freq(shape, **backend):
    """
    Compute Fourier squared frequency on the lattice and its inverse.
    """
    dim = len(shape)
    shape = torch.as_tensor(shape, **backend)
    g = spatial.identity_grid(shape, **backend)
    g -= shape // 2
    g /= shape
    g = g.square_().sum(-1)
    if fft._torch_has_old_fft:
        g = g.unsqueeze(-1)
    g = fft.ifftshift(g, dim=list(range(dim)))
    ig = g.reciprocal()
    ig[(0,) * dim] = 0
    return g, ig
Esempio n. 15
0
def roi_closing(label, radius=10, dim=None):
    """Performs a multi-label morphological closing.

    Parameters
    ----------
    label : (..., *spatial) tensor[int]
        Volume of labels.
    radius : float, default=1
        Radius of the structuring element (in voxels)
    dim : int, default=label.dim()
        Number of spatial dimensions

    Returns
    -------
    closed_label : tensor[int]

    """
    from scipy.ndimage import distance_transform_edt, binary_closing

    dim = dim or label.dim()
    closest_label = torch.zeros_like(label)
    closest_dist = label.new_full(label.shape, float('inf'), dtype=torch.float)
    dist = torch.empty_like(closest_dist)

    for l in label.unique():
        if l == 0:
            continue
        if label.dim() == dim:
            dist = torch.as_tensor(distance_transform_edt(label != l))
        elif label.dim() == dim + 1:
            for z in range(len(dist)):
                dist[z] = torch.as_tensor(
                    distance_transform_edt(label[z] != l))
        else:
            raise NotImplementedError
        closest_label[dist < closest_dist] = l
        closest_dist = torch.min(closest_dist, dist)

    struct = spatial.identity_grid([2 * radius + 1] * dim).sub_(radius)
    struct = struct.square().sum(-1).sqrt() <= radius
    struct = utils.unsqueeze(struct, 0, label.dim() - dim)
    mask = binary_closing(label > 0, struct)
    mask = torch.as_tensor(mask).bitwise_not_()
    closest_label[mask] = 0

    return closest_label
Esempio n. 16
0
def voxelize_rois(rois, shape, roi_to_vox=None, device=None):
    """Create a volume of labels from a parametric ROI.

    Parameters
    ----------
    rois : dict
        Object returned by `read_asc`
    shape : sequence[int]
    roi_to_vox : (d+1, d+1) tensor

    Returns
    -------
    roi : (*shape) tensor[int]
    names : list[str]

    """
    out = torch.empty(shape, dtype=torch.long)
    grid = spatial.identity_grid(shape[:2], device=device)
    if roi_to_vox is not None:
        roi_to_vox = roi_to_vox.to(device=device)

    names = list(rois['regions'].keys())

    for l, (name, shapes) in enumerate(rois['regions'].items()):
        print(name)
        label = l + 1
        for i, shape in enumerate(shapes):
            print(i + 1, '/', len(shapes), end='\r')
            vertices = [[p['x'], p['y'], p['z']] for p in shape['points']]
            vertices = torch.as_tensor(vertices, device=device)
            if roi_to_vox is not None:
                vertices = spatial.affine_matvec(roi_to_vox, vertices)
            z = math.round(vertices[0, 2]).int().item()
            if not (0 <= z < out.shape[-1]):
                print('Contour not in FOV. Skipping it...')
                continue
            vertices = vertices[:, :2]
            faces = [(i, i + 1 if i + 1 < len(vertices) else 0)
                     for i in range(len(vertices))]

            mask = is_inside(grid, vertices, faces).cpu()
            out[..., z][mask] = label
        print('')

    return out, names
Esempio n. 17
0
def affine_grid_backward(*grad_hess, grid=None):
    """Converts ∇ wrt dense displacement into ∇ wrt affine matrix

    g = affine_grid_backward(g, [grid=None])
    g, h = affine_grid_backward(g, h, [grid=None])

    Parameters
    ----------
    grad : (..., *spatial, dim) tensor
        Gradient with respect to a dense displacement.
    hess : (..., *spatial, dim*(dim+1)//2) tensor, optional
        Hessian with respect to a dense displacement.
    grid : (*spatial, dim) tensor, optional
        Pre-computed identity grid

    Returns
    -------
    grad : (..., dim, dim+1) tensor
        Gradient with respect to an affine matrix
    hess : (..., dim, dim+1, dim, dim+1) tensor, optional
        Hessian with respect to an affine matrix

    """
    has_hess = len(grad_hess) > 1
    grad, *hess = grad_hess
    hess = hess.pop(0) if hess else None
    del grad_hess

    dim = grad.shape[-1]
    shape = grad.shape[-dim - 1:-1]
    batch = grad.shape[:-dim - 1]
    nvox = py.prod(shape)
    if grid is None:
        grid = spatial.identity_grid(shape, **utils.backend(grad))
    grid = grid.reshape([1, nvox, dim])
    grad = grad.reshape([-1, nvox, dim])
    if hess is not None:
        hess = hess.reshape([-1, nvox, dim * (dim + 1) // 2])
        grad, hess = _affine_grid_backward_gh(grid, grad, hess)
        hess = hess.reshape([*batch, dim, dim + 1, dim, dim + 1])
    else:
        grad = _affine_grid_backward_g(grid, grad)
    grad = grad.reshape([*batch, dim, dim + 1])
    return (grad, hess) if has_hess else grad
Esempio n. 18
0
def ffd_exp(prm, shape, order=3, bound='dft', returns='disp'):
    """Transform FFD parameters into a displacement or transformation grid.

    Parameters
    ----------
    prm : (..., *spatial, dim)
        FFD parameters
    shape : sequence[int]
        Exponentiated shape
    order : int, default=3
        Spline order
    bound : str, default='dft'
        Boundary condition
    returns : {'disp', 'grid', 'disp+grid'}, default='grid'
        What to return:
        - 'disp' -> displacement grid
        - 'grid' -> transformation grid

    Returns
    -------
    disp : (..., *shape, dim), optional
        Displacement grid
    grid : (..., *shape, dim), optional
        Transformation grid

    """
    backend = dict(dtype=prm.dtype, device=prm.device)
    dim = prm.shape[-1]
    batch = prm.shape[:-(dim + 1)]
    prm = prm.reshape([-1, *prm.shape[-(dim + 1):]])
    disp = resize_grid(prm,
                       type='displacement',
                       shape=shape,
                       interpolation=order,
                       bound=bound)
    disp = disp.reshape(batch + disp.shape[1:])
    grid = disp + identity_grid(shape, **backend)
    if 'disp' in returns and 'grid' in returns:
        return disp, grid
    elif 'disp' in returns:
        return disp
    elif 'grid' in returns:
        return grid
Esempio n. 19
0
    def propagate_grad(self, g, h, moving, phi, left=None, right=None, inv=False):
        """Convert derivatives wrt warped image in loss space to
        to derivatives wrt parameters
        parameters:
            g (tensor) : gradient wrt warped image
            h (tensor) : hessian wrt warped image
            moving (Image) : moving image
            phi (tensor) : dense (exponentiated) displacement field
            left (matrix) : left affine
            right (matrix) : right affine
            inv (bool) : whether we're in a backward symmetric pass
        returns:
            g (tensor) : pushed gradient
            h (tensor) : pushed hessian
            gmu (tensor) : rotated spatial gradients
        """
        if inv:
            g = g.neg_()

        # build bits of warp
        dim = phi.shape[-1]
        fixed_shape = g.shape[-dim:]
        moving_shape = moving.shape

        # differentiate wrt δ in: Left o Phi o (Id + δ) o Right
        # we'll then propagate them through Phi by scaling and squaring
        if right is not None:
            right = spatial.affine_grid(right, fixed_shape)
        g = regutils.smart_push(g, right, shape=self.shape)
        h = regutils.smart_push(h, right, shape=self.shape)
        del right

        phi_left = spatial.identity_grid(self.shape, **utils.backend(phi))
        phi_left += phi
        if left is not None:
            phi_left = spatial.affine_matvec(left, phi_left)
        mugrad = moving.pull_grad(phi_left, rotate=False)
        del phi_left

        mugrad = _rotate_grad(mugrad, left, phi)

        return g, h, mugrad
Esempio n. 20
0
def draw_curves(shape, s, mode='gaussian', tiny=0, **kwargs):
    """Draw multiple BSpline curves

    Parameters
    ----------
    shape : list[int]
    s : list[BSplineCurve]
    mode : {'binary', 'gaussian'}

    Returns
    -------
    x : (*shape) tensor
        Drawn curve

    """
    s = list(s)
    x = identity_grid(shape, **utils.backend(s[0].waypoints))
    n = len(s)
    tiny = tiny / n
    if mode[0].lower() == 'b':
        s1 = s.pop(0)
        t, d = min_dist(x, s1, **kwargs)
        r = s1.eval_radius(t)
        c = d <= r
        while s:
            s1 = s.pop(0)
            t, d = min_dist(x, s1, **kwargs)
            r = s1.eval_radius(t)
            c.bitwise_or_(d <= r)
    else:
        s1 = s.pop(0)
        t, d = min_dist(x, s1, **kwargs)
        r = s1.eval_radius(t)
        c = dist_to_prob(d, r, tiny).neg_().add_(1)
        while s:
            s1 = s.pop(0)
            t, d = min_dist(x, s1, **kwargs)
            r = s1.eval_radius(t)
            c.mul_(dist_to_prob(d, r, tiny).neg_().add_(1))
        c = c.neg_().add_(1)
    return c
Esempio n. 21
0
def _get_dat_grid(dat, vx, samp, jitter=True, device='cpu'):
    """Get sub-sampled image data, and resampling grid.

    Parameters
    ----------
    dat : (X0, Y0, Z0) tensor_like
        Fixed image data.
    vx : (3,) tensor_like.
        Fixed voxel size.
    samp : int|float
        Sub-sampling level.
    jitter : bool, default=True
        Add random jittering to identity grid.

    Returns
    ----------
    dat_samp : (X1, Y1, Z1) tensor_like
        Sub-sampled fixed image data.
    grid : (X1, Y1, Z1) tensor_like
        Sub-sampled image data's resampling grid.

    """
    if isinstance(dat, (list, tuple)):
        dat = torch.zeros(dat, dtype=torch.float32, device=device)
    # Modulate samp with voxel size
    device = dat.device
    samp = torch.tensor((samp,) * 3).float().to(device)
    samp = torch.clamp(samp / vx, 1)
    # Create grid of fixed image, possibly sub-sampled
    grid = identity_grid(dat.shape,
        dtype=torch.float32, device=device)
    if jitter:
        torch.manual_seed(0)
        grid += torch.rand_like(grid)*samp
    # Sub-sampled
    samp = samp.round().int().tolist()
    grid = grid[::samp[0], ::samp[1], ::samp[2], ...]
    dat_samp = dat[::samp[0], ::samp[1], ::samp[2]]

    return dat_samp, grid
Esempio n. 22
0
def dist_map(shape, dtype=None, device=None):
    """Return the squared distance between all pairs in a FOV.

    Parameters
    ----------
    shape : sequence[int]
    dtype : optional
    device : optional

    Returns
    -------
    dist : (prod(shape), proD(shape) tensor
        Squared distance map

    """
    backend = dict(dtype=dtype, device=device)
    shape = py.make_tuple(shape)
    dim = len(shape)
    g = spatial.identity_grid(shape, **backend)
    g = g.reshape([-1, dim])
    g = (g[:, None, :] - g[None, :, :]).square_().sum(-1)
    return g
Esempio n. 23
0
def draw_curve(shape, s, mode='gaussian', tiny=0, **kwargs):
    """Draw a BSpline curve

    Parameters
    ----------
    shape : list[int]
    s : BSplineCurve
    mode : {'binary', 'gaussian'}

    Returns
    -------
    x : (*shape) tensor
        Drawn curve

    """
    x = identity_grid(shape, **utils.backend(s.waypoints))
    t, d = min_dist(x, s, **kwargs)
    r = s.eval_radius(t)
    if mode[0].lower() == 'b':
        return d <= r
    else:
        return dist_to_prob(d, r, tiny)
Esempio n. 24
0
 def add_identity(cls, disp):
     dim = disp.shape[-1]
     shape = disp.shape[-dim-1:-1]
     return spatial.identity_grid(shape, **utils.backend(disp)).add_(disp)
Esempio n. 25
0
    def __call__(self,
                 logaff,
                 grad=False,
                 hess=False,
                 gradmov=False,
                 hessmov=False,
                 in_line_search=False):
        """
        logaff : (..., nb) tensor, Lie parameters
        grad : Whether to compute and return the gradient wrt `logaff`
        hess : Whether to compute and return the Hessian wrt `logaff`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., logaff) tensor, optional, Gradient wrt Lie parameters
        h : (..., logaff) tensor, optional, Hessian wrt Lie parameters
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % logplot == 0

        # jitter
        # if not hasattr(self, '_fixed'):
        #     idj = spatial.identity_grid(self.fixed.shape[-self.dim:],
        #                                 jitter=True,
        #                                 **utils.backend(self.fixed))
        #     self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt)
        #     del idj
        # fixed = self._fixed
        fixed = self.fixed

        # forward
        if not torch.is_tensor(self.basis):
            self.basis = spatial.affine_basis(self.basis, self.dim,
                                              **utils.backend(logaff))
        aff = linalg.expm(logaff, self.basis)
        with torch.no_grad():
            _, gaff = linalg._expm(logaff,
                                   self.basis,
                                   grad_X=True,
                                   hess_X=False)

        aff = spatial.affine_matmul(aff, self.affine_fixed)
        aff = spatial.affine_lmdiv(self.affine_moving, aff)
        # /!\ derivatives are not "homogeneous" (they do not have a one
        # on the bottom right): we should *not* use affine_matmul and
        # such (I only lost a day...)
        gaff = torch.matmul(gaff, self.affine_fixed)
        gaff = linalg.lmdiv(self.affine_moving, gaff)
        # haff = torch.matmul(haff, self.affine_fixed)
        # haff = linalg.lmdiv(self.affine_moving, haff)
        if self.id is None:
            shape = self.fixed.shape[-self.dim:]
            self.id = spatial.identity_grid(shape,
                                            **utils.backend(logaff),
                                            jitter=False)
        grid = spatial.affine_matvec(aff, self.id)
        warped = spatial.grid_pull(self.moving, grid, **pullopt)
        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        cat=iscat,
                        dim=self.dim)

        # gradient/Hessian of the log-likelihood in observed space
        if not grad and not hess:
            llx = self.loss.loss(warped, fixed)
        elif not hess:
            llx, grad = self.loss.loss_grad(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
        else:
            llx, grad, hess = self.loss.loss_grad_hess(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
            if hessmov:
                hessmov = spatial.grid_push(hess, grid, **pullopt)
        del warped

        # compose with spatial gradients + dot product with grid
        if grad is not False or hess is not False:
            mugrad = spatial.grid_grad(self.moving, grid, **pullopt)
            grad = jg(mugrad, grad)
            if hess is not False:
                hess = jhj(mugrad, hess)
                grad, hess = regutils.affine_grid_backward(grad,
                                                           hess,
                                                           grid=self.id)
            else:
                grad = regutils.affine_grid_backward(grad)  # , grid=self.id)
            dim2 = self.dim * (self.dim + 1)
            grad = grad.reshape([*grad.shape[:-2], dim2])
            gaff = gaff[..., :-1, :]
            gaff = gaff.reshape([*gaff.shape[:-2], dim2])
            grad = linalg.matvec(gaff, grad)
            if hess is not False:
                hess = hess.reshape([*hess.shape[:-4], dim2, dim2])
                # haff = haff[..., :-1, :, :-1, :]
                # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2])
                hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2))
                hess = hess.abs().sum(-1).diag_embed()
            del mugrad

        # print objective
        llx = llx.item()
        ll = llx
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}',
                    end='\n')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\n')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [ll]
        if grad is not False:
            out.append(grad)
        if hess is not False:
            out.append(hess)
        if gradmov is not False:
            out.append(gradmov)
        if hessmov is not False:
            out.append(hessmov)
        return tuple(out) if len(out) > 1 else out[0]
Esempio n. 26
0
    def __call__(self, logaff, grad=False, hess=False, in_line_search=False):
        """
        logaff : (..., nb) tensor, Lie parameters
        grad : Whether to compute and return the gradient wrt `logaff`
        hess : Whether to compute and return the Hessian wrt `logaff`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., logaff) tensor, optional, Gradient wrt Lie parameters
        h : (..., logaff) tensor, optional, Hessian wrt Lie parameters
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        # select correct gradient mode
        if grad:
            logaff.requires_grad_()
            if logaff.grad is not None:
                logaff.grad.zero_()
        if grad and not torch.is_grad_enabled():
            with torch.enable_grad():
                return self(logaff, grad, in_line_search=in_line_search)
        elif not grad and torch.is_grad_enabled():
            with torch.no_grad():
                return self(logaff, grad, in_line_search=in_line_search)

        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                   and (self.n_iter - 1) % logplot == 0

        # jitter
        # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], jitter=True,
        #                             **utils.backend(self.fixed))
        # fixed = spatial.grid_pull(self.fixed, idj, **pullopt)
        # del idj
        fixed = self.fixed

        # forward
        if not torch.is_tensor(self.basis):
            self.basis = spatial.affine_basis(self.basis, self.dim,
                                              **utils.backend(logaff))
        aff = linalg.expm(logaff, self.basis)
        aff = spatial.affine_matmul(aff, self.affine_fixed)
        aff = spatial.affine_lmdiv(self.affine_moving, aff)
        if self.id is None:
            shape = self.fixed.shape[-self.dim:]
            self.id = spatial.identity_grid(shape, **utils.backend(logaff))
        grid = spatial.affine_matvec(aff, self.id)
        warped = spatial.grid_pull(self.moving, grid, **pullopt)
        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        cat=iscat,
                        dim=self.dim)

        # gradient/Hessian of the log-likelihood in observed space
        llx = self.loss.loss(warped, fixed)
        del warped

        # print objective
        lll = llx
        llx = llx.item()
        ll = llx
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}',
                    end='\n')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\n')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [lll]
        if grad is not False:
            lll.backward()
            grad = logaff.grad.clone()
            out.append(grad)
        logaff.requires_grad_(False)
        return tuple(out) if len(out) > 1 else out[0]
Esempio n. 27
0
def fit_curves_cat(f, s, vx=1, max_iter=8, tol=1e-8, max_levels=4):
    """Fit the set of curves that maximizes a Categorial likelihood

    Parameters
    ----------
    f : (*shape) tensor
        Observed grid of binary labels or smooth probabilities.
    s : list[BSplineCurve]
        Initial curves (will be modified in-place)

    Returns
    -------
    s : list[BSplineCurve]
        Fitted curves

    """

    TINY = 1e-6
    fig = elem = None
    backend = utils.backend(s[0].coeff)

    max_iter_position = 8
    max_iter_radius = 4

    vx = utils.make_vector(vx, f.dim(), **backend)
    vx0 = vx.clone()
    n0 = f.numel()

    # Build pyramid by restriction
    shapes = [f.shape]
    images = [f]
    vxs = [vx]
    for n_level in range(max_levels - 1):
        shape = [pymath.ceil(s / 2) for s in shapes[-1]]
        if all(s == 1 for s in shape):
            break
        shapes.append(shape)
        images.append(restrict(f.unsqueeze(-1), shapes[-1]).squeeze(-1))
        vx = vx * (torch.as_tensor(shapes[-2], **backend) /
                   torch.as_tensor(shapes[-1], **backend))
        vxs.append(vx)
        for s1 in s:
            s1.restrict(shapes[-2], shapes[-1])

    start = time.time()

    shape = None
    level = len(images) + 1
    while images:
        level -= 1
        print('-' * 16, 'level', level, '-' * 16)

        if shape is not None:
            for s1 in s:
                s1.prolong(shape, shapes[-1])
        f, shape, vx = images.pop(-1), shapes.pop(-1), vxs.pop(-1)
        x = identity_grid(f.shape, **backend)
        scl = vx.prod() / vx0.prod()

        def get_nll(e):
            ie = (1 - e).log()
            e = e.log()
            if f.dtype is torch.bool:
                ll = e[f].sum(dtype=torch.double) + ie[~f].sum(
                    dtype=torch.double)
            else:
                ll = (e * f).sum(dtype=torch.double) + (ie * (1 - f)).sum(
                    dtype=torch.double)
            return -ll

        nll = float('inf')
        max_iter_level = max_iter * 2**((level - 1) // 2)
        for n_iter in range(max_iter_level):

            nll0_prev = nll

            for n_curve in range(len(s)):

                s0 = s[n_curve]
                s1 = s[:n_curve] + s[n_curve + 1:]
                ie1 = _draw_curves_inv(f.shape, s1, TINY)

                for n_iter_position in range(max_iter_position):

                    t, d = min_dist(x, s0)
                    p = s0.eval_position(t).sub_(x)  # residuals
                    r = s0.eval_radius(t)
                    r = torch.as_tensor(r, **utils.backend(x))
                    e0 = dist_to_prob(d, r, TINY)
                    ome0 = 1 - e0
                    e = 1 - ome0 * ie1
                    nll_prev = nll
                    nll = get_nll(e)
                    lam = radius_to_prec(r)

                    # gradient of the categorical term
                    g = (1 - f / e) * e0 / ome0 * (-lam)
                    h = (e0 / ome0).square() * (1 - e) / e
                    g = g.unsqueeze(-1)
                    h = h.unsqueeze(-1)
                    lam = lam.unsqueeze(-1)

                    acc = 0.5
                    h = h * (lam * p).square()
                    if acc != 1:
                        h += (1 - acc) * g.abs()
                    g = g * p

                    # push
                    g = s0.push_position(g, t)
                    h = s0.push_position(h, t)

                    g *= scl
                    h *= scl
                    nll *= scl

                    g.div_(h)
                    s0.coeff -= g

                    wp = [ss.waypoints for ss in s]
                    fig, elem = plot_nll(nll, e, f, wp, fig, elem)
                    print('position', n_iter, n_curve, n_iter_position,
                          nll.item(), (nll_prev - nll).item() / n0)
                    s0.update_waypoints()
                    # if nll_prev - nll < tol * f.numel():
                    #     break

                if level < 3:
                    max_iter_radius_level = max_iter_radius
                else:
                    max_iter_radius_level = 0
                for n_iter_radius in range(max_iter_radius_level):

                    alpha = (2.355 / 2)**2
                    t, d = min_dist(x, s0)
                    r = s0.eval_radius(t)
                    r = torch.as_tensor(r, **utils.backend(x))
                    e0 = dist_to_prob(d, r)
                    ome0 = 1 - e0
                    e = 1 - ome0 * ie1
                    d = d.square_()
                    nll_prev = nll
                    nll = get_nll(e)

                    # gradient of the categorical term
                    alpha = alpha * d / r.pow(3)
                    g = (1 - f / e) * e0 / ome0 * alpha
                    h = e0 / ome0.square()

                    acc = 0
                    h *= alpha.square()
                    if acc != 1:
                        h += (1 - acc) * g.abs() * 3 / r

                    # push
                    g = s0.push_radius(g, t)
                    h = s0.push_radius(h, t)

                    g *= scl
                    h *= scl
                    nll *= scl

                    g.div_(h)
                    s0.coeff_radius -= g
                    s0.coeff_radius.clamp_min_(0.5)

                    wp = [ss.waypoints for ss in s]
                    fig, elem = plot_nll(nll, e, f, wp, fig, elem)
                    print('radius', n_iter, n_curve, n_iter_radius, nll.item(),
                          (nll_prev - nll).item() / n0)
                    s0.update_radius()
                    # if nll_prev - nll < tol * f.numel():
                    #     break

            if not n_iter % 10:
                print(n_iter, nll.item(), (nll0_prev - nll).item() / n0)
            # if abs(nll0_prev - nll) < tol * f.numel():
            #     print('Converged')
            #     break

        stop = time.time()
        print(stop - start)
Esempio n. 28
0
def fit_curve_cat(f,
                  s,
                  lam=0,
                  gamma=0,
                  vx=1,
                  max_iter=8,
                  tol=1e-8,
                  max_levels=4):
    """Fit the curve that maximizes the categorical likelihood

    Parameters
    ----------
    f : (*shape) tensor
        Observed grid of binary labels or smooth probabilities.
    s : BSplineCurve
        Initial curve (will be modified in-place)

    Other Parameters
    ----------------
    lam : float, default=0
        Centerline regularization (bending)
    gamma : float, default=0
        Radius regularization (membrane)
    vx : float, default=1
        Voxel size
    max_iter : int, default=128
        Maximum number of iterations per level
        (This will me multiplied by 2 at each resolution level, such that
        more iterations are used at coarser levels).
    tol : float, default=1e-8
        Unused
    max_levels : int, default=4
        Number of multi-resolution levels.

    Returns
    -------
    s : BSplineCurve
        Fitted curve

    """
    TINY = 1e-6
    fig = elem = None

    max_iter_position = 8
    max_iter_radius = 4

    backend = utils.backend(s.coeff)
    vx = utils.make_vector(vx, f.dim(), **backend)
    vx0 = vx.clone()
    n0 = f.numel()

    # Build pyramid by restriction
    shapes = [f.shape]
    images = [f]
    vxs = [vx]
    for n_level in range(max_levels - 1):
        shape = [pymath.ceil(s / 2) for s in shapes[-1]]
        if all(s == 1 for s in shape):
            break
        shapes.append(shape)
        images.append(restrict(f.unsqueeze(-1), shapes[-1]).squeeze(-1))
        s.restrict(shapes[-2], shapes[-1])
        vx = vx * (torch.as_tensor(shapes[-2], **backend) /
                   torch.as_tensor(shapes[-1], **backend))
        vxs.append(vx)

    start = time.time()

    shape = None
    level = len(images) + 1
    while images:
        level -= 1
        print('-' * 16, 'level', level, '-' * 16)

        if shape is not None:
            s.prolong(shape, shapes[-1])
        f, shape, vx = images.pop(-1), shapes.pop(-1), vxs.pop(-1)
        scl = vx.prod() / vx0.prod()
        x = identity_grid(f.shape, **backend)
        if lam:
            L = lam * bending3(len(s.coeff), **backend)
            reg = L.matmul(s.coeff).mul_(vx.square())
            reg = 0.5 * (s.coeff * reg).sum(dtype=torch.double)
        else:
            reg = 0
        if gamma:
            Lr = gamma * membrane3(len(s.coeff_radius), **backend)
            Lr /= vx.prod().pow_(1 / len(vx)).square_()
            reg_radius = Lr.matmul(s.coeff_radius)
            reg_radius = 0.5 * (s.coeff_radius *
                                reg_radius).sum(dtype=torch.double)
        else:
            reg_radius = 0

        def get_nll(e):
            ie = (1 - e).log()
            e = e.log()
            if f.dtype is torch.bool:
                ll = e[f].sum(dtype=torch.double) + ie[~f].sum(
                    dtype=torch.double)
            else:
                ll = (e * f).sum(dtype=torch.double) + (ie * (1 - f)).sum(
                    dtype=torch.double)
            ll = -ll
            return ll

        nll = float('inf')
        max_iter_level = max_iter * 2**((level - 1) // 2)
        for n_iter in range(max_iter_level):

            nll0_prev = nll

            for n_iter_position in range(max_iter_position):

                t, d = min_dist(x, s)
                p = s.eval_position(t).sub_(x)  # residuals
                r = s.eval_radius(t)
                r = torch.as_tensor(r, **utils.backend(x))
                e = dist_to_prob(d, r, tiny=TINY)
                nll_prev = nll
                nll = get_nll(e)
                prec = radius_to_prec(r)

                # gradient of the categorical term
                omf = (1 - f) if f.dtype.is_floating_point else f.bitwise_not()
                ome = (1 - e)
                g = (omf / ome - 1) * (-prec)
                h = omf * e / ome.square()
                g = g.unsqueeze(-1)
                h = h.unsqueeze(-1)
                prec = prec.unsqueeze(-1)

                acc = 0.5
                h = h * (prec * p).square()
                if acc != 1:
                    h += (1 - acc) * g.abs()
                g = g * p

                # push
                g = s.push_position(g, t)
                h = s.push_position(h, t)

                # resolution scale
                g *= scl
                h *= scl
                nll *= scl

                # regularisation + solve
                if lam:
                    reg = L.matmul(s.coeff).mul_(vx.square())
                    g += reg
                    reg = 0.5 * (s.coeff * reg).sum(dtype=torch.double)
                    # h += L[1, :].abs().sum()
                    g = torch.stack([
                        linalg.lmdiv(h1.diag() +
                                     (v1 * v1) * L, g1[:, None])[:, 0]
                        for v1, g1, h1 in zip(vx, g.T, h.T)
                    ], -1)
                else:
                    g.div_(h)
                    reg = 0
                s.coeff.sub_(g)
                # s.coeff.clamp_min_(0)
                # for d, sz in enumerate(f.shape):
                #     s.coeff[:, d].clamp_max_(sz-1)

                fig, elem = plot_nll([nll, reg, reg_radius], e, f, s.waypoints,
                                     fig, elem)
                nll = nll + reg + reg_radius
                print('position', n_iter, n_iter_position, nll.item(),
                      (nll_prev - nll).item() / n0)
                s.update_waypoints()
                # if nll_prev - nll < tol * f.numel():
                #     break

            if level < 3:
                max_iter_radius_level = max_iter_radius
            else:
                max_iter_radius_level = 0
            for n_iter_radius in range(max_iter_radius_level):

                alpha = (2.355 / 2)**2
                t, d = min_dist(x, s)
                r = s.eval_radius(t)
                r = torch.as_tensor(r, **utils.backend(x))
                e = dist_to_prob(d, r, TINY)
                d = d.square_()
                nll_prev = nll
                nll = get_nll(e)

                # gradient of the categorical term
                omf = (1 - f) if f.dtype.is_floating_point else f.bitwise_not()
                ome = (1 - e)
                alpha = alpha * d / r.pow(3)
                g = (omf / ome - 1) * alpha

                acc = 0
                h = omf * e / ome.square()
                h *= alpha.square()
                if acc != 1:
                    h += (1 - acc) * g.abs() * 3 / r

                # push
                g = s.push_radius(g, t)
                h = s.push_radius(h, t)

                # resolution scale
                g *= scl
                h *= scl
                nll *= scl

                # regularisation + solve
                if gamma:
                    reg_radius = Lr.matmul(s.coeff_radius)
                    g += reg_radius
                    reg_radius = 0.5 * (s.coeff_radius *
                                        reg_radius).sum(dtype=torch.double)
                    g = linalg.lmdiv(h.diag() + L, g[:, None])[:, 0]
                else:
                    g.div_(h)
                    reg_radius = 0

                # solve
                s.coeff_radius -= g
                s.coeff_radius.clamp_min_(0.5)

                fig, elem = plot_nll([nll, reg, reg_radius], e, f, s.waypoints,
                                     fig, elem)
                nll = nll + reg + reg_radius
                print('radius', n_iter, n_iter_radius, nll.item(),
                      (nll_prev - nll).item() / n0)
                s.update_radius()
                # if nll_prev - nll < tol * f.numel():
                #     break

            if not n_iter % 10:
                print(n_iter, nll.item(), (nll0_prev - nll).item() / n0)
            # if nll0_prev - nll < tol * f.numel():
            #     print('Converged')
            #     break

    stop = time.time()
    print(stop - start)
Esempio n. 29
0
def fit_curve_joint(f, s, max_iter=128, tol=1e-8):
    """Fit the curve that maximizes the joint probability p(f) * p(s)

    Parameters
    ----------
    f : (*shape) tensor
        Observed grid of binary labels or smooth probabilities.
    s : BSplineCurve
        Initial curve (will be modified in-place)
    max_iter : int, default=128
    tol : float, default=1e-8

    Returns
    -------
    s : BSplineCurve
        Fitted curve

    """

    x = identity_grid(f.shape, **utils.backend(s.coeff))

    max_iter_position = 10
    max_iter_radius = 3
    sumf = f.sum(dtype=torch.double)

    def get_nll(e):
        if f.dtype is torch.bool:
            return sumf + e.sum(
                dtype=torch.double) - 2 * e[f].sum(dtype=torch.double)
        else:
            return sumf + e.sum(
                dtype=torch.double) - 2 * (e * f).sum(dtype=torch.double)

    start = time.time()
    nll = float('inf')
    for n_iter in range(max_iter):

        nll0_prev = nll

        for n_iter_position in range(max_iter_position):

            t, d = min_dist(x, s)
            p = s.eval_position(t).sub_(x)  # residuals
            r = s.eval_radius(t)
            r = torch.as_tensor(r, **utils.backend(x))
            e = dist_to_prob(d, r)
            nll_prev = nll
            nll = get_nll(e)
            lam = radius_to_prec(r)

            # gradient of the categorical term
            g = e * (1 - 2 * f) * (-lam)
            g = g.unsqueeze(-1)
            lam = lam.unsqueeze(-1)
            # e = e.unsqueeze(-1)
            # h = g.abs() + e * (lam * p).square()
            h = g.abs() * (1 + lam * p.square())
            g = g * p

            # push
            g = s.push_position(g, t)
            h = s.push_position(h, t)
            g.div_(h)
            s.coeff -= g

            # print('position', n_iter, n_iter_position,
            #       nll.item(), (nll_prev - nll).item() / f.numel())
            if nll_prev - nll < tol * f.numel():
                break

        for n_iter_position in range(max_iter_radius):

            alpha = (2.355 / 2)**2
            t, d = min_dist(x, s)
            r = s.eval_radius(t)
            r = torch.as_tensor(r, **utils.backend(x))
            e = dist_to_prob(d, r)
            d = d.square_()
            nll_prev = nll
            nll = get_nll(e)

            # gradient of the categorical term
            g = e * (1 - 2 * f) * (alpha * d / r.pow(3))
            h = g.abs() * (alpha * d / r.pow(3)) * (1 + 3 / r)

            # push
            g = s.push_radius(g, t)
            h = s.push_radius(h, t)
            g.div_(h)
            s.coeff_radius -= g
            s.coeff_radius.clamp_min_(0.5)

            # print('radius', n_iter, n_iter_position,
            #       nll.item(), (nll_prev - nll).item() / f.numel())
            if nll_prev - nll < tol * f.numel():
                break

        if not n_iter % 10:
            print(n_iter, nll.item(), (nll0_prev - nll).item() / f.numel())
        if abs(nll0_prev - nll) < tol * f.numel():
            print('Converged')
            break

    stop = time.time()
    print(stop - start)
Esempio n. 30
0
    def forward(self, grid, **overload):
        """

        Parameters
        ----------
        grid : (N, *spatial, dim)
            Displacement grid
        overload : dict

        Returns
        -------
        aff : (N, dim+1, dim+1)
            Affine matrix that is closest to grid in the least square sense

        """
        shift = overload.get('shift', self.shift)
        grid = torch.as_tensor(grid)
        info = dict(dtype=grid.dtype, device=grid.device)
        nb_dim = grid.shape[-1]
        shape = grid.shape[1:-1]

        if shift:
            affine_shift = torch.cat((torch.eye(
                nb_dim, **info), -torch.as_tensor(shape, **info)[:, None] / 2),
                                     dim=1)
            affine_shift = spatial.as_euclidean(affine_shift)

        # the forward model is:
        #   phi(x) = M\A*M*x
        # where phi is a *transformation* field, M is the shift matrix
        # and A is the affine matrix.
        # We can decompose phi(x) = x + d(x), where d is a *displacement*
        # field, yielding:
        #   d(x) = M\A*M*x - x = (M\A*M - I)*x := B*x
        # If we write `d(x)` and `x` as large vox*(dim+1) matrices `D`
        # and `G`, we have:
        #   D = G*B'
        # Therefore, the least squares B is obtained as:
        #   B' = inv(G'*G) * (G'*D)
        # Then, A is
        #   A = M*(B + I)/M
        #
        # Finally, we project the affine matrix to its tangent space:
        #   prm[k] = <log(A), B[k]>
        # were <X,Y> = trace(X'*Y) is the Frobenius inner product.

        def igg(identity):
            # Compute inv(g*g'), where g has homogeneous coordinates.
            #   Instead of appending ones, we compute each element of
            #   the block matrix ourselves:
            #       [[g'*g,   g'*1],
            #        [1'*g,   1'*1]]
            #    where 1'*1 = N, the number of voxels.
            g = identity.reshape([identity.shape[0], -1, nb_dim])
            nb_vox = torch.as_tensor([[[g.shape[1]]]], **info)
            sumg = g.sum(dim=1, keepdim=True)
            gg = torch.matmul(g.transpose(-1, -2), g)
            gg = torch.cat((gg, sumg), dim=1)
            sumg = sumg.transpose(-1, -2)
            sumg = torch.cat((sumg, nb_vox), dim=1)
            gg = torch.cat((gg, sumg), dim=2)
            return gg.inverse()

        def gd(identity, disp):
            # compute g'*d, where g and d have homogeneous coordinates.
            #       [[g'*d,   g'*1],
            #        [1'*d,   1'*1]]
            g = identity.reshape([identity.shape[0], -1, nb_dim])
            d = disp.reshape([disp.shape[0], -1, nb_dim])
            nb_vox = torch.as_tensor([[[g.shape[1]]]], **info)
            sumg = g.sum(dim=1, keepdim=True)
            sumd = d.sum(dim=1, keepdim=True)
            gd = torch.matmul(g.transpose(-1, -2), d)
            gd = torch.cat((gd, sumd), dim=1)
            sumg = sumg.transpose(-1, -2)
            sumg = torch.cat((sumg, nb_vox), dim=1)
            sumg = sumg.expand([d.shape[0], sumg.shape[1], sumg.shape[2]])
            gd = torch.cat((gd, sumg), dim=2)
            return gd

        def eye(d):
            x = torch.eye(d, **info)
            z = x.new_zeros([1, d], **info)
            x = torch.cat((x, z), dim=0)
            z = x.new_zeros([d + 1, 1], **info)
            x = torch.cat((x, z), dim=1)
            return x

        identity = spatial.identity_grid(shape, **info)[None, ...]
        affine = torch.matmul(igg(identity), gd(identity, grid))
        affine = affine.transpose(-1, -2) + eye(nb_dim)
        affine = affine[..., :-1, :]
        if shift:
            affine = spatial.as_euclidean(affine)
            affine = spatial.affine_matmul(affine_shift, affine)
            affine = spatial.as_euclidean(affine)
            affine = spatial.affine_rmdiv(affine, affine_shift)
        affine = spatial.affine_make_square(affine)

        return affine