Example #1
0
 def greens(self, shape, **backend):
     opt = dict(absolute=self.absolute,
                membrane=self.membrane,
                bending=self.bending,
                lame=self.lame,
                factor=self.factor)
     kernel = spatial.greens(shape, **opt, **backend)
     return kernel
Example #2
0
 def set_kernel(self, kernel=None):
     if kernel is None:
         kernel = spatial.greens(self.shape, **self.penalty,
                                 factor=self.factor / py.prod(self.shape),
                                 voxel_size=self.voxel_size,
                                 **utils.backend(self.dat))
     self.kernel = kernel
     return self
Example #3
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if self.dim1d is None:
            shape = self.spatial_shape
            vx = self.voxel_size
        else:
            shape = [self.spatial_shape[self.dim1d]]
            vx = self.voxel_size[self.dim1d]

        self.kernel = spatial.greens(shape,
                                     **self.reg_prm,
                                     factor=self.factor,
                                     voxel_size=vx,
                                     **utils.backend(self))
Example #4
0
def register(fixed=None,
             moving=None,
             dim=None,
             lam=1.,
             loss='mse',
             optim='nesterov',
             hilbert=None,
             max_iter=500,
             sub_iter=16,
             lr=None,
             ls=0,
             plot=False,
             klosure=RegisterStep,
             kernel=None,
             **prm):
    """Nonlinear registration between two images using smooth displacements.

    Parameters
    ----------
    fixed : (..., K, *spatial) tensor
        Fixed image
    moving : (..., K, *spatial) tensor
        Moving image
    dim : int, default=`fixed.dim() - 1`
        Number of spatial dimensions
    lam : float, default=1
        Modulate regularisation
    loss : {'mse', 'cat'} or OptimizationLoss, default='mse'
        'mse': Mean-squared error
        'cat': Categorical cross-entropy
    optim : {'relax', 'cg', 'gd', 'momentum', 'nesterov'}, default='relax'
        'relax'     : Gauss-Newton (linear system solved by relaxation)
        'cg'        : Gauss-Newton (linear system solved by conjugate gradient)
        'gd'        : Gradient descent
        'momentum'  : Gradient descent with momentum
        'nesterov'  : Nesterov-accelerated gradient descent
        'lbfgs'     : Limited-memory BFGS
    hilbert : bool, default=True
        Use hilbert preconditioning (not used if optim is second order)
    max_iter : int, default=100
        Maximum number of Gauss-Newton or Gradient descent iterations
    sub_iter : int, default=16
        Number of relax/cg iterations per GN step
    lr : float, default=1
        Learning rate.
    ls : int, default=0
        Number of line search iterations.
    absolute : float, default=1e-4
        Penalty on absolute displacements
    membrane : float, default=1e-3
        Penalty on first derivatives
    bending : float, default=0.2
        Penalty on second derivatives
    lame : (float, float), default=(0.05, 0.2)
        Penalty on zooms and shears

    Returns
    -------
    disp : (..., *spatial, dim) tensor
        Displacement field.

    """
    defaults_velocity(prm)

    # If no inputs provided: demo "circle to square"
    if fixed is None or moving is None:
        fixed, moving = phantoms.demo_register(cat=(loss == 'cat'))

    # init tensors
    fixed, moving = utils.to_max_backend(fixed, moving)
    dim = dim or (fixed.dim() - 1)
    shape = fixed.shape[-dim:]
    lam = lam / py.prod(shape)
    prm['factor'] = lam
    velshape = [*fixed.shape[:-dim - 1], *shape, dim]
    vel = torch.zeros(velshape, **utils.backend(fixed))

    # init optimizer
    optim = regutils.make_iteroptim_grid(optim, lr, ls, max_iter, sub_iter,
                                         **prm)
    if hilbert is None:
        hilbert = not optim.requires_hess
    if hilbert and kernel is None:
        kernel = spatial.greens(shape, **prm, **utils.backend(fixed))
    if kernel is not None:
        optim.preconditioner = lambda x: spatial.greens_apply(x, kernel)

    # init loss
    loss = losses.make_loss(loss, dim)

    print(
        f'{"it":3s} | {"fit":^12s} + {"reg":^12s} = {"obj":^12s} | {"gain":^12s}'
    )
    print('-' * 63)
    closure = klosure(moving,
                      fixed,
                      loss,
                      plot=plot,
                      max_iter=optim.max_iter,
                      **prm)
    vel = optim.iter(vel, closure)
    print('')
    return vel
