Example #1
0
 def _build_affine_pyramid(self, affine, shape, levels, method):
     levels = list(levels)
     indexed_levels = list(enumerate(levels))
     indexed_levels.sort(key=lambda x: x[1])
     nb_levels = max(levels)
     affines = [affine] * levels.count(0)
     for level in range(1, nb_levels+1):
         if method[0] == 's':  # stride pyramid
             slicer = (slice(None, None, 2),) * len(shape)
             affine, shape = spatial.affine_sub(affine, shape, slicer)
         else:  # conv pyramid
             if method[0] == 'g':  # gaussian pyramid
                 padding = 'auto'
                 kernel = 3
             else:  # moving window
                 padding = 0
                 kernel = [min(2, s) for s in shape]
             affine, shape = spatial.affine_conv(affine, shape, kernel, 2,
                                                 padding=padding)
         affines += [affine] * levels.count(level)
     reordered_affines = [None] * len(levels)
     for (i, level), affine in zip(indexed_levels, affines):
         reordered_affines[i] = affine
     return reordered_affines
Example #2
0
def unstack(inp, dim=-1, output=None, transform=None):
    """Unstack a ND volume, while preserving the orientation matrices.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    dim : int, default=-1
        Dimension along which to unstack.
    output : [sequence of] str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{i}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file,
        `i` is the coordinate (starting at 1) of the slice.
    transform : [sequence of] str, optional
        Output filename(s) of the corresponding transforms.
        Not written by default.

    Returns
    -------
    output : list[str or (tensor, tensor)]
        If the input is a path, the output paths are returned.
        Else, the unstacked data and orientation matrices are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = ''

    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.data(), f.affine)
        if output is None:
            output = '{dir}{sep}{base}.{i}{ext}'
        dir, base, ext = py.fileparts(fname)

    dat, aff0 = inp
    ndim = aff0.shape[-1] - 1
    if dim > ndim:
        # we didn't touch the spatial dimensions
        aff = [aff0.clone() for _ in range(len(dat))]
    else:
        aff = []
        slicer = [slice(None) for _ in range(ndim)]
        shape = dat.shape[:ndim]
        for z in range(dat.shape[dim]):
            slicer[dim] = slice(z, z + 1)
            aff1, _ = spatial.affine_sub(aff0, shape, tuple(slicer))
            aff.append(aff1)
    dat = torch.unbind(dat, dim)
    dat = [d.unsqueeze(dim) for d in dat]
    dat = list(zip(dat, aff))

    formatted_output = []
    if output:
        output = py.make_list(output, len(dat))
        formatted_output = []
        for i, ((dat1, aff1), out1) in enumerate(zip(dat, output)):
            if is_file:
                out1 = out1.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep,
                                   i=i + 1)
                io.volumes.save(dat1, out1, like=fname, affine=aff1)
            else:
                out1 = out1.format(sep=os.path.sep, i=i + 1)
                io.volumes.save(dat1, out1, affine=aff1)
            formatted_output.append(out1)

    if transform:
        transform = py.make_list(transform, len(dat))
        for i, ((_, aff1), trf1) in enumerate(zip(dat, transform)):
            if is_file:
                trf1 = trf1.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep,
                                   i=i + 1)
            else:
                trf1 = trf1.format(sep=os.path.sep, i=i + 1)
            io.transforms.savef(torch.eye(4), trf1, source=aff0, target=aff1)

    if is_file:
        return formatted_output
    else:
        return dat, aff
Example #3
0
def crop(inp,
         size=None,
         center=None,
         space='vx',
         like=None,
         bbox=False,
         output=None,
         transform=None):
    """Crop a ND volume, while preserving the orientation matrices.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    size : [sequence of] int, optional
        Size of the patch to extract.
        Its unit and axes are defined by `units` and `layout`.
    center : [sequence of] int, optional
        Coordinate of the center of the patch.
        Its unit and axes are defined by `units` and `layout`.
        By default, the center of the FOV is used.
    space : [sequence of] {'vox', 'ras'}, default='vox'
        The space in which the `size` and `center` parameters are expressed.
    bbox : bool or float, default=False
        Crop at the bounding box of `inp > threshold`.
            If `bbox` is a float, it is the threshold to use.
            If `bbox` is `True`, the threshold is 0.
    like : str or (tensor, tensor), optional
        Reference patch.
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    output : [sequence of] str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{i}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file,
        `i` is the coordinate (starting at 1) of the slice.
    transform : [sequence of] str, optional
        Input or output filename(s) of the corresponding transforms.
        Not written by default.
        If a transform is provided and all other parameters
        (i.e., `size` and `like`) are None, the transform is considered
        as an input transform to apply.

    Returns
    -------
    output : list[str or (tensor, tensor)]
        If the input is a path, the output paths are returned.
        Else, the unstacked data and orientation matrices are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = None
    transform_in = False
    use_bbox = bool(bbox or isinstance(bbox, float))

    # --- Open input ---
    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.data(numpy=True) if use_bbox else f, f.affine)
        if output is None:
            output = '{dir}{sep}{base}.crop{ext}'
        dir, base, ext = py.fileparts(fname)
    dat, aff0 = inp
    dim = aff0.shape[-1] - 1
    shape0 = dat.shape[:dim]
    layout0 = spatial.affine_to_layout(aff0)

    # save input space in case we reorient later
    aff00 = aff0
    shape00 = shape0

    if bool(size) + bool(like) + bool(bbox or isinstance(bbox, float)) > 1:
        raise ValueError('Can only use one of `size`, `like` and `bbox`.')

    # --- Open reference and compute size/center ---
    if like:
        like_is_file = isinstance(like, str)
        if like_is_file:
            f = io.volumes.map(like)
            like = (f.shape, f.affine)
        like_shape, like_aff = like
        like_layout = spatial.affine_to_layout(like_aff)
        if (layout0 != like_layout).any():
            aff0, dat = spatial.affine_reorient(aff0, dat, like_layout)
            shape0 = dat.shape[:dim]
        if torch.is_tensor(like_shape):
            like_shape = like_shape.shape
        size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape)
        space = 'vox'

    elif bbox or isinstance(bbox, float):
        if bbox is True:
            bbox = 0.
        mask = torch.as_tensor(dat > bbox)
        while mask.dim() > 3:
            mask = mask.any(dim=-1)
        mins = []
        maxs = []
        for d in range(dim):
            n = mask.shape[d]
            idx = utils.movedim(mask, d,
                                0).reshape([n, -1
                                            ]).any(-1).nonzero(as_tuple=False)
            mins.append(idx.min())
            maxs.append(idx.max())
        mins = utils.as_tensor(mins)
        maxs = utils.as_tensor(maxs)
        size = maxs + 1 - mins
        center = (maxs + 1 + mins).float() / 2
        space = 'vox'
        del mask

    # --- Open transformation file and compute size/center ---
    elif not size:
        if not transform:
            raise ValueError('At least one of size/like/transform must '
                             'be provided')
        transform_in = True
        t = io.transforms.map(transform)
        if not isinstance(t, io.transforms.LinearTransformArray):
            raise TypeError('Expected an LTA file')
        like_aff, like_shape = t.destination_space()
        size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape)

    # --- use center of the FOV ---
    if not torch.is_tensor(center) and not center:
        center = torch.as_tensor(shape0[:dim], dtype=torch.float)
        center = center.sub_(1).mul_(0.5)

    # --- convert size/center to voxels ---
    size = utils.make_vector(size, dim, dtype=torch.long)
    center = utils.make_vector(center, dim, dtype=torch.float)
    space_size, space_center = py.make_list(space, 2)
    if space_center.lower() == 'ras':
        center = spatial.affine_matvec(spatial.affine_inv(aff0), center)
    if space_size.lower() == 'ras':
        perm = spatial.affine_to_layout(aff0)[:, 0]
        size = size[perm.long()]
        size = size / spatial.voxel_size(aff0)

    # --- compute first/last ---
    center = center.float()
    size = (size.ceil() if size.dtype.is_floating_point else size).long()
    first = center - size.float().sub_(1).mul_(0.5)
    first = first.round().long()
    last = (first + size).tolist()
    first = [max(f, 0) for f in first.tolist()]
    last = [min(l, s) for l, s in zip(last, shape0[:dim])]
    verb = 'Cropping patch ['
    verb += ', '.join([f'{f}:{l}' for f, l in zip(first, last)])
    verb += f'] from volume with shape {shape0[:dim]}'
    print(verb)
    slicer = tuple(slice(f, l) for f, l in zip(first, last))

    # --- do crop ---
    if use_bbox:
        dat = dat.numpy()
    dat = dat[slicer]
    if not torch.is_tensor(dat):
        dat = dat.data(numpy=True)
    aff, _ = spatial.affine_sub(aff0, shape0[:dim], slicer)
    shape = dat.shape[:dim]

    if output:
        if is_file:
            output = output.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep)
            io.volumes.save(dat, output, like=fname, affine=aff)
        else:
            output = output.format(sep=os.path.sep)
            io.volumes.save(dat, output, affine=aff)

    if transform and not transform_in:
        if is_file:
            transform = transform.format(dir=dir or '.',
                                         base=base,
                                         ext=ext,
                                         sep=os.path.sep)
        else:
            transform = transform.format(sep=os.path.sep)
        trf = io.transforms.LinearTransformArray(transform, 'w')
        trf.set_source_space(aff00, shape00)
        trf.set_destination_space(aff, shape)
        trf.set_metadata({
            'src': {
                'filename': fname
            },
            'dst': {
                'filename': output
            },
            'type': 1
        })  # RAS_TO_RAS
        trf.set_fdata(torch.eye(4))
        trf.save()

    if is_file:
        return output
    else:
        return dat, aff
