Exemple #1
0
    def forward(self, source, target, source_affine=None, target_affine=None):
        """

        Parameters
        ----------
        source : (sX, sY, sZ) tensor or str
        target : (tX, tY, tZ) tensor or str
        source_affine : (4, 4) tensor, optional
        target_affine : (4, 4) tensor, optional

        Returns
        -------
        warped : (tX, tY, tZ) tensor
            Source warped to target
        velocity : (vX, vY, vZ, 3) tensor
            Stationary velocity field
        affine : (4, 4) tensor, optional
            Affine of the velocity space

        """
        if self.verbose:
            print('Preprocessing... ', end='', flush=True)
        source, source_affine, source_orig, source_affine_orig \
            = self.load(source, source_affine)
        target, target_affine, target_orig, target_affine_orig \
            = self.load(target, target_affine)
        source = spatial.reslice(source, source_affine, target_affine,
                                 target.shape)
        if self.verbose:
            print('done.', flush=True)
            print('Registering... ', end='', flush=True)
        source = source[None, None]
        target = target[None, None]
        warped, vel, grid = super().forward(source, target)
        if self.verbose:
            print('done.', flush=True)
        del source, target, warped
        vel = vel[0]
        grid = grid[0]
        grid -= spatial.identity_grid(grid.shape[:-1],
                                      dtype=grid.dtype,
                                      device=grid.device)
        right_affine = target_affine.inverse() @ target_affine_orig
        right_affine = spatial.affine_grid(right_affine, target_orig.shape)
        grid = spatial.grid_pull(utils.movedim(grid, -1, 0),
                                 right_affine,
                                 bound='nearest',
                                 extrapolate=True)
        grid = utils.movedim(grid, 0, -1).add_(right_affine)
        left_affine = source_affine_orig.inverse() @ target_affine
        grid = spatial.affine_matvec(left_affine, grid)
        warped = spatial.grid_pull(source_orig, grid)
        return warped, vel, target_affine
Exemple #2
0
def pull1d(img, grid, grad=False, **kwargs):
    """Pull an image by a transform along the last dimension

    Parameters
    ----------
    img : (K, *spatial) tensor, Image
    grid : (*spatial) tensor, Sampling grid
    grad : bool, Sample gradients

    Returns
    -------
    warped_img : (K, *spatial) tensor
    warped_grad : (K, *spatial) tensor, if `grad`

    """
    if grid is None:
        if grad:
            bound = kwargs.get('bound', 'dft')
            return img, diff1d(img, dim=-1, bound=bound, side='c')
        else:
            return img, None
    kwargs.setdefault('extrapolate', True)
    kwargs.setdefault('bound', 'dft')
    img, grid = img.unsqueeze(-2), grid.unsqueeze(-1)
    warped = grid_pull(img, grid, **kwargs).squeeze(-2)
    if not grad:
        return warped
    grad = grid_grad(img, grid, **kwargs)
    grad = grad.squeeze(-1).squeeze(-2)
    return warped, grad
Exemple #3
0
def load_and_pull(volume, aff, shape, dtype=None, device=None):
    """

    Parameters
    ----------
    volume : Volume3D
    aff : (D+1,D+1) tensor
    shape : (D,) tuple

    Returns
    -------
    dat : tensor

    """

    backend = dict(dtype=dtype or aff.dtype, device=device or aff.device)
    aff = aff.to(**backend)
    identity = torch.eye(aff.shape[-1], **backend)
    fdata = volume.fdata(cache=False, **backend)
    inshape = fdata.shape
    inaff = volume.affine.to(**backend)
    aff = core.linalg.lmdiv(inaff, aff)
    if torch.allclose(aff, identity) and tuple(shape) == tuple(inshape):
        return fdata
    else:
        grid = spatial.affine_grid(aff, shape)
        return spatial.grid_pull(fdata[None, None, ...], grid[None, ...])[0, 0]
