Esempio n. 1
0
def _process_reg(dat, mat, mat_a, mat_fix, dim_fix, write):
    """Process registration results.
    """
    N = len(dat)
    rdat = torch.zeros((N, ) + dim_fix,
                       dtype=dat[0].dtype,
                       device=dat[0].device)
    for n in range(N):  # loop over input images
        if torch.all(mat_a[n] - torch.eye(4, device=mat_a[n].device) == 0):
            rdat[n] = dat[n]
        else:
            mat_r = lmdiv(mat[n], mat_a[n].mm(mat_fix))
            rdat[n] = _reslice_dat_3d(dat[n], mat_r, dim_fix)
        if write == 'reslice':
            dat[n] = rdat[n]
            mat[n] = mat_fix
        elif write == 'affine':
            mat[n] = lmdiv(mat_a[n], mat[n])
    # Write output to disk?
    if write in ['reslice', 'affine']:
        write = True
    else:
        write = False

    return dat, mat, write, rdat
Esempio n. 2
0
def _msk_fov(dat, mat, mat0, dim0):
    """Mask field-of-view (FOV) of image data according to other image's
    FOV.

    Parameters
    ----------
    dat : (X, Y, Z), tensor
        Image data.
    mat : (4, 4), tensor
        Image's affine.
    mat0 : (4, 4), tensor
        Other image's affine.
    dim0 : (3, ), list/tuple
        Other image's dimensions.

    Returns
    -------
    dat : (X, Y, Z), tensor
        Masked image data.

    """
    dim = dat.shape
    M = lmdiv(mat0, mat)  # mat0\mat1
    grid = affine_grid(M, dim)
    msk = (grid[..., 0] >= 1) & (grid[..., 0] <= dim0[0]) & \
          (grid[..., 1] >= 1) & (grid[..., 1] <= dim0[1]) & \
          (grid[..., 2] >= 1) & (grid[..., 2] <= dim0[2])
    dat[~msk] = 0

    return dat
Esempio n. 3
0
def _imatrix(M):
    """Return the parameters for creating an affine transformation matrix.

    Args:
        mat (torch.tensor): Affine transformation matrix (4, 4).

    Returns:
        P (torch.tensor): Affine parameters (<=12).

    Authors:
        John Ashburner & Stefan Kiebel, as part of the SPM12 software.

    """
    device = M.device
    dtype = M.dtype
    one = torch.tensor(1.0, device=device, dtype=dtype)
    # Translations and Zooms
    R = M[:-1, :-1]
    C = cholesky(R.t().mm(R))
    C = C.t()
    d = torch.diag(C)
    P = torch.tensor(
        [M[0, 3], M[1, 3], M[2, 3], 0, 0, 0, d[0], d[1], d[2], 0, 0, 0],
        device=device,
        dtype=dtype)
    if R.det() < 0:  # Fix for -ve determinants
        P[6] = -P[6]
    # Shears
    C = lmdiv(torch.diag(torch.diag(C)), C)
    P[9] = C[0, 1]
    P[10] = C[0, 2]
    P[11] = C[1, 2]
    R0 = affine_matrix_classic(
        torch.tensor([0, 0, 0, 0, 0, 0, P[6], P[7], P[8], P[9], P[10],
                      P[11]])).to(device)
    R0 = R0[:-1, :-1]
    R1 = R.mm(R0.inverse())  # This just leaves rotations in matrix R1
    # Correct rounding errors
    rang = lambda x: torch.min(torch.max(x, -one), one)
    P[4] = torch.asin(rang(R1[0, 2]))
    if (torch.abs(P[4]) - pi / 2)**2 < 1e-9:
        P[3] = 0
        P[5] = torch.atan2(-rang(R1[1, 0]), rang(-R1[2, 0] / R1[0, 2]))
    else:
        c = torch.cos(P[4])
        P[3] = torch.atan2(rang(R1[1, 2] / c), rang(R1[2, 2] / c))
        P[5] = torch.atan2(rang(R1[0, 1] / c), rang(R1[0, 0] / c))

    return P
Esempio n. 4
0
def _rescale(dat, mn_out=0, mx_out=511):
    """ Rescales image intensities between mn_out and mx_out.

    """
    backend = dict(dtype=dat.dtype, device=dat.device)
    msk = torch.isfinite(dat).bitwise_not_()
    msk = msk.bitwise_or_(dat == dat.min()).bitwise_or_(dat == dat.max())
    dat = dat.masked_fill_(msk, 0)
    # Make scaling to set image intensities between mn_out and mx_out
    mnmx_in = torch.as_tensor([[dat.min(), 1], [dat.max(), 1]], **backend)
    mnmx_out = torch.as_tensor([mn_out, mx_out], **backend)
    sf = linalg.lmdiv(mnmx_in, mnmx_out.unsqueeze(-1)).squeeze(-1)
    # Rescale
    dat = dat.mul_(sf[0]).add_(sf[1])
    # Clamp
    dat = dat.clamp_(mn_out, mx_out)

    return dat
Esempio n. 5
0
 def search_direction(self, grad, hess):
     grad, hess = self._add_marquardt(grad, hess)
     step = linalg.lmdiv(hess, grad[..., None])[..., 0]
     step.mul_(-self.lr)
     return step
