示例#1
0
文件: vae.py 项目: balbasty/nitorch
 def __init__(self, dim, img_size, in_channels, latent_dim, out_channels=None, 
              encoder=(16, 32, 64, 128, 256), decoder=None, neurite=False,
              final_activation='Tanh', **kwargs):
     super().__init__()
     self.dim = dim
     if isinstance(img_size, int):
         img_size = [img_size] * dim
     out_channels = out_channels or in_channels
     encoder = list(encoder)
     if not decoder:
         decoder = list(encoder)
         decoder.reverse()
     # TODO: add options for full stack like UNet
     self.decoder_in_channels = decoder[0]
     if neurite:
         self.encoder = NeuriteEncoder(dim, in_channels, encoder, **kwargs)
         self.decoder = NeuriteDecoder(dim, decoder[0], decoder[1:], **kwargs)
         self.final = NeuriteConv(dim, decoder[-1], out_channels, activation=final_activation)
     else:
         self.encoder = Encoder(dim, in_channels, encoder, **kwargs)
         self.decoder = Decoder(dim, decoder[0], decoder[1:], **kwargs)
         self.final = ConvBlock(dim, decoder[-1], out_channels, activation=final_activation)
     shape = torch.tensor(img_size).unsqueeze(0).unsqueeze(0)
     for layer in self.encoder:
         shape = layer.shape(shape)
     self.out_shape = shape[2:]
     self.latent_mu = Linear(encoder[-1] * py.prod(self.shape), latent_dim)
     self.latent_sigma = Linear(encoder[-1] * py.prod(self.shape), latent_dim)
     self.decoder_input = Linear(latent_dim, decoder[0]*4)
示例#2
0
def _pyramid_levels(vxs, shapes, opt):
    """
    Map global pyramid levels to per-image pyramid levels.

    The idea is that we're trying to match resolutions as well as possible
    across images at each global pyramid level. So we may be matching
    the 3rd pyramid level of an image with the 5th pyramid level of another
    image.
    """
    dim = len(shapes[0])
    # first: compute approximate voxel size and shape at each level
    pyramids = [
        _pyramid_levels1(vx, shape, opt) for vx, shape in zip(vxs, shapes)
    ]
    # NOTE: pyramids = [[(level, vx, shape), ...], ...]

    # second: match voxel sizes across images
    vx0 = min([py.prod(pyramid[0][1]) for pyramid in pyramids])
    vx0 = pymath.log2(vx0**(1 / dim))
    level_offsets = []
    for pyramid in pyramids:
        level, vx, shape = pyramid[0]
        vx1 = pymath.log2(py.prod(vx)**(1 / dim))
        level_offsets.append(round(vx1 - vx0))

    # third: keep only levels that overlap across images
    max_level = min(o + len(p) for o, p in zip(level_offsets, pyramids))
    pyramids = [p[:max_level - o] for o, p in zip(level_offsets, pyramids)]
    if any(len(p) == 0 for p in pyramids):
        raise ValueError(f'Some images do not overlap in the pyramid.')

    # fourth: compute pyramid index of each image at each level
    select_levels = []
    for level in opt.levels:
        if isinstance(level, int):
            select_levels.append(level)
        else:
            select_levels.extend(list(level))

    map_levels = []
    for pyramid, offset in zip(pyramids, level_offsets):
        map_levels1 = [pyramid[0][0]] * offset
        for pyramid_level in pyramid:
            map_levels1.append(pyramid_level[0])
        if select_levels:
            map_levels1 = [
                map_levels1[l] for l in select_levels if l < len(map_levels1)
            ]
        map_levels.append(map_levels1)

    return map_levels
示例#3
0
文件: prod.py 项目: balbasty/nitorch
def squeezed_prod(moving,
                  fixed,
                  lam=1,
                  dim=None,
                  grad=True,
                  hess=True,
                  mask=None):

    dim = dim or (fixed.dim() - 1)
    nvox = py.prod(fixed.shape[-dim:])

    e = (moving * fixed).mul_(-lam / 2).exp_()

    ll = 1 - e
    if mask is not None:
        ll *= mask
    ll = ll.sum() / nvox
    out = [ll]

    if grad:
        g = (e * fixed).mul_(lam / (2 * nvox))
        if mask is not None:
            g *= mask
        out.append(g)

    if hess:
        h = (e * fixed).mul_(fixed).mul_((lam / 2)**2 * nvox)
        if mask is not None:
            h *= mask
        out.append(h)

    return tuple(out) if len(out) > 1 else out[0]
示例#4
0
文件: prod.py 项目: balbasty/nitorch
def prod(moving, fixed, dim=None, grad=True, hess=True, mask=None):

    dim = dim or (fixed.dim() - 1)
    nvox = py.prod(fixed.shape[-dim:])

    ll = moving * fixed
    if mask is not None:
        ll *= mask
    ll = ll.sum() / nvox
    out = [ll]

    if grad:
        g = fixed / nvox
        if mask is not None:
            g *= mask
        out.append(g)

    if hess:
        if mask is not None:
            h = mask.to(moving.dtype, copy=True).unsqueeze(-dim - 1).div_(nvox)
        else:
            h = moving.new_full([1] * (dim + 1), 1 / nvox)
        out.append(h)

    return tuple(out) if len(out) > 1 else out[0]
示例#5
0
 def exp2(self, v=None, jacobian=False, add_identity=False,
          cache_result=False, recompute=True):
     """Exponentiate both forward and inverse transforms"""
     if v is None:
         v = self.dat.dat
     if recompute or self._cache is None or self._icache is None:
         grid, igrid = spatial.shoot(v, self.kernel, steps=self.steps,
                                     factor=self.factor / py.prod(self.shape),
                                     voxel_size=self.voxel_size, **self.penalty,
                                     return_inverse=True, displacement=True)
     if cache_result:
         self._cache = grid
         self._icache = igrid
     if jacobian:
         jac = spatial.grid_jacobian(grid, type='displacement')
         ijac = spatial.grid_jacobian(igrid, type='displacement')
         if add_identity:
             grid = self.add_identity(grid)
             igrid = self.add_identity(igrid)
         return grid, igrid, jac, ijac
     else:
         if add_identity:
             grid = self.add_identity(grid)
             igrid = self.add_identity(igrid)
         return grid, igrid