Example #5
0
    def __call__(self,
                 vel,
                 grad=False,
                 hess=False,
                 gradmov=False,
                 hessmov=False):
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        dim = vel.shape[-1]
        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        in_line_search = not grad and not hess
        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % logplot == 0

        # forward
        if self.kernel is None:
            self.kernel = spatial.greens(vel.shape[-dim - 1:-1], **self.prm,
                                         **utils.backend(vel))
        grid = spatial.shoot(vel, self.kernel, steps=self.steps, **self.prm)
        warped = spatial.grid_pull(self.moving,
                                   grid,
                                   bound='dct2',
                                   extrapolate=True)

        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        vel,
                        cat=iscat,
                        dim=dim)

        # gradient/Hessian of the log-likelihood in observed space
        if not grad and not hess:
            llx = self.loss.loss(warped, self.fixed)
        elif not hess:
            llx, grad = self.loss.loss_grad(warped, self.fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
        else:
            llx, grad, hess = self.loss.loss_grad_hess(warped, self.fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
            if hessmov:
                hessmov = spatial.grid_push(hess, grid, **pullopt)
        del warped

        # compose with spatial gradients
        if grad is not False or hess is not False:
            if self.mugrad is None:
                self.mugrad = spatial.diff(self.moving,
                                           dim=list(range(-dim, 0)),
                                           bound='dct2')
            if grad is not False:
                grad = grad.neg_()  # "final inverse" to "initial"
                grad = spatial.grid_push(grad, grid)
                grad = jg(self.mugrad, grad)
            if hess is not False:
                hess = spatial.grid_push(hess, grid)
                hess = jhj(self.mugrad, hess)

        # add regularization term
        vgrad = spatial.regulariser_grid(vel, **self.prm, kernel=True)
        llv = 0.5 * (vel * vgrad).sum()
        if grad is not False:
            grad += vgrad
        del vgrad

        # print objective
        llx = llx.item()
        llv = llv.item()
        ll = llx + llv
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g}',
                    end='\r')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\r')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [ll]
        if grad is not False:
            out.append(grad)
        if hess is not False:
            out.append(hess)
        if gradmov is not False:
            out.append(gradmov)
        if hessmov is not False:
            out.append(hessmov)
        return tuple(out) if len(out) > 1 else out[0]
Example #6
0
    def __call__(self, vel, grad=False):
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        # select correct gradient mode
        if grad:
            vel.requires_grad_()
            if vel.grad is not None:
                vel.grad.zero_()
        if grad and not torch.is_grad_enabled():
            with torch.enable_grad():
                return self(vel, grad)
        elif not grad and torch.is_grad_enabled():
            with torch.no_grad():
                return self(vel, grad)

        dim = vel.shape[-1]
        nvox = py.prod(vel.shape[-dim - 1:-1])

        in_line_search = not grad
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % 20 == 0

        # forward
        if self.kernel is None:
            self.kernel = spatial.greens(vel.shape[-dim - 1:-1], **self.prm,
                                         **utils.backend(vel))
        grid = spatial.shoot(vel, self.kernel, steps=self.steps, **self.prm)
        warped = spatial.grid_pull(self.moving,
                                   grid,
                                   bound='dct2',
                                   extrapolate=True)

        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        vel,
                        cat=iscat,
                        dim=dim)

        # log-likelihood in observed space
        llx = self.loss.loss(warped, self.fixed)
        del warped

        # add regularization term
        vgrad = spatial.regulariser_grid(vel, **self.prm).div_(nvox)
        llv = 0.5 * (vel * vgrad).sum()
        lll = llx + llv
        del vgrad

        # print objective
        llx = llx.item()
        llv = llv.item()
        ll = lll.item()
        if not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g}',
                    end='\r')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\r')

        out = [lll]
        if grad:
            lll.backward()
            out.append(vel.grad)
        vel.requires_grad_(False)
        return tuple(out) if len(out) > 1 else out[0]