Exemple #4
0
def _resample_inplane(x, sett):
    """Force in-plane resolution of observed data to be greater or equal to recon vx.
    """
    if sett.force_inplane_res and sett.max_iter > 0:
        I = torch.eye(4, device=sett.device, dtype=torch.float64)
        for c in range(len(x)):
            for n in range(len(x[c])):
                # get image data
                dat = x[c][n].dat[None, None, ...]
                mat_x = x[c][n].mat
                dim_x = torch.as_tensor(x[c][n].dim, device=sett.device, dtype=torch.float64)
                vx_x = voxel_size(mat_x)
                # make grid
                D = I.clone()
                for i in range(3):
                    D[i, i] = sett.vx / vx_x[i]
                    if D[i, i] < 1.0:
                        D[i, i] = 1
                if float((I - D).abs().sum()) < 1e-4:
                    continue
                mat_x = mat_x.matmul(D)
                dim_x = D[:3, :3].inverse().mm(dim_x[:, None]).floor().squeeze().cpu().int().tolist()
                grid = affine_grid(D.type(dat.dtype), dim_x)
                # resample
                dat = grid_pull(dat, grid[None, ...], bound='zero', extrapolate=False, interpolation=0)
                # do label
                if x[c][n].label is not None:
                    x[c][n].label[0] = _warp_label(x[c][n].label[0], grid)
                # assign
                x[c][n].dat = dat[0, 0, ...]
                x[c][n].mat = mat_x
                x[c][n].dim = dim_x

    return x
Exemple #5
0
def transform_pointset_dense(points, grid, type='grid', bound='dct2'):
    """Transform a pointset

    Points must already be expressed in "grid voxels" coordinates.

    Parameters
    ----------
    points : (n, dim) tensor
        Set of coordinates, in voxel space
    grid : (*spatial, dim) tensor
        Dense transformation or displacement grid, in voxel space
    type : {'grid', 'disp'}, defualt='grid'
        Transformation or displacement
    bound : str, default='dct2'
        Boundary conditions for out-of-bounds data

    Returns
    -------
    points : (n, dim) tensor
        Transformed coordinates

    """

    dim = grid.shape[-1]
    points = utils.unsqueeze(points, 0, dim)
    grid = utils.movedim(grid, -1, 0)[None]
    delta = spatial.grid_pull(grid, points, bound=bound, extrapolate=True)
    delta = utils.movedim(delta, 1, -1)
    if type == 'disp':
        points = points + delta
    else:
        points = delta
    points = utils.squeeze(points, -2, dim - 1).squeeze(0)
    return points
Exemple #6
0
def _init_y_dat(x, y, sett):
    """ Make initial guesses of reconstucted image(s) using b-spline interpolation,
        with averaging if more than one observation per channel.
    """
    dim_y = y[0].dim
    mat_y = y[0].mat
    for c in range(len(x)):
        dat_y = torch.zeros(dim_y, dtype=torch.float32, device=sett.device)
        num_x = len(x[c])
        sm    = torch.zeros_like(dat_y)
        for n in range(num_x):
            # Get image data
            dat = x[c][n].dat[None, None, ...]
            # Make output grid
            mat = mat_y.solve(x[c][n].mat)[0]  # mat_x\mat_y
            grid = affine_grid(mat.type(dat.dtype), dim_y)
            # Do resampling
            mn = torch.min(dat)
            mx = torch.max(dat)
            dat = grid_pull(dat, grid[None, ...],
                bound='zero', extrapolate=False, interpolation=1)
            dat[dat < mn] = mn
            dat[dat > mx] = mx
            sm = sm + (dat[0, 0, ...].round() != 0)
            dat_y = dat_y + dat[0, 0, ...]
        sm[sm == 0] = 1
        y[c].dat = dat_y / sm

    return y
