예제 #1
0
파일: volumes.py 프로젝트: balbasty/nitorch
def _get_default_native(affines, shapes):
    """Get default native space

    Parameters
    ----------
    affines : [sequence of] (4, 4) tensor_like or None
    shapes : [sequence of] (3,) tensor_like

    Returns
    -------
    affines : (N, 4, 4) tensor
    shapes : (N, 3) tensor

    """
    shapes = utils.as_tensor(shapes).reshape([-1, 3])
    shapes = shapes.unbind(dim=0)
    if torch.is_tensor(affines):
        affines = affines.reshape([-1, 4, 4])
        affines = affines.unbind(dim=0)
    shapes = py.make_list(shapes, max(len(shapes), len(affines)))
    affines = py.make_list(affines, max(len(shapes), len(affines)))

    # default affines
    affines = [spatial.affine_default(shape) if affine is None else affine
               for shape, affine in zip(shapes, affines)]

    affines = utils.as_tensor(affines)
    shapes = utils.as_tensor(shapes)
    affines, shapes = utils.to_max_device(affines, shapes)
    return affines, shapes
예제 #2
0
    def bracket(self, f0, closure):
        """Bracket the minimum

        Parameters
        ----------
        f0 : Initial value
        closure : callable(a) -> evaluate function at `a0 + a`

        Returns
        -------
        (a0, f0), (a1, f1), (a2, f2)
        """
        a0, a1 = 0, self.lr
        f1 = closure(a1)

        # sort such that f1 < f0
        if f1 > f0:
            a0, f0, a1, f1 = a1, f1, a0, f0

        a2 = a1 + self.gold * (a1 - a0)
        f2 = closure(a2)

        while f1 > f2:
            # fit quadratic polynomial
            a = utils.as_tensor([a0 - a1, 0., a2 - a1])
            quad = torch.stack([torch.ones_like(a), a, a.square()], -1)
            f = utils.as_tensor([f0, f1, f2]).unsqueeze(-1)
            quad = quad.pinverse().matmul(f).squeeze(-1)

            if quad[2] > 0:  # There is a minimum
                delta = -0.5 * quad[1] / quad[2].clamp_min(self.tiny)
                delta = delta.clamp_max_((1 + self.gold) * (a2 - a1))
                delta = delta.item()
                a = a1 + delta
            else:  # No minimum -> we must go farther than a2
                delta = self.gold * (a2 - a1)
                a = a2 + delta

            # check progress and update bracket
            # f2 < f1 < f0 so (assuming unicity) the minimum is in
            # (a1, a2) or (a2, inf)
            f = closure(a)
            if a1 < a < a2 or a2 < a < a1:
                if f1 < f < f2:
                    # minimum is in (a1, a2)
                    (a0, f0), (a1, f1), (a2, f2) = (a1, f1), (a, f), (a2, f2)
                    break
                elif f1 < f:  # implicitly: f0 < f1 < f
                    # minimum is in (a0, a)
                    (a0, f0), (a1, f1), (a2, f2) = (a0, f0), (a1, f1), (a, f)
                    break
            # shift by one point
            (a0, f0), (a1, f1), (a2, f2) = (a1, f1), (a2, f2), (a, f)

        return (a0, f0), (a1, f1), (a2, f2)