Esempio n. 6
0
    def __call__(self,
                 logaff,
                 grad=False,
                 hess=False,
                 gradmov=False,
                 hessmov=False,
                 in_line_search=False):
        """
        logaff : (..., nb) tensor, Lie parameters
        grad : Whether to compute and return the gradient wrt `logaff`
        hess : Whether to compute and return the Hessian wrt `logaff`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., logaff) tensor, optional, Gradient wrt Lie parameters
        h : (..., logaff) tensor, optional, Hessian wrt Lie parameters
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % logplot == 0

        # jitter
        # if not hasattr(self, '_fixed'):
        #     idj = spatial.identity_grid(self.fixed.shape[-self.dim:],
        #                                 jitter=True,
        #                                 **utils.backend(self.fixed))
        #     self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt)
        #     del idj
        # fixed = self._fixed
        fixed = self.fixed

        # forward
        if not torch.is_tensor(self.basis):
            self.basis = spatial.affine_basis(self.basis, self.dim,
                                              **utils.backend(logaff))
        aff = linalg.expm(logaff, self.basis)
        with torch.no_grad():
            _, gaff = linalg._expm(logaff,
                                   self.basis,
                                   grad_X=True,
                                   hess_X=False)

        aff = spatial.affine_matmul(aff, self.affine_fixed)
        aff = spatial.affine_lmdiv(self.affine_moving, aff)
        # /!\ derivatives are not "homogeneous" (they do not have a one
        # on the bottom right): we should *not* use affine_matmul and
        # such (I only lost a day...)
        gaff = torch.matmul(gaff, self.affine_fixed)
        gaff = linalg.lmdiv(self.affine_moving, gaff)
        # haff = torch.matmul(haff, self.affine_fixed)
        # haff = linalg.lmdiv(self.affine_moving, haff)
        if self.id is None:
            shape = self.fixed.shape[-self.dim:]
            self.id = spatial.identity_grid(shape,
                                            **utils.backend(logaff),
                                            jitter=False)
        grid = spatial.affine_matvec(aff, self.id)
        warped = spatial.grid_pull(self.moving, grid, **pullopt)
        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        cat=iscat,
                        dim=self.dim)

        # gradient/Hessian of the log-likelihood in observed space
        if not grad and not hess:
            llx = self.loss.loss(warped, fixed)
        elif not hess:
            llx, grad = self.loss.loss_grad(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
        else:
            llx, grad, hess = self.loss.loss_grad_hess(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
            if hessmov:
                hessmov = spatial.grid_push(hess, grid, **pullopt)
        del warped

        # compose with spatial gradients + dot product with grid
        if grad is not False or hess is not False:
            mugrad = spatial.grid_grad(self.moving, grid, **pullopt)
            grad = jg(mugrad, grad)
            if hess is not False:
                hess = jhj(mugrad, hess)
                grad, hess = regutils.affine_grid_backward(grad,
                                                           hess,
                                                           grid=self.id)
            else:
                grad = regutils.affine_grid_backward(grad)  # , grid=self.id)
            dim2 = self.dim * (self.dim + 1)
            grad = grad.reshape([*grad.shape[:-2], dim2])
            gaff = gaff[..., :-1, :]
            gaff = gaff.reshape([*gaff.shape[:-2], dim2])
            grad = linalg.matvec(gaff, grad)
            if hess is not False:
                hess = hess.reshape([*hess.shape[:-4], dim2, dim2])
                # haff = haff[..., :-1, :, :-1, :]
                # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2])
                hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2))
                hess = hess.abs().sum(-1).diag_embed()
            del mugrad

        # print objective
        llx = llx.item()
        ll = llx
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}',
                    end='\n')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\n')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [ll]
        if grad is not False:
            out.append(grad)
        if hess is not False:
            out.append(hess)
        if gradmov is not False:
            out.append(gradmov)
        if hessmov is not False:
            out.append(hessmov)
        return tuple(out) if len(out) > 1 else out[0]
Esempio n. 7
0
def fit_curve_cat(f,
                  s,
                  lam=0,
                  gamma=0,
                  vx=1,
                  max_iter=8,
                  tol=1e-8,
                  max_levels=4):
    """Fit the curve that maximizes the categorical likelihood

    Parameters
    ----------
    f : (*shape) tensor
        Observed grid of binary labels or smooth probabilities.
    s : BSplineCurve
        Initial curve (will be modified in-place)

    Other Parameters
    ----------------
    lam : float, default=0
        Centerline regularization (bending)
    gamma : float, default=0
        Radius regularization (membrane)
    vx : float, default=1
        Voxel size
    max_iter : int, default=128
        Maximum number of iterations per level
        (This will me multiplied by 2 at each resolution level, such that
        more iterations are used at coarser levels).
    tol : float, default=1e-8
        Unused
    max_levels : int, default=4
        Number of multi-resolution levels.

    Returns
    -------
    s : BSplineCurve
        Fitted curve

    """
    TINY = 1e-6
    fig = elem = None

    max_iter_position = 8
    max_iter_radius = 4

    backend = utils.backend(s.coeff)
    vx = utils.make_vector(vx, f.dim(), **backend)
    vx0 = vx.clone()
    n0 = f.numel()

    # Build pyramid by restriction
    shapes = [f.shape]
    images = [f]
    vxs = [vx]
    for n_level in range(max_levels - 1):
        shape = [pymath.ceil(s / 2) for s in shapes[-1]]
        if all(s == 1 for s in shape):
            break
        shapes.append(shape)
        images.append(restrict(f.unsqueeze(-1), shapes[-1]).squeeze(-1))
        s.restrict(shapes[-2], shapes[-1])
        vx = vx * (torch.as_tensor(shapes[-2], **backend) /
                   torch.as_tensor(shapes[-1], **backend))
        vxs.append(vx)

    start = time.time()

    shape = None
    level = len(images) + 1
    while images:
        level -= 1
        print('-' * 16, 'level', level, '-' * 16)

        if shape is not None:
            s.prolong(shape, shapes[-1])
        f, shape, vx = images.pop(-1), shapes.pop(-1), vxs.pop(-1)
        scl = vx.prod() / vx0.prod()
        x = identity_grid(f.shape, **backend)
        if lam:
            L = lam * bending3(len(s.coeff), **backend)
            reg = L.matmul(s.coeff).mul_(vx.square())
            reg = 0.5 * (s.coeff * reg).sum(dtype=torch.double)
        else:
            reg = 0
        if gamma:
            Lr = gamma * membrane3(len(s.coeff_radius), **backend)
            Lr /= vx.prod().pow_(1 / len(vx)).square_()
            reg_radius = Lr.matmul(s.coeff_radius)
            reg_radius = 0.5 * (s.coeff_radius *
                                reg_radius).sum(dtype=torch.double)
        else:
            reg_radius = 0

        def get_nll(e):
            ie = (1 - e).log()
            e = e.log()
            if f.dtype is torch.bool:
                ll = e[f].sum(dtype=torch.double) + ie[~f].sum(
                    dtype=torch.double)
            else:
                ll = (e * f).sum(dtype=torch.double) + (ie * (1 - f)).sum(
                    dtype=torch.double)
            ll = -ll
            return ll

        nll = float('inf')
        max_iter_level = max_iter * 2**((level - 1) // 2)
        for n_iter in range(max_iter_level):

            nll0_prev = nll

            for n_iter_position in range(max_iter_position):

                t, d = min_dist(x, s)
                p = s.eval_position(t).sub_(x)  # residuals
                r = s.eval_radius(t)
                r = torch.as_tensor(r, **utils.backend(x))
                e = dist_to_prob(d, r, tiny=TINY)
                nll_prev = nll
                nll = get_nll(e)
                prec = radius_to_prec(r)

                # gradient of the categorical term
                omf = (1 - f) if f.dtype.is_floating_point else f.bitwise_not()
                ome = (1 - e)
                g = (omf / ome - 1) * (-prec)
                h = omf * e / ome.square()
                g = g.unsqueeze(-1)
                h = h.unsqueeze(-1)
                prec = prec.unsqueeze(-1)

                acc = 0.5
                h = h * (prec * p).square()
                if acc != 1:
                    h += (1 - acc) * g.abs()
                g = g * p

                # push
                g = s.push_position(g, t)
                h = s.push_position(h, t)

                # resolution scale
                g *= scl
                h *= scl
                nll *= scl

                # regularisation + solve
                if lam:
                    reg = L.matmul(s.coeff).mul_(vx.square())
                    g += reg
                    reg = 0.5 * (s.coeff * reg).sum(dtype=torch.double)
                    # h += L[1, :].abs().sum()
                    g = torch.stack([
                        linalg.lmdiv(h1.diag() +
                                     (v1 * v1) * L, g1[:, None])[:, 0]
                        for v1, g1, h1 in zip(vx, g.T, h.T)
                    ], -1)
                else:
                    g.div_(h)
                    reg = 0
                s.coeff.sub_(g)
                # s.coeff.clamp_min_(0)
                # for d, sz in enumerate(f.shape):
                #     s.coeff[:, d].clamp_max_(sz-1)

                fig, elem = plot_nll([nll, reg, reg_radius], e, f, s.waypoints,
                                     fig, elem)
                nll = nll + reg + reg_radius
                print('position', n_iter, n_iter_position, nll.item(),
                      (nll_prev - nll).item() / n0)
                s.update_waypoints()
                # if nll_prev - nll < tol * f.numel():
                #     break

            if level < 3:
                max_iter_radius_level = max_iter_radius
            else:
                max_iter_radius_level = 0
            for n_iter_radius in range(max_iter_radius_level):

                alpha = (2.355 / 2)**2
                t, d = min_dist(x, s)
                r = s.eval_radius(t)
                r = torch.as_tensor(r, **utils.backend(x))
                e = dist_to_prob(d, r, TINY)
                d = d.square_()
                nll_prev = nll
                nll = get_nll(e)

                # gradient of the categorical term
                omf = (1 - f) if f.dtype.is_floating_point else f.bitwise_not()
                ome = (1 - e)
                alpha = alpha * d / r.pow(3)
                g = (omf / ome - 1) * alpha

                acc = 0
                h = omf * e / ome.square()
                h *= alpha.square()
                if acc != 1:
                    h += (1 - acc) * g.abs() * 3 / r

                # push
                g = s.push_radius(g, t)
                h = s.push_radius(h, t)

                # resolution scale
                g *= scl
                h *= scl
                nll *= scl

                # regularisation + solve
                if gamma:
                    reg_radius = Lr.matmul(s.coeff_radius)
                    g += reg_radius
                    reg_radius = 0.5 * (s.coeff_radius *
                                        reg_radius).sum(dtype=torch.double)
                    g = linalg.lmdiv(h.diag() + L, g[:, None])[:, 0]
                else:
                    g.div_(h)
                    reg_radius = 0

                # solve
                s.coeff_radius -= g
                s.coeff_radius.clamp_min_(0.5)

                fig, elem = plot_nll([nll, reg, reg_radius], e, f, s.waypoints,
                                     fig, elem)
                nll = nll + reg + reg_radius
                print('radius', n_iter, n_iter_radius, nll.item(),
                      (nll_prev - nll).item() / n0)
                s.update_radius()
                # if nll_prev - nll < tol * f.numel():
                #     break

            if not n_iter % 10:
                print(n_iter, nll.item(), (nll0_prev - nll).item() / n0)
            # if nll0_prev - nll < tol * f.numel():
            #     print('Converged')
            #     break

    stop = time.time()
    print(stop - start)
Esempio n. 8
0
def _mean_space(Mat, Dim, vx=None):
    """Compute a (mean) model space from individual spaces.

    Args:
        Mat (torch.tensor): N subjects' orientation matrices (N, 4, 4).
        Dim (torch.tensor): N subjects' dimensions (N, 3).
        vx (torch.tensor|tuple|float, optional): Voxel size (3,), defaults to None (estimate from input).

    Returns:
        mat (torch.tensor): Mean orientation matrix (4, 4).
        dim (torch.tensor): Mean dimensions (3,).
        vx (torch.tensor): Mean voxel size (3,).

    Authors:
        John Ashburner, as part of the SPM12 software.

    """
    device = Mat.device
    dtype = Mat.dtype
    N = Mat.shape[0]  # Number of subjects
    inf = float('inf')
    one = torch.tensor(1.0, device=device, dtype=dtype)
    if vx is None:
        vx = torch.tensor([inf, inf, inf], device=device, dtype=dtype)
    if isinstance(vx, float) or isinstance(vx, int):
        vx = (vx, ) * 3
    if isinstance(vx, tuple) and len(vx) == 3:
        vx = torch.tensor([vx[0], vx[1], vx[2]], device=device, dtype=dtype)
    # To float64
    Mat = Mat.type(dtype)
    Dim = Dim.type(dtype)
    # Get affine basis
    basis = 'SE'
    dim = 3 if Dim[0, 2] > 1 else 2
    B = affine_basis(basis, dim, device=device, dtype=dtype)

    # Find combination of 90 degree rotations and flips that brings all
    # the matrices closest to axial
    Mat0 = Mat.clone()
    pmatrix = torch.tensor(
        [[0, 1, 2], [1, 0, 2], [2, 0, 1], [2, 1, 0], [0, 2, 1], [1, 2, 0]],
        device=device)

    for n in range(N):  # Loop over subjects
        vx1 = voxel_size(Mat[n, ...])
        R = Mat[n, ...].mm(
            torch.diag(torch.cat((vx1, one[..., None]))).inverse())[:-1, :-1]
        minss = inf
        minR = torch.eye(3, dtype=dtype, device=device)
        for i in range(6):  # Permute (= 'rotate + flip') axes
            R1 = torch.zeros((3, 3), dtype=dtype, device=device)
            R1[pmatrix[i, 0], 0] = 1
            R1[pmatrix[i, 1], 1] = 1
            R1[pmatrix[i, 2], 2] = 1
            for j in range(8):  # Mirror (= 'flip') axes
                fd = [(j & 1) * 2 - 1, (j & 2) - 1, (j & 4) / 2 - 1]
                F = torch.diag(torch.tensor(fd, dtype=dtype, device=device))
                R2 = F.mm(R1)
                ss = torch.sum((R.mm(R2.inverse()) -
                                torch.eye(3, dtype=dtype, device=device))**2)
                if ss < minss:
                    minss = ss
                    minR = R2
        rdim = torch.abs(minR.mm(Dim[n, ...][..., None] - 1))
        R2 = minR.inverse()
        R22 = R2.mm((torch.div(
            torch.sum(R2, dim=0, keepdim=True).t(), 2, rounding_mode='floor') -
                     1) * rdim)
        minR = torch.cat((R2, R22), dim=1)
        minR = torch.cat(
            (minR, torch.tensor([0, 0, 0, 1], device=device,
                                dtype=dtype)[None, ...]),
            dim=0)
        Mat[n, ...] = Mat[n, ...].mm(minR)

    # Average of the matrices in Mat
    mat = meanm(Mat)

    # If average involves shears, then find the closest matrix that does not
    # require them.
    C_ix = torch.tensor(
        [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15],
        device=device)  # column-major ordering from (4, 4) tensor
    p = _imatrix(mat)
    if torch.sum(p[[9, 10, 11]]**2) > 1e-8:
        B2 = torch.zeros((3, 4, 4), device=device, dtype=dtype)
        B2[0, 0, 0] = 1
        B2[1, 1, 1] = 1
        B2[2, 2, 2] = 1

        p = torch.zeros(9, device=device, dtype=dtype)
        for n_iter in range(10000):
            # Rotations + Translations
            R, dR = _expm(p[[0, 1, 2, 3, 4, 5]], B, grad_X=True)
            # Zooms
            Z, dZ = _expm(p[[6, 7, 8]], B2, grad_X=True)

            M = R.mm(Z)
            dM = torch.zeros((4, 4, 9), device=device, dtype=dtype)
            for n in range(6):
                dM[..., n] = dR[n, ...].mm(Z)
            for n in range(3):
                dM[..., 6 + n] = R.mm(dZ[n, ...])
            dM = dM.reshape((16, 9))
            d = M.flatten() - mat.flatten()
            gr = dM.t().mm(d[..., None])
            Hes = dM.t().mm(dM)
            p = p - lmdiv(Hes, gr)[:, 0]
            if torch.sum(gr**2) < 1e-8:
                break
        mat = M.clone()

    # Set required voxel size
    vx_out = vx.clone()
    vx = voxel_size(mat)
    vx_out[~torch.isfinite(vx_out)] = vx[~torch.isfinite(vx_out)]
    mat = mat.mm(torch.cat((vx_out / vx, one[..., None])).diag())
    vx = voxel_size(mat)

    # Ensure that the FoV covers all images, with a few voxels to spare
    mn_all = torch.zeros([3, N], device=device, dtype=dtype)
    mx_all = torch.zeros([3, N], device=device, dtype=dtype)
    for n in range(N):
        dm = Dim[n, ...]
        corners = torch.tensor([[1, dm[0], 1, dm[0], 1, dm[0], 1, dm[0]],
                                [1, 1, dm[1], dm[1], 1, 1, dm[1], dm[1]],
                                [1, 1, 1, 1, dm[2], dm[2], dm[2], dm[2]],
                                [1, 1, 1, 1, 1, 1, 1, 1]],
                               device=device,
                               dtype=dtype)
        M = lmdiv(mat, Mat0[n])
        vx1 = M[:-1, :].mm(corners)
        mx_all[..., n] = torch.max(vx1, dim=1)[0]
        mn_all[..., n] = torch.min(vx1, dim=1)[0]
    mx = mx_all.max(dim=1)[0]
    mn = mn_all.min(dim=1)[0]
    mx = torch.ceil(mx)
    mn = torch.floor(mn)

    # Make output dimensions and orientation matrix
    dim = mx - mn + 1  # Output dimensions
    off = torch.tensor([0, 0, 0], device=device, dtype=dtype)
    mat = mat.mm(
        torch.tensor([[1, 0, 0, mn[0] -
                       (off[0] + 1)], [0, 1, 0, mn[1] - (off[1] + 1)],
                      [0, 0, 1, mn[2] - (off[2] + 1)], [0, 0, 0, 1]],
                     device=device,
                     dtype=dtype))

    return mat, dim, vx
Esempio n. 9
0
def slice_correct(x, dim=-1, nb_iter=20):

    n = x.shape[dim]
    x = utils.movedim(x, dim, -1)
    shape = x.shape
    x = x.reshape([-1, n])

    vmax = x.max()

    m = x > 0
    x = x.log()
    x[~m] = 0
    a = x.new_zeros([n])
    g = x.new_zeros([n])
    h = x.new_zeros([n, n])

    for i in range(nb_iter):

        # compute forward differences
        d = a + x

        d = d[:, 1:] - d[:, :-1]
        d[~(m[:, 1:] & m[:, :-1])] = 0
        w = d.abs()
        print(w.mean().item())

        import matplotlib.pyplot as plt
        plt.subplot(1, 2, 1)
        plt.imshow((a + x).reshape(shape)[shape[0] // 2].exp(),
                   vmin=0,
                   vmax=vmax)
        plt.colorbar()
        plt.subplot(1, 2, 2)
        plt.imshow(w.reshape([*shape[:-1], n - 1]).mean(0))
        plt.colorbar()
        plt.show()

        w = w.clamp_min_(1e-5).reciprocal_()
        w[~(m[:, 1:] & m[:, :-1])] = 0

        # compute gradient
        g.zero_()
        g = g.reshape([n])
        g[1:] = (w * d).sum(0)
        g[:-1] -= (w * d).sum(0)

        # compute hessian
        h.zero_()
        h.diagonal(0, -1, -2)[1:] = w.sum(0)
        h.diagonal(0, -1, -2)[:-1] += w.sum(0)
        h.diagonal(1, -1, -2)[:] = -w.sum(0)
        h.diagonal(-1, -1, -2)[:] = h.diagonal(1, -1, -2)

        h = h.reshape([n, n])
        g = g.reshape([n, 1])
        g /= len(x)
        h /= len(x)
        h.diagonal(0, -1, -2).add_(h.diagonal(0, -1, -2).max() * 1e-6)
        a -= linalg.lmdiv(h, g).reshape([n])

        # zero center
        a -= a.mean()

    x = (a + x).exp()
    x = x.reshape(shape).movedim(-1, dim)
    return x, a
Esempio n. 10
0
def shim(fmap,
         max_order=2,
         mask=None,
         isocenter=None,
         dim=None,
         returns='corrected'):
    """Subtract a linear combination of spherical harmonics that minimize gradients

    Parameters
    ----------
    fmap : (..., *spatial) tensor
        Field map
    max_order : int, default=2
        Maximum order of the spherical harmonics
    mask : tensor, optional
        Mask of voxels to include (typically brain mask)
    isocenter : [sequence of] float, default=shape/2
        Coordinate of isocenter, in voxels
    dim : int, default=fmap.dim()
        Number of spatial dimensions
    returns : combination of {'corrected', 'correction', 'parameters'}, default='corrected'
        Components to return

    Returns
    -------
    corrected : (..., *spatial) tensor, if 'corrected' in `returns`
        Corrected field map (with spherical harmonics subtracted)
    correction : (..., *spatial) tensor, if 'correction' in `returns`
        Linear combination of spherical harmonics.
    parameters : (..., k) tensor, if 'parameters' in `returns`
        Parameters of the linear combination

    """
    fmap = torch.as_tensor(fmap)
    dim = dim or fmap.dim()
    shape = fmap.shape[-dim:]
    batch = fmap.shape[:-dim]
    backend = utils.backend(fmap)
    dims = list(range(-dim, 0))

    if mask is not None:
        mask = ~mask  # make it a mask of background voxels

    # compute gradients
    gmap = diff(fmap, dim=dims, side='f', bound='dct2')
    if mask is not None:
        gmap[..., mask, :] = 0
    gmap = gmap.reshape([*batch, -1])

    # compute basis of spherical harmonics
    basis = []
    for i in range(1, max_order + 1):
        b = spherical_harmonics(shape, i, isocenter, **backend)
        b = utils.movedim(b, -1, 0)
        b = diff(b, dim=dims, side='f', bound='dct2')
        if mask is not None:
            b[..., mask, :] = 0
        b = b.reshape([b.shape[0], *batch, -1])
        basis.append(b)
    basis = torch.cat(basis, 0)
    basis = utils.movedim(basis, 0, -1)  # (*batch, vox*dim, k)

    # solve system
    prm = linalg.lmdiv(basis, gmap[..., None], method='pinv')[..., 0]
    # > (*batch, k)

    # rebuild basis (without taking gradients)
    basis = []
    for i in range(1, max_order + 1):
        b = spherical_harmonics(shape, i, isocenter, **backend)
        b = utils.movedim(b, -1, 0)
        b = b.reshape([b.shape[0], *batch, *shape])
        basis.append(b)
    basis = torch.cat(basis, 0)
    basis = utils.movedim(basis, 0, -1)  # (*batch, vox*dim, k)

    comb = linalg.matvec(basis.unsqueeze(-2), utils.unsqueeze(prm, -2, dim))
    comb = comb[..., 0]
    fmap = fmap - comb

    returns = returns.split('+')
    out = []
    for ret in returns:
        if ret == 'corrected':
            out.append(fmap)
        elif ret == 'correction':
            out.append(comb)
        elif ret[0] == 'p':
            out.append(prm)
    return out[0] if len(out) == 1 else tuple(out)
Esempio n. 11
0
    def do_affine_only(self, logaff, grad=False, hess=False, in_line_search=False):
        """Forward pass for updating the affine component (nonlin is None)"""

        sumloss = None
        sumgrad = None
        sumhess = None

        # ==============================================================
        #                     EXPONENTIATE TRANSFORMS
        # ==============================================================
        logaff0 = logaff
        aff0, iaff0, gaff0, igaff0 = self.affine.exp2(logaff0, grad=True)

        has_printed = False
        for loss in self.losses:

            moving, fixed, factor = loss.moving, loss.fixed, loss.factor
            if loss.backward:
                aff00, gaff00 = iaff0, igaff0
            else:
                aff00, gaff00 = aff0, gaff0

            # ----------------------------------------------------------
            # build full transform
            # ----------------------------------------------------------
            aff = aff00 @ fixed.affine
            aff = linalg.lmdiv(moving.affine, aff)
            gaff = gaff00 @ fixed.affine
            gaff = linalg.lmdiv(moving.affine, gaff)
            phi = spatial.affine_grid(aff, fixed.shape)

            # ----------------------------------------------------------
            # forward pass
            # ----------------------------------------------------------
            warped, mask = moving.pull(phi, mask=True)
            if fixed.masked:
                if mask is None:
                    mask = fixed.mask
                else:
                    mask = mask * fixed.mask

            do_print = not (has_printed or self.verbose < 3 or in_line_search
                            or loss.backward)
            if do_print:
                has_printed = True
                if moving.previewed:
                    preview = moving.pull(phi, preview=True, dat=False)
                else:
                    preview = warped
                init = spatial.affine_lmdiv(moving.affine, fixed.affine)
                if _almost_identity(init) and moving.shape == fixed.shape:
                    init = moving.preview
                else:
                    init = spatial.affine_grid(init, fixed.shape)
                    init = moving.pull(init, preview=True, dat=False)
                self.mov2fix(fixed.preview, init, preview, dim=fixed.dim,
                             title=f'(affine) {self.n_iter:03d}')

            # ----------------------------------------------------------
            # derivatives wrt moving
            # ----------------------------------------------------------
            g = h = None
            loss_args = (warped, fixed.dat)
            loss_kwargs = dict(dim=fixed.dim, mask=mask)
            state = loss.loss.get_state()
            if not grad and not hess:
                llx = loss.loss.loss(*loss_args, **loss_kwargs)
            elif not hess:
                llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs)
            else:
                llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs)
            del loss_args, loss_kwargs
            if in_line_search:
                loss.loss.set_state(state)

            # ----------------------------------------------------------
            # chain rule -> derivatives wrt Lie parameters
            # ----------------------------------------------------------

            def compose_grad(g, h, g_mu, g_aff):
                """
                g, h : gradient/Hessian of loss wrt moving image
                g_mu : spatial gradients of moving image
                g_aff : gradient of affine matrix wrt Lie parameters
                returns g, h: gradient/Hessian of loss wrt Lie parameters
                """
                # Note that `h` can be `None`, but the functions I
                # use deal with this case correctly.
                dim = g_mu.shape[-1]
                g = jg(g_mu, g)
                h = jhj(g_mu, h)
                g, h = regutils.affine_grid_backward(g, h)
                dim2 = dim * (dim + 1)
                g = g.reshape([*g.shape[:-2], dim2])
                g_aff = g_aff[..., :-1, :]
                g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2])
                g = linalg.matvec(g_aff, g)
                if h is not None:
                    h = h.reshape([*h.shape[:-4], dim2, dim2])
                    h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2))
                    # h = h.abs().sum(-1).diag_embed()
                return g, h

            # compose with spatial gradients
            if grad or hess:
                mugrad = moving.pull_grad(phi, rotate=False)
                g, h = compose_grad(g, h, mugrad, gaff)

                if loss.backward:
                    g = g.neg_()
                sumgrad = (g.mul_(factor) if sumgrad is None else
                           sumgrad.add_(g, alpha=factor))
                if hess:
                    sumhess = (h.mul_(factor) if sumhess is None else
                               sumhess.add_(h, alpha=factor))
            sumloss = (llx.mul_(factor) if sumloss is None else
                       sumloss.add_(llx, alpha=factor))

        # TODO add regularization term
        lla = 0

        # ==============================================================
        #                           VERBOSITY
        # ==============================================================
        llx = sumloss.item()
        sumloss += lla
        lla = lla
        ll = sumloss.item()
        self.loss_value = ll
        if self.verbose and (self.verbose > 1 or not in_line_search):
            if in_line_search:
                line = '(search) | '
            else:
                line = '(affine) | '
            line += f'{self.n_iter:03d} | {llx:12.6g} + {lla:12.6g} = {ll:12.6g}'
            if not in_line_search:
                if self.ll_prev is not None:
                    gain = self.ll_prev - ll
                    # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                    line += f' | {gain:12.6g}'
                self.all_ll.append(ll)
                self.ll_prev = ll
                self.ll_max = max(self.ll_max, ll)
                self.n_iter += 1
            print(line, end='\r')

        # ==============================================================
        #                           RETURN
        # ==============================================================
        out = [sumloss]
        if grad:
            out.append(sumgrad)
        if hess:
            out.append(sumhess)
        return tuple(out) if len(out) > 1 else out[0]
Esempio n. 12
0
    def do_affine(self, logaff, grad=False, hess=False, in_line_search=False):
        """Forward pass for updating the affine component (nonlin is not None)"""

        sumloss = None
        sumgrad = None
        sumhess = None

        # ==============================================================
        #                     EXPONENTIATE TRANSFORMS
        # ==============================================================
        logaff0 = logaff
        aff_pos = self.affine.position[0].lower()
        if any(loss.backward for loss in self.losses):
            aff0, iaff0, gaff0, igaff0 = \
                self.affine.exp2(logaff0, grad=True,
                                 cache_result=not in_line_search)
            phi0, iphi0 = self.nonlin.exp2(cache_result=True, recompute=False)
        else:
            iaff0, igaff0, iphi0 = None, None, None
            aff0, gaff0 = self.affine.exp(logaff0, grad=True,
                                          cache_result=not in_line_search)
            phi0 = self.nonlin.exp(cache_result=True, recompute=False)

        has_printed = False
        for loss in self.losses:

            moving, fixed, factor = loss.moving, loss.fixed, loss.factor
            if loss.backward:
                phi00, aff00, gaff00 = iphi0, iaff0, igaff0
            else:
                phi00, aff00, gaff00 = phi0, aff0, gaff0

            # ----------------------------------------------------------
            # build left and right affine matrices
            # ----------------------------------------------------------
            aff_right, gaff_right = fixed.affine, None
            if aff_pos in 'fs':
                gaff_right = gaff00 @ aff_right
                gaff_right = linalg.lmdiv(self.nonlin.affine, gaff_right)
                aff_right = aff00 @ aff_right
            aff_right = linalg.lmdiv(self.nonlin.affine, aff_right)
            aff_left, gaff_left = self.nonlin.affine, None
            if aff_pos in 'ms':
                gaff_left = gaff00 @ aff_left
                gaff_left = linalg.lmdiv(moving.affine, gaff_left)
                aff_left = aff00 @ aff_left
            aff_left = linalg.lmdiv(moving.affine, aff_left)

            # ----------------------------------------------------------
            # build full transform
            # ----------------------------------------------------------
            if _almost_identity(aff_right) and fixed.shape == self.nonlin.shape:
                right = None
                phi = spatial.add_identity_grid(phi00)
            else:
                right = spatial.affine_grid(aff_right, fixed.shape)
                phi = regutils.smart_pull_grid(phi00, right)
                phi += right
            phi_right = phi
            if _almost_identity(aff_left) and moving.shape == self.nonlin.shape:
                left = None
            else:
                left = spatial.affine_grid(aff_left, self.nonlin.shape)
                phi = spatial.affine_matvec(aff_left, phi)

            # ----------------------------------------------------------
            # forward pass
            # ----------------------------------------------------------
            warped, mask = moving.pull(phi, mask=True)
            if fixed.masked:
                if mask is None:
                    mask = fixed.mask
                else:
                    mask = mask * fixed.mask

            do_print = not (has_printed or self.verbose < 3 or in_line_search
                            or loss.backward)
            if do_print:
                has_printed = True
                if moving.previewed:
                    preview = moving.pull(phi, preview=True, dat=False)
                else:
                    preview = warped
                init = spatial.affine_lmdiv(moving.affine, fixed.affine)
                if _almost_identity(init) and moving.shape == fixed.shape:
                    init = moving.dat
                else:
                    init = spatial.affine_grid(init, fixed.shape)
                    init = moving.pull(init, preview=True, dat=False)
                self.mov2fix(fixed.dat, init, preview, dim=fixed.dim,
                             title=f'(affine) {self.n_iter:03d}')

            # ----------------------------------------------------------
            # derivatives wrt moving
            # ----------------------------------------------------------
            g = h = None
            loss_args = (warped, fixed.dat)
            loss_kwargs = dict(dim=fixed.dim, mask=mask)
            state = loss.loss.get_state()
            if not grad and not hess:
                llx = loss.loss.loss(*loss_args, **loss_kwargs)
            elif not hess:
                llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs)
            else:
                llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs)
            del loss_args, loss_kwargs
            if in_line_search:
                loss.loss.set_state(state)

            # ----------------------------------------------------------
            # chain rule -> derivatives wrt Lie parameters
            # ----------------------------------------------------------

            def compose_grad(g, h, g_mu, g_aff):
                """
                g, h : gradient/Hessian of loss wrt moving image
                g_mu : spatial gradients of moving image
                g_aff : gradient of affine matrix wrt Lie parameters
                returns g, h: gradient/Hessian of loss wrt Lie parameters
                """
                # Note that `h` can be `None`, but the functions I
                # use deal with this case correctly.
                dim = g_mu.shape[-1]
                g = jg(g_mu, g)
                h = jhj(g_mu, h)
                g, h = regutils.affine_grid_backward(g, h)
                dim2 = dim * (dim + 1)
                g = g.reshape([*g.shape[:-2], dim2])
                g_aff = g_aff[..., :-1, :]
                g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2])
                g = linalg.matvec(g_aff, g)
                if h is not None:
                    h = h.reshape([*h.shape[:-4], dim2, dim2])
                    h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2))
                    # h = h.abs().sum(-1).diag_embed()
                return g, h

            if grad or hess:
                g0, g = g, None
                h0, h = h, None
                if aff_pos in 'ms':
                    g_left = regutils.smart_push(g0, phi_right, shape=self.nonlin.shape)
                    h_left = regutils.smart_push(h0, phi_right, shape=self.nonlin.shape)
                    mugrad = moving.pull_grad(left, rotate=False)
                    g_left, h_left = compose_grad(g_left, h_left, mugrad, gaff_left)
                    g, h = g_left, h_left
                if aff_pos in 'fs':
                    g_right, h_right = g0, h0
                    mugrad = moving.pull_grad(phi, rotate=False)
                    jac = spatial.grid_jacobian(phi0, right, type='disp', extrapolate=False)
                    jac = torch.matmul(aff_left[:-1, :-1], jac)
                    mugrad = linalg.matvec(jac.transpose(-1, -2), mugrad)
                    g_right, h_right = compose_grad(g_right, h_right, mugrad, gaff_right)
                    g = g_right if g is None else g.add_(g_right)
                    h = h_right if h is None else h.add_(h_right)

                if loss.backward:
                    g = g.neg_()
                sumgrad = (g.mul_(factor) if sumgrad is None else
                           sumgrad.add_(g, alpha=factor))
                if hess:
                    sumhess = (h.mul_(factor) if sumhess is None else
                               sumhess.add_(h, alpha=factor))
            sumloss = (llx.mul_(factor) if sumloss is None else
                       sumloss.add_(llx, alpha=factor))

        # TODO add regularization term
        lla = 0

        # ==============================================================
        #                           VERBOSITY
        # ==============================================================
        llx = sumloss.item()
        sumloss += lla
        sumloss += self.llv
        self.loss_value = sumloss.item()
        if self.verbose and (self.verbose > 1 or not in_line_search):
            ll = sumloss.item()
            llv = self.llv
            if in_line_search:
                line = '(search) | '
            else:
                line = '(affine) | '
            line += f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} + {lla:12.6g} = {ll:12.6g}'
            if not in_line_search:
                if self.ll_prev is not None:
                    gain = self.ll_prev - ll
                    # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                    line += f' | {gain:12.6g}'
                self.all_ll.append(ll)
                self.ll_prev = ll
                self.ll_max = max(self.ll_max, ll)
                self.n_iter += 1
            print(line, end='\r')

        # ==============================================================
        #                           RETURN
        # ==============================================================
        out = [sumloss]
        if grad:
            out.append(sumgrad)
        if hess:
            out.append(sumhess)
        return tuple(out) if len(out) > 1 else out[0]
Esempio n. 13
0
    def do_vel(self, vel, grad=False, hess=False, in_line_search=False):
        """Forward pass for updating the nonlinear component"""

        sumloss = None
        sumgrad = None
        sumhess = None

        # ==============================================================
        #                     EXPONENTIATE TRANSFORMS
        # ==============================================================
        if self.affine:
            aff0, iaff0 = self.affine.exp2(cache_result=True, recompute=False)
            aff_pos = self.affine.position[0].lower()
        else:
            aff_pos = 'x'
            aff0 = iaff0 = torch.eye(self.nonlin.dim + 1)
        vel0 = vel
        if any(loss.backward for loss in self.losses):
            phi0, iphi0 = self.nonlin.exp2(vel0,
                                           recompute=True,
                                           cache_result=not in_line_search)
            ivel0 = -vel0
        else:
            phi0 = self.nonlin.exp(vel0,
                                   recompute=True,
                                   cache_result=not in_line_search)
            iphi0 = ivel0 = None
        aff0 = aff0.to(phi0)
        iaff0 = iaff0.to(phi0)

        # ==============================================================
        #                     ACCUMULATE DERIVATIVES
        # ==============================================================

        has_printed = False
        for loss in self.losses:

            # ==========================================================
            #                     ONE LOSS COMPONENT
            # ==========================================================
            moving, fixed, factor = loss.moving, loss.fixed, loss.factor
            if loss.backward:
                phi00, aff00, vel00 = iphi0, iaff0, ivel0
            else:
                phi00, aff00, vel00 = phi0, aff0, vel0

            # ----------------------------------------------------------
            # build left and right affine
            # ----------------------------------------------------------
            aff_right = fixed.affine
            if aff_pos in 'fs':  # affine position: fixed or symmetric
                aff_right = aff00 @ aff_right
            aff_right = linalg.lmdiv(self.nonlin.affine, aff_right)
            aff_left = self.nonlin.affine
            if aff_pos in 'ms':  # affine position: moving or symmetric
                aff_left = aff00 @ self.nonlin.affine
            aff_left = linalg.lmdiv(moving.affine, aff_left)

            # ----------------------------------------------------------
            # build full transform
            # ----------------------------------------------------------
            if _almost_identity(aff_right) and fixed.shape == self.nonlin.shape:
                aff_right = None
                phi = spatial.add_identity_grid(phi00)
                disp = phi00
            else:
                phi = spatial.affine_grid(aff_right, fixed.shape)
                disp = regutils.smart_pull_grid(phi00, phi)
                phi += disp
            if _almost_identity(aff_left) and moving.shape == self.nonlin.shape:
                aff_left = None
            else:
                phi = spatial.affine_matvec(aff_left, phi)

            # ----------------------------------------------------------
            # forward pass
            # ----------------------------------------------------------
            warped, mask = moving.pull(phi, mask=True)
            if fixed.masked:
                if mask is None:
                    mask = fixed.mask
                else:
                    mask = mask * fixed.mask

            do_print = not (has_printed or self.verbose < 3 or in_line_search
                            or loss.backward)
            if do_print:
                has_printed = True
                if moving.previewed:
                    preview = moving.pull(phi, preview=True, dat=False)
                else:
                    preview = warped
                init = spatial.affine_lmdiv(moving.affine, fixed.affine)
                if _almost_identity(init) and moving.shape == fixed.shape:
                    init = moving.dat
                else:
                    init = spatial.affine_grid(init, fixed.shape)
                    init = moving.pull(init, preview=True, dat=False)
                self.mov2fix(fixed.dat, init, preview, disp, dim=fixed.dim,
                             title=f'(nonlin) {self.n_iter:03d}')

            # ----------------------------------------------------------
            # derivatives wrt moving
            # ----------------------------------------------------------
            g = h = None
            loss_args = (warped, fixed.dat)
            loss_kwargs = dict(dim=fixed.dim, mask=mask)
            state = loss.loss.get_state()
            if not grad and not hess:
                llx = loss.loss.loss(*loss_args, **loss_kwargs)
            elif not hess:
                llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs)
            else:
                llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs)
            del loss_args, loss_kwargs
            if in_line_search:
                loss.loss.set_state(state)

            # ----------------------------------------------------------
            # chain rule -> derivatives wrt phi
            # ----------------------------------------------------------
            if grad or hess:

                g, h, mugrad = self.nonlin.propagate_grad(
                    g, h, moving, phi00, aff_left, aff_right,
                    inv=loss.backward)
                g = regutils.jg(mugrad, g)
                h = regutils.jhj(mugrad, h)
                if isinstance(self.nonlin, SVFModel):
                    # propagate backward by scaling and squaring
                    g, h = spatial.exp_backward(vel00, g, h,
                                                steps=self.nonlin.steps)

                sumgrad = (g.mul_(factor) if sumgrad is None else
                           sumgrad.add_(g, alpha=factor))
                if hess:
                    sumhess = (h.mul_(factor) if sumhess is None else
                               sumhess.add_(h, alpha=factor))
            sumloss = (llx.mul_(factor) if sumloss is None else
                       sumloss.add_(llx, alpha=factor))

        # ==============================================================
        #                       REGULARIZATION
        # ==============================================================
        vgrad = self.nonlin.regulariser(vel0)
        llv = 0.5 * vel0.flatten().dot(vgrad.flatten())
        if grad:
            sumgrad += vgrad
        del vgrad

        # ==============================================================
        #                           VERBOSITY
        # ==============================================================
        llx = sumloss.item()
        sumloss += llv
        sumloss += self.lla
        self.loss_value = sumloss.item()
        if self.verbose and (self.verbose > 1 or not in_line_search):
            llv = llv.item()
            ll = sumloss.item()
            lla = self.lla
            if in_line_search:
                line = '(search) | '
            else:
                line = '(nonlin) | '
            line += f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} + {lla:12.6g} = {ll:12.6g}'
            if not in_line_search:
                if self.ll_prev is not None:
                    gain = self.ll_prev - ll
                    # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                    line += f' | {gain:12.6g}'
                self.llv = llv
                self.all_ll.append(ll)
                self.ll_prev = ll
                self.ll_max = max(self.ll_max, ll)
                self.n_iter += 1
            print(line, end='\r')

        # ==============================================================
        #                           RETURN
        # ==============================================================
        out = [sumloss]
        if grad:
            out.append(sumgrad)
        if hess:
            out.append(sumhess)
        return tuple(out) if len(out) > 1 else out[0]
Esempio n. 14
0
def _compute_cost(q,
                  grid0,
                  dat_fix,
                  mat_fix,
                  dat,
                  mat,
                  mov,
                  cost_fun,
                  B,
                  mx_int,
                  fwhm,
                  return_res=False):
    """Compute registration cost function.

    Parameters
    ----------
    q : (N, Nq) tensor_like
        Lie algebra of affine registration fit.
    grid0 : (X1, Y1, Z1) tensor_like
        Sub-sampled image data's resampling grid.
    dat_fix : (X1, Y1, Z1) tensor_like
        Fixed image data.
    mat_fix : (4, 4) tensor_like
        Fixed affine matrix.
    dat : [N,] tensor_like
        List of input images.
    mat : [N,] tensor_like
        List of affine matrices.
    mov : [N,] int
        Indices of moving images.
    cost_fun : str
        Cost function to compute (see run_affine_reg).
    B : (Nq, N, N) tensor_like
        Affine basis.
    mx_int : int
        This parameter sets the max intensity in the images, which decides
        how many bins to use in the joint image histograms
        (e.g, mx_int=511 -> H.shape = (512, 512)).
    fwhm : float
        Full-width at half max of Gaussian kernel, for smoothing
        histogram.
    return_res : bool, default=False
        Return registration results for plotting.

    Returns
    ----------
    c : float
        Cost of aligning images with current estimate of q. If
        optimiser='powell', array_like, else tensor_like.
    res : tensor_like
        Registration results, for visualisation (only if return_res=True).

    """
    # Init
    device = grid0.device
    q = q.flatten()
    was_numpy = False
    if isinstance(q, np.ndarray):
        was_numpy = True
        q = torch.from_numpy(q).to(device)  # To torch tensor
    dm_fix = dat_fix.shape
    Nq = B.shape[0]
    N = torch.tensor(len(dat), device=device,
                     dtype=torch.float32)  # For modulating NJTV cost

    if cost_fun in _costs_edge:
        jtv = dat_fix.clone()
        if cost_fun == 'njtv':
            njtv = -dat_fix.sqrt()

    for i, m in enumerate(mov):  # Loop over moving images
        # Get affine matrix
        mat_a = expm(q[torch.arange(i * Nq, i * Nq + Nq)], B)
        # Compose matrices
        M = lmdiv(mat[m],
                  mat_a.mm(mat_fix)).to(grid0.dtype)  # mat_mov\mat_a*mat_fix
        # Transform fixed grid
        grid = affine_matvec(M, grid0)
        # Resample to fixed grid
        dat_new = grid_pull(dat[m],
                            grid,
                            bound='dft',
                            extrapolate=False,
                            interpolation=1)
        if cost_fun in _costs_edge:
            jtv += dat_new
            if cost_fun == 'njtv':
                njtv -= dat_new.sqrt()

    # Compute the cost function
    res = None
    if cost_fun in _costs_hist:
        # Histogram based costs
        # ----------
        # Compute joint histogram
        # OBS: This function expects both images to have the same max and min intesities,
        # this is ensured by the _data_loader() function.
        H = _hist_2d(dat_fix, dat_new, mx_int, fwhm)
        res = H

        # Get probabilities
        pxy = H / H.sum()
        px = pxy.sum(dim=0, keepdim=True)
        py = pxy.sum(dim=1, keepdim=True)

        # Compute cost
        if cost_fun == 'mi':
            # Mutual information
            mi = torch.sum(pxy * torch.log2(pxy / py.mm(px)))
            c = -mi
        elif cost_fun == 'ecc':
            # Entropy Correlation Coefficient
            # Maes, Collignon, Vandermeulen, Marchal & Suetens (1997).
            # "Multimodality image registration by maximisation of mutual
            # information". IEEE Transactions on Medical Imaging 16(2):187-198
            mi = torch.sum(pxy * torch.log2(pxy / py.mm(px)))
            ecc = -2 * mi / (torch.sum(px * px.log2()) +
                             torch.sum(py * py.log2()))
            c = -ecc
        elif cost_fun == 'nmi':
            # Normalised Mutual Information
            # Studholme,  Hill & Hawkes (1998).
            # "A normalized entropy measure of 3-D medical image alignment".
            # in Proc. Medical Imaging 1998, vol. 3338, San Diego, CA, pp. 132-143.
            nmi = (torch.sum(px * px.log2()) +
                   torch.sum(py * py.log2())) / torch.sum(pxy * pxy.log2())
            c = -nmi
        elif cost_fun == 'ncc':
            # Normalised Cross Correlation
            i = torch.arange(1,
                             pxy.shape[0] + 1,
                             device=device,
                             dtype=torch.float32)
            j = torch.arange(1,
                             pxy.shape[1] + 1,
                             device=device,
                             dtype=torch.float32)
            m1 = torch.sum(py * i[..., None])
            m2 = torch.sum(px * j[None, ...])
            sig1 = torch.sqrt(torch.sum(py[..., 0] * (i - m1)**2))
            sig2 = torch.sqrt(torch.sum(px[0, ...] * (j - m2)**2))
            i, j = torch.meshgrid(i - m1, j - m2)
            ncc = torch.sum(torch.sum(pxy * i * j)) / (sig1 * sig2)
            c = -ncc
    elif cost_fun in _costs_edge:
        # Normalised Joint Total Variation
        # M Brudfors, Y Balbastre, J Ashburner (2020).
        # "Groupwise Multimodal Image Registration using Joint Total Variation".
        # in MIUA 2020.
        jtv.sqrt_()
        if cost_fun == 'njtv':
            njtv += torch.sqrt(N) * jtv
            res = njtv
            c = torch.sum(njtv)
        else:
            res = jtv
            c = torch.sum(jtv)

    # _ = show_slices(res, fig_num=1, cmap='coolwarm')  # Can be uncommented for testing

    if was_numpy:
        # Back to numpy array
        c = c.cpu().numpy()

    if return_res:
        return c, res
    else:
        return c
Esempio n. 15
0
def kernel_fit(calib, kernel_size, patterns, lam=0.01):
    """Compute GRAPPA kernels

    All batch elements should have the same sampling pattern

    Parameters
    ----------
    calib : ([*batch], coils, *freq)
        Fully-sampled calibration data
    kernel_size : sequence[int]
        GRAPPA kernel size
    patterns : (N,) tensor[int]
        Code of patterns for which to learn a kernel.
        See `pattern_to_code`.
    lam : float, default=0.01
        Tikhonov regularization

    Returns
    -------
    kernels : dict of int -> ([*batch], coils, coils, nb_elem) tensor
        GRAPPA kernels

    """
    kernel_size = py.make_list(kernel_size)
    ndim = len(kernel_size)
    coils, *freq = calib.shape[-ndim - 1:]
    batch = calib.shape[:-ndim - 1]

    # find all possible patterns
    patterns = utils.as_tensor(patterns, device=calib.device)
    if patterns.dtype is torch.bool:
        patterns = pattern_to_code(patterns, ndim)
    patterns = patterns.flatten()

    # learn one kernel for each pattern
    calib = utils.unfold(calib, kernel_size, collapse=True)  # [*B, C, N, *K]
    calib = utils.movedim(calib, -ndim - 1, -ndim - 2)  # [*B, N, C, *K]

    def t(x):
        return x.transpose(-1, -2)

    def conjt(x):
        return t(x).conj()

    def diag(x):
        return x.diagonal(0, -1, -2)

    kernels = {}
    center = [(k - 1) // 2 for k in kernel_size]
    center = (Ellipsis, *center)
    for pattern_code in patterns:
        if code_has_center(pattern_code, kernel_size):
            continue
        pattern = code_to_pattern(pattern_code,
                                  kernel_size,
                                  device=calib.device)
        pattern_size = pattern.sum()
        if pattern_size == 0:
            continue

        calib_target = calib[center]  # [*B, N, C]
        calib_source = calib[..., pattern]  # [*B, N, C, P]
        calib_size = calib_target.shape[-2]
        flat_shape = [*batch, calib_size, pattern_size * coils]
        calib_source = calib_source.reshape(flat_shape)  # [*B, N, C*P]
        # solve
        H = conjt(calib_source).matmul(calib_source)  # [*B, C*P, C*P]
        diag(H).add_(lam * diag(H).abs().max(-1, keepdim=True).values)
        diag(H).add_(lam)
        g = conjt(calib_source).matmul(calib_target)  # [*B, C*P, C]
        k = linalg.lmdiv(H, g).transpose(-1, -2)  # [*B, C, C*P]
        k = k.reshape([*batch, coils, coils, pattern_size])  # [*B, C, C, P]
        kernels[pattern_code.item()] = k

    return kernels