Exemple #7
0
def warp_label(label, grid):
    """Warp label image according to grid.
    """
    ndim = len(label.shape[2:])
    dtype_seg = label.dtype
    if dtype_seg not in (torch.half, torch.float, torch.double):
        # hard labels to one-hot labels
        n_batch = label.shape[0]
        u_labels = label.unique()
        n_labels = len(u_labels)
        label_w = torch.zeros((
            n_batch,
            n_labels,
        ) + tuple(label.shape[2:]),
                              device=label.device,
                              dtype=torch.float32)
        for i, l in enumerate(u_labels):
            label_w[..., i, ...] = label == l
    else:
        label_w = label
    # warp
    label_w = spatial.grid_pull(label_w,
                                grid,
                                bound='dct2',
                                extrapolate=True,
                                interpolation=1)
    if dtype_seg not in (torch.half, torch.float, torch.double):
        # one-hot labels to hard labels
        label_w = label_w.argmax(dim=1, keepdim=True).type(dtype_seg)
    else:
        # normalise one-hot labels
        label_w = label_w / (label_w.sum(dim=1, keepdim=True) + eps())

    return label_w
Exemple #8
0
 def pull(q, grid):
     aff = core.linalg.expm(q, basis)
     aff = spatial.affine_matmul(aff, target_aff)
     aff = spatial.affine_lmdiv(source_aff, aff)
     expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None))
     grid = spatial.affine_matvec(aff[expd], grid)
     moved = spatial.grid_pull(source, grid, **pull_opt)
     return moved
Exemple #9
0
 def pull(q, vel):
     grid = spatial.exp(vel)
     aff = core.linalg.expm(q, basis)
     aff = spatial.affine_matmul(aff, target_aff)
     aff = spatial.affine_lmdiv(source_aff, aff)
     grid = spatial.affine_matvec(aff, grid)
     moved = spatial.grid_pull(source, grid, **pull_opt)
     return moved
Exemple #10
0
def warp_image(image, grid):
    """Warp image according to grid.
    """
    image = spatial.grid_pull(image,
                              grid,
                              bound='dct2',
                              extrapolate=True,
                              interpolation=1)

    return image
Exemple #11
0
def _deform_1d(img, disp, grad=False):
    img = utils.movedim(img, 0, -2)
    disp = disp.unsqueeze(-1)
    disp = spatial.add_identity_grid(disp)
    wrp = spatial.grid_pull(img, disp, bound=BND, extrapolate=True)
    wrp = utils.movedim(wrp, -2, 0)
    if not grad:
        return wrp, None
    grd = spatial.grid_grad(img, disp, bound=BND, extrapolate=True)
    grd = utils.movedim(grd.squeeze(-1), -2, 0)
    return wrp, grd
Exemple #12
0
def _crop_y(y, sett):
    """ Crop output images FOV to a fixed dimension

    Args:
        y (_output()): _output data.

    Returns:
        y (_output()): Cropped output data.

    """
    if not sett.crop:
        return y
    device = sett.device
    # Output image information
    mat_y = y[0].mat
    vx_y = voxel_size(mat_y)
    # Define cropped FOV
    mat_mu, dim_mu = _bb_atlas('atlas_t1',
        fov=sett.fov, dtype=torch.float64, device=device)
    # Modulate atlas with voxel size
    mat_vx = torch.diag(torch.cat((
        vx_y, torch.ones(1, dtype=torch.float64, device=device))))
    mat_mu = mat_mu.mm(mat_vx)
    dim_mu = mat_vx[:3, :3].inverse().mm(dim_mu[:, None]).floor().squeeze()
    # Make output grid
    M = mat_mu.solve(mat_y)[0].type(y[0].dat.dtype)
    grid = affine_grid(M, dim_mu)[None, ...]
    # Crop
    for c in range(len(y)):
        y[c].dat = grid_pull(y[c].dat[None, None, ...], grid,
                             bound='zero', extrapolate=False,
                             interpolation=0)[0, 0, ...]
        # Do labels?
        if y[c].label is not None:
            y[c].label = grid_pull(y[c].label[None, None, ...], grid,
                                   bound='zero', extrapolate=False,
                                   interpolation=0)[0, 0, ...]
        y[c].mat = mat_mu
        y[c].dim = tuple(dim_mu.int().tolist())

    return y
Exemple #13
0
    def eval_position(self, t):
        """Evaluate the position at a given (batched) time"""
        # convert (0, 1) to (0, n)
        shape = t.shape
        t = t.flatten()
        t = t.clamp(0, 1) * (len(self.waypoints) - 1)

        # interpolate
        y = self.coeff.T  # [D, K]
        t = t.unsqueeze(-1)  # [N, 1]
        x = grid_pull(y, t, interpolation=self.order, bound=self.bound)
        x = x.T  # [N, D]
        x = x.reshape([*shape, x.shape[-1]])
        return x
