Exemplo n.º 1
0
    def responsibilities(image, means, precisions, proportions):
        # aliases
        x = image
        m = means
        A = precisions
        p = proportions
        nb_dim = image.dim() - 2
        del image, means, precisions, proportions

        # voxel-wise term
        x = channel2last(x).unsqueeze(-2)  # [B, ...,  1, C]
        p = unsqueeze(p, dim=1, ndim=nb_dim)  # [B, ones, K]
        m = unsqueeze(m, dim=1, ndim=nb_dim)  # [B, ones, K, C]
        A = unsqueeze(A, dim=1, ndim=nb_dim)  # [B, ones, K, C, C]
        x = x - m
        z = matvec(A, x)
        z = (z * x).sum(dim=-1)  # [B, ..., K]
        z = -0.5 * z

        # constant term
        twopi = torch.as_tensor(2 * pi, dtype=A.dtype, device=A.device)
        nrm = torch.logdet(A) - A.shape[-1] * twopi.log()
        nrm = 0.5 * nrm + p.log()
        z = z + nrm

        # softmax
        z = last2channel(z)
        logz = torch.nn.functional.log_softmax(z, dim=1)
        z = torch.nn.functional.softmax(z, dim=1)

        return z, logz
Exemplo n.º 2
0
    def nll(image, resp, means, precisions):
        # aliases
        x = image
        z = resp
        m = means
        A = precisions
        nb_dim = image.dim() - 2
        del image, resp, means, precisions

        x = channel2last(x).unsqueeze(-2)  # [B, ...,  1, C]
        z = channel2last(z)  # [B, ..., K]
        m = unsqueeze(m, dim=1, ndim=nb_dim)  # [B, ones, K, C]
        A = unsqueeze(A, dim=1, ndim=nb_dim)  # [B, ones, K, C, C]
        x = x - m
        loss = matvec(A, x)
        loss = (loss * x).sum(dim=-1)  # [B, ..., K]
        loss = (loss * z).sum(dim=-1)  # [B, ...]
        loss = loss * 0.5
        return loss
Exemplo n.º 3
0
    def forward(self, prior, **overload):
        """

        Parameters
        ----------
        prior : (batch, channel, *shape)
            Prior probabilities or log-odds of the Categorical distribution
        overload : dict
            All parameters defined at buildtime can be overridden at call time

        Returns
        -------
        sample : (batch, 1, *shape)

        """

        # read arguments
        shape = overload.get('shape', self.shape)
        logits = overload.get('logits', self.logits)
        implicit = overload.get('implicit', self.implicit)

        # call prior in case it is a random parameter
        prior = prior() if callable(prior) else torch.as_tensor(prior)

        # repeat prior if shape provided
        if shape is not None:
            if prior.dim() != 2:
                raise ValueError('Expected tensor with shape (batch, channel) '
                                 'but got {}'.format(prior.shape))
            prior = expand(prior, [*prior.shape, *shape], side='right')

        # add implicit class
        if implicit:
            shape = list(prior.shape)
            shape[1] = 1
            zero = torch.zeros(shape, dtype=prior.dtype, device=prior.device)
            prior = torch.cat((prior, zero), dim=1)

        # reshape in 2d
        batch, channel, *shape = prior.shape
        prior = channel2last(prior)  # make class dimension last
        kwargs = dict()
        kwargs['logits' if logits else 'probs'] = prior

        # sample
        sample = torch.distributions.Categorical(**kwargs).sample()
        sample = sample.reshape([batch, 1, *shape])

        return sample