Example #4
0
def crop(inp,
         size=None,
         center=None,
         space='vx',
         like=None,
         output=None,
         transform=None):
    """Crop a ND volume, while preserving the orientation matrices.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    size : [sequence of] int, optional
        Size of the patch to extract.
        Its unit and axes are defined by `units` and `layout`.
    center : [sequence of] int, optional
        Coordinate of the center of the patch.
        Its unit and axes are defined by `units` and `layout`.
        By default, the center of the FOV is used.
    space : [sequence of] {'vox', 'ras'}, default='vox'
        The space in which the `size` and `center` parameters are expressed.
    like : str or (tensor, tensor), optional
        Reference patch.
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    output : [sequence of] str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{i}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file,
        `i` is the coordinate (starting at 1) of the slice.
    transform : [sequence of] str, optional
        Input or output filename(s) of the corresponding transforms.
        Not written by default.
        If a transform is provided and all other parameters
        (i.e., `size` and `like`) are None, the transform is considered
        as an input transform to apply.

    Returns
    -------
    output : list[str or (tensor, tensor)]
        If the input is a path, the output paths are returned.
        Else, the unstacked data and orientation matrices are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = None
    transform_in = False

    # --- Open input ---
    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.data(numpy=True), f.affine)
        if output is None:
            output = '{dir}{sep}{base}.crop{ext}'
        dir, base, ext = py.fileparts(fname)
    dat, aff0 = inp
    dim = aff0.shape[-1] - 1
    shape0 = dat.shape[:dim]

    if size and like:
        raise ValueError('Cannot use both `size` and `like`.')

    # --- Open reference and compute size/center ---
    if like:
        like_is_file = isinstance(like, str)
        if like_is_file:
            f = io.volumes.map(like)
            like = (f.shape, f.affine)
        like_shape, like_aff = like
        if torch.is_tensor(like_shape):
            like_shape = like_shape.shape
        size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape)

    # --- Open transformation file and compute size/center ---
    elif not size:
        if not transform:
            raise ValueError('At least one of size/like/transform must '
                             'be provided')
        transform_in = True
        t = io.transforms.map(transform)
        if not isinstance(t, io.transforms.LinearTransformArray):
            raise TypeError('Expected an LTA file')
        like_aff, like_shape = t.destination_space()
        size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape)

    # --- use center of the FOV ---
    if not torch.is_tensor(center) and not center:
        center = torch.as_tensor(shape0[:dim], dtype=torch.float) * 0.5

    # --- convert size/center to voxels ---
    size = utils.make_vector(size, dim, dtype=torch.long)
    center = utils.make_vector(center, dim, dtype=torch.float)
    space_size, space_center = py.make_list(space, 2)
    if space_center.lower() == 'ras':
        center = spatial.affine_matvec(spatial.affine_inv(aff0), center)
    if space_size.lower() == 'ras':
        perm = spatial.affine_to_layout(aff0)[:, 0]
        size = size[perm.long()]
        size = size / spatial.voxel_size(aff0)

    # --- compute first/last ---
    center = center.float()
    size = size.ceil().long()
    first = (center - size.float() / 2).round().long()
    last = (first + size).tolist()
    first = [max(f, 0) for f in first.tolist()]
    last = [min(l, s) for l, s in zip(last, shape0[:dim])]
    verb = 'Cropping patch ['
    verb += ', '.join([f'{f}:{l}' for f, l in zip(first, last)])
    verb += f'] from volume with shape {shape0[:dim]}'
    print(verb)
    slicer = tuple(slice(f, l) for f, l in zip(first, last))

    # --- do crop ---
    dat = dat[slicer]
    aff, _ = spatial.affine_sub(aff0, shape0[:dim], slicer)
    shape = dat.shape[:dim]

    if output:
        if is_file:
            output = output.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep)
            io.volumes.save(dat, output, like=fname, affine=aff)
        else:
            output = output.format(sep=os.path.sep)
            io.volumes.save(dat, output, affine=aff)

    if transform and not transform_in:
        if is_file:
            transform = transform.format(dir=dir or '.',
                                         base=base,
                                         ext=ext,
                                         sep=os.path.sep)
        else:
            transform = transform.format(sep=os.path.sep)
        trf = io.transforms.LinearTransformArray(transform, 'w')
        trf.set_source_space(aff0, shape0)
        trf.set_destination_space(aff, shape)
        trf.set_metadata({
            'src': {
                'filename': fname
            },
            'dst': {
                'filename': output
            },
            'type': 1
        })  # RAS_TO_RAS
        trf.set_fdata(torch.eye(4))
        trf.save()

    if is_file:
        return output
    else:
        return dat, aff
Example #5
0
def extract_patches(inp, size=64, stride=None, output=None, transform=None):
    """Extracgt patches from a 3D volume.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    size : [sequence of] int, default=64
        Patch size.
    stride : [sequence of] int, default=size
        Stride between patches.
    output : [sequence of] str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{i}_{j}_{k}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file,
        `i` is the coordinate (starting at 1) of the slice.
    transform : [sequence of] str, optional
        Output filename(s) of the corresponding transforms.
        Not written by default.

    Returns
    -------
    output : list[str] or (tensor, tensor)
        If the input is a path, the output paths are returned.
        Else, the unfolded data and orientation matrices are returned.
            Data will have shape (nx, ny, nz, *size, *channels).
            Affines will have shape (nx, ny, nz, 4, 4).

    """
    dir = ''
    base = ''
    ext = ''
    fname = ''

    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.fdata(), f.affine)
        if output is None:
            output = '{dir}{sep}{base}.{i}_{j}_{k}{ext}'
        dir, base, ext = py.fileparts(fname)

    dat, aff0 = inp

    shape = dat.shape[:3]
    size = py.make_list(size, 3)
    stride = py.make_list(stride, 3)
    stride = [st or sz for st, sz in zip(stride, size)]

    dat = utils.movedim(dat, [0, 1, 2], [-3, -2, -1])
    dat = utils.unfold(dat, size, stride)
    dat = utils.movedim(dat, [-6, -5, -4, -3, -2, -1], [0, 1, 2, 3, 4, 5])

    aff = aff0.new_empty(dat.shape[:3] + aff0.shape)
    for i in range(dat.shape[0]):
        for j in range(dat.shape[1]):
            for k in range(dat.shape[2]):
                index = (i, j, k)
                sub = [slice(st*idx, st*idx + sz)
                       for st, sz, idx in zip(stride, size, index)]
                aff[i, j, k], _ = spatial.affine_sub(aff0, shape, tuple(sub))

    formatted_output = []
    if output:
        output = py.make_list(output, py.prod(dat.shape[:3]))
        formatted_output = []
        for i in range(dat.shape[0]):
            for j in range(dat.shape[1]):
                for k in range(dat.shape[2]):
                    out1 = output.pop(0)
                    if is_file:
                        out1 = out1.format(dir=dir or '.', base=base, ext=ext,
                                           sep=os.path.sep, i=i+1, j=j+1, k=k+1)
                        io.volumes.savef(dat[i, j, k], out1, like=fname,
                                         affine=aff[i, j, k])
                    else:
                        out1 = out1.format(sep=os.path.sep, i=i, j=j, k=k)
                        io.volumes.savef(dat[i, j, k], out1, affine=aff[i, j, k])
                    formatted_output.append(out1)

    if transform:
        transform = py.make_list(transform, py.prod(dat.shape[:3]))
        for i in range(dat.shape[0]):
            for j in range(dat.shape[1]):
                for k in range(dat.shape[2]):
                    trf1 = transform.pop(0)
                    if is_file:
                        trf1 = trf1.format(dir=dir or '.', base=base, ext=ext,
                                           sep=os.path.sep, i=i+1, j=j+1, k=k+1)
                    else:
                        trf1 = trf1.format(sep=os.path.sep, i=i+1, j=j+1, k=k+1)
                    io.transforms.savef(torch.eye(4), trf1,
                                        source=aff0, target=aff[i, j, k])

    if is_file:
        return formatted_output
    else:
        return dat, aff
Example #6
0
def align_tpm(dat,
              tpm=None,
              weights=None,
              spacing=(8, 4),
              device=None,
              basis='affine',
              joint=False,
              progressive=False,
              bins=256,
              fwhm=None,
              max_iter_gn=100,
              max_iter_em=32,
              max_line_search=6,
              verbose=1):
    """Align a Tissue Probability Map to an image

    Input Parameters
    ----------------
    dat : file(s) or tensor or (tensor, affine)
        Input image(s)
    tpm : file(s) or tensor or (tensor, affine), optional
        Input tissue probability map. Uses SPM's TPM by default.
    weights : file(s) or tensor
        Input mask or weight map
    device : torch.device, optional
        Specify device
    verbose : int, default=1
        0 = Write nothing
        1 = Outer loop
        2 = Line search
        3 = Inner loop

    Option Parameters
    -----------------
    spacing : float(s), default=(8, 3)
        Sampling  distance in mm. If multiple value, coarse to fine fit.
        Larger is faster but less accurate.
    basis : {'trans', 'rot', 'rigid', 'sim', 'aff'}, default='affine'
        Transformation model
    joint : bool, default=False
        Estimate a single affine for all images
    progressive : bool, default=False
        Fit prameters progressively (translation then rigid then affine)

    Optimization parameters
    -----------------------
    bins : int, default=256
        Number of bins to use to discretize the input image
    fwhm : float, default=bins/64
        Full-width at half-maximum used to smooth the joint histogram
    max_iter_gn : int, default=100
        Maximumm number of Gauss-Newton iterations
    max_iter_em : int, default=32
        Maximum number of EM iterations
    max_line_search : int, default=6
        Maximum number of line search steps

    Returns
    -------
    aff : ([B], 4, 4) tensor
        Affine matrix.
        Can be applied to the TPM by `aff \ tpm.affine`
        or to the image by `aff @ dat.affine`.

    """
    # ------------------------------------------------------------------
    #       LOAD DATA
    # ------------------------------------------------------------------
    affine_dat = affine_tpm = None
    if isinstance(dat, (list, tuple)) and torch.is_tensor(dat[0]):
        affine_dat = dat[1] if len(dat) > 1 else None
        dat = dat[0]
    if isinstance(tpm, (list, tuple)) and torch.is_tensor(tpm[0]):
        affine_tpm = tpm[1] if len(tpm) > 1 else None
        tpm = tpm[0]
    backend = get_backend(dat, tpm, device)
    tpm, affine_tpm = get_prior(tpm, affine_tpm, **backend)
    dim = tpm.dim() - 1
    dat, weights, affine_dat = get_data(dat, weights, affine_dat, dim,
                                        **backend)

    if weights is None:
        weights = 1
    weights = weights * torch.isfinite(dat)

    # ------------------------------------------------------------------
    #       DEFAULT ORIENTATION MATRICES
    # ------------------------------------------------------------------
    if affine_tpm is not None:
        affine_tpm = affine_tpm.to(dat.dtype)
    else:
        affine_tpm = spatial.affine_default(tpm.shape[-dim:], **backend)
    if affine_dat is None:
        affine_dat = spatial.affine_default(dat.shape[-dim:], **backend)

    dat = dat.unsqueeze(1)  # [B, 1, *spatial]
    weights = weights.unsqueeze(1)  # [B, 1, *spatial]
    tpm = tpm.unsqueeze(0)  # [1, K, *spatial]

    # ------------------------------------------------------------------
    #       DISCRETIZE
    # ------------------------------------------------------------------
    dat = discretize(dat, nbins=bins, mask=weights)

    # ------------------------------------------------------------------
    #       OPTIONS
    # ------------------------------------------------------------------
    opt = dict(
        basis=basis,
        joint=joint,
        progressive=progressive,
        fwhm=fwhm,
        max_iter_gn=max_iter_gn,
        max_iter_em=max_iter_em,
        max_line_search=max_line_search,
        verbose=verbose,
    )

    # ------------------------------------------------------------------
    #       SPACING
    # ------------------------------------------------------------------
    spacing = py.make_list(spacing) or [0]
    dat0, affine_dat0, weights0 = dat, affine_dat, weights
    vx = spatial.voxel_size(affine_dat0).tolist()
    prm = None
    for sp in spacing:

        if sp:
            sp = [max(1, int(pymath.floor(sp / vx1))) for vx1 in vx]
            sp = [slice(None, None, sp1) for sp1 in sp]
            affine_dat, _ = spatial.affine_sub(affine_dat0, dat0.shape[-dim:],
                                               tuple(sp))
            dat = dat0[(Ellipsis, *sp)]
            if weights is not None:
                weights = weights0[(Ellipsis, *sp)]

        _, aff, prm = fit_affine_tpm(dat,
                                     tpm,
                                     affine_dat,
                                     affine_tpm,
                                     weights,
                                     **opt,
                                     prm=prm)

    return aff.squeeze()
Example #7
0
    def slice(self, index, new_shape=None, _pre_expanded=False):
        """Extract a sub-part of the array.

        Indices can only be slices, ellipses, integers or None.

        Parameters
        ----------
        index : tuple[slice or ellipsis or int or None]

        Other Parameters
        ----------------
        new_shape : sequence[int], optional
            Output shape of the sliced object
        _pre_expanded : bool, default=False
            Set to True of `expand_index` has already been called on `index`

        Returns
        -------
        subarray : type(self)
            MappedArray object, with the indexing operations and affine
            matrix relating to the new sub-array.

        """
        index = expand_index(index, self.shape)
        new_shape = guess_shape(index, self.shape)
        if any(isinstance(idx, list) for idx in index) > 1:
            raise ValueError('List indices not currently supported '
                             '(otherwise we enter advanced indexing '
                             'territory and it becomes too complicated).')
        new = copy(self)
        new.shape = new_shape

        # compute new affine
        if self.affine is not None:
            spatial_shape = [
                sz for sz, msk in zip(self.shape, self.spatial) if msk
            ]
            spatial_index = [idx for idx in index if not is_newaxis(idx)]
            spatial_index = [
                idx for idx, msk in zip(spatial_index, self.spatial) if msk
            ]
            affine, _ = affine_sub(self.affine, spatial_shape,
                                   tuple(spatial_index))
        else:
            affine = None
        new.affine = affine

        # compute new slicer
        perm_shape = [self._shape[d] for d in self.permutation]
        new.slicer = compose_index(self.slicer, index, perm_shape)

        # compute new spatial mask
        spatial = []
        i = 0
        for idx in new.slicer:
            if is_newaxis(idx):
                spatial.append(False)
            else:
                # original axis
                if not is_droppedaxis(idx):
                    spatial.append(self._spatial[self.permutation[i]])
                i += 1
        new.spatial = tuple(spatial)

        return new