Exemple #14
0
def smart_pull_grid(vel, grid, type='disp', *args, **kwargs):
    """Interpolate a velocity/grid/displacement field.

    Notes
    -----
    Defaults differ from grid_pull:
    - bound -> dft
    - extrapolate -> True

    Parameters
    ----------
    vel : ([batch], *spatial, ndim) tensor
        Velocity
    grid : ([batch], *spatial, ndim) tensor
        Transformation field
    kwargs : dict
        Options to ``grid_pull``

    Returns
    -------
    pulled_vel : ([batch], *spatial, ndim) tensor
        Velocity

    """
    if grid is None or vel is None:
        return vel
    kwargs.setdefault('bound', 'dft')
    kwargs.setdefault('extrapolate', True)
    dim = vel.shape[-1]
    if type == 'grid':
        id = spatial.identity_grid(vel.shape[-dim - 1:-1],
                                   **utils.backend(vel))
        vel = vel - id
    vel = utils.movedim(vel, -1, -dim - 1)
    vel_no_batch = vel.dim() == dim + 1
    grid_no_batch = grid.dim() == dim + 1
    if vel_no_batch:
        vel = vel[None]
    if grid_no_batch:
        grid = grid[None]
    vel = spatial.grid_pull(vel, grid, *args, **kwargs)
    vel = utils.movedim(vel, -dim - 1, -1)
    if vel_no_batch:
        vel = vel[0]
    if type == 'grid':
        id = spatial.identity_grid(vel.shape[-dim - 1:-1],
                                   **utils.backend(vel))
        vel += id
    return vel
Exemple #15
0
    def eval_radius(self, t):
        """Evaluate the radius at a given (batched) time"""
        if not torch.is_tensor(self.radius):
            return self.radius

        # convert (0, 1) to (0, n)
        shape = t.shape
        t = t.flatten()
        t = t.clamp(0, 1) * (len(self.waypoints) - 1)

        # interpolate
        y = self.coeff_radius  # [K]
        t = t.unsqueeze(-1)  # [N, 1]
        x = grid_pull(y, t, interpolation=self.order, bound=self.bound)
        x = x.reshape(shape)
        return x
Exemple #16
0
def _jhistc_backward(g,
                     x,
                     w=None,
                     order=0,
                     bound='replicate',
                     extrapolate=True,
                     gradx=True,
                     gradw=False):
    """Compute derivative of the joint histogram.

    The input must already be a soft mapping to bins indices.

    Parameters
    ----------
    g : (b, bins, bins) tensor
    x : (b, n, 2) tensor
    w : ([b], n) tensor, optional
    order : int, default=0
    bound : {'zero', 'nearest'}, default='nearest'
    extrapolate : bool, default=True
    gradx : bool, default=True
    gradw : bool, default=False

    Returns
    -------
    gx : (b, n, 2) tensor, if gradx
    gw : ([b], n) tensor, if gradw

    """
    extrapolate = 1 if extrapolate else 2
    opt = dict(interpolation=order, bound=bound, extrapolate=extrapolate)
    x = x.unsqueeze(-3)  # make 2d spatial
    g = g.unsqueeze(-3)  # add channel dimension
    out = []
    if gradx:
        gx = grid_grad(g, x, **opt)
        gx = gx.squeeze(-3).squeeze(-3)
        if w is not None:
            gx *= w.unsqueeze(-1)
        out.append(gx)
    if gradw and w is not None:
        gw = grid_pull(g, x, **opt)
        gw = gw.squeeze(-2).squeeze(-2)  # drop spatial + channel
        out.append(gw)
    elif gradw:
        out.append(None)
    return out[0] if len(out) == 1 else tuple(out)