Exemplo n.º 4
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch size
        overload : dict
            All parameters defined at build time can be overridden at call time

        Returns
        -------
        vel : (batch, *shape, dim) tensor
            Velocity field

        """

        # get arguments
        opt = {
            'channel': overload.get('dim', self.field.channel),
            'shape': overload.get('shape', self.field.shape),
            'amplitude': overload.get('amplitude', self.field.amplitude),
            'fwhm': overload.get('fwhm', self.field.fwhm),
            'dtype': overload.get('dtype', self.field.dtype),
            'device': overload.get('device', self.field.device),
        }

        # preprocess amplitude
        # > RandomField broadcast amplitude to (channel, *shape), with
        #   padding from the left, which means that a 1d amplitude would
        #   be broadcasted to (1, ..., dim) instead of (dim, ..., 1)
        # > We therefore reshape amplitude to avoid left-side padding
        def preprocess(a):
            a = torch.as_tensor(a)
            a = unsqueeze(a, dim=-1, ndim=opt['channel'] + 1 - a.dim())
            return a

        amplitude = opt['amplitude']
        if callable(amplitude):
            amplitude_fn = amplitude
            amplitude = lambda *args, **kwargs: preprocess(
                amplitude_fn(*args, **kwargs))
        else:
            amplitude = preprocess(amplitude)
        opt['amplitude'] = amplitude

        return channel2last(self.field(batch, **opt))
Exemplo n.º 5
0
    def forward(self, source, target, *, _loss=None, _metric=None):
        """

        Parameters
        ----------
        source : tensor (batch, channel, *spatial)
            Source/moving image
        target : tensor (batch, channel, *spatial)
            Target/fixed image

        _loss : dict, optional
            If provided, all registered losses are computed and appended.
        _metric : dict, optional
            If provided, all registered metrics are computed and appended.

        Returns
        -------
        deformed_source : tensor (batch, channel, *spatial)
            Deformed source image
        affine_prm : tensor (batch,, *nb_prm)
            affine Lie/Classic parameters

        """
        # sanity checks
        check.dim(self.dim, source, target)
        check.shape(target, source, dims=[0], broadcast_ok=True)
        check.shape(target, source, dims=range(2, self.dim + 2))

        # chain operations
        source_and_target = torch.cat((source, target), dim=1)
        dense = channel2last(self.unet(source_and_target))
        affprm = self.dense2prm(dense)
        affine = self.exp(affprm.double()).to(dense.dtype)
        grid = self.grid(affine, shape=target.shape[2:])
        deformed_source = self.pull(source, grid)

        # compute loss and metrics
        self.compute(_loss,
                     _metric,
                     image=[deformed_source, target],
                     affine=[affprm],
                     dense=[dense])

        return deformed_source, affprm, dense
Exemplo n.º 6
0
def _pull_vel(vel, grid, *args, **kwargs):
    """Interpolate a velocity/grid/displacement field.

    Parameters
    ----------
    vel : (batch, ..., ndim) tensor
        Velocity
    grid : (batch, ..., ndim) tensor
        Transformation field
    opt : dict
        Options to ``grid_pull``

    Returns
    -------
    pulled_vel : (batch, ..., ndim) tensor
        Velocity

    """
    return channel2last(grid_pull(last2channel(vel), grid, *args, **kwargs))
Exemplo n.º 7
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch size

        Other Parameters
        ----------------
        shape : sequence[int], optional
        device : torch.device, optional
        dtype : torch.dtype, optional

        Returns
        -------
        vel : (batch, *shape, dim) tensor
            Velocity field

        """
        overload['channel'] = len(overload.get('shape', self.field.shape))
        return utils.channel2last(self.field(batch, **overload))
Exemplo n.º 8
0
def resize_grid(grid,
                factor=None,
                shape=None,
                type='grid',
                affine=None,
                *args,
                **kwargs):
    """Resize a displacement grid by a factor.

    The displacement grid is resized *and* rescaled, so that
    displacements are expressed in the new voxel referential.

    Notes
    -----
    .. A least one of `factor` and `shape` must be specified.
    .. If `anchor in ('centers', 'edges')`, and both `factor` and `shape`
       are specified, `factor` is discarded.
    .. If `anchor in ('first', 'last')`, `factor` must be provided even
       if `shape` is specified.
    .. Because of rounding, it is in general not assured that
       `resize(resize(x, f), 1/f)` returns a tensor with the same shape as x.

    Parameters
    ----------
    grid : (batch, ..., ndim) tensor
        Grid to resize
    factor : float or list[float], optional
        Resizing factor
        * > 1 : larger image <-> smaller voxels
        * < 1 : smaller image <-> larger voxels
    shape : (ndim,) sequence[int], optional
        Output shape
    type : {'grid', 'displacement'}, default='grid'
        Grid type:
        * 'grid' correspond to dense grids of coordinates.
        * 'displacement' correspond to dense grid of relative displacements.
        Both types are not rescaled in the same way.
    affine : (batch, ndim[+1], ndim+1), optional
        Orientation matrix of the input grid.
        If provided, the orientation matrix of the resized image is
        returned as well.
    anchor : {'centers', 'edges', 'first', 'last'}, default='centers'
        * In cases 'c' and 'e', the volume shape is multiplied by the
          zoom factor (and eventually truncated), and two anchor points
          are used to determine the voxel size.
        * In cases 'f' and 'l', a single anchor point is used so that
          the voxel size is exactly divided by the zoom factor.
          This case with an integer factor corresponds to subslicing
          the volume (e.g., `vol[::f, ::f, ::f]`).
        * A list of anchors (one per dimension) can also be provided.
    **kwargs
        Parameters of `grid_pull`.

    Returns
    -------
    resized : (batch, ..., ndim) tensor
        Resized grid.
    affine : (batch, ndim[+1], ndim+1) tensor, optional
        Orientation matrix

    """
    # resize grid
    kwargs['_return_trf'] = True
    grid = utils.last2channel(grid)
    outputs = resize(grid, factor, shape, affine, *args, **kwargs)
    if affine is not None:
        grid, affine, (scales, shifts) = outputs
    else:
        grid, (scales, shifts) = outputs
    grid = utils.channel2last(grid)

    # rescale each component
    # scales and shifts map resized coordinates to original coordinates:
    #   original = scale * resized + shift
    # here we want to transform original coordinates into resized ones:
    #   resized = (original - shift) / scale
    grids = []
    for d, (scl, shft) in enumerate(zip(scales, shifts)):
        grid1 = utils.slice_tensor(grid, d, dim=-1)
        if type[0].lower() == 'g':
            grid1 = grid1 - shft
        grid1 = grid1 / scl
        grids.append(grid1)
    grid = torch.stack(grids, -1)

    # return
    if affine is not None:
        return grid, affine
    else:
        return grid