Example #7
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch size
        overload : dict

        Returns
        -------
        field : (batch, channel, *shape) tensor
            Generated random field

        """

        # get arguments
        shape = overload.get('shape', self.shape)
        mean = overload.get('mean', self.mean)
        voxel_size = overload.get('voxel_size', self.voxel_size)
        dtype = overload.get('dtype', self.dtype)
        device = overload.get('device', self.device)
        backend = dict(dtype=dtype, device=device)

        # sample if parameters are callable
        nb_dim = len(shape)
        voxel_size = utils.make_vector(voxel_size, nb_dim, **backend)
        voxel_size = voxel_size.tolist()
        lame = py.make_list(self.lame, 2)

        if (hasattr(self, '_greens')
                and self._voxel_size == voxel_size
                and self._shape == shape):
            greens = self._greens.to(dtype=dtype, device=device)
        else:
            greens = spatial.greens(
                shape,
                absolute=self.absolute,
                membrane=self.membrane,
                bending=self.bending,
                lame=self.lame,
                voxel_size=voxel_size,
                device=device,
                dtype=dtype)
            if any(lame):
                greens, scale, _ = torch.svd(greens)
                scale = scale.sqrt_()
                greens *= scale.unsqueeze(-1)
            else:
                greens = greens.sqrt_()

            if self.cache_greens:
                self._greens = greens
                self._voxel_size = voxel_size
                self._shape = shape

        sample = torch.randn([2, batch, *shape, nb_dim], **backend)

        # multiply by square root of greens
        if greens.dim() > nb_dim:  # lame
            sample = linalg.matvec(greens, sample)
        else:
            sample = sample * greens.unsqueeze(-1)
            voxel_size = utils.make_vector(voxel_size, nb_dim, **backend)
            sample = sample / voxel_size.sqrt()
        sample = fft.complex(sample[0], sample[1])

        # inverse Fourier transform
        dims = list(range(-nb_dim-1, -1))
        sample = fft.real(fft.ifftn(sample, dim=dims))
        sample *= py.prod(shape)

        # add mean
        sample += mean

        return sample
Example #8
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch size

        Other Parameters
        ----------------
        shape : sequence[int], optional
        channel : int, optional
        voxel_size : float or (dim,) vector_like, optional
        device : torch.device, optional
        dtype : torch.dtype, optional

        Returns
        -------
        field : (batch, channel, *shape) tensor
            Generated random field

        """

        # get arguments
        shape = overload.get('shape', self.shape)
        channel = overload.get('channel', self.channel)
        voxel_size = overload.get('voxel_size', self.voxel_size)
        dtype = overload.get('dtype', self.dtype)
        device = overload.get('device', self.device)
        backend = dict(dtype=dtype, device=device)

        # sample if parameters are callable
        nb_dim = len(shape)
        voxel_size = utils.make_vector(voxel_size, nb_dim, **backend)
        voxel_size = voxel_size.tolist()

        if (hasattr(self, '_greens')
                and self._voxel_size == voxel_size
                and self._channel == channel
                and self._shape == shape):
            greens = self._greens.to(dtype=dtype, device=device)
        else:
            mean = utils.make_vector(self.mean, channel, **backend)
            absolute = utils.make_vector(self.absolute, channel, **backend)
            membrane = utils.make_vector(self.membrane, channel, **backend)
            bending = utils.make_vector(self.bending, channel, **backend)

            greens = []
            for c in range(channel):
                greens.append(spatial.greens(
                    shape,
                    absolute=absolute[c],
                    membrane=membrane[c],
                    bending=bending[c],
                    lame=0,
                    voxel_size=voxel_size,
                    device=device,
                    dtype=dtype))
            greens = torch.stack(greens)
            greens = greens.sqrt_()

            if self.cache_greens:
                self._greens = greens
                self._voxel_size = voxel_size
                self._shape = shape

        # sample white noise
        sample = torch.randn([2, batch, channel, *shape], **backend)
        sample *= greens.unsqueeze(-1)
        sample = fft.complex(sample[0], sample[1])

        # inverse Fourier transform
        dims = list(range(-nb_dim, 0))
        sample = fft.real(fft.ifftn(sample, dim=dims))
        sample *= py.prod(shape)

        # add mean
        sample += utils.unsqueeze(mean, -1, len(shape))

        return sample
Example #9
0
    def forward(self, velocity, fwd=None, inv=None, voxel_size=None):
        """

        Parameters
        ----------
        velocity :(batch, *spatial, dim) tensor
            Initial velocity field.
        fwd : bool, default=self.fwd
        inv : bool, default=self.inv
        voxel_size : sequence[float], default=self.voxel_size

        Returns
        -------
        forward : (batch, *spatial, dim) tensor, if `forward is True`
            Forward displacement (if `displacement is True`) or
            transformation (if `displacement is False`) field.
        inverse : (batch, *spatial, dim) tensor, if `inverse is True`
            Inverse displacement (if `displacement is True`) or
            transformation (if `displacement is False`) field.

        """
        fwd = fwd if fwd is not None else self.fwd
        inv = inv if inv is not None else self.inv
        voxel_size = voxel_size if voxel_size is not None else self.voxel_size

        shoot_opt = {
            'steps': self.steps,
            'displacement': self.displacement,
            'voxel_size': voxel_size,
            'absolute': self.absolute,
            'membrane': self.membrane,
            'bending': self.bending,
            'lame': self.lame,
            'factor': self.factor,
        }
        greens_prm = {
            'absolute': self.absolute,
            'membrane': self.membrane,
            'bending': self.bending,
            'lame': self.lame,
            'factor': self.factor,
            'voxel_size': voxel_size,
            'shape': velocity.shape[1:-1],
        }

        if self.cache_greens:
            if getattr(self, '_greens_prm', None) == greens_prm:
                greens = self._greens.to(**utils.backend(velocity))
            else:
                greens = spatial.greens(**greens_prm,
                                        **utils.backend(velocity))
                self._greens = greens
                self._greens_prm = greens_prm
        else:
            greens = spatial.greens(**greens_prm, **utils.backend(velocity))

        shoot_fn = spatial.shoot_approx if self.approx else spatial.shoot

        output = []
        if inv:
            y, iy = shoot_fn(velocity,
                             greens,
                             return_inverse=True,
                             **shoot_opt)
            if fwd:
                output.append(y)
            output.append(iy)
        elif fwd:
            y = shoot_fn(velocity, greens, **shoot_opt)
            output.append(y)

        return output if len(output) > 1 else \
               output[0] if len(output) == 1 else \
               None