Exemple #17
0
def intensity_to_rgb(image,
                     min=None,
                     max=None,
                     colormap='gray',
                     n=256,
                     eq=False):
    """Colormap an intensity image

    Parameters
    ----------
    image : (*batch, H, W) tensor
        A (batch of) 2d image
    min : tensor_like, optional
        Minimum value. Should be broadcastable to batch.
        Default: min of image for each batch element.
    max : tensor_like, optional
        Maximum value. Should be broadcastable to batch.
        Default: max of image for each batch element.
    colormap : str or (K, 3) tensor, default='gray'
        A colormap or the name of a matplotlib colormap.
    n : int, default=256
        Number of color levels to use.
    eq : bool or {'linear', 'quadratic', 'log', None}, default=None
        Apply histogram equalization.
        If 'quadratic' or 'log', the histogram of the transformed signal
        is equalized.

    Returns
    -------
    rgb : (*batch, H, W, 3) tensor
        A (batch of) of RGB image.

    """
    image = torch.as_tensor(image).detach()
    image = intensity_preproc(image, min=min, max=max, eq=eq)

    # map
    colormap = _get_colormap_intensity(colormap, n, image.dtype, image.device)
    shape = image.shape
    image = image.mul_(n - 1).clamp_(0, n - 1)
    image = image.reshape([1, -1, 1])
    colormap = colormap.T.reshape([1, 3, -1])
    image = spatial.grid_pull(colormap, image)
    image = image.reshape([3, *shape])
    image = utils.movedim(image, 0, -1)

    return image
Exemple #18
0
    def forward(self, x, grid):
        """

        Parameters
        ----------
        x : (batch, channel, *spatial_in) tensor
            Input image to deform
        grid : (batch, *spatial_out, len(spatial_in)) tensor
            Transformation grid

        Returns
        -------
        pulled : (batch, channel, *spatial_out) tensor
            Deformed image.

        """
        return spatial.grid_pull(x, grid, self.interpolation, self.bound,
                                 self.extrapolate)
Exemple #19
0
def pull1d(img, grid, dim, grad=False, **kwargs):
    if grid is None:
        if grad:
            bound = kwargs.get('bound', 'dft')
            return img, spatial.diff1d(img, dim=dim, bound=bound, side='c')
        else:
            return img, None
    kwargs.setdefault('extrapolate', True)
    kwargs.setdefault('bound', 'dft')
    img = core.utils.movedim(img, dim, -1).unsqueeze(-2)
    grid = core.utils.movedim(grid, dim, -1).unsqueeze(-1)
    warped = spatial.grid_pull(img, grid, **kwargs)
    warped = core.utils.movedim(warped.squeeze(-2), -1, dim)
    if not grad:
        return warped, None
    grad = spatial.grid_grad(img, grid, **kwargs)
    grad = core.utils.movedim(grad.squeeze(-1).squeeze(-2), -1, dim)
    return warped, grad
Exemple #20
0
def _warp_label(label, grid):
    """Warp a label image.
    """
    u = label.unique()
    if u.numel() > 255:
        raise ValueError('Too many label values.')
    f1 = torch.zeros(grid.shape[:3],
        device=label.device, dtype=label.dtype)
    p1 = f1.clone()
    for u1 in u:
        g0 = (label == u1).float()
        tmp = grid_pull(g0[None, None, ...], grid[None, ...],
            bound='zero', extrapolate=False, interpolation=1)[0, 0, ...]
        msk1 = tmp > p1
        p1[msk1] = tmp[msk1]
        f1[msk1] = u1

    return f1
Exemple #21
0
def _reslice_dat_3d(dat,
                    affine,
                    dim_out,
                    interpolation='linear',
                    bound='zero',
                    extrapolate=False):
    """Reslice 3D image data.

    Parameters
    ----------
    dat : (Xi, Yi, Zi), tensor_like
        Input image data.
    affine : (4, 4), tensor_like
        Affine transformation that maps from voxels in output image to
        voxels in input image.
    dim_out : (Xo, Yo, Zo), list or tuple
        Output image dimensions.
    interpolation : str, default='linear'
        Interpolation order.
    bound : str, default='zero'
        Boundary condition.
    extrapolate : bool, default=False
        Extrapolate out-of-bounds data.

    Returns
    -------
    dat : (dim_out), tensor_like
        Resliced image data.

    """
    if len(dat.shape) != 3:
        raise ValueError('Input error: len(dat.shape) != 3')

    grid = affine_grid(affine, dim_out).type(dat.dtype)
    grid = grid[None, ...]
    dat = dat[None, None, ...]
    dat = grid_pull(dat,
                    grid,
                    bound=bound,
                    interpolation=interpolation,
                    extrapolate=extrapolate)
    dat = dat[0, 0, ...]

    return dat