示例#6
0
    def data(self,
             dtype=None,
             device=None,
             casting='unsafe',
             rand=True,
             cutoff=None,
             dim=None,
             numpy=False):

        # --- sanity check before reading ---
        dtype = self.dtype if dtype is None else dtype
        dtype = dtypes.dtype(dtype)
        if not numpy and dtype.torch is None:
            raise TypeError(
                'Data type {} does not exist in PyTorch.'.format(dtype))

        # --- check that view is not empty ---
        if py.prod(self.shape) == 0:
            if numpy:
                return np.zeros(self.shape, dtype=dtype.numpy)
            else:
                return torch.zeros(self.shape,
                                   dtype=dtype.torch,
                                   device=device)

        # --- read native data ---
        slicer, perm, newdim = split_operation(self.permutation, self.slicer,
                                               'r')
        with self.tiffobj() as f:
            dat = self._read_data_raw(slicer, tiffobj=f)
        dat = dat.transpose(perm)[newdim]
        indtype = dtypes.dtype(self.dtype)

        # --- cutoff ---
        dat = volutils.cutoff(dat, cutoff, dim)

        # --- cast ---
        rand = rand and not indtype.is_floating_point
        if rand and not dtype.is_floating_point:
            tmpdtype = dtypes.float64
        else:
            tmpdtype = dtype
        dat, scale = volutils.cast(dat,
                                   tmpdtype.numpy,
                                   casting,
                                   with_scale=True)

        # --- random sample ---
        # uniform noise in the uncertainty interval
        if rand and not (scale == 1 and not dtype.is_floating_point):
            dat = volutils.addnoise(dat, scale)

        # --- final cast ---
        dat = volutils.cast(dat, dtype.numpy, 'unsafe')

        # convert to torch if needed
        if not numpy:
            dat = torch.as_tensor(dat, device=device)
        return dat
示例#7
0
def _auto_weighted_hard(truth, nb_classes, **backend):
    dim = truth.dim() - 2
    nvox = py.prod(truth.shape[-dim:])
    weighted = [(truth == i).sum(dim=list(range(-dim, 0)), keepdim=True)
                for i in range(nb_classes)]
    weighted = torch.cat(weighted, dim=1).to(**backend)
    weighted = weighted.clamp_min_(0.5).div_(nvox).reciprocal_()
    return weighted
示例#8
0
 def set_kernel(self, kernel=None):
     if kernel is None:
         kernel = spatial.greens(self.shape, **self.penalty,
                                 factor=self.factor / py.prod(self.shape),
                                 voxel_size=self.voxel_size,
                                 **utils.backend(self.dat))
     self.kernel = kernel
     return self
示例#9
0
 def _make_chunks(self, x):
     """Cut output of hypernetwork into weights with correct shape"""
     offset = 0
     all_shapes = [p.shape for p in self._get_weights(self.network)]
     for shape in all_shapes:
         numel = py.prod(shape)
         w = x[offset:offset + numel].reshape(shape)
         offset += numel
         yield w
示例#10
0
文件: array.py 项目: balbasty/nitorch
    def raw_data(self):
        # --- check that view is not empty ---
        if py.prod(self.shape) == 0:
            return np.zeros(self.shape, dtype=self.dtype)

        # --- read native data ---
        slicer, perm, newdim = split_operation(self.permutation, self.slicer, 'r')
        with self.tiffobj() as f:
            dat = self._read_data_raw(slicer, tiffobj=f)
        dat = dat.transpose(perm)[newdim]
        return dat