Exemplo n.º 9
0
def compose(*args, interpolation='linear', bound='dft'):
    """Compose multiple spatial deformations (affine matrices or flow fields).
    """

    # TODO:
    # . add shape/dim argument to generate (if needed) an identity field
    #   at the end of the chain.
    # . possibility to provide fields that have an orientation matrix?
    #   (or keep it the responsibility of the user?)
    # . For higher order (> 1) interpolation: convert to spline coeficients.

    def ismatrix(x):
        """Check that a tensor is a matrix (ndim == 2)."""
        x = torch.as_tensor(x)
        shape = torch.as_tensor(x.shape)
        return shape.numel() == 2

    # Pre-pass: check dimensionality
    dim = None
    last_affine = False
    at_least_one_field = False
    for arg in args:
        if ismatrix(arg):
            last_affine = True
            dim1 = arg.shape[1]
        else:
            last_affine = False
            at_least_one_field = True
            dim1 = arg.dim() - 2
        if dim is not None and dim != dim1:
            raise ValueError("All deformations should have the same "
                             "dimensionality (2D/3D).")
        elif dim is None:
            dim = dim1
    if at_least_one_field and last_affine:
        raise ValueError("The last deformation cannot be an affine matrix. "
                         "Use affine_field to transform it first.")

    # First pass: compose all sequential affine matrices
    args1 = []
    last_affine = None
    for arg in args:
        if ismatrix(arg):
            if last_affine is None:
                last_affine = _make_square(arg)
            else:
                last_affine = last_affine.matmul(_make_square(arg))
        else:
            if last_affine is not None:
                args1.append(last_affine)
                last_affine = None
            args1.append(arg)

    if not at_least_one_field:
        return last_affine

    # Second pass: perform all possible "field x matrix" compositions
    args2 = []
    last_affine = None
    for arg in args1:
        if ismatrix(arg):
            last_affine = arg
        else:
            if last_affine is not None:
                new_field = arg.matmul(
                    last_affine[:dim, :dim].transpose(0, 1)) \
                  + last_affine[:dim, dim].reshape((1,)*(dim+1) + (dim,))
                args2.append(new_field)
            else:
                args2.append(arg)
    if last_affine is not None:
        args2.append(last_affine)

    # Third pass: compose all flow fields
    field = args2[-1]
    for arg in args2[-2::-1]:  # args2[-2:0:-1]
        arg = arg - identity_grid(arg.shape[1:-1], arg.dtype, arg.device)
        arg = utils.last2channel(arg)
        field = field + utils.channel2last(
            grid_pull(arg, field, interpolation, bound))

    # /!\ (TODO) The very first field (the first one being interpolated)
    # potentially contains a multiplication with an affine matrix (i.e.,
    # it might not be expressed in voxels). This affine transformation should
    # be removed prior to subtracting the identity, and added back at the end.
    # However, I don't know how to 'guess' this matrix.
    #
    # After further though, I think we can find the matrix that minimizes in
    # the least-square sense (F*M-I), where F is NbVox*D and contains the
    # deformation field, I is NbVox*D and contains the identity field
    # (expressed in voxels) and M is the inverse of the unknown matrix.
    # This problem has a closed form solution: (F'*F)\(F'*I).
    # For better stability, We could encode M in gl(D), the Lie
    # algebra of invertible matrices, and use gauss-newton to optimise
    # the problem.
    #
    # Below is a tentative implementatin of the linear version
    # > Needs F'F to be invertible and well-conditioned

    # # For the last field, we factor out a possible affine transformation
    # arg = args2[0]
    # shape = arg.shape
    # N = shape[0]                                     # Batch size
    # D = shape[-1]                                    # Dimension
    # V = torch.as_tensor(shape[1:-1]).prod()          # Nb of voxels
    # Id = identity(arg.shape[-2:0:-1], arg.dtype, arg.device).reshape(V, D)
    # arg = arg.reshape(N, V, D)                       # Field as a matrix
    # one = torch.ones((N, V, 1), dtype=arg.dtype, device=arg.device)
    # arg = cat((arg, one), 2)
    # Id  = cat((Id, one))
    # AA = arg.transpose(1, 2).bmm(arg)                # LHS of linear system
    # AI = arg.transpose(1, 2).bmm(arg)                # RHS of linear system
    # M, _ = torch.solve(AI, AA)                       # Solution
    # arg = arg.bmm(M) - Id                            # Closest displacement
    # arg = arg[..., :-1].reshape(shape)
    # arg = utils.last2channel(arg)
    # field = grid_pull(arg, field, interpolation, bound)     # Interpolate
    # field = field + channel2grid(grid_pull(arg, field, interpolation, bound))
    # shape = field.shape
    # V = torch.as_tensor(shape[1:-1]).prod()
    # field = field.reshape(N, V, D)
    # one = torch.ones((N, V, 1), dtype=field.dtype, device=field.device)
    # field, _ = torch.solve(field.transpose(1, 2), M.transpose(1, 2))
    # field = field.transpose(1, 2)[..., :-1].reshape(shape)

    return field