Exemple #22
0
def smart_pull(tensor, grid):
    """Pull iff grid is defined (+ add/remove batch dim).

    Parameters
    ----------
    tensor : (channels, *input_shape) tensor
        Input volume
    grid : (*output_shape, D) tensor or None
        Sampling grid

    Returns
    -------
    pulled : (channels, *output_shape) tensor
        Sampled volume

    """
    if grid is None:
        return tensor
    return spatial.grid_pull(tensor[None, ...], grid[None, ...])[0]
Exemple #23
0
    def eval_grad_position(self, t):
        """Evaluate position and its gradient wrt time"""
        # convert (0, 1) to (0, n)
        shape = t.shape
        t = t.flatten()
        t = t.clamp(0, 1) * (len(self.waypoints) - 1)

        # interpolate
        y = self.coeff.T  # [D, K]
        t = t.unsqueeze(-1)  # [N, 1]
        x = grid_pull(y, t, interpolation=self.order, bound=self.bound)
        x = x.T  # [N, D]
        g = grid_grad(y, t, interpolation=self.order, bound=self.bound)
        g = g.squeeze(-1).T  # [N, D]

        x = x.reshape([*shape, x.shape[-1]])
        g = g.reshape([*shape, g.shape[-1]])
        g *= (len(self.waypoints) - 1)
        return x, g
Exemple #24
0
 def slice_to(self, stack, cache_result=False, recompute=True):
     aff = self.exp(cache_result=cache_result, recompute=recompute)
     if recompute or not hasattr(self, '_sliced'):
         aff = spatial.affine_matmul(aff, self.affine)
         aff_reorient = spatial.affine_reorient(self.affine, self.shape, stack.layout)
         aff = spatial.affine_lmdiv(aff_reorient, aff)
         aff = spatial.affine_grid(aff, self.shape)
         sliced = spatial.grid_pull(self.dat, aff, bound=self.bound,
                                    extrapolate=self.extrapolate)
         fwhm = [0] * self.dim
         fwhm[-1] = stack.slice_width / spatial.voxel_size(aff_reorient)[-1]
         sliced = spatial.smooth(sliced, fwhm, dim=self.dim, bound=self.bound)
         slices = []
         for stack_slice in stack.slices:
             aff = spatial.affine_matmul(stack.affine, )
             aff = spatial.affine_lmdiv(aff_reorient, )
     if cache_result:
         self._sliced = sliced
     return sliced
Exemple #25
0
def smart_pull_jac(jac, grid, *args, **kwargs):
    """Interpolate a jacobian field.

    Notes
    -----
    Defaults differ from grid_pull:
    - bound -> dft
    - extrapolate -> True

    Parameters
    ----------
    jac : ([batch], *spatial_in, ndim, ndim) tensor
        Jacobian field
    grid : ([batch], *spatial_out, ndim) tensor
        Transformation field
    kwargs : dict
        Options to ``grid_pull``

    Returns
    -------
    pulled_jac : ([batch], *spatial_out, ndim) tensor
        Jacobian field

    """
    if grid is None or jac is None:
        return jac
    kwargs.setdefault('bound', 'dft')
    kwargs.setdefault('extrapolate', True)
    dim = jac.shape[-1]
    jac = jac.reshape([*jac.shape[:-2], dim * dim])  # collapse matrix
    jac = utils.movedim(jac, -1, -dim - 1)
    jac_no_batch = jac.dim() == dim + 1
    grid_no_batch = grid.dim() == dim + 1
    if jac_no_batch:
        jac = jac[None]
    if grid_no_batch:
        grid = grid[None]
    jac = spatial.grid_pull(jac, grid, *args, **kwargs)
    jac = utils.movedim(jac, -dim - 1, -1)
    jac = jac.reshape([*jac.shape[:-1], dim, dim])
    if jac_no_batch:
        jac = jac[0]
    return jac
