예제 #1
0
    def forward(self, x):

        backend = utils.backend(x)

        # compute intensity bounds
        vmin = self.vmin
        if vmin is None:
            vmin = x.reshape([*x.shape[:2], -1]).min(dim=-1).values
        vmax = self.vmax
        if vmax is None:
            vmax = x.reshape([*x.shape[:2], -1]).max(dim=-1).values
        vmin = torch.as_tensor(vmin, **backend).expand(x.shape[:2])
        vmin = unsqueeze(vmin, -1, x.dim() - vmin.dim())
        vmax = torch.as_tensor(vmax, **backend).expand(x.shape[:2])
        vmax = unsqueeze(vmax, -1, x.dim() - vmax.dim())

        # sample factor
        factor_exp = utils.make_vector(self.factor_exp, x.shape[1], **backend)
        factor_scale = utils.make_vector(self.factor_scale, x.shape[1],
                                         **backend)
        factor = self.factor(factor_exp, factor_scale)
        factor = factor.sample([len(x)])
        factor = unsqueeze(factor, -1, x.dim() - 2)

        # apply correction
        x = (x - vmin) / (vmax - vmin)
        x = x.pow(factor)
        x = x * (vmax - vmin) + vmin
        return x
예제 #2
0
    def forward(self, x, noise=None, return_resolution=False):

        if noise is not None:
            noise = noise.expand(x.shape)

        dim = x.dim() - 2
        backend = utils.backend(x)
        resolution_exp = utils.make_vector(self.resolution_exp, x.shape[1],
                                           **backend)
        resolution_scale = utils.make_vector(self.resolution_scale, x.shape[1],
                                             **backend)

        all_resolutions = []
        out = torch.empty_like(x)
        for b in range(len(x)):
            for c in range(x.shape[1]):
                resolution = self.resolution(resolution_exp[c],
                                             resolution_scale[c]).sample()
                resolution = resolution.clamp_min(1)
                fwhm = [resolution] * dim
                y = smooth(x[b, c], fwhm=fwhm, dim=dim, padding='same', bound='dct2')
                if noise is not None:
                    y += noise[b, c]
                factor = [1/resolution] * dim
                y = y[None, None]  # need batch and channel for resize
                y = resize(y, factor=factor, anchor='f')
                factor = [resolution] * dim
                all_resolutions.append(factor)
                y = resize(y, factor=factor, shape=x.shape[2:], anchor='f')
                out[b, c] = y[0, 0]

        all_resolutions = utils.as_tensor(all_resolutions, **backend)
        return (out, all_resolutions) if return_resolution else out
예제 #3
0
    def forward(self, image, gfactor=None):
        backend = utils.backend(image)

        sigma = utils.make_vector(self.sigma, image.shape[1], **backend)
        ncoils = utils.make_vector(self.ncoils,
                                   image.shape[1],
                                   device=backend['device'],
                                   dtype=torch.int)

        zero = torch.tensor(0, **backend)

        def sampler():
            shape = [len(image), *image.shape[2:]]
            noise = td.Normal(zero, sigma).sample(shape).square_()
            return utils.movedim(noise, -1, 1)

        # sample noise
        noise = sampler()
        for n in range(2 * ncoils.max() - 1):
            tmp = sampler()
            tmp[:, 2 * ncoils + 1 >= n + 1, ...] = 0
            noise += tmp
        noise = noise.sqrt_()
        noise /= ncoils

        if gfactor is not None:
            noise *= gfactor

        image = image + noise
        return image
예제 #4
0
파일: spatial.py 프로젝트: balbasty/nitorch
    def __init__(self,
                 radii=1,
                 pradii=1,
                 tau_sort=1,
                 tau_large=1,
                 tau_ratio0=1,
                 tau_ratio1=1):
        """

        Parameters
        ----------
        radii : [sequence of] float, defualt=1
            List of possible vessel radii (= FWHM of the Gaussian filter)
        pradii : [sequence of] float, default=1
            Prior probability of each radius. Will be normalized to one.
        tau_sort : float, default=1
            Temperature of the soft-sorting operation
        tau_large : float, default=1
            Penalty that encourages the two main eigenvalues to be
            very large (= very curved) and negative (= white ridges)
        tau_ratio0 : float, default=1
            Penalty that encourages the smallest eigenvalue to be much
            smaller than the two main eigenvalues (= plate-like)
        tau_ratio1 : float, default=1
            Penalty that encourages the two main eigenvalues to be similar
            (= tube-like)
        """
        super().__init__()
        self.radii = utils.make_vector(radii)
        pradii = utils.make_vector(pradii, len(self.radii))
        self.pradii = pradii / pradii.sum()
        self.tau_sort = tau_sort
        self.tau_large = tau_large
        self.tau_ratio0 = tau_ratio0
        self.tau_ratio1 = tau_ratio1
