Beispiel #1
0
    def forward(self, x, v=None):

        dim = x.dim() - 2
        if dim not in (2, 3):
            raise ValueError(f'{type(self).__name__} only implemented '
                             f'in 2D or 3D.')

        radii = self.radii.to(**utils.backend(x))
        pradii = self.pradii.to(**utils.backend(x)).log()

        # compute joint log-likelihood `ln p(x, radius | v)`
        loss = x.new_zeros([len(radii), *x.shape])
        for i, (p, r) in enumerate(zip(pradii, radii)):
            # compute unsorted eigenvalues
            e = spatial.hessian_eig(x, r, dim=dim, sort=None)
            # soft sort
            P = math.softsort(e.abs(), tau=self.tau_sort, descending=True)
            e = linalg.matvec(P, e)
            e = utils.movedim(e, -1, 0)
            # compute penalties
            loss[i] = -self.tau_large * e[1:].sum(0)  # white ridges
            e = e.square().clamp_min_(1e-32).log()
            if dim == 3:
                loss[i] += self.tau_ratio1 * (e[1] - e[2])  # tubes
            loss[i] += self.tau_ratio0 * (e[1] - e[0])  # not plates
            loss[i] += p  # radius prior

        # compute (stable) log-sum-exp (== model evidence `ln p(x | v)`)
        loss = math.logsumexp(loss, dim=0)

        # weight by probability to be a vessel and return `E_v[ln p(x | v)]`
        if v is None:
            v = x
        return -(loss * v).sum() / (v.sum() + 1e-3)
Beispiel #2
0
 def _set_weights(module, conv_keys, f, prefix='unet'):
     # print(prefix)
     if isinstance(module, Conv):
         if conv_keys:
             key = conv_keys.pop(0)
         else:
             # we might have reached the final "feat 2 class" conv
             key = 'vxm_dense_flow'
         try:
             kernel = torch.as_tensor(f[key][key]['kernel:0'],
                                      **utils.backend(module.weight))
         except:
             kernel = torch.as_tensor(f[key][key + '_1']['kernel:0'],
                                      **utils.backend(module.weight))
         kernel = utils.movedim(kernel, [-1, -2], [0, 1])
         module.weight.copy_(kernel)
         try:
             bias = torch.as_tensor(f[key][key]['bias:0'],
                                    **utils.backend(module.bias))
         except:
             bias = torch.as_tensor(f[key][key + '_1']['bias:0'],
                                    **utils.backend(module.bias))
         module.bias.copy_(bias)
     else:
         for name, child in module.named_children():
             _set_weights(child, conv_keys, f, f'{prefix}.{name}')
Beispiel #3
0
    def forward(self, x, v=None):

        dim = x.dim() - 2
        if dim not in (2, 3):
            raise ValueError(f'{type(self).__name__} only implemented '
                             f'in 2D or 3D.')

        radii = self.radii.to(**utils.backend(x))
        pradii = self.pradii.to(**utils.backend(x)).log()

        # compute log-likelihood (vessel | radius, x)
        loss = x.new_zeros([len(radii), *x.shape])
        for i, (p, r) in enumerate(zip(pradii, radii)):
            # compute unsorted eigenvalues
            e = spatial.hessian_eig(x, r, dim=dim, sort=None)
            e = utils.movedim(e, -1, 0)
            if dim == 3:
                loss[i] = self.vesselness3d(e[0], e[1], e[2])

        # compute (stable) log-sum-exp (== model evidence)
        loss = math.logsumexp(loss, dim=0)

        # weight by probability to be a vessel and return
        if v is None:
            v = x
        return -(loss * v).sum() / (v.sum() + 1e-3)
Beispiel #4
0
 def forward(self, image, **overload):
     factor = overload.get('factor', self.factor)
     if factor is None:
         factor = self.default_factor(len(image), **utils.backend(image))
     if callable(factor):
         factor = factor(image.shape[0])
     factor = torch.as_tensor(factor, **utils.backend(image))
     factor = unsqueeze(factor, -1, image.dim() - factor.dim())
     image = self.op(image, factor)
     return image
