Ejemplo n.º 1
0
def orient(inp, affine=None, layout=None, voxel_size=None, center=None,
           like=None, output=None, output_transform=None):
    """Overwrite the orientation matrix

    Parameters
    ----------
    inp : str or (tuple, tensor)
        Either a path to a volume file or a tuple `(shape, affine)`, where
        the first element contains the volume shape and the second contains
        the orientation matrix.
    affine : {'self', 'like'} or (4, 4) tensor_like, default='like'
        Target affine matrix
    layout : {'self', 'like'} or layout-like, default='like'
        Target orientation.
    voxel_size : {'self', 'like'} or [sequence of] float, default='like'
        Target voxel size.
    center : {'self', 'like'} or [sequence of] float, default='like'
        World coordinate of the center of the field of view.
    like : str or (tuple, tensor)
        Either a path to a volume file or a tuple `(shape, affine)`, where
        the first element contains the volume shape and the second contains
        the orientation matrix.
    output : str, optional
        Output filename.
        If the input is not a path, the reoriented data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{layout}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file.
    output_transform : str, optional
        Filename of output transform.
        If the input is not a path, the reoriented data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}_to_{layout}.lta', where `dir` and `base`
        are the directory and base name of the input file.

    Returns
    -------
    output : str or (tuple, tensor)
        If the input is a path, the output path is returned.
        Else, the new shape and orientation matrix are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = ''

    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        dim = f.affine.shape[-1] - 1
        inp = (f.shape[:dim], f.affine)
        if output is None:
            output = '{dir}{sep}{base}.{layout}{ext}'
        if output_transform is None:
            output_transform = '{dir}{sep}{base}_to_{layout}.lta'
        dir, base, ext = py.fileparts(fname)

    like_is_file = isinstance(like, str) and like
    if like_is_file:
        f = io.volumes.map(like)
        dim = f.affine.shape[-1] - 1
        like = (f.shape[:dim], f.affine)

    shape, aff0 = inp
    dim = aff0.shape[-1] - 1
    if like:
        shape_like, aff_like = like
    else:
        shape_like, aff_like = (shape, aff0)

    if voxel_size in (None, 'like') or len(voxel_size) == 0:
        voxel_size = spatial.voxel_size(aff_like)
    elif voxel_size == 'self':
        voxel_size = spatial.voxel_size(aff0)
    elif voxel_size == 'standard':
        voxel_size = 1.
    voxel_size = utils.make_vector(voxel_size, dim)

    if not layout or layout == 'like':
        layout = spatial.affine_to_layout(aff_like)
    elif layout == 'self':
        layout = spatial.affine_to_layout(aff0)
    elif layout == 'standard':
        layout = 'RAS'
    layout = spatial.volume_layout(layout)

    if center in (None, 'like') or len(center) == 0:
        center = (torch.as_tensor(shape_like, dtype=torch.float) - 1) * 0.5
        center = spatial.affine_matvec(aff_like, center)
    elif center == 'self':
        center = (torch.as_tensor(shape, dtype=torch.float) - 1) * 0.5
        center = spatial.affine_matvec(aff0, center)
    elif center == 'standard':
        center = 0.
    center = utils.make_vector(center, dim)

    if affine in (None, 'like') or len(affine) == 0:
        affine = aff_like
    elif affine == 'self':
        affine = aff0
    elif affine == 'standard':
        affine = torch.eye(dim+1, dim+1)
    affine = torch.as_tensor(affine, dtype=torch.float)
    if affine.numel() == dim*(dim+1):
        affine = spatial.affine_make_rect(affine.reshape(dim, dim+1))
    elif affine.numel() == (dim+1)**2:
        affine = affine.reshape(dim+1, dim+1)
    else:
        raise ValueError(f'Input affine should have {dim*(dim+1)} or '
                         f'{(dim+1)**2} element but got {affine.numel()}.')

    affine = spatial.affine_modify(affine, shape, voxel_size=voxel_size,
                                   layout=layout, center=center)
    affine = affine.double()

    if output:
        dat = io.volumes.load(fname, numpy=True)
        layout = spatial.volume_layout_to_name(layout)
        if is_file:
            output = output.format(dir=dir or '.', base=base, ext=ext,
                                   sep=os.path.sep, layout=layout)
            io.volumes.save(dat, output, like=fname, affine=affine)
        else:
            output = output.format(sep=os.path.sep, layout=layout)
            io.volumes.save(dat, output, affine=affine)

    if output_transform:
        transform = spatial.affine_rmdiv(affine, aff0)
        output_transform = output_transform.format(
            dir=dir or '.', base=base, sep=os.path.sep, layout=layout)
        io.transforms.savef(transform.cpu(), output_transform, type=2)

    if is_file:
        return output
    else:
        return shape, affine
Ejemplo n.º 2
0
    def forward(self, grid, **overload):
        """

        Parameters
        ----------
        grid : (N, *spatial, dim)
            Displacement grid
        overload : dict

        Returns
        -------
        aff : (N, dim+1, dim+1)
            Affine matrix that is closest to grid in the least square sense

        """
        shift = overload.get('shift', self.shift)
        grid = torch.as_tensor(grid)
        info = dict(dtype=grid.dtype, device=grid.device)
        nb_dim = grid.shape[-1]
        shape = grid.shape[1:-1]

        if shift:
            affine_shift = torch.cat((torch.eye(
                nb_dim, **info), -torch.as_tensor(shape, **info)[:, None] / 2),
                                     dim=1)
            affine_shift = spatial.as_euclidean(affine_shift)

        # the forward model is:
        #   phi(x) = M\A*M*x
        # where phi is a *transformation* field, M is the shift matrix
        # and A is the affine matrix.
        # We can decompose phi(x) = x + d(x), where d is a *displacement*
        # field, yielding:
        #   d(x) = M\A*M*x - x = (M\A*M - I)*x := B*x
        # If we write `d(x)` and `x` as large vox*(dim+1) matrices `D`
        # and `G`, we have:
        #   D = G*B'
        # Therefore, the least squares B is obtained as:
        #   B' = inv(G'*G) * (G'*D)
        # Then, A is
        #   A = M*(B + I)/M
        #
        # Finally, we project the affine matrix to its tangent space:
        #   prm[k] = <log(A), B[k]>
        # were <X,Y> = trace(X'*Y) is the Frobenius inner product.

        def igg(identity):
            # Compute inv(g*g'), where g has homogeneous coordinates.
            #   Instead of appending ones, we compute each element of
            #   the block matrix ourselves:
            #       [[g'*g,   g'*1],
            #        [1'*g,   1'*1]]
            #    where 1'*1 = N, the number of voxels.
            g = identity.reshape([identity.shape[0], -1, nb_dim])
            nb_vox = torch.as_tensor([[[g.shape[1]]]], **info)
            sumg = g.sum(dim=1, keepdim=True)
            gg = torch.matmul(g.transpose(-1, -2), g)
            gg = torch.cat((gg, sumg), dim=1)
            sumg = sumg.transpose(-1, -2)
            sumg = torch.cat((sumg, nb_vox), dim=1)
            gg = torch.cat((gg, sumg), dim=2)
            return gg.inverse()

        def gd(identity, disp):
            # compute g'*d, where g and d have homogeneous coordinates.
            #       [[g'*d,   g'*1],
            #        [1'*d,   1'*1]]
            g = identity.reshape([identity.shape[0], -1, nb_dim])
            d = disp.reshape([disp.shape[0], -1, nb_dim])
            nb_vox = torch.as_tensor([[[g.shape[1]]]], **info)
            sumg = g.sum(dim=1, keepdim=True)
            sumd = d.sum(dim=1, keepdim=True)
            gd = torch.matmul(g.transpose(-1, -2), d)
            gd = torch.cat((gd, sumd), dim=1)
            sumg = sumg.transpose(-1, -2)
            sumg = torch.cat((sumg, nb_vox), dim=1)
            sumg = sumg.expand([d.shape[0], sumg.shape[1], sumg.shape[2]])
            gd = torch.cat((gd, sumg), dim=2)
            return gd

        def eye(d):
            x = torch.eye(d, **info)
            z = x.new_zeros([1, d], **info)
            x = torch.cat((x, z), dim=0)
            z = x.new_zeros([d + 1, 1], **info)
            x = torch.cat((x, z), dim=1)
            return x

        identity = spatial.identity_grid(shape, **info)[None, ...]
        affine = torch.matmul(igg(identity), gd(identity, grid))
        affine = affine.transpose(-1, -2) + eye(nb_dim)
        affine = affine[..., :-1, :]
        if shift:
            affine = spatial.as_euclidean(affine)
            affine = spatial.affine_matmul(affine_shift, affine)
            affine = spatial.as_euclidean(affine)
            affine = spatial.affine_rmdiv(affine, affine_shift)
        affine = spatial.affine_make_square(affine)

        return affine
Ejemplo n.º 3
0
def get_oriented_slice(image, dim=-1, index=None, affine=None,
                       space=None, bbox=None, interpolation=1,
                       transpose_sagittal=False, return_index=False,
                       return_mat=False):
    """Sample a slice in a RAS system

    Parameters
    ----------
    image : (..., *shape3)
    dim : int, default=-1
        Index of spatial dimension to sample in the visualization space
        If RAS: -1 = axial / -2 = coronal / -3 = sagittal
    index : int, default=shape//2
        Coordinate (in voxel) of the slice to extract
    affine : (4, 4) tensor, optional
        Orientation matrix of the image
    space : (4, 4) tensor, optional
        Orientation matrix of the visualisation space.
        Default: RAS with minimum voxel size of all inputs.
    bbox : (2, D) tensor_like, optional
        Bounding box: min and max coordinates (in millimetric
        visualisation space). Default: bounding box of the input image.
    interpolation : {0, 1}, default=1
        Interpolation order.

    Returns
    -------
    slice : (..., *shape2) tensor
        Slice in the visualisation space.

    """
    # preproc dim
    if isinstance(dim, str):
        dim = dim.lower()[0]
        if dim == 'a':
            dim = -1
        if dim == 'c':
            dim = -2
        if dim == 's':
            dim = -3

    backend = utils.backend(image)

    # compute default space (mn/mx are in voxels)
    affine, shape = _get_default_native(affine, image.shape[-3:])
    space, mn, mx = _get_default_space(affine, [shape], space, bbox)
    affine, shape = (affine[0], shape[0])

    # compute default cursor (in voxels)
    if index is None:
        index = (mx + mn) / 2
    else:
        index = torch.as_tensor(index)
        index = spatial.affine_matvec(spatial.affine_inv(space), index)

    # include slice to volume matrix
    shape = tuple(((mx-mn) + 1).round().int().tolist())
    if dim == -1:
        # axial
        shift = [[1, 0, 0, - mn[0] + 1],
                 [0, 1, 0, - mn[1] + 1],
                 [0, 0, 1, - index[2]],
                 [0, 0, 0, 1]]
        shift = utils.as_tensor(shift, **backend)
        shape = shape[:-1]
        index = (index[0] - mn[0] + 1, index[1] - mn[1] + 1)
    elif dim == -2:
        # coronal
        shift = [[1, 0, 0, - mn[0] + 1],
                 [0, 0, 1, - mn[2] + 1],
                 [0, 1, 0, - index[1]],
                 [0, 0, 0, 1]]
        shift = utils.as_tensor(shift, **backend)
        shape = (shape[0], shape[2])
        index = (index[0] - mn[0] + 1, index[2] - mn[2] + 1)
    elif dim == -3:
        # sagittal
        if not transpose_sagittal:
            shift = [[0, 0, 1, - mn[2] + 1],
                     [0, 1, 0, - mn[1] + 1],
                     [1, 0, 0, - index[0]],
                     [0, 0, 0, 1]]
            shift = utils.as_tensor(shift, **backend)
            shape = (shape[2], shape[1])
            index = (index[2] - mn[2] + 1, index[1] - mn[1] + 1)
        else:
            shift = [[0, -1, 0,   mx[1] + 1],
                     [0,  0, 1, - mn[2] + 1],
                     [1,  0, 0, - index[0]],
                     [0,  0, 0, 1]]
            shift = utils.as_tensor(shift, **backend)
            shape = (shape[1], shape[2])
            index = (mx[1] + 1 - index[1], index[2] - mn[2] + 1)
    else:
        raise ValueError(f'Unknown dimension {dim}')

    # sample
    space = spatial.affine_rmdiv(space, shift)
    affine = spatial.affine_lmdiv(affine, space)
    affine = affine.to(**backend)
    grid = spatial.affine_grid(affine, [*shape, 1])
    *channel, s0, s1, s2 = image.shape
    imshape = (s0, s1, s2)
    image = image.reshape([1, -1, *imshape])
    image = spatial.grid_pull(image, grid[None], interpolation=interpolation,
                              bound='dct2', extrapolate=False)
    image = image.reshape([*channel, *shape])
    return ((image, index, space) if return_index and return_mat else
            (image, index) if return_index else
            (image, space) if return_mat else image)