예제 #3
0
    def search_in_bracket(self, bracket, closure):
        b0, b1 = (bracket[0][0], bracket[2][0])
        if b1 < b0:
            b0, b1 = b1, b0
        # sort by values
        (a0, f0), (a1, f1), (a2, f2) = self.sort_bracket(bracket)

        d = d0 = float('inf')
        for n_iter in range(self.max_iter):

            if abs(a0 - 0.5 * (b0 + b1)) + 0.5 * (b1 - b0) <= 2 * self.tol:
                return a0, f0

            d1, d0 = d0, d

            # fit quadratic polynomial
            a = utils.as_tensor([0., a1 - a0, a2 - a0])
            quad = torch.stack([torch.ones_like(a), a, a.square()], -1)
            f = utils.as_tensor([f0, f1, f2]).unsqueeze(-1)
            quad = quad.pinverse().matmul(f).squeeze(-1)

            d = -0.5 * quad[1] / quad[2].clamp_min(self.tiny)
            d = d.item()
            a = a0 + d

            tiny = self.tiny * (1 + 2 * abs(a0))
            if abs(d) > abs(d1) / 2 or not (b0 + tiny < a <
                                            b1 - tiny) or quad[-1] < 0:
                if a0 > 0.5 * (b0 + b1):
                    d = self.igold * (b0 - a0)
                else:
                    d = self.igold * (b1 - a0)
                a = a0 + d

            # check progress and update bracket
            f = closure(a)
            if f < f0:
                # f < f0 < f1 < f2
                b0, b1 = (b0, a0) if a < a0 else (a0, b1)
                (a0, f0), (a1, f1), (a2, f2) = (a, f), (a0, f0), (a1, f1)
            else:
                b0, b1 = (a, b1) if a < a0 else (b0, a)
                if f < f1:
                    # f0 < f < f1 < f2
                    (a0, f0), (a1, f1), (a2, f2) = (a0, f0), (a, f), (a1, f1)
                elif f < f2:
                    # f0 < f1 < f < f2
                    (a0, f0), (a1, f1), (a2, f2) = (a0, f0), (a1, f1), (a, f)

        return a0, f0
예제 #4
0
    def forward(self, x, noise=None, return_resolution=False):

        if noise is not None:
            noise = noise.expand(x.shape)

        dim = x.dim() - 2
        backend = utils.backend(x)
        resolution_exp = utils.make_vector(self.resolution_exp, x.shape[1],
                                           **backend)
        resolution_scale = utils.make_vector(self.resolution_scale, x.shape[1],
                                             **backend)

        all_resolutions = []
        out = torch.empty_like(x)
        for b in range(len(x)):
            for c in range(x.shape[1]):
                resolution = self.resolution(resolution_exp[c],
                                             resolution_scale[c]).sample()
                resolution = resolution.clamp_min(1)
                fwhm = [resolution] * dim
                y = smooth(x[b, c], fwhm=fwhm, dim=dim, padding='same', bound='dct2')
                if noise is not None:
                    y += noise[b, c]
                factor = [1/resolution] * dim
                y = y[None, None]  # need batch and channel for resize
                y = resize(y, factor=factor, anchor='f')
                factor = [resolution] * dim
                all_resolutions.append(factor)
                y = resize(y, factor=factor, shape=x.shape[2:], anchor='f')
                out[b, c] = y[0, 0]

        all_resolutions = utils.as_tensor(all_resolutions, **backend)
        return (out, all_resolutions) if return_resolution else out
예제 #5
0
파일: volumes.py 프로젝트: balbasty/nitorch
def _get_default_space(affines, shapes, space=None, bbox=None):
    """Get default visualisation space

    Parameters
    ----------
    affines : [sequence of] (4, 4) tensor_like
    shapes : [sequence of] (3,) tensor_like
    space : (4, 4) tensor_like, optional
    bbox : (2, 3) tensor_like, optional

    Returns
    -------
    space, bbox

    """

    affines, shapes = _get_default_native(affines, shapes)
    voxel_size = spatial.voxel_size(affines)
    voxel_size = voxel_size.min()

    if space is None:
        space = torch.eye(4)
        space[:-1, :-1] *= voxel_size
    voxel_size = spatial.voxel_size(space)

    if bbox is None:
        shapes = torch.as_tensor(shapes)
        mn, mx = spatial.compute_fov(space, affines, shapes)
    else:
        mn, mx = utils.as_tensor(bbox)
        mn /= voxel_size
        mx /= voxel_size

    return space, mn, mx