Beispiel #5
0
def get_backend(x, prior, device=None):
    if torch.is_tensor(x):
        backend = utils.backend(x)
    elif torch.is_tensor(prior):
        backend = utils.backend(prior)
    else:
        backend = dict(dtype=torch.get_default_dtype(), device='cpu')
    if device:
        backend['device'] = device
    return backend
Beispiel #6
0
 def restrict(self, from_shape, to_shape=None):
     """Apply transform == to a restriction of the underlying grid"""
     to_shape = to_shape or [pymath.ceil(s/2) for s in from_shape]
     shifts = [0.5 * (frm / to - 1)
               for frm, to in zip(from_shape, to_shape)]
     scales = [frm / to for frm, to in zip(from_shape, to_shape)]
     shifts = torch.as_tensor(shifts, **utils.backend(self.waypoints))
     scales = torch.as_tensor(scales, **utils.backend(self.waypoints))
     self.waypoints.sub_(shifts).div_(scales)
     self.coeff.sub_(shifts).div_(scales)
     self.radius.div_(scales.prod().pow_(1/len(scales)))
     self.coeff_radius.div_(scales.prod().pow_(1/len(scales)))
Beispiel #7
0
 def prolong(self, from_shape, to_shape=None):
     """Apply transform == to a prolongation of the underlying grid"""
     to_shape = to_shape or [2*s for s in from_shape]
     from_shape, to_shape = to_shape, from_shape
     shifts = [0.5 * (frm / to - 1)
               for frm, to in zip(from_shape, to_shape)]
     scales = [frm / to for frm, to in zip(from_shape, to_shape)]
     shifts = torch.as_tensor(shifts, **utils.backend(self.waypoints))
     scales = torch.as_tensor(scales, **utils.backend(self.waypoints))
     self.waypoints.mul_(scales).add_(shifts)
     self.coeff.mul_(scales).add_(shifts)
     self.radius.mul_(scales.prod().pow_(1/len(scales)))
     self.coeff_radius.mul_(scales.prod().pow_(1/len(scales)))
Beispiel #8
0
def smart_pull_grid(vel, grid, type='disp', *args, **kwargs):
    """Interpolate a velocity/grid/displacement field.

    Notes
    -----
    Defaults differ from grid_pull:
    - bound -> dft
    - extrapolate -> True

    Parameters
    ----------
    vel : ([batch], *spatial, ndim) tensor
        Velocity
    grid : ([batch], *spatial, ndim) tensor
        Transformation field
    kwargs : dict
        Options to ``grid_pull``

    Returns
    -------
    pulled_vel : ([batch], *spatial, ndim) tensor
        Velocity

    """
    if grid is None or vel is None:
        return vel
    kwargs.setdefault('bound', 'dft')
    kwargs.setdefault('extrapolate', True)
    dim = vel.shape[-1]
    if type == 'grid':
        id = spatial.identity_grid(vel.shape[-dim - 1:-1],
                                   **utils.backend(vel))
        vel = vel - id
    vel = utils.movedim(vel, -1, -dim - 1)
    vel_no_batch = vel.dim() == dim + 1
    grid_no_batch = grid.dim() == dim + 1
    if vel_no_batch:
        vel = vel[None]
    if grid_no_batch:
        grid = grid[None]
    vel = spatial.grid_pull(vel, grid, *args, **kwargs)
    vel = utils.movedim(vel, -dim - 1, -1)
    if vel_no_batch:
        vel = vel[0]
    if type == 'grid':
        id = spatial.identity_grid(vel.shape[-dim - 1:-1],
                                   **utils.backend(vel))
        vel += id
    return vel