Exemple #26
0
def pull(image, grid, interpolation=1, bound='dct2', extrapolate=False):
    """Sample a multi-channel image

    Parameters
    ----------
    image : (channel, *inshape) tensor
    grid : (*outshape, dim) tensor

    Returns
    -------
    imageout : (channel, *outshape)

    """
    image = image[None]
    grid = grid[None]
    image = grid_pull(image,
                      grid,
                      interpolation=interpolation,
                      bound=bound,
                      extrapolate=extrapolate)[0]
    return image
Exemple #27
0
def pull_grid(gridin, grid, interpolation=1, bound='dft', extrapolate=True):
    """Sample a displacement field.

    Parameters
    ----------
    gridin : (*inshape, dim) tensor
    grid : (*outshape, dim) tensor

    Returns
    -------
    gridout : (*outshape, dim) tensor

    """
    gridin = movedim(gridin, -1, 0)[None]
    grid = grid[None]
    gridout = grid_pull(gridin,
                        grid,
                        interpolation=interpolation,
                        bound=bound,
                        extrapolate=extrapolate)
    gridout = movedim(gridout[0], 0, -1)
    return gridout
Exemple #28
0
    def forward(self, x, grid, **overload):
        """

        Parameters
        ----------
        x : (batch, channel, *spatial_in) tensor
            Input image to deform
        grid : (batch, *spatial_out, len(spatial_in)) tensor
            Transformation grid
        overload : dict
            All parameters defined at build time can be overridden
            at call time.

        Returns
        -------
        pulled : (batch, channel, *spatial_out) tensor
            Deformed image.

        """
        interpolation = overload.get('interpolation', self.interpolation)
        bound = overload.get('bound', self.bound)
        extrapolate = overload.get('extrapolate', self.extrapolate)
        return spatial.grid_pull(x, grid, interpolation, bound, extrapolate)