示例#11
0
    def __init__(self, shape, in_channels, out_channels, nb_levels=0,
                 decoder=(32, 32, 32, 32), kernel_size=3,
                 activation=tnn.LeakyReLU(0.2), unpool=None):
        """

        Parameters
        ----------
        shape : sequence[int]
            Output spatial shape
        in_channels : int
            Number of input channels (= meta variables)
        out_channels : int
            Number of output channels
        nb_levels : int, default=0
            Number of levels in the decoder.
            If 0: directly generate the image using a dense layer.
        decoder : sequence[int], default=(32, 32, 32, 32)
            Number of features after each layers.
            If len(decoder) is larger than the number of levels, additional
            stride-1 convolutions are applied.
        kernel_size : [sequence of] int, default=3
        activation : str or callable, default=LeakyReLU(0.2)
        unpool : {'conv', 'up', None}, default=None
                'conv' -> 2x2x2 strided convolution (no bias, no activation)
                'up'   -> linear upsampling
                 None  -> use strided convolutions in the decoder
        """
        super().__init__()
        shape = py.make_list(shape)
        dim = len(shape)
        small_shape = [s // 2**nb_levels for s in shape]
        in_feat, *decoder = decoder
        self.dense = Linear(in_channels, py.prod(small_shape)*in_feat)
        self.reshape = lambda x: x.reshape([-1, in_feat, *small_shape])
        decoder, stack = decoder[:nb_levels], decoder[nb_levels:]
        if decoder:
            self.decoder = Decoder(dim, in_feat, decoder,
                                   kernel_size=kernel_size,
                                   activation=activation,
                                   unpool=unpool)
            in_feat = decoder[-1]
        else:
            self.decoder = lambda x: x
        if stack:
            self.stack = StackedConv(dim, in_feat, stack,
                                     kernel_size=kernel_size,
                                     activation=activation)
            in_feat = stack[-1]
        else:
            self.stack = lambda x: x
        self.final = StackedConv(dim, in_feat, out_channels,
                                 kernel_size=kernel_size,
                                 activation=None)
示例#12
0
def cc_sample(shape, sigma, alpha, mu=None, repeats=1, **backend):
    """Sample random fields with a constant correlation.

    This function computes the square root of the covariance matrix by SVD.

    Parameters
    ----------
    shape : sequence[int]
        Shape of the image / volume.å
    sigma : float
        Variance.
    alpha : float
        Correlation.
    mu : () or (*shape) tensor_like
        SE mean
    repeats : int, default=1
        Number of sampled fields.

    Returns
    -------
    field : (repeats, *shape) tensor
        Sampled random fields.

    """
    # Build SE covariance matrix
    n = py.prod(shape)
    e = torch.full([n, n], alpha, **backend)
    e.diagonal(0, -1, -2).add_(1 - alpha)
    backend = utils.backend(e)

    # SVD of covariance
    u, s, _ = torch.svd(e)
    s = s.sqrt_()

    # Sample white noise and apply transform
    full_shape = (repeats, *shape)
    field = torch.randn(full_shape, **backend).reshape([repeats, -1])
    field = linalg.matvec(u, field.mul_(s))
    field.mul_(sigma)
    field = field.reshape(full_shape)

    # Add mean
    if mu is not None:
        mu = torch.as_tensor(mu, **backend)
        field += mu

    return field
示例#13
0
 def is_virtual():
     # ImageJ virtual hyperstacks store all image metadata in the first
     # page and image data are stored contiguously before the second
     # page, if any
     if not page.is_final:
         return False
     images = meta.get('images', 0)
     if images <= 1:
         return False
     offset, count = page.is_contiguous
     if (count != py.prod(page.shape) * page.bitspersample // 8
             or offset + count * images > self.filehandle.size):
         raise ValueError()
     # check that next page is stored after data
     if len(pages) > 1 and offset + count * images > pages[1].offset:
         return False
     return True
示例#14
0
def fit_se_log(log_cov, sqdist):
    """Fit the amplitude and length-scale of a squared-exponential kernel

    Parameters
    ----------
    log_cov : (*batch, vox, vox)
        Log of the empirical covariance matrix
    sqdist : tuple[int] or (vox, vox) tensor
        If a tensor -> it is the pre-computed squared distance map
        If a tuple -> it is the shape and we build the distance map

    Returns
    -------
    sig : (*batch,) tensor
        Amplitude of the kernel
    lam : (*batch,) tensor
        Length-scale of the kernel

    """
    log_cov = torch.as_tensor(log_cov).clone()
    backend = utils.backend(log_cov)
    if not torch.is_tensor(sqdist):
        shape = sqdist
        sqdist = dist_map(shape, **backend)
    else:
        sqdist = sqdist.to(**backend).clone()

    # linear regression
    eps = constants.eps(log_cov.dtype)
    y = log_cov.reshape([-1, py.prod(sqdist.shape)])
    msk = torch.isfinite(y)
    y[~msk] = 0
    y0 = y.sum(-1, keepdim=True) / msk.sum(-1, keepdim=True)
    y -= y0
    x = sqdist.flatten() * msk
    x0 = x.sum(-1, keepdim=True) / msk.sum(-1, keepdim=True)
    x -= x0
    b = (x * y).sum(-1) / x.square().sum(-1).clamp_min_(eps)
    a = y0 - b * x0
    a = a[..., 0]

    lam = b.reciprocal_().mul_(-0.5).sqrt_()
    sig = a.div_(2).exp_()
    return sig, lam
示例#15
0
文件: utils.py 项目: balbasty/nitorch
def affine_grid_backward(*grad_hess, grid=None):
    """Converts ∇ wrt dense displacement into ∇ wrt affine matrix

    g = affine_grid_backward(g, [grid=None])
    g, h = affine_grid_backward(g, h, [grid=None])

    Parameters
    ----------
    grad : (..., *spatial, dim) tensor
        Gradient with respect to a dense displacement.
    hess : (..., *spatial, dim*(dim+1)//2) tensor, optional
        Hessian with respect to a dense displacement.
    grid : (*spatial, dim) tensor, optional
        Pre-computed identity grid

    Returns
    -------
    grad : (..., dim, dim+1) tensor
        Gradient with respect to an affine matrix
    hess : (..., dim, dim+1, dim, dim+1) tensor, optional
        Hessian with respect to an affine matrix

    """
    has_hess = len(grad_hess) > 1
    grad, *hess = grad_hess
    hess = hess.pop(0) if hess else None
    del grad_hess

    dim = grad.shape[-1]
    shape = grad.shape[-dim - 1:-1]
    batch = grad.shape[:-dim - 1]
    nvox = py.prod(shape)
    if grid is None:
        grid = spatial.identity_grid(shape, **utils.backend(grad))
    grid = grid.reshape([1, nvox, dim])
    grad = grad.reshape([-1, nvox, dim])
    if hess is not None:
        hess = hess.reshape([-1, nvox, dim * (dim + 1) // 2])
        grad, hess = _affine_grid_backward_gh(grid, grad, hess)
        hess = hess.reshape([*batch, dim, dim + 1, dim, dim + 1])
    else:
        grad = _affine_grid_backward_g(grid, grad)
    grad = grad.reshape([*batch, dim, dim + 1])
    return (grad, hess) if has_hess else grad
示例#16
0
def code_has_center(code, kernel_size):
    """Return True if the pattern corresponding to code has the center sampled

    Parameters
    ----------
    code : int or tensor[int]
    kernel_size : sequence[int]

    Returns
    -------
    mask : bool or tensor[bool]

    """
    kernel_size = py.make_list(kernel_size)
    code_center = torch.arange(py.prod(kernel_size)).reshape(kernel_size)
    center = [(k - 1) // 2 for k in kernel_size]
    code_center = code_center[tuple(center)]
    code = (code >> code_center) & 1
    return code.bool() if torch.is_tensor(code) else bool(code)
示例#17
0
    def irls(self, moving, fixed, lam, mask, joint, dim, **kwargs):
        nvox = py.prod(fixed.shape[-dim:])
        compute_lam = lam is None or self.compute_lam
        # --- Fixed lam -> no reweighting ------------------------------
        lll = llw = 0
        if not compute_lam:
            weights = self.reweight(moving,
                                    fixed,
                                    lam=lam,
                                    joint=joint,
                                    dim=dim,
                                    mask=mask,
                                    **kwargs)
            llw = weights[weights > 1e-9].reciprocal_().sum().div_(2 * nvox)
            lam = lam * weights
            return lll, llw, lam

        # --- Estimated lam -> IRLS loop -------------------------------
        if lam is None:
            lam = weighted_precision(moving, fixed, dim=dim, weights=mask)
        lll = llw = float('inf')
        for n_iter in range(32):
            lll_prev = lll
            weights = self.reweight(moving,
                                    fixed,
                                    lam=lam,
                                    joint=joint,
                                    dim=dim,
                                    mask=mask,
                                    **kwargs)
            lam = weighted_precision(moving, fixed, dim=dim, weights=weights)
            lam /= weights.mean(list(range(-dim, 0)))
            lll = -0.5 * lam.log().sum()
            llw = weights[weights > 1e-9].reciprocal_().sum().div_(2 * nvox)
            if abs(lll_prev - lll) < 1e-4:
                break
        if self.cache:
            self.lam = lam
        lam = lam * weights
        return lll, llw, lam
示例#18
0
 def regulariser(self, v=None):
     if v is None:
         v = self.dat
     return spatial.regulariser_grid(v, **self.penalty,
                                     factor=self.factor / py.prod(self.shape),
                                     voxel_size=self.voxel_size)
示例#19
0
文件: cc.py 项目: balbasty/nitorch
def cc(moving, fixed, dim=None, grad=True, hess=True, mask=None):
    """Squared Pearson's correlation coefficient loss

        1 - (E[(x - mu_x)'(y - mu_y)]/(s_x * s_y)) ** 2

    Parameters
    ----------
    moving : (..., K, *spatial) tensor
        Moving image with K channels.
    fixed : (..., K, *spatial) tensor
        Fixed image with K channels.
    dim : int, default=`fixed.dim() - 1`
        Number of spatial dimensions.
    grad : bool, default=True
        Compute an return gradient
    hess : bool, default=True
        Compute and return approximate Hessian

    Returns
    -------
    ll : () tensor

    """
    moving, fixed = utils.to_max_backend(moving, fixed)
    moving = moving.clone()
    fixed = fixed.clone()
    dim = dim or (fixed.dim() - 1)
    dims = list(range(-dim, 0))

    if mask is not None:
        mask = mask.to(fixed.device)
        mean = lambda x: (x * mask).sum(dim=dims, keepdim=True).div_(
            mask.sum(dim=dims, keepdim=True))
    else:
        mean = lambda x: x.mean(dim=dims, keepdim=True)

    n = py.prod(fixed.shape[-dim:])
    moving -= mean(moving)
    fixed -= mean(fixed)
    sigm = mean(moving.square()).sqrt_()
    sigf = mean(fixed.square()).sqrt_()
    moving = moving.div_(sigm)
    fixed = fixed.div_(sigf)

    corr = mean(moving * fixed)
    corr2 = 1 - corr.square()
    corr2.clamp_min_(1e-8)

    out = []
    if grad:
        g = 2 * corr * (moving * corr - fixed) / (n * sigm)
        g /= corr2  # chain rule for log
        if mask is not None:
            g = g.mul_(mask)
        out.append(g)

    if hess:
        # approximate hessian
        h = 2 * (corr / sigm).square() / n
        h /= corr2  # chain rule for log
        if mask is not None:
            h = h * mask
        out.append(h)

    # return stuff
    corr = corr2.log_().sum()
    out = [corr, *out]
    return tuple(out) if len(out) > 1 else out[0]
示例#20
0
def intensity_preproc(*images, min=None, max=None, eq=None):
    """(Joint) rescaling and intensity equalizing.

    Parameters
    ----------
    *images : (*batch, H, W) tensor
        Input (batch of) 2d images.
        All batch shapes should be broadcastable together.
    min : tensor_like, optional
        Minimum value. Should be broadcastable to batch.
        Default: 5th percentile of each batch element.
    max : tensor_like, optional
        Maximum value. Should be broadcastable to batch.
        Default: 95th percentile of each batch element.
    eq : {'linear', 'quadratic', 'log', None} or float, default=None
        Apply histogram equalization.
        If 'quadratic' or 'log', the histogram of the transformed signal
        is equalized.
        If float, the signal is taken to that power before being equalized.

    Returns
    -------
    *images : (*batch, H, W) tensor
        Preprocessed images.
        Intensities are scaled within [0, 1].

    """

    if len(images) == 1:
        images = [utils.to_max_backend(*images)]
    else:
        images = utils.to_max_backend(*images)
    backend = utils.backend(images[0])
    eps = constants.eps(images[0].dtype)

    # rescale min/max
    min = py.make_list(min, len(images))
    max = py.make_list(max, len(images))
    min = [
        utils.quantile(image, 0.05, bins=2048, dim=[-1, -2], keepdim=True)
        if mn is None else torch.as_tensor(mn, **backend)[None, None]
        for image, mn in zip(images, min)
    ]
    min, *othermin = min
    for mn in othermin:
        min = torch.min(min, mn)
    del othermin
    max = [
        utils.quantile(image, 0.95, bins=2048, dim=[-1, -2], keepdim=True)
        if mx is None else torch.as_tensor(mx, **backend)[None, None]
        for image, mx in zip(images, max)
    ]
    max, *othermax = max
    for mx in othermax:
        max = torch.max(max, mx)
    del othermax
    images = [torch.max(torch.min(image, max), min) for image in images]
    images = [
        image.mul_(1 / (max - min + eps)).add_(1 / (1 - max / min))
        for image in images
    ]

    if not eq:
        return tuple(images) if len(images) > 1 else images[0]

    # reshape and concatenate
    batch = utils.expanded_shape(*[image.shape[:-2] for image in images])
    images = [image.expand([*batch, *image.shape[-2:]]) for image in images]
    shapes = [image.shape[-2:] for image in images]
    chunks = [py.prod(s) for s in shapes]
    images = [image.reshape([*batch, c]) for image, c in zip(images, chunks)]
    images = torch.cat(images, dim=-1)

    if eq is True:
        eq = 'linear'
    if not isinstance(eq, str):
        if eq >= 0:
            images = images.pow(eq)
        else:
            images = images.clamp_min_(constants.eps(images.dtype)).pow(eq)
    elif eq.startswith('q'):
        images = images.square()
    elif eq.startswith('log'):
        images = images.clamp_min_(constants.eps(images.dtype)).log()

    images = histeq(images, dim=-1)

    if not (isinstance(eq, str) and eq.startswith('lin')):
        # rescale min/max
        images -= math.min(images, dim=-1, keepdim=True)
        images /= math.max(images, dim=-1, keepdim=True)

    images = images.split(chunks, dim=-1)
    images = [image.reshape(*batch, *s) for image, s in zip(images, shapes)]

    return tuple(images) if len(images) > 1 else images[0]
示例#21
0
文件: mse.py 项目: balbasty/nitorch
def mse(moving, fixed, lam=1, dim=None, grad=True, hess=True, mask=None):
    """Mean-squared error loss for optimisation-based registration.

    (A factor 1/2 is included, and the loss is averaged across voxels,
    but not across channels or batches)

    Parameters
    ----------
    moving : ([B], K, *spatial) tensor
        Moving image
    fixed : ([B], K, *spatial) tensor
        Fixed image
    lam : float or ([B], K|1, [*spatial]) tensor_like
        Gaussian noise precision (or IRLS weights)
    dim : int, default=`fixed.dim() - 1`
        Number of spatial dimensions
    grad : bool, default=True
        Compute and return gradient
    hess : bool, default=True
        Compute and return Hessian

    Returns
    -------
    ll : () tensor
        Negative log-likelihood
    g : (..., K, *spatial) tensor, optional
        Gradient with respect to the moving imaged
    h : (..., K, *spatial) tensor, optional
        (Diagonal) Hessian with respect to the moving image

    """
    fixed, moving, lam = utils.to_max_backend(fixed, moving, lam)
    if mask is not None:
        mask = mask.to(fixed.device)
    dim = dim or (fixed.dim() - 1)
    if lam.dim() <= 2:
        if lam.dim() == 0:
            lam = lam.flatten()
        lam = utils.unsqueeze(lam, -1, dim)  # pad spatial dimensions
    nvox = py.prod(fixed.shape[-dim:])

    if moving.requires_grad:
        ll = moving - fixed
        if mask is not None:
            ll = ll.mul_(mask)
        ll = ll.square().mul_(lam).sum() / (2 * nvox)
    else:
        ll = moving - fixed
        if mask is not None:
            ll = ll.mul_(mask)
        ll = ll.square_().mul_(lam).sum() / (2 * nvox)

    out = [ll]
    if grad:
        g = moving - fixed
        if mask is not None:
            g = g.mul_(mask)
        g = g.mul_(lam).div_(nvox)
        out.append(g)
    if hess:
        h = lam / nvox
        if mask is not None:
            h = mask * h
        out.append(h)

    return tuple(out) if len(out) > 1 else out[0]
示例#22
0
def extract_patches(inp, size=64, stride=None, output=None, transform=None):
    """Extracgt patches from a 3D volume.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    size : [sequence of] int, default=64
        Patch size.
    stride : [sequence of] int, default=size
        Stride between patches.
    output : [sequence of] str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{i}_{j}_{k}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file,
        `i` is the coordinate (starting at 1) of the slice.
    transform : [sequence of] str, optional
        Output filename(s) of the corresponding transforms.
        Not written by default.

    Returns
    -------
    output : list[str] or (tensor, tensor)
        If the input is a path, the output paths are returned.
        Else, the unfolded data and orientation matrices are returned.
            Data will have shape (nx, ny, nz, *size, *channels).
            Affines will have shape (nx, ny, nz, 4, 4).

    """
    dir = ''
    base = ''
    ext = ''
    fname = ''

    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.fdata(), f.affine)
        if output is None:
            output = '{dir}{sep}{base}.{i}_{j}_{k}{ext}'
        dir, base, ext = py.fileparts(fname)

    dat, aff0 = inp

    shape = dat.shape[:3]
    size = py.make_list(size, 3)
    stride = py.make_list(stride, 3)
    stride = [st or sz for st, sz in zip(stride, size)]

    dat = utils.movedim(dat, [0, 1, 2], [-3, -2, -1])
    dat = utils.unfold(dat, size, stride)
    dat = utils.movedim(dat, [-6, -5, -4, -3, -2, -1], [0, 1, 2, 3, 4, 5])

    aff = aff0.new_empty(dat.shape[:3] + aff0.shape)
    for i in range(dat.shape[0]):
        for j in range(dat.shape[1]):
            for k in range(dat.shape[2]):
                index = (i, j, k)
                sub = [slice(st*idx, st*idx + sz)
                       for st, sz, idx in zip(stride, size, index)]
                aff[i, j, k], _ = spatial.affine_sub(aff0, shape, tuple(sub))

    formatted_output = []
    if output:
        output = py.make_list(output, py.prod(dat.shape[:3]))
        formatted_output = []
        for i in range(dat.shape[0]):
            for j in range(dat.shape[1]):
                for k in range(dat.shape[2]):
                    out1 = output.pop(0)
                    if is_file:
                        out1 = out1.format(dir=dir or '.', base=base, ext=ext,
                                           sep=os.path.sep, i=i+1, j=j+1, k=k+1)
                        io.volumes.savef(dat[i, j, k], out1, like=fname,
                                         affine=aff[i, j, k])
                    else:
                        out1 = out1.format(sep=os.path.sep, i=i, j=j, k=k)
                        io.volumes.savef(dat[i, j, k], out1, affine=aff[i, j, k])
                    formatted_output.append(out1)

    if transform:
        transform = py.make_list(transform, py.prod(dat.shape[:3]))
        for i in range(dat.shape[0]):
            for j in range(dat.shape[1]):
                for k in range(dat.shape[2]):
                    trf1 = transform.pop(0)
                    if is_file:
                        trf1 = trf1.format(dir=dir or '.', base=base, ext=ext,
                                           sep=os.path.sep, i=i+1, j=j+1, k=k+1)
                    else:
                        trf1 = trf1.format(sep=os.path.sep, i=i+1, j=j+1, k=k+1)
                    io.transforms.savef(torch.eye(4), trf1,
                                        source=aff0, target=aff[i, j, k])

    if is_file:
        return formatted_output
    else:
        return dat, aff
示例#23
0
def slicer_sub2ind(slicer, shape):
    """Convert a multi-dimensional slicer into a linear slicer.

    Parameters
    ----------
    slicer : sequence[slice or int]
        Should not have new axes.
        Should have only positive strides.
    shape : sequence[int]
        Should have the same length as slicer

    Returns
    -------
    index : slice or int or list[int]

    """

    slicer = expand_index(slicer, shape)
    shape_out = guess_shape(slicer, shape)
    if any(
            isinstance(idx, slice) and idx.step and idx.step < 0
            for idx in slicer):
        raise ValueError('sub2ind does not like negative strides')
    if any(is_newaxis(idx) for idx in slicer):
        raise ValueError('sub2ind does not like new axes')

    slicer0 = slicer
    shape0 = shape

    # 1) collapse slices
    slicer = list(reversed(slicer))
    shape = list(reversed(shape))
    new_slicer = slice(None)
    new_shape = 1
    while len(slicer) > 0:
        idx, *slicer = slicer
        shp, *shape = shape

        if isinstance(idx, slice):
            if idx == slice(None):
                # merge full slices
                new_shape *= shp
                continue
            else:
                # stop trying to merge
                if idx.step in (1, None):
                    start = idx.start or 0
                    stop = idx.stop or shp
                    new_slicer = slice(start * new_shape, stop * new_shape)
                    new_shape *= shp
                    new_slicer = simplify_slice(new_slicer, new_shape)
                    new_slicer = [new_slicer] + slicer
                    new_shape = [new_shape] + shape
                else:
                    if new_shape != 1:
                        new_slicer = [new_slicer, idx] + slicer
                        new_shape = [new_shape, shp] + shape
                    else:
                        new_slicer = [idx] + slicer
                        new_shape = [shp] + shape
                break

        elif isinstance(idx, int):
            if shp == 1:
                continue
            else:
                new_slicer = slice(idx * new_shape, (idx + 1) * new_shape)
                new_shape *= shp
                new_slicer = simplify_slice(new_slicer, new_shape)
                if new_shape != 1:
                    new_slicer = [new_slicer] + slicer
                    new_shape = [new_shape] + shape
                else:
                    new_slicer = [idx] + slicer
                    new_shape = [shp] + shape
                break

    new_slicer = py.make_list(new_slicer)
    new_shape = py.make_list(new_shape)

    assert py.prod(shape0) == py.prod(new_shape), \
           "Oops: lost something: {} vs {}".format(py.prod(shape0),
                                                   py.prod(new_shape))

    # 2) If we have a unique index, we can stop here
    if len(new_slicer) == 1:
        return new_slicer[0]

    # 3) Extract linear indices
    strides = [1] + list(py.cumprod(new_shape[1:]))
    new_index = []
    for idx, shp, stride in zip(new_slicer, new_shape, strides):
        if isinstance(idx, slice):
            start = idx.start or 0
            stop = idx.stop or shp
            step = idx.step or 1
            idx = list(range(start, stop, step))
        else:
            idx = [idx]
        idx = [i * stride for i in idx]
        if new_index:
            new_index = list(itertools.product(idx, new_index))
            new_index = [sum(idx) for idx in new_index]
        else:
            new_index = idx

    assert len(new_index) == py.prod(shape_out), \
           "Oops: lost something: {} vs {}".format(len(new_index),
                                                   py.prod(shape_out))

    return new_index
示例#24
0
def _auto_weighted_soft(truth):
    dim = truth.dim() - 2
    nvox = py.prod(truth.shape[-dim:])
    weighted = truth.sum(dim=list(range(2, 2 + dim)), keepdim=True)
    weighted = weighted.clamp_min_(0.5).div_(nvox).reciprocal_()
    return weighted
示例#25
0
    def __init__(self, dim, unet=None, pull=None, exp=None,
                 *, in_channels=2):
        """

        Parameters
        ----------
        dim : int
            Dimensionality of the input (1|2|3)
        unet : dict
            Dictionary of U-Net parameters with fields:
                encoder : sequence[int], default=[16, 32, 32, 32, 32]
                decoder : sequence[int], default=[32, 32, 32, 32, 16, 16]
                conv_per_layer : int, default=1
                kernel_size : int, default=3
                activation : str or callable, default=LeakyReLU(0.2)
                pool : {'max', 'conv', 'down', None}, default=None
                    'max'  -> 2x2x2 max-pooling
                    'conv' -> 2x2x2 strided convolution (no bias, no activation)
                    'down' -> downsampling
                     None  -> use strided convolutions in the encoder
                unpool : {'conv', 'up', None}, default=None
                    'conv' -> 2x2x2 strided convolution (no bias, no activation)
                    'up'   -> linear upsampling
                     None  -> use strided convolutions in the decoder
        pull : dict
            Dictionary of Transformer parameters with fields:
                interpolation : {0..7}, default=1
                bound : str, default='dct2'
                extrapolate : bool, default=False
        exp : dict
            Dictionary of Exponentiation parameters with fields:
                interpolation : {0..7}, default=1
                bound : str, default='dft'
                steps : int, default=8
                shoot : bool, default=False
                downsample : float, default=2
            If shoot is True, these fields are also present:
                absolute : float, default=0.0001
                membrane : float, default=0.001
                bending : float, default=0.2
                lame : (float, float), default=(0.05, 0.2)
        """
        # default parameters
        unet = dict(unet or {})
        unet.setdefault('encoder', [16, 32, 32, 32, 32])
        unet.setdefault('decoder', [32, 32, 32, 32, 16, 16])
        unet.setdefault('kernel_size', 3)
        unet.setdefault('pool', None)
        unet.setdefault('unpool', None)
        unet.setdefault('activation', tnn.LeakyReLU(0.2))
        pull = dict(pull or {})
        pull.setdefault('interpolation', 1)
        pull.setdefault('bound', 'dct2')
        pull.setdefault('extrapolate', False)
        exp = dict(exp or {})
        exp.setdefault('interpolation', 1)
        exp.setdefault('bound', 'dft')
        exp.setdefault('steps', 8)
        exp.setdefault('shoot', False)
        exp.setdefault('downsample', 2)
        exp.setdefault('absolute', 0.0001)
        exp.setdefault('membrane', 0.001)
        exp.setdefault('bending', 0.2)
        exp.setdefault('lame', (0.05, 0.2))
        exp.setdefault('factor', 1)
        do_shoot = exp.pop('shoot')
        downsample_vel = exp.pop('downsample')
        vel_inter = exp['interpolation']
        vel_bound = exp['bound']
        if do_shoot:
            exp.pop('interpolation')
            exp.pop('bound')
            exp.pop('voxel_size', downsample_vel)
            exp['factor'] *= py.prod(downsample_vel)
        else:
            exp.pop('absolute')
            exp.pop('membrane')
            exp.pop('bending')
            exp.pop('lame')
            exp.pop('factor')
        exp['displacement'] = True
        unet['skip_decoder_level'] = int(pymath.floor(pymath.log(downsample_vel) / pymath.log(2)))

        # prepare layers
        super().__init__()
        self.unet = UNet2(dim, in_channels, dim, **unet,)
        self.velexp = GridShoot(**exp) if do_shoot else GridExp(**exp)
        self.resize = GridResize(interpolation=vel_inter, bound=vel_bound,
                                 factor=downsample_vel, anchor='f',
                                 type='displacement')
        self.pull = GridPull(**pull)
        self.dim = dim

        # register losses/metrics
        self.tags = ['image', 'velocity', 'segmentation']
示例#26
0
    def __init__(self, dim, unet=None, pull=None, exp=None, template=None):
        """

        Parameters
        ----------
        dim : int
            Dimensionality of the input (1|2|3)
        unet : dict
            Dictionary of U-Net parameters with fields:
                encoder : sequence[int], default=[16, 32, 32, 32, 32]
                decoder : sequence[int], default=[32, 32, 32, 32, 16, 16]
                conv_per_layer : int, default=1
                kernel_size : int, default=3
                activation : str or callable, default=LeakyReLU(0.2)
                pool : {'max', 'conv', 'down', None}, default=None
                    'max'  -> 2x2x2 max-pooling
                    'conv' -> 2x2x2 strided convolution (no bias, no activation)
                    'down' -> downsampling
                     None  -> use strided convolutions in the encoder
                unpool : {'conv', 'up', None}, default=None
                    'conv' -> 2x2x2 strided convolution (no bias, no activation)
                    'up'   -> linear upsampling
                     None  -> use strided convolutions in the decoder
        pull : dict
            Dictionary of Transformer parameters with fields:
                interpolation : {0..7}, default=1
                bound : str, default='dct2'
                extrapolate : bool, default=False
        exp : dict
            Dictionary of Exponentiation parameters with fields:
                interpolation : {0..7}, default=1
                bound : str, default='dft'
                steps : int, default=8
                shoot : bool, default=False
                downsample : float, default=2
            If shoot is True, these fields are also present:
                absolute : float, default=0.0001
                membrane : float, default=0.001
                bending : float, default=0.2
                lame : (float, float), default=(0.05, 0.2)
        template : dict
            Dictionary of Template parameters with fields:
                shape : tuple[int], default=(192,) * dim
                mom : float or int, default=100
                    If in (0, 1), momentum of the running mean.
                    The mean is updated according to:
                        `new_mean = (1-mom) * old_mean + mom * new_sample`
                    If 0, use cumulative average:
                        `new_n = old_n + 1`
                        `mom = 1/new_n`
                    If > 1, cap the weight of a new sample in the average:
                        `new_n = min(cap, old_n + 1)`
                        `mom = 1/new_n`
                cat : bool or int, default=False
                    Build a categorical template.
                implicit : bool, default=True
                    Whether the template has an implicit background class.

        """
        # default parameters
        unet = dict(unet or {})
        unet.setdefault('encoder', [16, 32, 32, 32, 32])
        unet.setdefault('decoder', [32, 32, 32, 32, 16, 16])
        unet.setdefault('kernel_size', 3)
        unet.setdefault('pool', None)
        unet.setdefault('unpool', None)
        unet.setdefault('activation', tnn.LeakyReLU(0.2))
        pull = dict(pull or {})
        pull.setdefault('interpolation', 1)
        pull.setdefault('bound', 'dct2')
        pull.setdefault('extrapolate', False)
        exp = dict(exp or {})
        exp.setdefault('interpolation', 1)
        exp.setdefault('bound', 'dft')
        exp.setdefault('steps', 8)
        exp.setdefault('shoot', False)
        exp.setdefault('downsample', 2)
        exp.setdefault('absolute', 0.0001)
        exp.setdefault('membrane', 0.001)
        exp.setdefault('bending', 0.2)
        exp.setdefault('lame', (0.05, 0.2))
        exp.setdefault('factor', 1)
        do_shoot = exp.pop('shoot')
        downsample_vel = utils.make_vector(exp.pop('downsample'), dim).tolist()
        vel_inter = exp['interpolation']
        vel_bound = exp['bound']
        if do_shoot:
            exp.pop('interpolation')
            exp.pop('bound')
            exp.pop('voxel_size', downsample_vel)
            exp['factor'] *= py.prod(downsample_vel)
            if do_shoot == 'approx':
                exp['approx'] = True
        else:
            exp.pop('absolute')
            exp.pop('membrane')
            exp.pop('bending')
            exp.pop('lame')
            exp.pop('factor')
        exp['displacement'] = True
        template = dict(template or {})
        template.setdefault('shape', (192,)*dim)
        template.setdefault('mom', 100)
        template.setdefault('cat', False)
        template.setdefault('implicit', True)

        self.cat = template['cat']
        self.implicit = template['implicit']

        # prepare layers
        super().__init__()
        template_channels = (self.cat + (not self.implicit)) if self.cat else 1
        self.template = tnn.Parameter(torch.zeros([template_channels, *template['shape']]))
        self.unet = UNet2(dim, template_channels + 1, dim, **unet)
        self.resize = GridResize(interpolation=vel_inter, bound=vel_bound,
                                 factor=[1 / f for f in downsample_vel],
                                 type='displacement')
        self.velexp = GridShoot(**exp) if do_shoot else GridExp(**exp)
        self.pull = GridPull(**pull)
        self.dim = dim
        self.mom = template['mom']

        # register losses/metrics
        self.tags = ['match', 'velocity',  'template', 'mean']
示例#27
0
def _build_nonlin(options, can_use_2nd_order, affine, image_dict):
    dim = 3
    device = next(iter(image_dict.values())).dat.device

    nonlin = None
    nonlin_optim = None
    if options.nonlin:
        # build mean space
        vx = options.nonlin.voxel_size
        if isinstance(vx[-1], str):
            *vx, vx_unit = vx
        else:
            vx_unit = 'mm'
        pad = options.nonlin.pad
        if isinstance(pad[-1], str):
            *pad, pad_unit = pad
        else:
            pad_unit = '%'
        vx = py.make_list(vx, dim)
        pad = py.make_list(pad, dim)
        space = objects.MeanSpace(
            [image_dict[key] for key in (options.nonlin.fov or image_dict)],
            voxel_size=vx,
            vx_unit=vx_unit,
            pad=pad,
            pad_unit=pad_unit)
        print(space)
        prm = dict(absolute=options.nonlin.absolute,
                   membrane=options.nonlin.membrane,
                   bending=options.nonlin.bending,
                   lame=options.nonlin.lame)

        vel = objects.Displacement(space.shape,
                                   affine=space.affine,
                                   dim=dim,
                                   device=device)
        Model = objects.NonLinModel.subclass(options.nonlin.name)
        nonlin = Model(dat=vel,
                       factor=options.nonlin.factor,
                       prm=prm,
                       steps=getattr(options.nonlin, 'steps', None))

        max_iter = options.nonlin.optim.max_iter
        if not max_iter:
            if affine and options.optim.name == 'interleaved':
                max_iter = 10
            else:
                max_iter = 50
        if options.nonlin.optim.name == 'unset':
            if can_use_2nd_order:
                options.nonlin.optim.name = 'gn'
            else:
                options.nonlin.optim.name = 'lbfgs'
        if options.nonlin.optim.name == 'gd':
            nonlin_optim = optim.GradientDescent(lr=options.nonlin.optim.lr)
            nonlin_optim.preconditioner = nonlin.greens_apply
        elif options.nonlin.optim.name == 'cg':
            nonlin_optim = optim.ConjugateGradientDescent(
                lr=options.nonlin.optim.lr, beta=options.nonlin.optim.beta)
            nonlin_optim.preconditioner = nonlin.greens_apply
        elif options.nonlin.optim.name == 'mom':
            nonlin_optim = optim.Momentum(
                lr=options.nonlin.optim.lr,
                momentum=options.nonlin.optim.momentum)
            nonlin_optim.preconditioner = nonlin.greens_apply
        elif options.nonlin.optim.name == 'nes':
            nonlin_optim = optim.Nesterov(
                lr=options.nonlin.optim.lr,
                momentum=options.nonlin.optim.momentum,
                auto_restart=options.nonlin.optim.restart)
            nonlin_optim.preconditioner = nonlin.greens_apply
        elif options.nonlin.optim.name == 'ogm':
            nonlin_optim = optim.OGM(lr=options.nonlin.optim.lr,
                                     momentum=options.nonlin.optim.momentum,
                                     relax=options.nonlin.optim.relax,
                                     auto_restart=options.nonlin.optim.restart)
            nonlin_optim.preconditioner = nonlin.greens_apply
        elif options.nonlin.optim.name == 'gn':
            marquardt = getattr(options.nonlin.optim, 'marquardt', None)
            sub_iter = getattr(options.nonlin.optim, 'sub_iter', None)
            if not sub_iter:
                if options.nonlin.optim.fmg:
                    sub_iter = 2
                else:
                    sub_iter = 16
            prm = {
                'factor': nonlin.factor / py.prod(nonlin.shape),
                'voxel_size': nonlin.voxel_size,
                **nonlin.prm
            }
            if getattr(options.nonlin.optim, 'solver', 'cg') == 'cg':
                nonlin_optim = optim.GridCG(lr=options.nonlin.optim.lr,
                                            marquardt=marquardt,
                                            max_iter=sub_iter,
                                            **prm)
            elif getattr(options.nonlin.optim, 'solver') == 'relax':
                nonlin_optim = optim.GridRelax(lr=options.nonlin.optim.lr,
                                               marquardt=marquardt,
                                               max_iter=sub_iter,
                                               **prm)
            else:
                raise ValueError(getattr(options.nonlin.optim, 'solver'))
        elif options.nonlin.optim.name == 'lbfgs':
            nonlin_optim = optim.LBFGS(lr=options.nonlin.optim.lr,
                                       history=getattr(options.nonlin.optim,
                                                       'history'))
            nonlin_optim.preconditioner = nonlin.greens_apply
            # TODO: tolerance?
        else:
            raise ValueError(options.nonlin.optim.name)
        if options.nonlin.optim.line_search:
            nonlin_optim.search = options.nonlin.optim.line_search
        nonlin_optim.iter = optim.OptimIterator(
            max_iter=max_iter, tol=options.nonlin.optim.tolerance)

    return nonlin, nonlin_optim
示例#28
0
def empirical_cov(series,
                  nb_dim=1,
                  dim=None,
                  subtract_mean=True,
                  flatten=False,
                  keepdim=False,
                  return_mean=False):
    """Compute an empirical covariance

    Parameters
    ----------
    series : (..., *dims) tensor_like
        Sample series
    nb_dim : int, default=1
        Number of spatial dimensions.
    dim : [sequence of] int, default=None
        Dimensions that are reduced when computing the covariance.
        If None: all but the last `nb_dim`.
    subtract_mean : bool, default=True
        Subtract empirical mean before computing the covariance.
    flatten : bool, default=False
        If True, flatten the 'covariance' dimensions.
    keepdim : bool, default=False
        Keep reduced dimensions.

    Returns
    -------
    cov : (..., *dims, *dims) or (..., prod(dims), prod(dims)) tensor
        Covariance.
    mean : (..., *dims) or (..., prod(dims)) tensor, if `return_mean`
        Mean.

    """

    # Convert to tensor
    series = torch.as_tensor(series)
    prespatial = series.shape[:-nb_dim]
    spatial = series.shape[-nb_dim:]

    if dim is None:
        dim = range(series.dim() - nb_dim)
    dim = py.make_tuple(dim)
    dim = [series.dim() + d if d < 0 else d for d in dim]

    reduced = [prespatial[d] for d in dim]
    batch = [
        prespatial[d] for d in range(series.dim() - nb_dim) if d not in dim
    ]

    # Subtract mean
    if subtract_mean:
        mean = series.mean(dim=dim, keepdim=True)
        series = series - mean

    # Compute empirical covariance.
    series = series.reshape([*series.shape[:-nb_dim], -1])
    series = utils.movedim(series, dim, -2)
    series = series.reshape([*batch, -1, series.shape[-1]])
    n_reduced = series.shape[-2]
    n_vox = series.shape[-1]
    # (*batch, reduced, spatial)

    # Torch's matmul just uses too much memory
    # We don't expect to have more than about 100 time frames,
    # so it is better to unroll the loop in python.
    # cov = torch.matmul(series.transpose(-1, -2), series)
    cov = None
    buf = series.new_empty([*batch, n_vox, n_vox])
    for i in range(n_reduced):
        buf = torch.mul(series.transpose(-1, -2)[..., :, i, None],
                        series[..., i, None, :],
                        out=buf)
        if cov is None:
            cov = buf.clone()
        else:
            cov += buf
    cov /= py.prod(reduced)

    if keepdim:
        outshape = [1 if d in dim else s for d, s in enumerate(prespatial)]
    else:
        outshape = list(batch)
    if flatten:
        outshape_mean = outshape + [py.prod(spatial)]
        outshape += [py.prod(spatial)] * 2
    else:
        outshape_mean = outshape + list(spatial)
        outshape += list(spatial) * 2

    cov = cov.reshape(outshape)
    if return_mean:
        mean = mean.reshape(outshape_mean)
        return cov, mean
    return cov
示例#29
0
文件: gmm.py 项目: balbasty/nitorch
def lgmmh(moving, fixed, dim=None, bins=3, patch=7, stride=1,
          grad=True, hess=True, mode='g', max_iter=128,
          theta=None, return_theta=False):

    fixed, moving = utils.to_max_backend(fixed, moving)
    dim = dim or (fixed.dim() - 1)
    shape = fixed.shape[-dim:]

    if not isinstance(patch, (list, tuple)):
        patch = [patch]
    patch = list(patch)
    if not isinstance(stride, (list, tuple)):
        stride = [stride]
    stride = [s or 0 for s in stride]

    fwd = Fwd(patch, stride, dim, mode)
    bwd = Bwd(patch, stride, dim, mode, shape)

    gmmfit = fit_lgmm2(moving, fixed, bins, max_iter, dim,
                       patch=patch, stride=stride, mode=mode, theta=theta)

    # drop unused variables
    get = gmmfit.get if return_theta else gmmfit.pop
    pop = gmmfit.pop

    z = pop('resp')
    moving_mean = get('xmean')
    fixed_mean = get('ymean')
    moving_var = get('xvar')
    fixed_var = get('yvar')
    corr = get('corr')
    prior = get('prior')
    out = [pop('nll')]
    nvox = py.prod(z.shape[-dim:])

    moving = moving.unsqueeze(-dim-1)
    fixed = fixed.unsqueeze(-dim-1)

    if grad:
        z0 = fwd(z, None).clamp_min_(1e-10)

        # gradient of the GMM entropy
        # L = 0.5 * \sum_k pi_k log|\Sigma_k| + cte
        @torch.jit.script
        def make_grad(bwd: Bwd, z, z0, moving, fixed, moving_mean, fixed_mean,
                      moving_var, fixed_var, corr, prior) -> Tensor:
            cov = corr * (moving_var * fixed_var).sqrt()
            idet = moving_var * fixed_var * (1 - corr * corr)
            idet = prior / idet
            # gradient of determinant + chain rule of log
            g = moving * bwd(fixed_var * idet, z, z0) - fixed * bwd(cov * idet, z, z0)
            g -= bwd((moving_mean * fixed_var - fixed_mean * cov) * idet, z, z0)
            g = g.sum(-bwd.dim-1)
            return g

        g = make_grad(bwd, z, z0, moving, fixed, moving_mean, fixed_mean,
                      moving_var, fixed_var, corr, prior)
        g.div_(nvox)
        out.append(g)

        if hess:
            # # hessian of (1 - corr^2)
            # imoving_var = moving_var.reciprocal()
            # corr2 = corr * corr
            # h = corr2 * imoving_var
            # # chain rule (with Fisher's scoring)
            # h /= 1 - corr2
            # # hessian of log(moving_var)
            # h += imoving_var
            # # weight by proportion and sum
            # h = h * (z * prior)
            # h = h.sum(-1)

            @torch.jit.script
            def make_hess(bwd: Bwd, z, z0, moving_var, corr, prior) -> Tensor:
                h = (1 - z0) * prior / (moving_var * (1 - corr * corr))
                h = bwd(h, z, z0)
                h = h.sum(-bwd.dim-1)
                return h

            h = make_hess(bwd, z, z0, moving_var, corr, prior)
            h.div_(nvox)
            out.append(h)

    if return_theta:
        out.append(gmmfit)

    return out[0] if len(out) == 1 else tuple(out)
示例#30
0
文件: array.py 项目: balbasty/nitorch
    def _shape_split_imagej(self, tiffobj):
        """Split the shape into different components (ImageJ format).

        This is largely copied from tifffile.
        """

        pages = tiffobj.pages
        pages.useframes = True
        pages.keyframe = 0
        page = pages[0]
        meta = tiffobj.imagej_metadata

        def is_virtual():
            # ImageJ virtual hyperstacks store all image metadata in the first
            # page and image data are stored contiguously before the second
            # page, if any
            if not page.is_final:
                return False
            images = meta.get('images', 0)
            if images <= 1:
                return False
            offset, count = page.is_contiguous
            if (
                count != py.prod(page.shape) * page.bitspersample // 8
                or offset + count * images > self.filehandle.size
            ):
                raise ValueError()
            # check that next page is stored after data
            if len(pages) > 1 and offset + count * images > pages[1].offset:
                return False
            return True

        isvirtual = is_virtual()
        if isvirtual:
            # no need to read other pages
            pages = [page]
        else:
            pages = pages[:]

        images = meta.get('images', len(pages))
        frames = meta.get('frames', 1)
        slices = meta.get('slices', 1)
        channels = meta.get('channels', 1)

        # compute shape of the collection of pages
        shape = []
        axes = []
        if frames > 1:
            shape.append(frames)
            axes.append('T')
        if slices > 1:
            shape.append(slices)
            axes.append('Z')
        if channels > 1 and (py.prod(shape) if shape else 1) != images:
            shape.append(channels)
            axes.append('C')

        remain = images // (py.prod(shape) if shape else 1)
        if remain > 1:
            shape.append(remain)
            axes.append('I')

        if page.axes[0] == 'S' and 'C' in axes:
            # planar storage, S == C, saved by Bio-Formats
            return tuple(), tuple(shape), tuple(page.shape[1:])
        elif page.axes[0] == 'I':
            # contiguous multiple images
            return tuple(), tuple(shape), tuple(page.shape[1:])
        elif page.axes[:2] == 'SI':
            # color-mapped contiguous multiple images
            return tuple(page.shape[0:1]), tuple(shape), tuple(page.shape[2:])
        else:
            return tuple(), tuple(shape), tuple(page.shape)