Beispiel #9
0
    def forward(self, x, noise=None, return_resolution=False):

        if noise is not None:
            noise = noise.expand(x.shape)

        dim = x.dim() - 2
        backend = utils.backend(x)
        resolution_exp = utils.make_vector(self.resolution_exp, x.shape[1],
                                           **backend)
        resolution_scale = utils.make_vector(self.resolution_scale, x.shape[1],
                                             **backend)

        all_resolutions = []
        out = torch.empty_like(x)
        for b in range(len(x)):
            for c in range(x.shape[1]):
                resolution = self.resolution(resolution_exp[c],
                                             resolution_scale[c]).sample()
                resolution = resolution.clamp_min(1)
                fwhm = [resolution] * dim
                y = smooth(x[b, c], fwhm=fwhm, dim=dim, padding='same', bound='dct2')
                if noise is not None:
                    y += noise[b, c]
                factor = [1/resolution] * dim
                y = y[None, None]  # need batch and channel for resize
                y = resize(y, factor=factor, anchor='f')
                factor = [resolution] * dim
                all_resolutions.append(factor)
                y = resize(y, factor=factor, shape=x.shape[2:], anchor='f')
                out[b, c] = y[0, 0]

        all_resolutions = utils.as_tensor(all_resolutions, **backend)
        return (out, all_resolutions) if return_resolution else out
Beispiel #10
0
    def forward(self, image, gfactor=None):
        backend = utils.backend(image)

        sigma = utils.make_vector(self.sigma, image.shape[1], **backend)
        ncoils = utils.make_vector(self.ncoils,
                                   image.shape[1],
                                   device=backend['device'],
                                   dtype=torch.int)

        zero = torch.tensor(0, **backend)

        def sampler():
            shape = [len(image), *image.shape[2:]]
            noise = td.Normal(zero, sigma).sample(shape).square_()
            return utils.movedim(noise, -1, 1)

        # sample noise
        noise = sampler()
        for n in range(2 * ncoils.max() - 1):
            tmp = sampler()
            tmp[:, 2 * ncoils + 1 >= n + 1, ...] = 0
            noise += tmp
        noise = noise.sqrt_()
        noise /= ncoils

        if gfactor is not None:
            noise *= gfactor

        image = image + noise
        return image
Beispiel #11
0
    def forward(self, image, **overload):
        backend = utils.backend(image)
        sigma = overload.get('sigma', self.sigma)
        gfactor = overload.get('gfactor', self.gfactor)

        # sample sigma
        if sigma is None:
            sigma = self.default_sigma(*image.shape[:2], **backend)
        if callable(sigma):
            sigma = sigma(image.shape[:2])
        sigma = torch.as_tensor(sigma, **backend)
        sigma = unsqueeze(sigma, -1, 2 - sigma.dim())

        # sample gfactor
        if gfactor is True:
            gfactor = field.RandomMultiplicativeField()
        if callable(gfactor):
            gfactor = gfactor(image.shape)

        # sample noise
        zero = torch.tensor(0, **backend)
        noise = td.Normal(zero, sigma).sample(image.shape[2:])
        noise = utils.movedim(noise, [-1, -2], [0, 1])

        if torch.is_tensor(gfactor):
            noise *= gfactor

        image = image + noise
        return image
Beispiel #12
0
    def forward(self, x):

        backend = utils.backend(x)

        # compute intensity bounds
        vmin = self.vmin
        if vmin is None:
            vmin = x.reshape([*x.shape[:2], -1]).min(dim=-1).values
        vmax = self.vmax
        if vmax is None:
            vmax = x.reshape([*x.shape[:2], -1]).max(dim=-1).values
        vmin = torch.as_tensor(vmin, **backend).expand(x.shape[:2])
        vmin = unsqueeze(vmin, -1, x.dim() - vmin.dim())
        vmax = torch.as_tensor(vmax, **backend).expand(x.shape[:2])
        vmax = unsqueeze(vmax, -1, x.dim() - vmax.dim())

        # sample factor
        factor_exp = utils.make_vector(self.factor_exp, x.shape[1], **backend)
        factor_scale = utils.make_vector(self.factor_scale, x.shape[1],
                                         **backend)
        factor = self.factor(factor_exp, factor_scale)
        factor = factor.sample([len(x)])
        factor = unsqueeze(factor, -1, x.dim() - 2)

        # apply correction
        x = (x - vmin) / (vmax - vmin)
        x = x.pow(factor)
        x = x * (vmax - vmin) + vmin
        return x
