Example #1
0
 def regulariser(self, v=None):
     if v is None:
         v = self.dat
     return spatial.regulariser_grid(v, **self.penalty,
                                     factor=self.factor / py.prod(self.shape),
                                     voxel_size=self.voxel_size)
Example #2
0
    def __call__(self, vel, grad=False):
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        # select correct gradient mode
        if grad:
            vel.requires_grad_()
            if vel.grad is not None:
                vel.grad.zero_()
        if grad and not torch.is_grad_enabled():
            with torch.enable_grad():
                return self(vel, grad)
        elif not grad and torch.is_grad_enabled():
            with torch.no_grad():
                return self(vel, grad)

        dim = vel.shape[-1]

        in_line_search = not grad
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % 20 == 0

        # forward
        if self.id is None:
            self.id = spatial.identity_grid(vel.shape[-dim - 1:-1],
                                            **utils.backend(vel))
        grid = self.id + vel
        warped = spatial.grid_pull(self.moving,
                                   grid,
                                   bound='dct2',
                                   extrapolate=True)

        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        vel,
                        cat=iscat,
                        dim=dim)

        # log-likelihood in observed space
        llx = self.loss.loss(warped, self.fixed)
        del warped

        # add regularization term
        vgrad = spatial.regulariser_grid(vel, **self.prm)
        llv = 0.5 * (vel * vgrad).sum()
        lll = llx + llv
        del vgrad

        # print objective
        llx = llx.item()
        llv = llv.item()
        ll = lll.item()
        if not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g}',
                    end='\r')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\r')

        out = [lll]
        if grad:
            lll.backward()
            out.append(vel.grad)
        vel.requires_grad_(False)
        return tuple(out) if len(out) > 1 else out[0]
Example #3
0
    def __call__(self,
                 vel,
                 grad=False,
                 hess=False,
                 gradmov=False,
                 hessmov=False,
                 in_line_search=False):
        """
        vel : (..., *spatial, dim) tensor, Displacement
        grad : Whether to compute and return the gradient wrt `vel`
        hess : Whether to compute and return the Hessian wrt `vel`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., *spatial, dim) tensor, optional, Gradient wrt velocity
        h : (..., *spatial, ?) tensor, optional, Hessian wrt velocity
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        dim = vel.shape[-1]
        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        # in_line_search = not grad and not hess
        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % logplot == 0

        # forward
        if self.id is None:
            self.id = spatial.identity_grid(vel.shape[-dim - 1:-1],
                                            **utils.backend(vel))
        grid = self.id + vel
        warped = spatial.grid_pull(self.moving, grid, **pullopt)

        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        vel,
                        cat=iscat,
                        dim=dim)

        # gradient/Hessian of the log-likelihood in observed space
        if not grad and not hess and not hessmov:
            llx = self.loss.loss(warped, self.fixed)
        elif not hess and not hessmov:
            llx, grad = self.loss.loss_grad(warped, self.fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
        else:
            llx, grad, hess = self.loss.loss_grad_hess(warped, self.fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
            if hessmov:
                hessmov = spatial.grid_push(hess, grid, **pullopt)
        del warped

        # compose with spatial gradients
        if grad is not False or hess is not False:
            mugrad = spatial.grid_grad(self.moving, grid, **pullopt)
            if grad is not False:
                grad = jg(mugrad, grad)
            if hess is not False:
                hess = jhj(mugrad, hess)

        # add regularization term
        vgrad = spatial.regulariser_grid(vel, **self.prm, kernel=False)
        llv = 0.5 * (vel * vgrad).sum()
        if grad is not False:
            grad += vgrad
        del vgrad

        # print objective
        llx = llx.item()
        llv = llv.item()
        ll = llx + llv
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g}',
                    end='\r')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\r')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [ll]
        if grad is not False:
            out.append(grad)
        if hess is not False:
            out.append(hess)
        if gradmov is not False:
            out.append(gradmov)
        if hessmov is not False:
            out.append(hessmov)
        return tuple(out) if len(out) > 1 else out[0]
Example #4
0
    def __call__(self,
                 vel,
                 grad=False,
                 hess=False,
                 gradmov=False,
                 hessmov=False):
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        dim = vel.shape[-1]
        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        in_line_search = not grad and not hess
        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % logplot == 0

        # forward
        if self.kernel is None:
            self.kernel = spatial.greens(vel.shape[-dim - 1:-1], **self.prm,
                                         **utils.backend(vel))
        grid = spatial.shoot(vel, self.kernel, steps=self.steps, **self.prm)
        warped = spatial.grid_pull(self.moving,
                                   grid,
                                   bound='dct2',
                                   extrapolate=True)

        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        vel,
                        cat=iscat,
                        dim=dim)

        # gradient/Hessian of the log-likelihood in observed space
        if not grad and not hess:
            llx = self.loss.loss(warped, self.fixed)
        elif not hess:
            llx, grad = self.loss.loss_grad(warped, self.fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
        else:
            llx, grad, hess = self.loss.loss_grad_hess(warped, self.fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
            if hessmov:
                hessmov = spatial.grid_push(hess, grid, **pullopt)
        del warped

        # compose with spatial gradients
        if grad is not False or hess is not False:
            if self.mugrad is None:
                self.mugrad = spatial.diff(self.moving,
                                           dim=list(range(-dim, 0)),
                                           bound='dct2')
            if grad is not False:
                grad = grad.neg_()  # "final inverse" to "initial"
                grad = spatial.grid_push(grad, grid)
                grad = jg(self.mugrad, grad)
            if hess is not False:
                hess = spatial.grid_push(hess, grid)
                hess = jhj(self.mugrad, hess)

        # add regularization term
        vgrad = spatial.regulariser_grid(vel, **self.prm, kernel=True)
        llv = 0.5 * (vel * vgrad).sum()
        if grad is not False:
            grad += vgrad
        del vgrad

        # print objective
        llx = llx.item()
        llv = llv.item()
        ll = llx + llv
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g}',
                    end='\r')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\r')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [ll]
        if grad is not False:
            out.append(grad)
        if hess is not False:
            out.append(hess)
        if gradmov is not False:
            out.append(gradmov)
        if hessmov is not False:
            out.append(hessmov)
        return tuple(out) if len(out) > 1 else out[0]