Exemplo n.º 10
0
    def forward(self,
                source,
                target,
                source_seg=None,
                target_seg=None,
                *,
                _loss=None,
                _metric=None):
        """

        Parameters
        ----------
        source : tensor (batch, channel, *spatial)
            Source/moving image
        target : tensor (batch, channel, *spatial)
            Target/fixed image

        _loss : dict, optional
            If provided, all registered losses are computed and appended.
        _metric : dict, optional
            If provided, all registered metrics are computed and appended.

        Returns
        -------
        deformed_source : tensor (batch, channel, *spatial)
            Deformed source image
        affine_prm : tensor (batch,, *spatial, len(spatial))
            affine Lie parameters

        """
        # sanity checks
        check.dim(self.dim, source, target, source_seg, target_seg)
        check.shape(target, source, dims=[0], broadcast_ok=True)
        check.shape(target, source, dims=range(2, self.dim + 2))
        check.shape(target_seg, source_seg, dims=[0], broadcast_ok=True)
        check.shape(target_seg, source_seg, dims=range(2, self.dim + 2))

        # chain operations
        source_and_target = torch.cat((source, target), dim=1)

        # generate affine
        affine_prm = self.cnn(source_and_target)
        affine_prm = affine_prm.reshape(affine_prm.shape[:2])

        # generate velocity
        velocity = self.unet(source_and_target)
        velocity = channel2last(velocity)

        # generate deformation grid
        grid = self.exp(velocity, affine_prm)

        # deform
        deformed_source = self.pull(source, grid)
        if source_seg is not None:
            if source_seg.shape[2:] != source.shape[2:]:
                grid = spatial.resize_grid(grid, shape=source_seg.shape[2:])
            deformed_source_seg = self.pull(source_seg, grid)
        else:
            deformed_source_seg = None

        # compute loss and metrics
        self.compute(_loss,
                     _metric,
                     image=[deformed_source, target],
                     velocity=[velocity],
                     segmentation=[deformed_source_seg, target_seg],
                     affine=[affine_prm])

        if deformed_source_seg is None:
            return deformed_source, velocity, affine_prm
        else:
            return deformed_source, deformed_source_seg, velocity, affine_prm