Beispiel #13
0
def _composition_jac(jac, rhs, lhs=None, type='grid', identity=None, **kwargs):
    """Jacobian of the composition `(lhs)o(rhs)`

    Parameters
    ----------
    jac : ([batch], *spatial, ndim, ndim) tensor
        Jacobian of input RHS transformation
    rhs : ([batch], *spatial, ndim) tensor
        RHS transformation
    lhs : ([batch], *spatial, ndim) tensor, default=`rhs`
        LHS small displacement
    kwargs : dict
        Options to ``grid_pull``

    Returns
    -------
    composed_jac : ([batch], *spatial, ndim, ndim) tensor
        Jacobian of composition

    """
    if lhs is None:
        lhs = rhs
    dim = rhs.shape[-1]
    backend = utils.backend(rhs)
    typer, typel = py.make_list(type, 2)
    jac_left = grid_jacobian(lhs, type=typel)
    if typer != 'grid':
        if identity is None:
            identity = identity_grid(rhs.shape[-dim - 1:-1], **backend)
        rhs = rhs + identity
    jac_left = _pull_jac(jac_left, rhs)
    jac = torch.matmul(jac_left, jac)
    return jac
Beispiel #14
0
def unwrap(phase, dim=None, bound='dct2', max_iter=0, tol=1e-5):
    """Laplacian unwrapping of the phase

    Parameters
    ----------
    phase : tensor
        Wrapped phase, in radian
    dim : int, default=phase.dim()
        Number of spatial dimensions
    max_iter : int, default=0
        Maximum number of unwrapping iterations.
        If 0, return the Laplacian filtered phase, which is not exactly
        equal to the input phase modulo 2 pi.
    tol : float, default=1e-5
        Tolerance for early stopping


    Returns
    -------
    unwrapped : tensor

    References
    ----------
    .. "Fast phase unwrapping algorithm for interferometric applications"
       Marvin A. Schofield and Yimei Zhu
       Optics Letters (2003)

    """
    # TODO: would be nice to use DCT/DST rather than padding once they
    #       are available in PyTorch.

    dim = dim or phase.dim()
    dims = list(range(-dim, 0))
    shape = bigshape = phase.shape[-dim:]

    if bound not in ('dct', 'circulant'):
        phase = utils.pad(phase, [d//2 for d in shape], side='both', mode=bound)
        bigshape = phase.shape[-dim:]

    freq = _laplacian_freq(bigshape, **utils.backend(phase))
    phase = fft.ifftshift(phase, dim=dims)
    twopi = 2 * pymath.pi

    if max_iter == 0:
        phase = _laplacian_filter(phase, freq, dims)
    else:
        for n_iter in range(1, max_iter+1):
            filtered_phase = _laplacian_filter(phase, freq, dims)
            filtered_phase.sub_(phase).div_(twopi).round_().mul_(twopi)
            phase += filtered_phase

            if n_iter < max_iter and filtered_phase.mean() < tol:
                break

    phase = fft.fftshift(phase, dim=dims)

    if bound not in ('dct', 'circulant'):
        slicer = [slice(d//2, d+d//2) for d in shape]
        phase = phase[(Ellipsis, *slicer)]
    return phase
Beispiel #15
0
def get_spm_prior(**backend):
    fname = path_spm_prior()
    f = io.map(fname).movedim(-1, 0)[:-1]  # drop background
    aff = f.affine
    dat = f.fdata(**backend)
    aff = aff.to(**utils.backend(dat))
    return dat, aff
Beispiel #16
0
def build_se(sqdist, sigma, lam, **backend):
    """Build squared-exponential covariance matrix

    Parameters
    ----------
    sqdist : sequence[int] or (vox, vox) tensor
        If a tensor -> it is the pre-computed squared distance map
        If a tuple -> it is the shape and we build the distance map
    sigma : (*batch) tensor_like
        Amplitude
    lam : (*batch) tensor_like
        Length-scale

    Returns
    -------
    cov : (*batch, vox, vox) tensor
        Covariance matrix

    """
    lam, sigma = utils.to_max_backend(lam, sigma, **backend, force_float=True)
    backend = utils.backend(lam)

    # Build SE covariance matrix
    if not torch.is_tensor(sqdist):
        shape = sqdist
        e = dist_map(shape, **backend)
    else:
        e = sqdist.to(**backend)
    del sqdist
    lam = lam[..., None, None]
    sigma = sigma[..., None, None]
    e = e.mul(-0.5 / (lam**2)).exp_().mul_(sigma**2)
    return e
Beispiel #17
0
    def __init__(self, waypoints, order=3, radius=1):
        """

        Parameters
        ----------
        waypoints : (N, D) tensor
            List of waypoints, that the curve will interpolate.
        order : int, default=3
            Order of the encoding B-splines
        radius : float or (N,) tensor
            Radius of the curve at each waypoint.
        """
        waypoints = torch.as_tensor(waypoints)
        if not waypoints.dtype.is_floating_point:
            waypoints = waypoints.to(torch.get_default_dtype())
        self.waypoints = waypoints
        self.order = order
        self.bound = 'dct2'
        self.coeff = spline_coeff(waypoints,
                                  interpolation=self.order,
                                  bound=self.bound,
                                  dim=0)
        if not isinstance(radius, (int, float)):
            radius = torch.as_tensor(radius, **utils.backend(waypoints))
        self.radius = radius
        if torch.is_tensor(radius):
            self.coeff_radius = spline_coeff(radius,
                                             interpolation=self.order,
                                             bound=self.bound,
                                             dim=0)
Beispiel #18
0
 def __init__(self, dat, affine=None, dim=None, mask=None,
              bound='dct2', extrapolate=False, **backend):
     # I don't call super().__init__() on purpose
     if torch.is_tensor(affine):
         affine = [affine] * len(dat)
     elif affine is None:
         affine = []
         for dat1 in dat:
             dim1 = dim or dat1.dim
             if callable(dim1):
                 dim1 = dim1()
             if hasattr(dat1, 'affine'):
                 aff1 = dat1.affine
             else:
                 shape1 = dat1.shape[-dim1:]
                 aff1 = spatial.affine_default(shape1, **utils.backend(dat1))
             affine.append(aff1)
     affine = py.make_list(affine, len(dat))
     mask = py.make_list(mask, len(dat))
     self._dat = []
     for dat1, aff1, mask1 in zip(dat, affine, mask):
         if not isinstance(dat1, Image):
             dat1 = Image(dat1, aff1, mask=mask1, dim=dim,
                          bound=bound, extrapolate=extrapolate)
         self._dat.append(dat1)
Beispiel #19
0
    def forward(self, affine):
        """

        Parameters
        ----------
        affine : (batch, dim+1, dim+1) tensor

        Returns
        -------
        logaff : (batch, nbprm) tensor

        """
        # When the affine is well conditioned, its log should be real.
        # Here, I take the real part just in case.
        # Another solution could be to regularise the affine (by loading
        # slightly the diagonal) until it is well conditioned -- but
        # how would that work with autograd?
        backend = utils.backend(affine)
        affine = core.linalg.logm(affine.double())
        if affine.is_complex():
            affine = affine.real
        affine = affine.to(**backend)
        basis = self.basis.to(**backend)
        affine = core.linalg.mdot(affine[:, None, ...], basis[None, ...])
        return affine
Beispiel #20
0
def exp_forward(vel,
                inverse=False,
                steps=8,
                interpolation='linear',
                bound='dft',
                displacement=False,
                jacobian=False,
                _anagrad=False):
    """Exponentiate a stationary velocity field by scaling and squaring.

    This function always uses autodiff in the backward pass.
    It can also compute Jacobian fields on the fly.

    Parameters
    ----------
    vel : ([batch], *spatial, dim) tensor
        Stationary velocity field.
    inverse : bool, default=False
        Generate the inverse transformation instead of the forward.
    steps : int, default=8
        Number of scaling and squaring steps
        (corresponding to 2**steps integration steps).
    interpolation : {0..7}, default=1
        Interpolation order
    bound : str, default='dft'
        Boundary conditions
    displacement : bool, default=False
        Return a displacement field rather than a transformation field

    Returns
    -------
    grid : ([batch], *spatial, dim) tensor
        Exponentiated tranformation

    """
    backend = utils.backend(vel)
    vel = -vel if inverse else vel.clone()

    # Precompute identity + aliases
    dim = vel.shape[-1]
    spatial = vel.shape[-1 - dim:-1]
    id = identity_grid(spatial, **backend)
    jac = torch.eye(dim, **backend).expand([*vel.shape[:-1], dim, dim])
    opt = {'interpolation': interpolation, 'bound': bound}

    if not _anagrad and vel.requires_grad:
        iadd = lambda x, y: x.add(y)
    else:
        iadd = lambda x, y: x.add_(y)

    vel /= (2**steps)
    for i in range(steps):
        if jacobian:
            jac = _composition_jac(jac, vel, type='displacement', identity=id)
        vel = iadd(vel, _pull_vel(vel, id + vel, **opt))

    if not displacement:
        vel += id
    return (vel, jac) if jacobian else vel
Beispiel #21
0
 def set_kernel(self, kernel=None):
     if kernel is None:
         kernel = spatial.greens(self.shape, **self.penalty,
                                 factor=self.factor / py.prod(self.shape),
                                 voxel_size=self.voxel_size,
                                 **utils.backend(self.dat))
     self.kernel = kernel
     return self
Beispiel #22
0
 def update_waypoints(self):
     """Convert coefficients into waypoints"""
     t = torch.linspace(0, 1, len(self.coeff), **utils.backend(self.coeff))
     p = self.eval_position(t)
     if p.shape == self.waypoints.shape:
         self.waypoints.copy_(p)
     else:
         self.waypoints = p
Beispiel #23
0
    def __init__(self,
                 moving,
                 fixed,
                 loss,
                 basis='CSO',
                 dim=None,
                 affine_moving=None,
                 affine_fixed=None,
                 verbose=True,
                 plot=False,
                 max_iter=100,
                 bound='dct2',
                 extrapolate=True,
                 **prm):
        if dim is None:
            if affine_fixed is not None:
                dim = affine_fixed.shape[-1] - 1
            elif affine_moving is not None:
                dim = affine_moving.shape[-1] - 1
        dim = dim or fixed.dim() - 1
        self.dim = dim
        self.moving = moving  # moving image
        self.fixed = fixed  # fixed image
        self.loss = loss  # similarity loss (`OptimizationLoss`)
        self.verbose = verbose  # print stuff
        self.plot = plot  # plot stuff
        self.prm = prm  # dict of regularization parameters
        self.bound = bound
        self.extrapolate = extrapolate
        self.basis = basis
        if affine_fixed is None:
            affine_fixed = spatial.affine_default(fixed.shape[-dim:],
                                                  **utils.backend(fixed))
        if affine_moving is None:
            affine_moving = spatial.affine_default(moving.shape[-dim:],
                                                   **utils.backend(moving))
        self.affine_fixed = affine_fixed
        self.affine_moving = affine_moving

        # pretty printing
        self.max_iter = max_iter  # max number of iterations
        self.n_iter = 0  # current iteration
        self.ll_prev = None  # previous loss value
        self.ll_max = 0  # max loss value
        self.id = None
Beispiel #24
0
def draw_curves(shape, s, mode='gaussian', tiny=0, **kwargs):
    """Draw multiple BSpline curves

    Parameters
    ----------
    shape : list[int]
    s : list[BSplineCurve]
    mode : {'binary', 'gaussian'}

    Returns
    -------
    x : (*shape) tensor
        Drawn curve
    lab : (*shape) tensor[int]
        Label of closest curve

    """
    s = list(s)
    x = identity_grid(shape, **utils.backend(s[0].waypoints))
    n = len(s)
    tiny = tiny / n
    l = x.new_zeros(shape, dtype=torch.long)
    if mode[0].lower() == 'b':
        s1 = s.pop(0)
        t, d = min_dist(x, s1, **kwargs)
        r = s1.eval_radius(t)
        c = d <= r
        l[c] = 1
        cnt = 1
        while s:
            cnt += 1
            s1 = s.pop(0)
            t, d = min_dist(x, s1, **kwargs)
            r = s1.eval_radius(t)
            c.bitwise_or_(d <= r)
            l[d <= r] = cnt
    else:
        s1 = s.pop(0)
        t, d = min_dist(x, s1, **kwargs)
        r = s1.eval_radius(t)
        c = dist_to_prob(d, r, tiny)
        l.fill_(1)
        cnt = 1
        p = c.clone()
        c = c.neg_().add_(1)
        while s:
            cnt += 1
            s1 = s.pop(0)
            t, d = min_dist(x, s1, **kwargs)
            r = s1.eval_radius(t)
            c1 = dist_to_prob(d, r, tiny)
            l[c1 > p] = cnt
            p = torch.maximum(c1, p)
            c.mul_(c1.neg_().add_(1))
        c = c.neg_().add_(1)
    return c, l
Beispiel #25
0
    def step(self, param=None, closure=None, **kwargs):

        if param is None:
            param = self.param
        if closure is None:
            closure = self.closure
        closure_ls = get_closure_ls(closure)
        x = param

        if self.delta is None:
            self.delta = torch.eye(len(x), **utils.backend(x)).mul_(self.lr)

        f = kwargs.get('loss', closure(x)).item()
        x0, f0 = x.clone(), f
        i_largest_step = len(x)
        largest_step = 0
        for i in range(len(x)):
            fi = f
            closure_i = lambda a: closure_ls(x.add(self.delta[i], alpha=a)
                                             ).item()
            a, f = self.brent(0, closure_i, loss=f)
            if f < fi:
                x.add_(self.delta[i], alpha=a)
                step = abs(fi - f)
                if step > largest_step:
                    largest_step = step
                    i_largest_step = i
            else:
                f = fi
        if f < f0:
            # repeat the same step and see if we improve
            f1 = closure_ls(2 * x - x0).item()  # x + (x - x0)
            if f1 < f:
                delta1 = x - x0
                closure_i = lambda a: closure_ls(x.add(delta1, alpha=a)).item()
                a, f = self.brent(0, closure_i, loss=f)
                x.add_(delta1, alpha=a)
                self.delta[i_largest_step].copy_(delta1).mul_(a)
        # for verbosity only
        closure(x)

        f = torch.as_tensor(f, **utils.backend(x))
        return x, f
Beispiel #26
0
 def __init__(self, dat, affine=None, dim=None, **backend):
     """
     Parameters
     ----------
     dat : ([C], *spatial) tensor
     affine : tensor, optional
     dim : int, default=`dat.dim() - 1`
     **backend : dtype, device
     """
     if isinstance(dat, str):
         dat = io.map(dat)[None]
     if isinstance(dat, io.MappedArray):
         if affine is None:
             affine = dat.affine
         dat = dat.fdata(rand=True, **backend)
     self.dim = dim or dat.dim() - 1
     self.dat = dat
     if affine is None:
         affine = spatial.affine_default(self.shape, **utils.backend(dat))
     self.affine = affine.to(utils.backend(self.dat)['device'])
Beispiel #27
0
 def update_radius(self):
     """Convert coefficients into radii"""
     if not hasattr(self, 'coeff_radius'):
         return
     t = torch.linspace(0, 1, len(self.coeff_radius),
                        **utils.backend(self.coeff_radius))
     r = self.eval_radius(t)
     if torch.is_tensor(self.radius) and r.shape == self.radius.shape:
         self.radius.copy_(r)
     else:
         self.radius = r
Beispiel #28
0
def get_spm_prior(**backend):
    url = 'https://github.com/spm/spm12/raw/master/tpm/TPM.nii'
    fname = os.path.join(cache_dir, 'SPM12_TPM.nii')
    if not os.path.exists(fname):
        os.makedirs(cache_dir, exist_ok=True)
        fname = download(url, fname)
    f = io.map(fname).movedim(-1, 0)  #[:-1]  # drop background
    aff = f.affine
    dat = f.fdata(**backend)
    aff = aff.to(**utils.backend(dat))
    return dat, aff
Beispiel #29
0
def grid_jacobian(grid, sample=None, bound='dft', voxel_size=1, type='grid',
                  add_identity=True, extrapolate=True):
    """Compute the Jacobian of a transformation field

    Notes
    -----
    .. If `add_identity` is True, we compute the Jacobian
       of the transformation field (identity + displacement), even if
       a displacement is provided, by adding ones to the diagonal.
    .. If `sample` is not used, this function uses central finite
       differences to estimate the Jacobian.
    .. If 'sample' is provided, `grid_grad` is used to sample derivatives.

    Parameters
    ----------
    grid : (..., *spatial, dim) tensor
        Transformation or displacement field
    sample : (..., *spatial, dim) tensor, optional
        Coordinates to sample in the input grid.
    bound : str, default='dft'
        Boundary condition
    voxel_size : [sequence of] float, default=1
        Voxel size
    type : {'grid', 'disp'}, default='grid'
        Whether the input is a transformation ('grid') or displacement
        ('disp') field.
    add_identity : bool, default=True
        Adds the identity to the Jacobian of the displacement, making it
        the jacobian of the transformation.
    extrapolate : bool, default=True
        Extrapolate out-of-boudn data (only useful is `sample` is used)

    Returns
    -------
    jac : (..., *spatial, dim, dim) tensor
        Jacobian. In each matrix: jac[i, j] = d psi[i] / d xj

    """
    grid = torch.as_tensor(grid)
    dim = grid.shape[-1]
    shape = grid.shape[-dim-1:-1]
    if type == 'grid':
        grid = grid - identity_grid(shape, **utils.backend(grid))
    if sample is None:
        dims = list(range(-dim-1, -1))
        jac = diff(grid, dim=dims, bound=bound, voxel_size=voxel_size, side='c')
    else:
        grid = utils.movedim(grid, -1, -dim-1)
        jac = grid_grad(grid, sample, bound=bound, extrapolate=extrapolate)
        jac = utils.movedim(jac, -dim-2, -2)
    if add_identity:
        torch.diagonal(jac, 0, -1, -2).add_(1)
    return jac
Beispiel #30
0
 def resize(cls, x, affine, target_vx=1):
     target_vx = utils.make_vector(target_vx, x.dim(),
                                   **utils.backend(affine))
     vx = spatial.voxel_size(affine)
     factor = vx / target_vx
     fwhm = 0.25 * factor.reciprocal()
     fwhm[factor > 1] = 0
     x = spatial.smooth(x, fwhm=fwhm.tolist(), dim=3)
     x, affine = spatial.resize(x[None, None],
                                factor.tolist(),
                                affine=affine)
     x = x[0, 0]
     return x, affine