Exemple #29
0
    def __call__(self,
                 logaff,
                 grad=False,
                 hess=False,
                 gradmov=False,
                 hessmov=False,
                 in_line_search=False):
        """
        logaff : (..., nb) tensor, Lie parameters
        grad : Whether to compute and return the gradient wrt `logaff`
        hess : Whether to compute and return the Hessian wrt `logaff`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., logaff) tensor, optional, Gradient wrt Lie parameters
        h : (..., logaff) tensor, optional, Hessian wrt Lie parameters
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % logplot == 0

        # jitter
        # if not hasattr(self, '_fixed'):
        #     idj = spatial.identity_grid(self.fixed.shape[-self.dim:],
        #                                 jitter=True,
        #                                 **utils.backend(self.fixed))
        #     self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt)
        #     del idj
        # fixed = self._fixed
        fixed = self.fixed

        # forward
        if not torch.is_tensor(self.basis):
            self.basis = spatial.affine_basis(self.basis, self.dim,
                                              **utils.backend(logaff))
        aff = linalg.expm(logaff, self.basis)
        with torch.no_grad():
            _, gaff = linalg._expm(logaff,
                                   self.basis,
                                   grad_X=True,
                                   hess_X=False)

        aff = spatial.affine_matmul(aff, self.affine_fixed)
        aff = spatial.affine_lmdiv(self.affine_moving, aff)
        # /!\ derivatives are not "homogeneous" (they do not have a one
        # on the bottom right): we should *not* use affine_matmul and
        # such (I only lost a day...)
        gaff = torch.matmul(gaff, self.affine_fixed)
        gaff = linalg.lmdiv(self.affine_moving, gaff)
        # haff = torch.matmul(haff, self.affine_fixed)
        # haff = linalg.lmdiv(self.affine_moving, haff)
        if self.id is None:
            shape = self.fixed.shape[-self.dim:]
            self.id = spatial.identity_grid(shape,
                                            **utils.backend(logaff),
                                            jitter=False)
        grid = spatial.affine_matvec(aff, self.id)
        warped = spatial.grid_pull(self.moving, grid, **pullopt)
        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        cat=iscat,
                        dim=self.dim)

        # gradient/Hessian of the log-likelihood in observed space
        if not grad and not hess:
            llx = self.loss.loss(warped, fixed)
        elif not hess:
            llx, grad = self.loss.loss_grad(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
        else:
            llx, grad, hess = self.loss.loss_grad_hess(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
            if hessmov:
                hessmov = spatial.grid_push(hess, grid, **pullopt)
        del warped

        # compose with spatial gradients + dot product with grid
        if grad is not False or hess is not False:
            mugrad = spatial.grid_grad(self.moving, grid, **pullopt)
            grad = jg(mugrad, grad)
            if hess is not False:
                hess = jhj(mugrad, hess)
                grad, hess = regutils.affine_grid_backward(grad,
                                                           hess,
                                                           grid=self.id)
            else:
                grad = regutils.affine_grid_backward(grad)  # , grid=self.id)
            dim2 = self.dim * (self.dim + 1)
            grad = grad.reshape([*grad.shape[:-2], dim2])
            gaff = gaff[..., :-1, :]
            gaff = gaff.reshape([*gaff.shape[:-2], dim2])
            grad = linalg.matvec(gaff, grad)
            if hess is not False:
                hess = hess.reshape([*hess.shape[:-4], dim2, dim2])
                # haff = haff[..., :-1, :, :-1, :]
                # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2])
                hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2))
                hess = hess.abs().sum(-1).diag_embed()
            del mugrad

        # print objective
        llx = llx.item()
        ll = llx
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}',
                    end='\n')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\n')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [ll]
        if grad is not False:
            out.append(grad)
        if hess is not False:
            out.append(hess)
        if gradmov is not False:
            out.append(gradmov)
        if hessmov is not False:
            out.append(hessmov)
        return tuple(out) if len(out) > 1 else out[0]
Exemple #30
0
    def __call__(self, logaff, grad=False, hess=False, in_line_search=False):
        """
        logaff : (..., nb) tensor, Lie parameters
        grad : Whether to compute and return the gradient wrt `logaff`
        hess : Whether to compute and return the Hessian wrt `logaff`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., logaff) tensor, optional, Gradient wrt Lie parameters
        h : (..., logaff) tensor, optional, Hessian wrt Lie parameters
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        # select correct gradient mode
        if grad:
            logaff.requires_grad_()
            if logaff.grad is not None:
                logaff.grad.zero_()
        if grad and not torch.is_grad_enabled():
            with torch.enable_grad():
                return self(logaff, grad, in_line_search=in_line_search)
        elif not grad and torch.is_grad_enabled():
            with torch.no_grad():
                return self(logaff, grad, in_line_search=in_line_search)

        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                   and (self.n_iter - 1) % logplot == 0

        # jitter
        # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], jitter=True,
        #                             **utils.backend(self.fixed))
        # fixed = spatial.grid_pull(self.fixed, idj, **pullopt)
        # del idj
        fixed = self.fixed

        # forward
        if not torch.is_tensor(self.basis):
            self.basis = spatial.affine_basis(self.basis, self.dim,
                                              **utils.backend(logaff))
        aff = linalg.expm(logaff, self.basis)
        aff = spatial.affine_matmul(aff, self.affine_fixed)
        aff = spatial.affine_lmdiv(self.affine_moving, aff)
        if self.id is None:
            shape = self.fixed.shape[-self.dim:]
            self.id = spatial.identity_grid(shape, **utils.backend(logaff))
        grid = spatial.affine_matvec(aff, self.id)
        warped = spatial.grid_pull(self.moving, grid, **pullopt)
        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        cat=iscat,
                        dim=self.dim)

        # gradient/Hessian of the log-likelihood in observed space
        llx = self.loss.loss(warped, fixed)
        del warped

        # print objective
        lll = llx
        llx = llx.item()
        ll = llx
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}',
                    end='\n')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\n')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [lll]
        if grad is not False:
            lll.backward()
            grad = logaff.grad.clone()
            out.append(grad)
        logaff.requires_grad_(False)
        return tuple(out) if len(out) > 1 else out[0]