예제 #5
0
def regulariser_grid(v,
                     absolute=0,
                     membrane=0,
                     bending=0,
                     lame=0,
                     voxel_size=1,
                     bound='dft',
                     weights=None):
    """Precision matrix for a mixture of energies for a deformation grid.

    Parameters
    ----------
    v : (..., *spatial, dim) tensor
    absolute : float, default=0
    membrane : float, default=0
    bending : float, default=0
    lame : (float, float), default=0
    voxel_size : [sequence of] float, default=1
    bound : str, default='dft'
    weights : [dict of] (..., *spatial) tensor, optional
        If a dict: keys must be in {'absolute', 'membrane', 'bending', 'lame'}
        Else: the same weight map is shared across penalties.

    Returns
    -------
    Lv : (..., *spatial, dim) tensor

    """
    v = torch.as_tensor(v)
    backend = dict(dtype=v.dtype, device=v.device)
    dim = v.shape[-1]

    voxel_size = make_vector(voxel_size, dim, **backend)
    lame = make_vector(lame, 2, **backend)
    fdopt = dict(bound=bound, voxel_size=voxel_size)
    if isinstance(weights, dict):
        wa = weights.get('absolute', None)
        wm = weights.get('membrane', None)
        wb = weights.get('bending', None)
        wl = weights.get('lame', None)
    else:
        wa = wm = wb = wl = weights
    wl = make_list(wl, 2)

    y = 0
    if absolute:
        y += absolute_grid(v, weights=wa) * absolute
    if membrane:
        y += membrane_grid(v, weights=wm, **fdopt) * membrane
    if bending:
        y += bending_grid(v, weights=wb, **fdopt) * bending
    if lame[0]:
        y += lame_div(v, weights=wl[0], **fdopt) * lame[0]
    if lame[1]:
        y += lame_shear(v, weights=wl[1], **fdopt) * lame[1]

    if y is 0:
        y = torch.zeros_like(v)
    return y
예제 #6
0
def membrane_weights(field,
                     lam=1,
                     voxel_size=1,
                     bound='dct2',
                     dim=None,
                     joint=True,
                     return_sum=False):
    """Update the (L1) weights of the membrane energy.

    Parameters
    ----------
    field : (..., K, *spatial) tensor
        Field
    lam : float or (K,) sequence[float], default=1
        Regularisation factor
    voxel_size : float or sequence[float], default=1
        Voxel size
    bound : str, default='dct2'
        Boundary condition.
    dim : int, optional
        Number of spatial dimensions
    joint : bool, default=False
        Joint norm across channels.
    return_sum : bool, default=False

    Returns
    -------
    weight : (..., 1 or K, *spatial) tensor
        Weights for the reweighted least squares scheme
    """
    field = torch.as_tensor(field)
    backend = core.utils.backend(field)
    dim = dim or field.dim() - 1
    nb_prm = field.shape[-dim - 1]
    voxel_size = make_vector(voxel_size, dim, **backend)
    lam = make_vector(lam, nb_prm, **backend)
    lam = core.utils.unsqueeze(lam, -1, dim + 1)
    if joint:
        lam = lam * nb_prm
    dims = list(range(field.dim() - dim, field.dim()))
    fieldb = diff(field,
                  dim=dims,
                  voxel_size=voxel_size,
                  side='b',
                  bound=bound)
    field = diff(field, dim=dims, voxel_size=voxel_size, side='f', bound=bound)
    field.square_().mul_(lam)
    field += fieldb.square_().mul_(lam)
    field /= 2.
    dims = [-1] + ([-dim - 2] if joint else [])
    field = field.sum(dim=dims, keepdims=True)[..., 0].sqrt_()
    if return_sum:
        ll = field.sum()
        return field.clamp_min_(1e-5).reciprocal_(), ll
    else:
        return field.clamp_min_(1e-5).reciprocal_()
예제 #7
0
파일: field.py 프로젝트: balbasty/nitorch
 def _make_sampler(self, name, **backend):
     exp = getattr(self, name + '_exp')
     scale = getattr(self, name + '_scale')
     dist = getattr(self, name)
     exp = utils.make_vector(exp, **backend)
     scale = utils.make_vector(scale, **backend)
     if dist and (scale > 0).all():
         sampler = dist(exp, scale)
     else:
         sampler = _get_dist('dirac')(exp)
     return sampler
예제 #8
0
    def forward(self, x):
        dim = x.dim() - 2
        backend = dict(dtype=x.dtype, device=x.device)

        fwhm_exp = utils.make_vector(self.fwhm_exp, 1 if self.iso else dim, **backend)
        fwhm_scale = utils.make_vector(self.fwhm_scale, 1 if self.iso else dim, **backend)

        out = torch.as_tensor(x)
        for b in range(len(x)):
            fwhm = self.fwhm(fwhm_exp, fwhm_scale).sample().clamp_min_(0).expand([dim]).clone()
            out[b] = smooth(x[b], fwhm=fwhm, dim=dim, padding='same', bound='dct2')
        return out