예제 #6
0
 def rotation(self):
     i = self.i
     j = self.j
     k = self.k
     r = self.r
     matrix = [
         [1 - 2 * (j**2 + k**2), 2 * (i * j - k * r), 2 * (i * k + j * r)],
         [2 * (i * j + k * r), 1 - 2 * (i**2 + k**2), 2 * (j * k - i * r)],
         [2 * (i * k - j * r), 2 * (j * k + i * r), 1 - 2 * (i**2 + j**2)]
     ]
     matrix = utils.as_tensor(matrix)
     matrix = utils.movedim(matrix, [0, 1], [-2, -1])
     return matrix
예제 #7
0
파일: gui.py 프로젝트: balbasty/nitorch
 def space(self, value):
     self._space = value
     if torch.is_tensor(value):
         if value.shape != (4, 4):
             raise ValueError('Expected 4x4 matrix')
         self._space_matrix = value
     elif isinstance(value, int):
         affines = [image.affine for image in self.images]
         self._space_matrix = affines[value]
     else:
         if value is not None:
             raise ValueError('Expected a 4x4 matrix or an int or None')
         affines = [image.affine for image in self.images]
         voxel_size = spatial.voxel_size(utils.as_tensor(affines))
         voxel_size = voxel_size.min()
         self._space_matrix = torch.eye(4)
         self._space_matrix[:-1, :-1] *= voxel_size