예제 #9
0
 def _make_sampler(self, name, dim, **backend):
     exp = getattr(self, name + '_exp')
     scale = getattr(self, name + '_scale')
     dist = getattr(self, name)
     ndim = (dim if name in ('translation', 'zoom')
             else dim*(dim-1)//2)
     exp = utils.make_vector(exp, ndim, **backend)
     scale = utils.make_vector(scale, ndim, **backend)
     if dist and (scale > 0).all():
         sampler = dist(exp, scale)
     else:
         sampler = _get_dist('dirac')(exp)
     return sampler
예제 #10
0
파일: field.py 프로젝트: balbasty/nitorch
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch size

        Other Parameters
        ----------------
        shape : sequence[int], optional
        channel : int, optional
        device : torch.device, optional
        dtype : torch.dtype, optional

        Returns
        -------
        field : (batch, channel, *shape) tensor
            Generated random field

        """

        # get arguments
        shape = overload.get('shape', self.shape)
        channel = overload.get('channel', self.channel)
        dtype = overload.get('dtype', self.dtype)
        device = overload.get('device', self.device)
        backend = dict(dtype=dtype, device=device)

        # sample if parameters are callable
        nb_dim = len(shape)

        # device/dtype
        mean = utils.make_vector(self.mean, channel, **backend)
        amplitude = utils.make_vector(self.amplitude, channel, **backend)
        fwhm = utils.make_vector(self.fwhm, channel, **backend)
        
        # convert SE parameters to noise/kernel parameters
        sigma_se = fwhm / math.sqrt(8*math.log(2))
        amplitude = amplitude * (2*pi)**(nb_dim/4) * sigma_se.sqrt()
        fwhm = fwhm * math.sqrt(2)
        
        # smooth
        out = torch.empty([batch, channel, *shape], **backend)
        for b in range(batch):
            for c in range(channel):
                sample = torch.distributions.Normal(mean[c], amplitude[c]).sample(shape)
                out[b, c] = spatial.smooth(
                    sample, 'gauss', fwhm,
                    basis=self.basis, bound='dct2', dim=nb_dim, padding='same')
        return out
예제 #11
0
파일: spatial.py 프로젝트: balbasty/nitorch
    def __init__(self, radii=1, pradii=1):
        """

        Parameters
        ----------
        radii : [sequence of] float, defualt=1
            List of possible vessel radii (= FWHM of the Gaussian filter)
        pradii : [sequence of] float, default=1
            Prior probability of each radius. Will be normalized to one.
        """
        super().__init__()
        self.radii = utils.make_vector(radii)
        pradii = utils.make_vector(pradii, len(self.radii))
        self.pradii = pradii / pradii.sum()
예제 #12
0
def mrf_covariance(Z, W=None, vx=1):
    """Compute the covariance of the MRF term

    Notes
    -----
    .. This function returns
            V = \sum_n (diag(Z[:, n]) - Z[:, n] @ Z[:, n].T) * W[n]
                * \sum_{m \in Neighbours(n)} 1/square(d(n,m))
       where Z are the input responsibilities, W are the voxels weights
       and d(n,m) is the distance between voxels n and m (i.e., voxel size).
    .. Only first order neighbors are used (4 in 2D, 6 in 3D).

    Parameters
    ----------
    Z : (K, *spatial) tensor
        Responsibilities
    W : (*spatial) tensor, optional
        Voxel weights
    vx : [sequence of] float, default=1
        Voxel size

    Returns
    -------
    V : (K, K) tensor
        Covariance

    """
    def reduce(P, Q):
        P = P.reshape([len(P), -1])
        Q = Q.reshape([len(Q), -1])
        return P.matmul(Q.T)

    dim = Z.dim() - 1
    vx = utils.make_vector(vx, dim, dtype=Z.dtype, device='cpu')
    ivx2 = vx.reciprocal().square_()

    # build weights responsibilities
    V = Z.new_zeros(Z.shape[1:])
    for d in range(dim):
        V = V.transpose(d, 0)
        V[1:] += ivx2[d]
        V[:-1] += ivx2[d]
        V = V.transpose(d, 0)
    if W is not None:
        V *= W
        if W.dtype is not torch.bool:
            V *= W
            V *= W
    V *= 2  # overcounting

    # Compute (weighted) covariance
    V = reduce(Z, Z * V).neg_()
    Vdiag = V.diagonal(0, -1, -2)
    if W is None:
        Vdiag += Z.reshape([len(Z), -1]).sum(-1)
    elif W.dtype is torch.bool:
        Vdiag += Z[:, W].sum(-1)
    else:
        Vdiag += Z.reshape([len(Z), -1]).matmul(W.reshape([-1, 1]))[:, 0]
    return V
예제 #13
0
def l1_distance_transform(x, dim=None, vx=1):
    """Compute the L1 distance transform of a binary image

    Parameters
    ----------
    x : (..., *spatial) tensor
        Input tensor
    dim : int, default=`x.dim()`
        Number of spatial dimensions
    vx : [sequence of] float, default=1
        Voxel size

    Returns
    -------
    d : (..., *spatial) tensor
        Distance map

    References
    ----------
    ..[1] "Distance Transforms of Sampled Functions"
          Pedro F. Felzenszwalb & Daniel P. Huttenlocher
          Theory of Computing (2012)
          https://www.theoryofcomputing.org/articles/v008a019/v008a019.pdf
    """
    dtype = x.dtype if x.dtype.is_floating_point else torch.get_default_dtype()
    x = x.to(dtype, copy=True)
    x.masked_fill_(x > 0, float('inf'))
    dim = dim or x.dim()
    vx = utils.make_vector(vx, dim, dtype=torch.float).tolist()
    for d, w in enumerate(vx):
        x = _l1dt_1d_(x, d-dim, w)
    return x
예제 #14
0
    def forward(self, *image, **overload):
        """

        Parameters
        ----------
        image : (batch, channel, *spatial)
        overload

        Returns
        -------

        """

        image = list(image)
        device = image[0].device

        nb_dim = image[0].dim() - 2
        prob = utils.make_vector(overload.get('prob', self.prob),
                                 dtype=torch.float, device=device)
        dim = overload.get('dim', self.dim)
        dim = py.make_list(dim or range(-nb_dim, 0), nb_dim)

        # sample shift
        flip = torch.rand((nb_dim,), device=device) > (1 - prob)
        dim = [d for d, f in zip(dim, flip) if f]

        if dim:
            for i, img in enumerate(image):
                image[i] = img.flip(dim)
        return image[0] if len(image) == 1 else tuple(image)
예제 #15
0
파일: mixture.py 프로젝트: balbasty/nitorch
    def forward(self, x, **overload):
        """

        Parameters
        ----------
        x : (batch, 1 or classes[-1], *shape) tensor
            Labels or probabilities

        Returns
        -------
        x : (batch, channel, *shape) tensor

        """
        batch, _, *shape = x.shape

        device = x.device
        dtype = x.dtype
        if not dtype.is_floating_point:
            dtype = self.dtype
        backend = dict(dtype=dtype, device=device)

        nb_classes = overload.get('nb_classes', self.nb_classes)
        nb_channels = overload.get('nb_channels', self.nb_channels)
        means_exp = torch.as_tensor(self.means_exp, **backend)
        means_scale = torch.as_tensor(self.means_scale, **backend)
        scales_exp = torch.as_tensor(self.scales_exp, **backend)
        scales_scale = torch.as_tensor(self.scales_scale, **backend)
        means_exp = means_exp.expand([nb_channels, nb_classes]).clone()
        means_scale = means_scale.expand([nb_channels, nb_classes]).clone()
        scales_exp = scales_exp.expand([nb_channels, nb_classes]).clone()
        scales_scale = scales_scale.expand([nb_channels, nb_classes]).clone()
        fwhm_exp = utils.make_vector(self.fwhm_exp, nb_channels, **backend)
        fwhm_scale = utils.make_vector(self.fwhm_scale, nb_channels, **backend)

        out = torch.zeros([batch, nb_channels, *shape], **backend)
        for b in range(batch):
            means = self.means(means_exp, means_scale).sample()
            scales = self.scales(scales_exp,
                                 scales_scale).sample().clamp_min_(0)
            if self.background_zero:
                means[:, 0] = 0
                scales[:, 0] = 0.1
            fwhm = self.fwhm(fwhm_exp, fwhm_scale).sample().clamp_min_(0)
            sampler = RandomGaussianMixture(means, scales, fwhm=fwhm)
            out[b] = sampler(x[None, b])[0]
        return out
예제 #16
0
def spherical_harmonics(shape, order=2, isocenter=None, **backend):
    """Generate a basis of spherical harmonics on a lattice

    Notes
    -----
    .. This should be checked!
    .. Only orders 1 and 2 implemented
    .. I tried to implement some sort of "circular" harmonics in
       dimension 2 but I don't know what I am doing.
    .. The basis is not orthogonal

    Parameters
    ----------
    shape : sequence of int
    order : {1, 2}, default=2
    isocenter : [sequence of] int, default=shape/2
    dtype : torch.dtype, optional
    device : torch.device, optional

    Returns
    -------
    b : (*shape, 2*order + 1) tensor
        Basis

    """
    shape = py.make_list(shape)
    dim = len(shape)
    if dim not in (2, 3):
        raise ValueError('Dimension must be 2 or 3')
    if order not in (1, 2):
        raise ValueError('Order must be 1 or 2')

    if isocenter is None:
        isocenter = [s / 2 for s in shape]
    isocenter = utils.make_vector(isocenter, **backend)

    ramps = identity_grid(shape, **backend)
    for i, ramp in enumerate(ramps.unbind(-1)):
        ramp -= isocenter[i]
        ramp /= shape[i] / 2

    if order == 1:
        return ramps
    # order == 2
    if dim == 3:
        basis = [
            ramps[..., 0] * ramps[..., 1], ramps[..., 0] * ramps[..., 2],
            ramps[..., 1] * ramps[..., 2],
            ramps[..., 0].square() - ramps[..., 1].square(),
            ramps[..., 0].square() - ramps[..., 2].square()
        ]
        return torch.stack(basis, -1)
    else:  # basis == 2
        basis = [
            ramps[..., 0] * ramps[..., 1],
            ramps[..., 0].square() - ramps[..., 1].square()
        ]
        return torch.stack(basis, -1)
예제 #17
0
    def forward(self, x):
        dim = x.dim() - 2
        backend = utils.backend(x)
        kernel_exp = utils.make_vector(self.kernel_exp, dim, **backend)
        kernel_scale = utils.make_vector(self.kernel_scale, dim, **backend)

        shape = x.shape[2:]
        for n in range(self.nb_drop):
            kernel = [self.kernel(k_e, k_s).sample() for k_e,k_s in zip(kernel_exp, kernel_scale)]
            kernel = [torch.clamp(k, min=4, max=shape[i]).int().item() for i,k in enumerate(kernel)]
            pshape = [x+(k-x%k) for x,k in zip(shape,kernel)]
            x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape))
            x = utils.unfold(x, kernel, collapse=True)
            i1 = torch.randint(low=0, high=x.shape[2]-1, size=(1,)).item()
            x[:,:,i1] = 0
            x = utils.fold(x, dim=dim, stride=kernel, collapsed=True, shape=pshape)
            x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape))
        return x
예제 #18
0
def mrf_suffstat(Z, W=None, vx=1):
    """ Compute the sum of probabilities across neighbors

    Notes
    -----
    .. This function returns
            Z[k, m] = \sum_{n \in Neighbours(m)} Z[k, n] * W[n] / d(n,m)
       where Z are the input responsibilities, W are the voxels weights
       and d(n,m) is the distance between voxels n and m (i.e., voxel size).
    .. Only first order neighbors are used (4 in 2D, 6 in 3D).

    Parameters
    ----------
    Z : (K, *spatial) tensor
        Responsibilities
    W : (*spatial) tensor, optional
        Voxel weights
    vx : [sequence of] float, default=1
        Voxel size

    Returns
    -------
    E : (K, *spatial) tensor
        Sufficient statistics

    """
    dim = Z.dim() - 1
    K = len(Z)
    vx = utils.make_vector(vx, dim, dtype=Z.dtype, device='cpu')
    ivx = vx.reciprocal().tolist()

    S = torch.zeros_like(Z)
    if W is not None:
        W = W.to(Z.dtype)

    # iterate across first order neighbors
    for d in range(dim):
        Z = Z.transpose(d + 1, 1)
        S = S.transpose(d + 1, 1)
        if W is not None:
            W = W.transpose(d, 0)
        ivx1 = ivx[d]

        for k in range(K):
            if W is not None:
                S[k, 1:].addcmul_(W[:-1], Z[k, :-1], value=ivx1)
                S[k, :-1].addcmul_(W[1:], Z[k, 1:], value=ivx1)
            else:
                S[k, 1:].add_(Z[k, :-1], alpha=ivx1)
                S[k, :-1].add_(Z[k, 1:], alpha=ivx1)

        Z = Z.transpose(1, d + 1)
        S = S.transpose(1, d + 1)
        if W is not None:
            W = W.transpose(0, d)

    return S
예제 #19
0
    def forward(self, x):
        dim = x.dim() - 2
        backend = utils.backend(x)
        kernel_exp = utils.make_vector(self.kernel_exp, dim,
                                           **backend)
        kernel_scale = utils.make_vector(self.kernel_scale, dim,
                                             **backend)

        kernel = [self.kernel(k_e, k_s).sample() for k_e,k_s in zip(kernel_exp, kernel_scale)]
        shape = x.shape[2:]
        kernel = [torch.clamp(k, min=4, max=shape[i]).int().item() for i,k in enumerate(kernel)]
        pshape = [x+(k-x%k) for x,k in zip(shape,kernel)]
        x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape))
        x = utils.unfold(x, kernel, collapse=True)
        x = x[:, :, torch.randperm(x.shape[2])]
        x = utils.fold(x, dim=dim, stride=kernel, collapsed=True, shape=pshape)
        x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape))
        return x
예제 #20
0
파일: field.py 프로젝트: balbasty/nitorch
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch size

        Other Parameters
        ----------------
        shape : sequence[int], optional
        channel : int, optional
        device : torch.device, optional
        dtype : torch.dtype, optional

        Returns
        -------
        field : (batch, channel, *shape) tensor
            Generated random field

        """

        # get arguments
        shape = overload.get('shape', self.shape)
        channel = overload.get('channel', self.channel)
        dtype = overload.get('dtype', self.dtype)
        device = overload.get('device', self.device)
        backend = dict(dtype=dtype, device=device)

        # device/dtype
        nb_dim = len(shape)
        mean = utils.make_vector(self.mean, channel, **backend)
        amplitude = utils.make_vector(self.amplitude, channel, **backend)
        fwhm = utils.make_vector(self.fwhm, nb_dim, **backend)

        # sample spline coefficients
        nodes = [(s/f).ceil().int().item()
                 for s, f in zip(shape, fwhm)]
        sample = torch.randn([batch, channel, *nodes], **backend)
        sample *= utils.unsqueeze(amplitude, -1, nb_dim)
        sample = spatial.resize(sample, shape=shape, interpolation=self.basis,
                                bound='dct2', prefilter=False)
        sample += utils.unsqueeze(mean, -1, nb_dim)
        return sample
예제 #21
0
def bending(field, voxel_size=1, bound='dct2', dim=None, weights=None):
    """Precision matrix for the Bending energy

    Note
    ----
    .. This is exactly equivalent to SPM's bending energy

    Parameters
    ----------
    field : (..., *spatial) tensor
    voxel_size : float or sequence[float], default=1
    bound : str, default='dct2'
    dim : int, default=field.dim()
    weights : (..., *spatial) tensor, optional

    Returns
    -------
    field : (..., *spatial) tensor

    """
    field = torch.as_tensor(field)
    dim = dim or field.dim()
    voxel_size = make_vector(voxel_size, dim)
    bound = make_list(bound, dim)
    dims = list(range(field.dim() - dim, field.dim()))
    if weights is not None:
        backend = dict(dtype=field.dtype, device=field.device)
        weights = torch.as_tensor(weights, **backend)

    mom = 0
    for i in range(dim):
        for side_i in ('f', 'b'):
            opti = dict(dim=dims[i],
                        bound=bound[i],
                        side=side_i,
                        voxel_size=voxel_size[i])
            di = diff1d(field, **opti)
            for j in range(i, dim):
                for side_j in ('f', 'b'):
                    optj = dict(dim=dims[j],
                                bound=bound[j],
                                side=side_j,
                                voxel_size=voxel_size[j])
                    dj = diff1d(di, **optj)
                    if weights is not None:
                        dj = dj * weights
                    dj = div1d(dj, **optj)
                    dj = div1d(dj, **opti)
                    if i != j:
                        # off diagonal -> x2  (upper + lower element)
                        dj = dj * 2
                    mom += dj
    mom = mom / 4.
    return mom
예제 #22
0
 def resize(cls, x, affine, target_vx=1):
     target_vx = utils.make_vector(target_vx, x.dim(),
                                   **utils.backend(affine))
     vx = spatial.voxel_size(affine)
     factor = vx / target_vx
     fwhm = 0.25 * factor.reciprocal()
     fwhm[factor > 1] = 0
     x = spatial.smooth(x, fwhm=fwhm.tolist(), dim=3)
     x, affine = spatial.resize(x[None, None],
                                factor.tolist(),
                                affine=affine)
     x = x[0, 0]
     return x, affine
예제 #23
0
파일: main.py 프로젝트: balbasty/nitorch
def downsample(x, aff_in, vx_out):
    """
    Downsample an image (by an integer factor) to approximately
    match a target voxel size
    """
    vx_in = spatial.voxel_size(aff_in)
    dim = len(vx_in)
    vx_out = utils.make_vector(vx_out, dim)
    factor = (vx_out / vx_in).clamp_min(1).floor().long()
    if (factor == 1).all():
        return x, aff_in
    factor = factor.tolist()
    x, aff_out = spatial.pool(dim, x, factor, affine=aff_in)
    return x, aff_out
예제 #24
0
파일: struct.py 프로젝트: balbasty/nitorch
 def call(self, v):
     factor = make_vector(self.factor, 2, dtype=v.dtype, device=v.device)
     loss = 0
     if factor[0]:
         m = spatial.lame_div(v)
         loss += (v * m).sum(-1).mean()
         if factor[0] != 1:
             loss = loss * factor[0]
     if factor[1]:
         m = spatial.lame_shear(v)
         loss += (v * m).sum(-1).mean()
         if factor[1] != 1:
             loss = loss * factor[1]
     return loss
예제 #25
0
    def _init_par(self, X, W=None):
        """  Initialise CMM specific parameters: dof, sig

        """
        K = self.K
        dtype = torch.float64

        # Init mixing prop
        self._init_mp(dtype)

        if self.sig is None:
            if W is None:
                self.sig = torch.mean(X) * 5
                self.sig = torch.sum((self.sig - X).square()) / torch.numel(X)
                self.sig = torch.sqrt(
                    self.sig / (K + 1) *
                    (torch.arange(1, K + 1, dtype=dtype, device=self.dev)))
            else:
                self.sig = 5 * (W * X).sum() / W.sum()
                self.sig = (W * (self.sig - X).square()).sum() / W.sum()
                self.sig = torch.sqrt(
                    self.sig / (K + 1) *
                    (torch.arange(1, K + 1, dtype=dtype, device=self.dev)))
        else:
            self.sig = utils.make_vector(self.sig,
                                         K,
                                         dtype=dtype,
                                         device=self.dev)

        if self.dof is None:
            self.dof = torch.full([K], 3, dtype=dtype, device=self.dev)
        else:
            self.dof = utils.make_vector(self.dof,
                                         K,
                                         dtype=dtype,
                                         device=self.dev)
        return
예제 #26
0
def _membrane_l2(field, voxel_size=1, bound='dct2', dim=None):
    """Precision matrix for the Membrane energy

    Note
    ----
    .. Specialized implementation for the l2 version (i.e., no weight map).
    .. This is exactly equivalent to SPM's membrane energy

    Parameters
    ----------
    field : (..., *spatial) tensor
    voxel_size : float or sequence[float], default=1
    bound : str, default='dct2'
    dim : int, default=field.dim()

    Returns
    -------
    field : (..., *spatial) tensor

    """
    field = torch.as_tensor(field)
    backend = core.utils.backend(field)
    dim = dim or field.dim()
    voxel_size = make_vector(voxel_size, dim, **backend)
    vx = voxel_size.square().reciprocal()

    # build sparse kernel
    kernel = [2 * vx.sum()]
    center_index = [1] * dim
    indices = [list(center_index)]
    for d in range(dim):
        # cross
        kernel += [-vx[d]] * 2
        index = list(center_index)
        index[d] = 0
        indices.append(index)
        index = list(center_index)
        index[d] = 2
        indices.append(index)
    indices = torch.as_tensor(indices, dtype=torch.long, device=field.device)
    kernel = torch.as_tensor(kernel, **backend)
    kernel = torch.sparse_coo_tensor(indices.t(), kernel, [3] * dim)

    # perform convolution
    return spconv(field, kernel, bound=bound, dim=dim)
예제 #27
0
def read_info(options):
    """Load affine transforms and space info of other volumes"""
    def read_file(fname):
        o = struct.FileWithInfo()
        o.fname = fname
        o.dir = os.path.dirname(fname) or '.'
        o.base = os.path.basename(fname)
        o.base, o.ext = os.path.splitext(o.base)
        if o.ext in ('.gz', '.bz2'):
            zext = o.ext
            o.base, o.ext = os.path.splitext(o.base)
            o.ext += zext
        f = io.volumes.map(fname)
        o.float = nitype(f.dtype).is_floating_point
        o.shape = squeeze_to_nd(f.shape, dim=3, channels=1)
        o.channels = o.shape[-1]
        o.shape = o.shape[:3]
        o.affine = f.affine.float()
        return o

    def read_affine(fname):
        mat = io.transforms.loadf(fname).float()
        return squeeze_to_nd(mat, 0, 2)

    def read_field(fname):
        f = io.volumes.map(fname)
        return f.affine.float(), f.shape[:3]

    options.files = [read_file(file) for file in options.files]
    for trf in options.transformations:
        if isinstance(trf, struct.Linear):
            trf.affine = read_affine(trf.file)
        else:
            trf.affine, trf.shape = read_field(trf.file)
    if options.target:
        options.target = read_file(options.target)
        if options.voxel_size:
            options.voxel_size = utils.make_vector(
                options.voxel_size, 3, dtype=options.target.affine.dtype)
            factor = spatial.voxel_size(
                options.target.affine) / options.voxel_size
            options.target.affine, options.target.shape = \
                spatial.affine_resize(options.target.affine, options.target.shape,
                                      factor=factor, anchor='f')
예제 #28
0
파일: spatial.py 프로젝트: balbasty/nitorch
    def shape(self, image, affine=None, output_shape=None):
        output_shape = output_shape or self.shape

        # read parameters
        if torch.is_tensor(image):
            inshape = tuple(image.shape)
        else:
            inshape = image
        nb_dim = len(inshape) - 2
        batch = inshape[:2]
        inshape = inshape[2:]
        factor = utils.make_vector(self.factor or 0., nb_dim).tolist()
        output_shape = make_list(output_shape or [None], nb_dim)

        # compute output shape
        output_shape = [
            int(inshp * f) if outshp is None else outshp
            for inshp, outshp, f in zip(inshape, output_shape, factor)
        ]
        return (*batch, *output_shape)
예제 #29
0
def absolute_grid(grid, voxel_size=1, weights=None):
    """Precision matrix for the Absolute energy of a deformation grid

    Parameters
    ----------
    grid : (..., *spatial, dim) tensor
    voxel_size : float or sequence[float], default=1
    weights : (..., *spatial) tensor, optional

    Returns
    -------
    field : (..., *spatial, dim) tensor

    """
    grid = torch.as_tensor(grid)
    dim = grid.shape[-1]
    voxel_size = make_vector(voxel_size, dim)
    grid = grid * voxel_size.square()
    if weights is not None:
        backend = dict(dtype=grid.dtype, device=grid.device)
        weights = torch.as_tensor(weights, **backend)
        grid = grid * weights[..., None]
    return grid
예제 #30
0
def crop(inp,
         size=None,
         center=None,
         space='vx',
         like=None,
         bbox=False,
         output=None,
         transform=None):
    """Crop a ND volume, while preserving the orientation matrices.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    size : [sequence of] int, optional
        Size of the patch to extract.
        Its unit and axes are defined by `units` and `layout`.
    center : [sequence of] int, optional
        Coordinate of the center of the patch.
        Its unit and axes are defined by `units` and `layout`.
        By default, the center of the FOV is used.
    space : [sequence of] {'vox', 'ras'}, default='vox'
        The space in which the `size` and `center` parameters are expressed.
    bbox : bool or float, default=False
        Crop at the bounding box of `inp > threshold`.
            If `bbox` is a float, it is the threshold to use.
            If `bbox` is `True`, the threshold is 0.
    like : str or (tensor, tensor), optional
        Reference patch.
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    output : [sequence of] str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{i}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file,
        `i` is the coordinate (starting at 1) of the slice.
    transform : [sequence of] str, optional
        Input or output filename(s) of the corresponding transforms.
        Not written by default.
        If a transform is provided and all other parameters
        (i.e., `size` and `like`) are None, the transform is considered
        as an input transform to apply.

    Returns
    -------
    output : list[str or (tensor, tensor)]
        If the input is a path, the output paths are returned.
        Else, the unstacked data and orientation matrices are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = None
    transform_in = False
    use_bbox = bool(bbox or isinstance(bbox, float))

    # --- Open input ---
    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.data(numpy=True) if use_bbox else f, f.affine)
        if output is None:
            output = '{dir}{sep}{base}.crop{ext}'
        dir, base, ext = py.fileparts(fname)
    dat, aff0 = inp
    dim = aff0.shape[-1] - 1
    shape0 = dat.shape[:dim]
    layout0 = spatial.affine_to_layout(aff0)

    # save input space in case we reorient later
    aff00 = aff0
    shape00 = shape0

    if bool(size) + bool(like) + bool(bbox or isinstance(bbox, float)) > 1:
        raise ValueError('Can only use one of `size`, `like` and `bbox`.')

    # --- Open reference and compute size/center ---
    if like:
        like_is_file = isinstance(like, str)
        if like_is_file:
            f = io.volumes.map(like)
            like = (f.shape, f.affine)
        like_shape, like_aff = like
        like_layout = spatial.affine_to_layout(like_aff)
        if (layout0 != like_layout).any():
            aff0, dat = spatial.affine_reorient(aff0, dat, like_layout)
            shape0 = dat.shape[:dim]
        if torch.is_tensor(like_shape):
            like_shape = like_shape.shape
        size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape)
        space = 'vox'

    elif bbox or isinstance(bbox, float):
        if bbox is True:
            bbox = 0.
        mask = torch.as_tensor(dat > bbox)
        while mask.dim() > 3:
            mask = mask.any(dim=-1)
        mins = []
        maxs = []
        for d in range(dim):
            n = mask.shape[d]
            idx = utils.movedim(mask, d,
                                0).reshape([n, -1
                                            ]).any(-1).nonzero(as_tuple=False)
            mins.append(idx.min())
            maxs.append(idx.max())
        mins = utils.as_tensor(mins)
        maxs = utils.as_tensor(maxs)
        size = maxs + 1 - mins
        center = (maxs + 1 + mins).float() / 2
        space = 'vox'
        del mask

    # --- Open transformation file and compute size/center ---
    elif not size:
        if not transform:
            raise ValueError('At least one of size/like/transform must '
                             'be provided')
        transform_in = True
        t = io.transforms.map(transform)
        if not isinstance(t, io.transforms.LinearTransformArray):
            raise TypeError('Expected an LTA file')
        like_aff, like_shape = t.destination_space()
        size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape)

    # --- use center of the FOV ---
    if not torch.is_tensor(center) and not center:
        center = torch.as_tensor(shape0[:dim], dtype=torch.float)
        center = center.sub_(1).mul_(0.5)

    # --- convert size/center to voxels ---
    size = utils.make_vector(size, dim, dtype=torch.long)
    center = utils.make_vector(center, dim, dtype=torch.float)
    space_size, space_center = py.make_list(space, 2)
    if space_center.lower() == 'ras':
        center = spatial.affine_matvec(spatial.affine_inv(aff0), center)
    if space_size.lower() == 'ras':
        perm = spatial.affine_to_layout(aff0)[:, 0]
        size = size[perm.long()]
        size = size / spatial.voxel_size(aff0)

    # --- compute first/last ---
    center = center.float()
    size = (size.ceil() if size.dtype.is_floating_point else size).long()
    first = center - size.float().sub_(1).mul_(0.5)
    first = first.round().long()
    last = (first + size).tolist()
    first = [max(f, 0) for f in first.tolist()]
    last = [min(l, s) for l, s in zip(last, shape0[:dim])]
    verb = 'Cropping patch ['
    verb += ', '.join([f'{f}:{l}' for f, l in zip(first, last)])
    verb += f'] from volume with shape {shape0[:dim]}'
    print(verb)
    slicer = tuple(slice(f, l) for f, l in zip(first, last))

    # --- do crop ---
    if use_bbox:
        dat = dat.numpy()
    dat = dat[slicer]
    if not torch.is_tensor(dat):
        dat = dat.data(numpy=True)
    aff, _ = spatial.affine_sub(aff0, shape0[:dim], slicer)
    shape = dat.shape[:dim]

    if output:
        if is_file:
            output = output.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep)
            io.volumes.save(dat, output, like=fname, affine=aff)
        else:
            output = output.format(sep=os.path.sep)
            io.volumes.save(dat, output, affine=aff)

    if transform and not transform_in:
        if is_file:
            transform = transform.format(dir=dir or '.',
                                         base=base,
                                         ext=ext,
                                         sep=os.path.sep)
        else:
            transform = transform.format(sep=os.path.sep)
        trf = io.transforms.LinearTransformArray(transform, 'w')
        trf.set_source_space(aff00, shape00)
        trf.set_destination_space(aff, shape)
        trf.set_metadata({
            'src': {
                'filename': fname
            },
            'dst': {
                'filename': output
            },
            'type': 1
        })  # RAS_TO_RAS
        trf.set_fdata(torch.eye(4))
        trf.save()

    if is_file:
        return output
    else:
        return dat, aff