예제 #8
0
def crop(inp,
         size=None,
         center=None,
         space='vx',
         like=None,
         bbox=False,
         output=None,
         transform=None):
    """Crop a ND volume, while preserving the orientation matrices.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    size : [sequence of] int, optional
        Size of the patch to extract.
        Its unit and axes are defined by `units` and `layout`.
    center : [sequence of] int, optional
        Coordinate of the center of the patch.
        Its unit and axes are defined by `units` and `layout`.
        By default, the center of the FOV is used.
    space : [sequence of] {'vox', 'ras'}, default='vox'
        The space in which the `size` and `center` parameters are expressed.
    bbox : bool or float, default=False
        Crop at the bounding box of `inp > threshold`.
            If `bbox` is a float, it is the threshold to use.
            If `bbox` is `True`, the threshold is 0.
    like : str or (tensor, tensor), optional
        Reference patch.
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    output : [sequence of] str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{i}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file,
        `i` is the coordinate (starting at 1) of the slice.
    transform : [sequence of] str, optional
        Input or output filename(s) of the corresponding transforms.
        Not written by default.
        If a transform is provided and all other parameters
        (i.e., `size` and `like`) are None, the transform is considered
        as an input transform to apply.

    Returns
    -------
    output : list[str or (tensor, tensor)]
        If the input is a path, the output paths are returned.
        Else, the unstacked data and orientation matrices are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = None
    transform_in = False
    use_bbox = bool(bbox or isinstance(bbox, float))

    # --- Open input ---
    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.data(numpy=True) if use_bbox else f, f.affine)
        if output is None:
            output = '{dir}{sep}{base}.crop{ext}'
        dir, base, ext = py.fileparts(fname)
    dat, aff0 = inp
    dim = aff0.shape[-1] - 1
    shape0 = dat.shape[:dim]
    layout0 = spatial.affine_to_layout(aff0)

    # save input space in case we reorient later
    aff00 = aff0
    shape00 = shape0

    if bool(size) + bool(like) + bool(bbox or isinstance(bbox, float)) > 1:
        raise ValueError('Can only use one of `size`, `like` and `bbox`.')

    # --- Open reference and compute size/center ---
    if like:
        like_is_file = isinstance(like, str)
        if like_is_file:
            f = io.volumes.map(like)
            like = (f.shape, f.affine)
        like_shape, like_aff = like
        like_layout = spatial.affine_to_layout(like_aff)
        if (layout0 != like_layout).any():
            aff0, dat = spatial.affine_reorient(aff0, dat, like_layout)
            shape0 = dat.shape[:dim]
        if torch.is_tensor(like_shape):
            like_shape = like_shape.shape
        size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape)
        space = 'vox'

    elif bbox or isinstance(bbox, float):
        if bbox is True:
            bbox = 0.
        mask = torch.as_tensor(dat > bbox)
        while mask.dim() > 3:
            mask = mask.any(dim=-1)
        mins = []
        maxs = []
        for d in range(dim):
            n = mask.shape[d]
            idx = utils.movedim(mask, d,
                                0).reshape([n, -1
                                            ]).any(-1).nonzero(as_tuple=False)
            mins.append(idx.min())
            maxs.append(idx.max())
        mins = utils.as_tensor(mins)
        maxs = utils.as_tensor(maxs)
        size = maxs + 1 - mins
        center = (maxs + 1 + mins).float() / 2
        space = 'vox'
        del mask

    # --- Open transformation file and compute size/center ---
    elif not size:
        if not transform:
            raise ValueError('At least one of size/like/transform must '
                             'be provided')
        transform_in = True
        t = io.transforms.map(transform)
        if not isinstance(t, io.transforms.LinearTransformArray):
            raise TypeError('Expected an LTA file')
        like_aff, like_shape = t.destination_space()
        size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape)

    # --- use center of the FOV ---
    if not torch.is_tensor(center) and not center:
        center = torch.as_tensor(shape0[:dim], dtype=torch.float)
        center = center.sub_(1).mul_(0.5)

    # --- convert size/center to voxels ---
    size = utils.make_vector(size, dim, dtype=torch.long)
    center = utils.make_vector(center, dim, dtype=torch.float)
    space_size, space_center = py.make_list(space, 2)
    if space_center.lower() == 'ras':
        center = spatial.affine_matvec(spatial.affine_inv(aff0), center)
    if space_size.lower() == 'ras':
        perm = spatial.affine_to_layout(aff0)[:, 0]
        size = size[perm.long()]
        size = size / spatial.voxel_size(aff0)

    # --- compute first/last ---
    center = center.float()
    size = (size.ceil() if size.dtype.is_floating_point else size).long()
    first = center - size.float().sub_(1).mul_(0.5)
    first = first.round().long()
    last = (first + size).tolist()
    first = [max(f, 0) for f in first.tolist()]
    last = [min(l, s) for l, s in zip(last, shape0[:dim])]
    verb = 'Cropping patch ['
    verb += ', '.join([f'{f}:{l}' for f, l in zip(first, last)])
    verb += f'] from volume with shape {shape0[:dim]}'
    print(verb)
    slicer = tuple(slice(f, l) for f, l in zip(first, last))

    # --- do crop ---
    if use_bbox:
        dat = dat.numpy()
    dat = dat[slicer]
    if not torch.is_tensor(dat):
        dat = dat.data(numpy=True)
    aff, _ = spatial.affine_sub(aff0, shape0[:dim], slicer)
    shape = dat.shape[:dim]

    if output:
        if is_file:
            output = output.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep)
            io.volumes.save(dat, output, like=fname, affine=aff)
        else:
            output = output.format(sep=os.path.sep)
            io.volumes.save(dat, output, affine=aff)

    if transform and not transform_in:
        if is_file:
            transform = transform.format(dir=dir or '.',
                                         base=base,
                                         ext=ext,
                                         sep=os.path.sep)
        else:
            transform = transform.format(sep=os.path.sep)
        trf = io.transforms.LinearTransformArray(transform, 'w')
        trf.set_source_space(aff00, shape00)
        trf.set_destination_space(aff, shape)
        trf.set_metadata({
            'src': {
                'filename': fname
            },
            'dst': {
                'filename': output
            },
            'type': 1
        })  # RAS_TO_RAS
        trf.set_fdata(torch.eye(4))
        trf.save()

    if is_file:
        return output
    else:
        return dat, aff
예제 #9
0
def is_inside(points, vertices, faces=None):
    """Test if a point is inside a polygon/surface.

    The polygon or surface *must* be closed.

    Parameters
    ----------
    points : (..., dim) tensor
        Coordinates of points to test
    vertices : (nv, dim) tensor
        Vertex coordinates
    faces : (nf, dim) tensor[int]
        Faces are encoded by the indices of its vertices.
        By default, assume that vertices are ordered and define a closed curve

    Returns
    -------
    check : (...) tensor[bool]

    """
    # This function uses a ray-tracing technique:
    #
    #   A half-line is started in each point. If it crosses an even
    #   number of faces, it is inside the shape. If it crosses an even
    #   number of faces, it is not.
    #
    #   In practice, we loop through faces (as we expect there are much
    #   less vertices than voxels) and compute intersection points between
    #   all lines and each face in a batched fashion. We only want to
    #   send these rays in one direction, so we keep aside points whose
    #   intersection have a positive coordinate along the ray.

    points = torch.as_tensor(points)
    vertices = torch.as_tensor(vertices)
    if faces is None:
        faces = [(i, i + 1) for i in range(len(vertices) - 1)]
        faces += [(len(vertices) - 1, 0)]
        faces = utils.as_tensor(faces, dtype=torch.long)

    points, vertices = utils.to_max_dtype(points, vertices)
    points, vertices, faces = utils.to_max_device(points, vertices, faces)
    backend = utils.backend(points)
    batch = points.shape[:-1]
    dim = points.shape[-1]
    eps = constants.eps(points.dtype)
    cross = points.new_zeros(batch, dtype=torch.long)

    ray = torch.randn(dim, **backend)

    for face in faces:
        face = vertices[face]

        # compute normal vector
        origin = face[0]
        if dim == 3:
            u = face[1] - face[0]
            v = face[2] - face[0]
            norm = torch.stack([
                u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2],
                u[0] * v[1] - u[1] * v[0]
            ])
        else:
            assert dim == 2
            u = face[1] - face[0]
            norm = torch.stack([-u[1], u[0]])

        # check co-linearity between face and ray
        colinear = linalg.dot(ray,
                              norm).abs() / (ray.norm() * norm.norm()) < eps
        if colinear:
            continue

        # compute intersection between ray and plane
        #   plane: <norm, x - origin> = 0
        #   line: x = p + t*u
        #   => <norm, p + t*u - origin> = 0
        intersection = linalg.dot(norm, points - origin)
        intersection /= linalg.dot(norm, ray)
        halfmask = intersection >= 0  # we only want to shoot in one direction
        intersection = intersection[halfmask]
        halfpoints = points[halfmask]
        intersection = intersection[..., None] * (-ray)
        intersection += halfpoints

        # check if the intersection is inside the face
        #   first, we project it onto a frame of dimension `dim-1`
        #   defined by (origin, (u, v))
        intersection -= origin
        if dim == 3:
            interu = linalg.dot(intersection, u)
            interv = linalg.dot(intersection, v)
            intersection = (interu >= 0) & (interv > 0) & (interu + interv < 1)
        else:
            intersection = linalg.dot(intersection, u)
            intersection /= u.norm().square_()
            intersection = (intersection >= 0) & (intersection < 1)

        cross[halfmask] += intersection

    # check that the number of crossings is even
    cross = cross.bitwise_and_(1).bool()
    return cross
예제 #10
0
파일: gui.py 프로젝트: balbasty/nitorch
 def _max_fov(self):
     affines = [image.affine for image in self.images]
     shapes = [image.shape for image in self.images]
     affines = utils.as_tensor(affines)
     shapes = utils.as_tensor(shapes)
     return spatial.compute_fov(self._space_matrix, affines, shapes)
예제 #11
0
파일: gui.py 프로젝트: balbasty/nitorch
 def _index_from_cursor(self, x, y, image, n_ax):
     p = utils.as_tensor([x, y, 0])
     mat = image._mats[n_ax]
     self.index = spatial.affine_matvec(mat, p)
예제 #12
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch size
        overload : dict
            All parameters defined at build time can be overridden at call time

        Returns
        -------
        affine : (batch, dim[+1], dim+1) tensor
            Velocity field

        """
        dim = overload.get('dim', self.dim)
        translation = make_list(overload.get('translation', self.translation))
        rotation = make_list(overload.get('rotation', self.rotation))
        zoom = make_list(overload.get('zoom', self.zoom))
        shear = make_list(overload.get('shear', self.shear))
        dtype = make_list(overload.get('dtype', self.dtype))
        device = make_list(overload.get('device', self.device))

        # compute dimension
        dim = dim or max(len(translation), len(rotation), len(zoom),
                         len(shear))
        translation = make_list(translation, dim)
        rotation = make_list(rotation, dim * (dim - 1) // 2)
        zoom = make_list(zoom, dim)
        shear = make_list(shear, dim * (dim - 1) // 2)

        # sample values if needed
        translation = [
            x([batch]) if callable(x) else self.default_translation([batch])
            if x is True else 0. if x is None or x is False else x
            for x in translation
        ]
        rotation = [
            x([batch]) if callable(x) else self.default_rotation([batch])
            if x is True else 0. if x is None or x is False else x
            for x in rotation
        ]
        zoom = [
            x([batch]) if callable(x) else self.default_zoom([batch])
            if x is True else 1. if x is None or x is False else x
            for x in zoom
        ]
        shear = [
            x([batch]) if callable(x) else self.default_shear([batch])
            if x is True else 0. if x is None or x is False else x
            for x in shear
        ]
        rotation = [x * math.pi / 180 for x in rotation]  # degree -> radian
        prm = [*translation, *rotation, *zoom, *shear]
        prm = [
            p.expand(batch) if torch.is_tensor(p) and p.shape[0] != batch else
            make_list(p, batch) if not torch.is_tensor(p) else p for p in prm
        ]

        prm = utils.as_tensor(prm)
        prm = prm.transpose(0, 1)

        # generate affine matrix
        mat = affine_matrix_classic(prm, dim=dim).\
            type(self.dtype).to(self.device)

        return mat
예제 #13
0
def graph_atlas(velocities, nodes=None, latent=None):
    """Infer template-to-image SVFs from pairwise SVFs..

    References
    ----------
    ..[1] "Robust joint registration of multiple stains and MRI for
           multimodal 3D histology reconstruction: Application to the
           Allen human brain atlas"
          Adrià Casamitjana, Marco Lorenzi, Sebastiano Ferraris,
          Loic Peter, Marc Modat, Allison Stevens, Bruce Fischl,
          Tom Vercauteren, Juan Eugenio Iglesias
          https://arxiv.org/abs/2104.14873

    Parameters
    ----------
    velocities : (n, *spatial, dim) tensor
    nodes : n-sequence of (int, int), default=[(1, 2), (2, 3), ..., (N, 1)]
    latent : k-sequence of (int, int) tensor, default=[(0, 1), ..., (0, N)]

    Returns
    -------
    latent : sequence of (k, *spatial, dim) tensor

    """
    def default_nodes(n):
        nodes = []
        for n in range(1, n):
            nodes += [(n, n + 1)]
        nodes += [(n, 1)]
        return nodes

    def default_latent(n):
        nodes = [(0, n) for n in range(1, n + 1)]
        return nodes

    velocities = utils.as_tensor(velocities)
    backend = utils.backend(velocities)

    # defaults
    observed_nodes = list(nodes or default_nodes(len(velocities)))
    latent_nodes = list(latent or default_latent(len(velocities)))

    # compute W
    connections = _build_matrix(observed_nodes, latent_nodes)
    coordinates = [(latent_nodes.index(j), i)
                   for i, c in enumerate(connections) for j in c]
    values = [v for c in connections for v in c.values()]
    values = torch.as_tensor(values, dtype=torch.float)
    coordinates = torch.as_tensor(coordinates, dtype=torch.long).T
    w = torch.sparse_coo_tensor(
        coordinates, values,
        [len(latent_nodes), len(observed_nodes)], **backend)

    # re-parameterise last velocity to enforce `sum of velocities = 0`
    w = w.to_dense()
    wlast = w[-1:, :]
    w = w[:-1, :]
    w -= wlast

    velocities = utils.movedim(velocities, 0, -1)[..., None]
    latent = w.transpose(-1, -2).pinverse().matmul(velocities)[..., 0]
    latent = utils.movedim(latent, -1, 0)
    latent = torch.cat([latent, -latent.sum(0, keepdim=True)], dim=0)

    return latent
예제 #14
0
파일: volumes.py 프로젝트: balbasty/nitorch
def get_oriented_slice(image, dim=-1, index=None, affine=None,
                       space=None, bbox=None, interpolation=1,
                       transpose_sagittal=False, return_index=False,
                       return_mat=False):
    """Sample a slice in a RAS system

    Parameters
    ----------
    image : (..., *shape3)
    dim : int, default=-1
        Index of spatial dimension to sample in the visualization space
        If RAS: -1 = axial / -2 = coronal / -3 = sagittal
    index : int, default=shape//2
        Coordinate (in voxel) of the slice to extract
    affine : (4, 4) tensor, optional
        Orientation matrix of the image
    space : (4, 4) tensor, optional
        Orientation matrix of the visualisation space.
        Default: RAS with minimum voxel size of all inputs.
    bbox : (2, D) tensor_like, optional
        Bounding box: min and max coordinates (in millimetric
        visualisation space). Default: bounding box of the input image.
    interpolation : {0, 1}, default=1
        Interpolation order.

    Returns
    -------
    slice : (..., *shape2) tensor
        Slice in the visualisation space.

    """
    # preproc dim
    if isinstance(dim, str):
        dim = dim.lower()[0]
        if dim == 'a':
            dim = -1
        if dim == 'c':
            dim = -2
        if dim == 's':
            dim = -3

    backend = utils.backend(image)

    # compute default space (mn/mx are in voxels)
    affine, shape = _get_default_native(affine, image.shape[-3:])
    space, mn, mx = _get_default_space(affine, [shape], space, bbox)
    affine, shape = (affine[0], shape[0])

    # compute default cursor (in voxels)
    if index is None:
        index = (mx + mn) / 2
    else:
        index = torch.as_tensor(index)
        index = spatial.affine_matvec(spatial.affine_inv(space), index)

    # include slice to volume matrix
    shape = tuple(((mx-mn) + 1).round().int().tolist())
    if dim == -1:
        # axial
        shift = [[1, 0, 0, - mn[0] + 1],
                 [0, 1, 0, - mn[1] + 1],
                 [0, 0, 1, - index[2]],
                 [0, 0, 0, 1]]
        shift = utils.as_tensor(shift, **backend)
        shape = shape[:-1]
        index = (index[0] - mn[0] + 1, index[1] - mn[1] + 1)
    elif dim == -2:
        # coronal
        shift = [[1, 0, 0, - mn[0] + 1],
                 [0, 0, 1, - mn[2] + 1],
                 [0, 1, 0, - index[1]],
                 [0, 0, 0, 1]]
        shift = utils.as_tensor(shift, **backend)
        shape = (shape[0], shape[2])
        index = (index[0] - mn[0] + 1, index[2] - mn[2] + 1)
    elif dim == -3:
        # sagittal
        if not transpose_sagittal:
            shift = [[0, 0, 1, - mn[2] + 1],
                     [0, 1, 0, - mn[1] + 1],
                     [1, 0, 0, - index[0]],
                     [0, 0, 0, 1]]
            shift = utils.as_tensor(shift, **backend)
            shape = (shape[2], shape[1])
            index = (index[2] - mn[2] + 1, index[1] - mn[1] + 1)
        else:
            shift = [[0, -1, 0,   mx[1] + 1],
                     [0,  0, 1, - mn[2] + 1],
                     [1,  0, 0, - index[0]],
                     [0,  0, 0, 1]]
            shift = utils.as_tensor(shift, **backend)
            shape = (shape[1], shape[2])
            index = (mx[1] + 1 - index[1], index[2] - mn[2] + 1)
    else:
        raise ValueError(f'Unknown dimension {dim}')

    # sample
    space = spatial.affine_rmdiv(space, shift)
    affine = spatial.affine_lmdiv(affine, space)
    affine = affine.to(**backend)
    grid = spatial.affine_grid(affine, [*shape, 1])
    *channel, s0, s1, s2 = image.shape
    imshape = (s0, s1, s2)
    image = image.reshape([1, -1, *imshape])
    image = spatial.grid_pull(image, grid[None], interpolation=interpolation,
                              bound='dct2', extrapolate=False)
    image = image.reshape([*channel, *shape])
    return ((image, index, space) if return_index and return_mat else
            (image, index) if return_index else
            (image, space) if return_mat else image)
예제 #15
0
파일: grappa.py 프로젝트: balbasty/nitorch
def kernel_fit(calib, kernel_size, patterns, lam=0.01):
    """Compute GRAPPA kernels

    All batch elements should have the same sampling pattern

    Parameters
    ----------
    calib : ([*batch], coils, *freq)
        Fully-sampled calibration data
    kernel_size : sequence[int]
        GRAPPA kernel size
    patterns : (N,) tensor[int]
        Code of patterns for which to learn a kernel.
        See `pattern_to_code`.
    lam : float, default=0.01
        Tikhonov regularization

    Returns
    -------
    kernels : dict of int -> ([*batch], coils, coils, nb_elem) tensor
        GRAPPA kernels

    """
    kernel_size = py.make_list(kernel_size)
    ndim = len(kernel_size)
    coils, *freq = calib.shape[-ndim - 1:]
    batch = calib.shape[:-ndim - 1]

    # find all possible patterns
    patterns = utils.as_tensor(patterns, device=calib.device)
    if patterns.dtype is torch.bool:
        patterns = pattern_to_code(patterns, ndim)
    patterns = patterns.flatten()

    # learn one kernel for each pattern
    calib = utils.unfold(calib, kernel_size, collapse=True)  # [*B, C, N, *K]
    calib = utils.movedim(calib, -ndim - 1, -ndim - 2)  # [*B, N, C, *K]

    def t(x):
        return x.transpose(-1, -2)

    def conjt(x):
        return t(x).conj()

    def diag(x):
        return x.diagonal(0, -1, -2)

    kernels = {}
    center = [(k - 1) // 2 for k in kernel_size]
    center = (Ellipsis, *center)
    for pattern_code in patterns:
        if code_has_center(pattern_code, kernel_size):
            continue
        pattern = code_to_pattern(pattern_code,
                                  kernel_size,
                                  device=calib.device)
        pattern_size = pattern.sum()
        if pattern_size == 0:
            continue

        calib_target = calib[center]  # [*B, N, C]
        calib_source = calib[..., pattern]  # [*B, N, C, P]
        calib_size = calib_target.shape[-2]
        flat_shape = [*batch, calib_size, pattern_size * coils]
        calib_source = calib_source.reshape(flat_shape)  # [*B, N, C*P]
        # solve
        H = conjt(calib_source).matmul(calib_source)  # [*B, C*P, C*P]
        diag(H).add_(lam * diag(H).abs().max(-1, keepdim=True).values)
        diag(H).add_(lam)
        g = conjt(calib_source).matmul(calib_target)  # [*B, C*P, C]
        k = linalg.lmdiv(H, g).transpose(-1, -2)  # [*B, C, C*P]
        k = k.reshape([*batch, coils, coils, pattern_size])  # [*B, C, C, P]
        kernels[pattern_code.item()] = k

    return kernels