Exemplo n.º 1
0
def _get_default_space(affines, shapes, space=None, bbox=None):
    """Get default visualisation space

    Parameters
    ----------
    affines : [sequence of] (4, 4) tensor_like
    shapes : [sequence of] (3,) tensor_like
    space : (4, 4) tensor_like, optional
    bbox : (2, 3) tensor_like, optional

    Returns
    -------
    space, bbox

    """

    affines, shapes = _get_default_native(affines, shapes)
    voxel_size = spatial.voxel_size(affines)
    voxel_size = voxel_size.min()

    if space is None:
        space = torch.eye(4)
        space[:-1, :-1] *= voxel_size
    voxel_size = spatial.voxel_size(space)

    if bbox is None:
        shapes = torch.as_tensor(shapes)
        mn, mx = spatial.compute_fov(space, affines, shapes)
    else:
        mn, mx = utils.as_tensor(bbox)
        mn /= voxel_size
        mx /= voxel_size

    return space, mn, mx
Exemplo n.º 2
0
 def build_from_target(target):
     """Compose all transformations, starting from the final orientation"""
     grid = spatial.affine_grid(target.affine.to(**backend), target.shape)
     for trf in reversed(options.transformations):
         if isinstance(trf, Linear):
             grid = spatial.affine_matvec(trf.affine.to(**backend), grid)
         else:
             mat = trf.affine.to(**backend)
             if trf.inv:
                 vx0 = spatial.voxel_size(mat)
                 vx1 = spatial.voxel_size(target.affine.to(**backend))
                 factor = vx0 / vx1
                 disp, mat = spatial.resize_grid(trf.dat[None],
                                                 factor,
                                                 affine=mat,
                                                 interpolation=trf.spline)
                 disp = spatial.grid_inv(disp[0], type='disp')
                 order = 1
             else:
                 disp = trf.dat
                 order = trf.spline
             imat = spatial.affine_inv(mat)
             grid = spatial.affine_matvec(imat, grid)
             grid += helpers.pull_grid(disp, grid, interpolation=order)
             grid = spatial.affine_matvec(mat, grid)
     return grid
Exemplo n.º 3
0
def _patch(patch, affine, shape, level):
    """Compute the patch size in voxels"""
    dim = affine.shape[-1] - 1
    patch = py.make_list(patch)
    unit = 'pct'
    if isinstance(patch[-1], str):
        *patch, unit = patch
    patch = py.make_list(patch, dim)
    unit = unit.lower()
    if unit[0] == 'v':  # voxels
        patch = [float(p) / 2**level for p in patch]
    elif unit in ('m', 'mm', 'cm', 'um'):  # assume RAS orientation
        factor = (1e-3 if unit == 'um' else
                  1e1 if unit == 'cm' else 1e3 if unit == 'm' else 1)
        affine_ras = spatial.affine_reorient(affine, layout='RAS')
        vx_ras = spatial.voxel_size(affine_ras).tolist()
        patch = [factor * p / v for p, v in zip(patch, vx_ras)]
        patch = _ras_to_layout(patch, affine)
    elif unit[0] in 'p%':  # percentage of shape
        patch = [0.01 * p * s for p, s in zip(patch, shape)]
    else:
        raise ValueError('Unknown patch unit:', unit)

    # round down to zero small patch sizes
    patch = [0 if p < 1e-3 else p for p in patch]
    return patch
Exemplo n.º 4
0
def _precond(x, y, rho, sett):
    """Compute CG preconditioner.

    """
    if len(x) != 1:
        raise ValueError(
            'CG pre-conditioning only supports one repeat per contrast.')
    # Parameters
    n = 0
    dm_y = y.dim
    lam = y.lam
    vx = voxel_size(y.mat).float()
    # tau*At(A(1))
    M = x[n].tau * _proj_apply(
        'AtA',
        torch.ones(dm_y, device=sett.device, dtype=torch.float32)[None, None,
                                                                  ...],
        x[n].po,
        method=sett.method,
        bound=sett.bound,
        interpolation=sett.interpolation)
    # + 2*rho*lam**2*sum(1/vx^2) (not lam*lam?)
    M += 2 * rho * lam**2 * vx.square().reciprocal().sum()
    M = M[0, 0, ...]
    # Return as lambda function
    precond = lambda x: x / M

    return precond
Exemplo n.º 5
0
def _to_gradient_magnitudes(dat, mat, scl):
    """ Compute squared gradient magnitudes (modulated with scaling and voxel size).

    OBS: Replaces the image data in dat.

    Parameters
    ----------
    dat : (X, Y, Z) tensor_like
        Image data.
    mat : (4, 4) tensor_like
        Affine matrix.
    scl : (N, ) tensor_like
        Gradient scaling parameter.

    Returns
    ----------
    dat : (X, Y, Z) tensor_like
        Squared gradient magnitudes.

    """
    # Get voxel size
    vx = voxel_size(mat)
    gr = scl*im_gradient(dat, vx=vx, which='forward', bound='zero')
    # Square gradients
    gr = torch.sum(gr**2, dim=0)
    dat = gr

    return dat
Exemplo n.º 6
0
def _reset_origin(dat, mat, interpolation):
    """Reset affine matrix.

    Parameters
    ----------
    dat : (X0, Y0, Z0) tensor_like, dtype=float32
        Image data.
    mat : (4, 4) tensor_like, dtype=float64
        Affine matrix.
    interpolation : int, default=1 (linear)
        Interpolation order.

    Returns
    -------
    dat : (X1, Y1, Z1) tensor_like, dtype=float32
        New image data.
    mat : (4, 4) tensor_like, dtype=float64
        New affine matrix.

    """
    device = dat.device
    # Reslice image data to world FOV
    dat, mat = _world_reslice(dat, mat, interpolation=interpolation)
    # Compute new, reset, affine matrix
    vx = voxel_size(mat)
    if mat[:3, :3].det() < 0:
        vx[0] = -vx[0]
    vx = vx.tolist()
    mat = affine_default(dat.shape, vx, dtype=torch.float64, device=device)

    return dat, mat
Exemplo n.º 7
0
def get_kernel(kernel, affine, shape, level):
    """Convert the provided kernel size (RAS mm or pct) to native voxels"""
    dim = affine.shape[-1] - 1
    kernel = py.make_list(kernel)
    unit = 'pct'
    if isinstance(kernel[-1], str):
        *kernel, unit = kernel
    kernel = py.make_list(kernel, dim)
    unit = unit.lower()
    if unit[0] == 'v':  # voxels
        kernel = [p / 2**level for p in kernel]
    elif unit in ('m', 'mm', 'cm', 'um'):  # assume RAS orientation
        factor = (1e-3 if unit == 'um' else
                  1e1 if unit == 'cm' else
                  1e3 if unit == 'm' else
                  1)
        affine_ras = spatial.affine_reorient(affine, layout='RAS')
        vx_ras = spatial.voxel_size(affine_ras).tolist()
        kernel = [factor * p / v for p, v in zip(kernel, vx_ras)]
        kernel = ras_to_layout(kernel, affine)
    elif unit[0] in 'p%':    # percentage of shape
        kernel = [0.01 * p * s for p, s in zip(kernel, shape)]
    else:
        raise ValueError('Unknown patch unit:', unit)

    # ensure patch size is an integer >= 2 (else, no gradients)
    kernel = list(map(lambda x: max(int(pymath.ceil(x)), 2), kernel))
    return kernel
Exemplo n.º 8
0
def _world_reslice(dat, mat, interpolation=1, vx=None):
    """Reslice image data to world space.

    Parameters
    ----------
    dat : (X0, Y0, Z0) tensor_like, dtype=float32
        Image data.
    mat : (4, 4) tensor_like, dtype=float64
        Affine matrix.
    interpolation : int, default=1 (linear)
        Interpolation order.
    vx : float | [float,] *3, optional
        Output voxel size.

    Returns
    -------
    dat : (X1, Y1, Z1) tensor_like, dtype=float32
        New image data.
    mat : (4, 4) tensor_like, dtype=float64
        New affine matrix.

    """
    device = dat.device
    # Get voxel size
    if vx is None:
        vx = voxel_size(mat).type(torch.float64).to(device)
    else:
        if not isinstance(vx, (list, tuple)):
            vx = (vx, ) * 3
        vx = torch.as_tensor(vx).type(torch.float64).to(device)
    # Get corners
    c = _get_corners_3d(dat.shape).type(torch.float64).to(device)
    c = c.t()
    # Corners in world space
    c_world = mat[:3, :4].mm(c)
    c_world[0, :] = -c_world[0, :]
    # Get bounding box
    mx = c_world.max(dim=1)[0].round()
    mn = c_world.min(dim=1)[0].round()
    # Compute output affine
    mat_mn = affine_matrix_classic(mn).type(torch.float64).to(device)
    mat_vx = torch.diag(
        torch.cat((vx, torch.ones(1, dtype=torch.float64, device=device))))
    mat_1 = affine_matrix_classic(
        -1 * torch.ones(3, dtype=torch.float64, device=device))
    mat_out = mat_mn.mm(mat_vx.mm(mat_1))
    # Comput output image dimensions
    dim_out = mat_out.inverse().mm(
        torch.cat((mx, torch.ones(1, dtype=torch.float64,
                                  device=device)))[:, None])
    dim_out = dim_out[:3].ceil().flatten().int().tolist()
    I = torch.diag(torch.ones(4, dtype=torch.float64, device=device))
    I[0, 0] = -I[0, 0]
    mat_out = I.mm(mat_out)
    # Compute mapping from output to input
    mat = mat_out.solve(mat)[0]
    # Reslice image data
    dat = _reslice_dat_3d(dat, mat, dim_out, interpolation=interpolation)

    return dat, mat_out
Exemplo n.º 9
0
def _resample_inplane(x, sett):
    """Force in-plane resolution of observed data to be greater or equal to recon vx.
    """
    if sett.force_inplane_res and sett.max_iter > 0:
        I = torch.eye(4, device=sett.device, dtype=torch.float64)
        for c in range(len(x)):
            for n in range(len(x[c])):
                # get image data
                dat = x[c][n].dat[None, None, ...]
                mat_x = x[c][n].mat
                dim_x = torch.as_tensor(x[c][n].dim, device=sett.device, dtype=torch.float64)
                vx_x = voxel_size(mat_x)
                # make grid
                D = I.clone()
                for i in range(3):
                    D[i, i] = sett.vx / vx_x[i]
                    if D[i, i] < 1.0:
                        D[i, i] = 1
                if float((I - D).abs().sum()) < 1e-4:
                    continue
                mat_x = mat_x.matmul(D)
                dim_x = D[:3, :3].inverse().mm(dim_x[:, None]).floor().squeeze().cpu().int().tolist()
                grid = affine_grid(D.type(dat.dtype), dim_x)
                # resample
                dat = grid_pull(dat, grid[None, ...], bound='zero', extrapolate=False, interpolation=0)
                # do label
                if x[c][n].label is not None:
                    x[c][n].label[0] = _warp_label(x[c][n].label[0], grid)
                # assign
                x[c][n].dat = dat[0, 0, ...]
                x[c][n].mat = mat_x
                x[c][n].dim = dim_x

    return x
Exemplo n.º 10
0
def set_affine(header, affine, shape=None):
    if torch.is_tensor(affine):
        affine = affine.detach().cpu()
    affine = np.asanyarray(affine)
    vx = np.asanyarray(voxel_size(affine))
    vx0 = header.get_zooms()
    vx = [vx[i] if i < len(vx) else vx0[i] for i in range(len(vx0))]
    header.set_zooms(vx)
    if isinstance(header, MGHHeader):
        if shape is None:
            warn('Cannot set the affine of a MGH file without '
                 'knowing the data shape', RuntimeWarning)
        elif affine.shape not in ((3, 4), (4, 4)):
            raise ValueError('Expected a (3, 4) or (4, 4) affine matrix. '
                             'Got {}'.format(affine.shape))
        else:
            Mdc = affine[:3, :3] / vx
            shape = np.asarray(shape[:3])
            c_ras = affine.dot(np.hstack((shape / 2.0, [1])))[:3]

            # Assign after we've had a chance to raise exceptions
            header['delta'] = vx
            header['Mdc'] = Mdc.T
            header['Pxyz_c'] = c_ras
    elif isinstance(header, Nifti1Header):
        header.set_sform(affine)
        header.set_qform(affine)
    elif isinstance(header, Spm99AnalyzeHeader):
        header.set_origin_from_affine(affine)
    else:
        warn('Format {} does not accept orientation matrices. '
             'It will be discarded.'.format(type(header).__name__),
             RuntimeWarning)
    return header
Exemplo n.º 11
0
def _autoreg(argv=None):
    """Autograd Registration

    This is a command-line utility.
    """

    # parse arguments
    argv = argv or list(sys.argv)
    options = parse(list(argv))
    if not options:
        return

    # add a couple of defaults
    for trf in options.transformations:
        if isinstance(trf, struct.NonLinear) and not trf.losses:
            trf.losses.append(struct.AbsoluteLoss(factor=0.0001))
            trf.losses.append(struct.MembraneLoss(factor=0.001))
            trf.losses.append(struct.BendingLoss(factor=0.2))
            trf.losses.append(struct.LinearElasticLoss(factor=(0.05, 0.2)))
        trf.losses = [collapse_losses(trf.losses)]
    if not options.optimizers:
        options.optimizers.append(struct.Adam())

    options.propagate_defaults()
    options.read_info()
    options.propagate_filenames()

    if options.verbose >= 2:
        print(repr(options))

    load_data(options)
    load_transforms(options)

    print('Losses:')
    for loss in options.losses:
        print(f' - {loss.name}')
        for f, m in zip(loss.fixed.dat, loss.moving.dat):
            print(f'   -| {list(m[0].shape)}, {spatial.voxel_size(m[1]).tolist()}')
            print(f'   -> {list(f[0].shape)}, {spatial.voxel_size(f[1]).tolist()}')
    print('Transforms')
    for trf in options.transformations:
        print(f' - {trf.name}')
        if isinstance(trf, struct.NonLinear):
            pyramid0 = trf.pyramid[-1]
            for pyramid in reversed(trf.pyramid):
                factor = 2**(pyramid0 - pyramid)
                shape = [s*factor for s in trf.dat.shape]
                vx = spatial.voxel_size(trf.affine) / factor
                print(f'   - {list(shape)}, {vx.tolist()}')

    while not all_optimized(options):
        add_freedom(options)
        init_optimizers(options)
        optimize(options)

    free_data(options)
    detach_transforms(options)
    write_transforms(options)
    write_data(options)
Exemplo n.º 12
0
def _nonlin_rls(maps, lam=1., norm='jtv'):
    """Update the (L1) weights.

    Parameters
    ----------
    map : (P, *shape) ParameterMaps
        Parameter map
    lam : float or (P,) sequence[float], default=1
        Regularisation factor
    norm : {'tv', 'jtv'}, default='jtv'

    Returns
    -------
    rls : ([P], *shape) tensor
        Weights from the reweighted least squares scheme
    """

    if norm not in ('tv', 'jtv', '__internal__'):
        return None

    if isinstance(maps, ParameterMap):
        # single map
        # this should only be an internal call
        # -> we return the squared gradient map
        assert norm == '__internal__'
        vx = spatial.voxel_size(maps.affine)
        grad_fwd = spatial.diff(maps.fdata(),
                                dim=[0, 1, 2],
                                voxel_size=vx,
                                side='f')
        grad_bwd = spatial.diff(maps.fdata(),
                                dim=[0, 1, 2],
                                voxel_size=vx,
                                side='b')

        grad = grad_fwd.square_().sum(-1)
        grad += grad_bwd.square_().sum(-1)
        grad *= lam / 2.  # average across sides (2)
        return grad

    # multiple maps

    if norm == 'tv':
        rls = []
        for map, l in zip(maps, lam):
            rls1 = _nonlin_rls(map, l, '__internal__')
            rls1 = rls1.sqrt_()
            rls.append(rls1)
        return torch.stack(rls, dim=0)
    else:
        assert norm == 'jtv'
        rls = 0
        for map, l in zip(maps, lam):
            rls += _nonlin_rls(map, l, '__internal__')
        rls = rls.sqrt_()

    return rls
Exemplo n.º 13
0
def upsample_vel(v, aff_in, aff_out, shape, readout):
    """
    Upsample a 1D displacement field (by a potentially non-integer factor) using
    second order spline interpolation.
    Scales the displacement field appropriately to take into account the
    change of voxel size.
    """
    if v.shape == shape:
        return v
    vx_down = spatial.voxel_size(aff_in)
    vx_down = vx_down[readout]
    vx_up = spatial.voxel_size(aff_out)[readout]
    factor = vx_down / vx_up
    v = spatial.reslice(v, aff_in, aff_out, shape,
                        bound='dct2', interpolation=2, prefilter=True,
                        extrapolate=True)
    v *= factor
    return v
Exemplo n.º 14
0
    def forward(self, x, affine=None):
        """

        Parameters
        ----------
        x : (X, Y, Z) tensor or str
        affine : (4, 4) tensor, optional

        Returns
        -------
        seg : (32, oX, oY, oZ) tensor
            Segmentation
        resliced : (oX, oY, oZ) tensor
            Input resliced to 1 mm RAS
        affine : (4, 4) tensor
            Output orientation matrix

        """
        if self.verbose:
            print('Preprocessing... ', end='', flush=True)
        if isinstance(x, str):
            x = io.map(x)
        if isinstance(x, io.MappedArray):
            if affine is None:
                affine = x.affine
                x = x.fdata()
                x = x.reshape(x.shape[:3])
        x = SynthPreproc.addnoise(x)
        if affine is not None:
            affine, x = spatial.affine_reorient(affine, x, 'RAS')
            vx = spatial.voxel_size(affine)
            fwhm = 0.25 * vx.reciprocal()
            fwhm[vx > 1] = 0
            x = spatial.smooth(x, fwhm=fwhm.tolist(), dim=3)
            x, affine = spatial.resize(x[None, None],
                                       vx.tolist(),
                                       affine=affine)
            x = x[0, 0]
        oshape = x.shape
        x, crop = SynthPreproc.crop(x)
        x = SynthPreproc.preproc(x)[None, None]
        if self.verbose:
            print('done.', flush=True)
            print('Segmenting... ', end='', flush=True)
        s, x = super().forward(x)[0], x[0, 0]
        if self.verbose:
            print('done.', flush=True)
            print('Postprocessing... ', end='', flush=True)
        s = self.relabel(s.argmax(0))
        x = SynthPreproc.pad(x, oshape, crop)
        s = SynthPreproc.pad(s, oshape, crop)
        if self.verbose:
            print('done.', flush=True)
        return s, x, affine
Exemplo n.º 15
0
    def resize(self):
        affine, shape = spatial.affine_resize(self.affine0, self.shape0,
                                              1 / (2**(self.level - 1)))

        scl0 = spatial.voxel_size(self.affine0).prod()
        scl = spatial.voxel_size(affine).prod() / scl0
        self.lam_scale = scl

        for map in self.maps:
            map.volume = spatial.resize(map.volume[None, None, ...],
                                        shape=shape)[0, 0]
            map.affine = affine
        self.maps.affine = affine
        if self.rls is not None:
            if self.rls.dim() == len(shape):
                self.rls = spatial.resize(self.rls[None, None], hape=shape)[0,
                                                                            0]
            else:
                self.rls = spatial.resize(self.rls[None], shape=shape)[0]
            self.nll['rls'] = self.rls.reciprocal().sum(dtype=torch.double)
Exemplo n.º 16
0
def set_voxel_size(header, vx, shape=None):
    vx0 = header.get_zooms()
    nb_dim = max(len(vx0), len(vx))
    vx = [vx[i] if i < len(vx) else vx0[i] for i in range(nb_dim)]
    header.set_zooms(vx)
    aff = torch.as_tensor(header.get_best_affine())
    vx = torch.as_tensor(vx, dtype=aff.dtype, device=aff.device)
    vx0 = voxel_size(aff)
    aff[:-1,:] *= vx[:, None] / vx0[:, None]
    header = set_affine(header, aff, shape)
    return header
Exemplo n.º 17
0
class SpatialTensor:
    """Base class for tensors with an orientation"""
    def __init__(self, dat, affine=None, dim=None, **backend):
        """
        Parameters
        ----------
        dat : ([C], *spatial) tensor
        affine : tensor, optional
        dim : int, default=`dat.dim() - 1`
        **backend : dtype, device
        """
        if isinstance(dat, str):
            dat = io.map(dat)[None]
        if isinstance(dat, io.MappedArray):
            if affine is None:
                affine = dat.affine
            dat = dat.fdata(rand=True, **backend)
        self.dim = dim or dat.dim() - 1
        self.dat = dat
        if affine is None:
            affine = spatial.affine_default(self.shape, **utils.backend(dat))
        self.affine = affine.to(utils.backend(self.dat)['device'])

    def to(self, *args, **kwargs):
        return copy.copy(self).to_(*args, **kwargs)

    def to_(self, *args, **kwargs):
        self.dat = self.dat.to(*args, **kwargs)
        self.affine = self.affine.to(*args, **kwargs)
        return self

    voxel_size = property(lambda self: spatial.voxel_size(self.affine))
    shape = property(lambda self: self.dat.shape[-self.dim:])
    dtype = property(lambda self: self.dat.dtype)
    device = property(lambda self: self.dat.device)

    def _prm_as_str(self):
        s = [f'shape={list(self.shape)}']
        v = [f'{vx:.2g}' for vx in self.voxel_size.tolist()]
        v = ', '.join(v)
        s += [f'voxel_size=[{v}]']
        if self.dtype != torch.float32:
            s += [f'dtype={self.dtype}']
        if self.device.type != 'cpu':
            s +=[f'device={self.device}']
        return s

    def __repr__(self):
        s = ', '.join(self._prm_as_str())
        s = f'{self.__class__.__name__}({s})'
        return s

    __str__ = __repr__
Exemplo n.º 18
0
 def resize(cls, x, affine, target_vx=1):
     target_vx = utils.make_vector(target_vx, x.dim(),
                                   **utils.backend(affine))
     vx = spatial.voxel_size(affine)
     factor = vx / target_vx
     fwhm = 0.25 * factor.reciprocal()
     fwhm[factor > 1] = 0
     x = spatial.smooth(x, fwhm=fwhm.tolist(), dim=3)
     x, affine = spatial.resize(x[None, None],
                                factor.tolist(),
                                affine=affine)
     x = x[0, 0]
     return x, affine
Exemplo n.º 19
0
def downsample(x, aff_in, vx_out):
    """
    Downsample an image (by an integer factor) to approximately
    match a target voxel size
    """
    vx_in = spatial.voxel_size(aff_in)
    dim = len(vx_in)
    vx_out = utils.make_vector(vx_out, dim)
    factor = (vx_out / vx_in).clamp_min(1).floor().long()
    if (factor == 1).all():
        return x, aff_in
    factor = factor.tolist()
    x, aff_out = spatial.pool(dim, x, factor, affine=aff_in)
    return x, aff_out
Exemplo n.º 20
0
 def space(self, value):
     self._space = value
     if torch.is_tensor(value):
         if value.shape != (4, 4):
             raise ValueError('Expected 4x4 matrix')
         self._space_matrix = value
     elif isinstance(value, int):
         affines = [image.affine for image in self.images]
         self._space_matrix = affines[value]
     else:
         if value is not None:
             raise ValueError('Expected a 4x4 matrix or an int or None')
         affines = [image.affine for image in self.images]
         voxel_size = spatial.voxel_size(utils.as_tensor(affines))
         voxel_size = voxel_size.min()
         self._space_matrix = torch.eye(4)
         self._space_matrix[:-1, :-1] *= voxel_size
Exemplo n.º 21
0
 def load(self, x, affine=None):
     if isinstance(x, str):
         x = io.map(x)
     if isinstance(x, io.MappedArray):
         if affine is None:
             affine = x.affine
             x = x.fdata()
             x = x.reshape(x.shape[:3])
     affine_original = affine
     x_original = x.shape
     if affine is not None:
         affine, x = spatial.affine_reorient(affine, x, 'RAS')
         vx = spatial.voxel_size(affine)
         x, affine = spatial.resize(x[None, None],
                                    vx.tolist(),
                                    affine=affine)
         x = x[0, 0]
     return x, affine, x_original, affine_original
Exemplo n.º 22
0
def read_info(options):
    """Load affine transforms and space info of other volumes"""
    def read_file(fname):
        o = struct.FileWithInfo()
        o.fname = fname
        o.dir = os.path.dirname(fname) or '.'
        o.base = os.path.basename(fname)
        o.base, o.ext = os.path.splitext(o.base)
        if o.ext in ('.gz', '.bz2'):
            zext = o.ext
            o.base, o.ext = os.path.splitext(o.base)
            o.ext += zext
        f = io.volumes.map(fname)
        o.float = nitype(f.dtype).is_floating_point
        o.shape = squeeze_to_nd(f.shape, dim=3, channels=1)
        o.channels = o.shape[-1]
        o.shape = o.shape[:3]
        o.affine = f.affine.float()
        return o

    def read_affine(fname):
        mat = io.transforms.loadf(fname).float()
        return squeeze_to_nd(mat, 0, 2)

    def read_field(fname):
        f = io.volumes.map(fname)
        return f.affine.float(), f.shape[:3]

    options.files = [read_file(file) for file in options.files]
    for trf in options.transformations:
        if isinstance(trf, struct.Linear):
            trf.affine = read_affine(trf.file)
        else:
            trf.affine, trf.shape = read_field(trf.file)
    if options.target:
        options.target = read_file(options.target)
        if options.voxel_size:
            options.voxel_size = utils.make_vector(
                options.voxel_size, 3, dtype=options.target.affine.dtype)
            factor = spatial.voxel_size(
                options.target.affine) / options.voxel_size
            options.target.affine, options.target.shape = \
                spatial.affine_resize(options.target.affine, options.target.shape,
                                      factor=factor, anchor='f')
Exemplo n.º 23
0
def _smooth_for_reg(dat, mat, samp):
    """Smoothing for image registration. FWHM is computed from voxel size
       and sub-sampling amount.

    Parameters
    ----------
    dat : (X, Y, Z) tensor_like
        3D image volume.
    mat : (4, 4) tensor_like
        Affine matrix.
    samp : float
        Amount of sub-sampling (in mm).

    Returns
    -------
    dat : (Nx, Ny, Nz) tensor_like
        Smoothed 3D image volume.

    """
    if samp <= 0:
        return dat
    samp = torch.tensor((samp, ) * 3, dtype=dat.dtype, device=dat.device)
    # Make smoothing kernel
    vx = voxel_size(mat).to(dat.device).type(dat.dtype)
    fwhm = torch.sqrt(
        torch.max(samp**2 - vx**2,
                  torch.zeros(3, device=dat.device, dtype=dat.dtype))) / vx
    smo = smooth(('gauss', ) * 3,
                 fwhm=fwhm,
                 device=dat.device,
                 dtype=dat.dtype,
                 sep=True)
    # Padding amount for subsequent convolution
    size_pad = (smo[0].shape[2], smo[1].shape[3], smo[2].shape[4])
    size_pad = (torch.tensor(size_pad) - 1) // 2
    size_pad = tuple(size_pad.int().tolist())
    # Smooth deformation with Gaussian kernel (by separable convolution)
    dat = pad(dat, size_pad, side='both')
    dat = dat[None, None, ...]
    dat = F.conv3d(dat, smo[0])
    dat = F.conv3d(dat, smo[1])
    dat = F.conv3d(dat, smo[2])[0, 0, ...]

    return dat
Exemplo n.º 24
0
 def slice_to(self, stack, cache_result=False, recompute=True):
     aff = self.exp(cache_result=cache_result, recompute=recompute)
     if recompute or not hasattr(self, '_sliced'):
         aff = spatial.affine_matmul(aff, self.affine)
         aff_reorient = spatial.affine_reorient(self.affine, self.shape, stack.layout)
         aff = spatial.affine_lmdiv(aff_reorient, aff)
         aff = spatial.affine_grid(aff, self.shape)
         sliced = spatial.grid_pull(self.dat, aff, bound=self.bound,
                                    extrapolate=self.extrapolate)
         fwhm = [0] * self.dim
         fwhm[-1] = stack.slice_width / spatial.voxel_size(aff_reorient)[-1]
         sliced = spatial.smooth(sliced, fwhm, dim=self.dim, bound=self.bound)
         slices = []
         for stack_slice in stack.slices:
             aff = spatial.affine_matmul(stack.affine, )
             aff = spatial.affine_lmdiv(aff_reorient, )
     if cache_result:
         self._sliced = sliced
     return sliced
Exemplo n.º 25
0
def _estimate_hyperpar(x, sett):
    """ Estimate noise precision (tau) and mean brain
        intensity (mu) of each observed image.

    Args:
        x (_input()): Input data.

    Returns:
        tau (list): List of C torch.tensor(float) with noise precision of each MR image.
        lam (torch.tensor(float)): The parameter lambda (1, C).

    """
    # Print info to screen
    t0 = _print_info('hyper_par', sett)
    # Total number of observations
    N = sum([len(xn) for xn in x])
    # Do estimation
    cnt = 0
    for c in range(len(x)):
        for n in range(len(x[c])):
            # Get data
            dat = x[c][n].dat
            if x[c][n].ct:
                # Estimate noise sd from estimate of FWHM
                sd_bg = estimate_fwhm(dat, voxel_size(x[c][n].mat), mn=20, mx=50)[1]
                mu_bg = torch.tensor(0.0, device=dat.device, dtype=dat.dtype)
                mu_fg = torch.tensor(4096, device=dat.device, dtype=dat.dtype)
            else:
                # Get noise and foreground statistics
                sd_bg, sd_fg, mu_bg, mu_fg = estimate_noise(dat, num_class=2, show_fit=sett.show_hyperpar,
                                                            fig_num=100 + cnt)
                mu_bg = torch.tensor(0.0, device=dat.device, dtype=dat.dtype)
            # Set values
            x[c][n].sd = sd_bg.float()
            x[c][n].tau = 1 / sd_bg.float() ** 2
            x[c][n].mu = torch.abs(mu_fg.float() - mu_bg.float())
            cnt += 1

    # Print info to screen
    _print_info('hyper_par', sett, x, t0)

    return x
Exemplo n.º 26
0
def _crop_y(y, sett):
    """ Crop output images FOV to a fixed dimension

    Args:
        y (_output()): _output data.

    Returns:
        y (_output()): Cropped output data.

    """
    if not sett.crop:
        return y
    device = sett.device
    # Output image information
    mat_y = y[0].mat
    vx_y = voxel_size(mat_y)
    # Define cropped FOV
    mat_mu, dim_mu = _bb_atlas('atlas_t1',
        fov=sett.fov, dtype=torch.float64, device=device)
    # Modulate atlas with voxel size
    mat_vx = torch.diag(torch.cat((
        vx_y, torch.ones(1, dtype=torch.float64, device=device))))
    mat_mu = mat_mu.mm(mat_vx)
    dim_mu = mat_vx[:3, :3].inverse().mm(dim_mu[:, None]).floor().squeeze()
    # Make output grid
    M = mat_mu.solve(mat_y)[0].type(y[0].dat.dtype)
    grid = affine_grid(M, dim_mu)[None, ...]
    # Crop
    for c in range(len(y)):
        y[c].dat = grid_pull(y[c].dat[None, None, ...], grid,
                             bound='zero', extrapolate=False,
                             interpolation=0)[0, 0, ...]
        # Do labels?
        if y[c].label is not None:
            y[c].label = grid_pull(y[c].label[None, None, ...], grid,
                                   bound='zero', extrapolate=False,
                                   interpolation=0)[0, 0, ...]
        y[c].mat = mat_mu
        y[c].dim = tuple(dim_mu.int().tolist())

    return y
Exemplo n.º 27
0
def _cli(args):
    """Command-line interface for `smooth` without exception handling"""
    args = args or sys.argv[1:]

    options = parser(args)
    if options.help:
        print(help)
        return

    fwhm = options.fwhm
    unit = 'mm'
    if isinstance(fwhm[-1], str):
        *fwhm, unit = fwhm
    fwhm = make_list(fwhm, 3)

    options.output = make_list(options.output, len(options.files))
    for fname, ofname in zip(options.files, options.output):
        f = io.map(fname)
        vx = voxel_size(f.affine).tolist()
        dim = len(vx)
        if unit == 'mm':
            fwhm1 = [f / v for f, v in zip(fwhm, vx)]
        else:
            fwhm1 = fwhm[:len(vx)]

        dat = f.fdata()
        dat = movedim_front2back(dat, dim)
        dat = smooth(dat,
                     type=options.method,
                     fwhm=fwhm1,
                     basis=options.basis,
                     bound=options.padding,
                     dim=dim)
        dat = movedim_back2front(dat, dim)

        folder, base, ext = fileparts(fname)
        ofname = ofname.format(dir=folder or '.',
                               base=base,
                               ext=ext,
                               sep=os.path.sep)
        io.savef(dat, ofname, like=f)
Exemplo n.º 28
0
def _compute_nll(x, y, sett, rho, sum_dtype=torch.float64):
    """ Compute negative model log-likelihood.

    Args:
        rho (torch.Tensor): ADMM step size.
        sum_dtype (torch.dtype): Defaults to torch.float64.

    Returns:
        nll_yx (torch.tensor()): Negative log-posterior
        nll_xy (torch.tensor()): Negative log-likelihood.
        nll_y (torch.tensor()): Negative log-prior.

    """
    vx_y = voxel_size(y[0].mat).float()
    nll_xy = torch.tensor(0, device=sett.device, dtype=torch.float64)
    for c in range(len(x)):
        # Neg. log-likelihood term
        for n in range(len(x[c])):
            msk = x[c][n].dat != 0
            Ay = _proj('A',
                       y[c].dat,
                       x[c],
                       y[c],
                       n=n,
                       method=sett.method,
                       do=sett.do_proj,
                       bound=sett.bound,
                       interpolation=sett.interpolation)
            nll_xy += 0.5 * x[c][n].tau * torch.sum(
                (x[c][n].dat[msk] - Ay[msk])**2, dtype=sum_dtype)
        # Neg. log-prior term
        Dy = y[c].lam * im_gradient(
            y[c].dat, vx=vx_y, bound=sett.bound, which=sett.diff)
        if c > 0:
            nll_y += torch.sum(Dy**2, dim=0)
        else:
            nll_y = torch.sum(Dy**2, dim=0)

    nll_y = torch.sum(torch.sqrt(nll_y), dtype=sum_dtype)

    return nll_xy + nll_y, nll_xy, nll_y
Exemplo n.º 29
0
def _all_mat_dim_vx(x, sett):
    """ Get all images affine matrices, dimensions and voxel sizes (as numpy arrays).

    Returns:
        all_mat (torch.tensor): Image orientation matrices (N, 4, 4).
        Dim (torch.tensor): Image dimensions (N, 3).
        all_vx (torch.tensor): Image voxel sizes (N, 3).

    """
    N = sum([len(xn) for xn in x])
    all_mat = torch.zeros((N, 4, 4), device=sett.device, dtype=torch.float64)
    all_dim = torch.zeros((N, 3), device=sett.device, dtype=torch.float64)
    all_vx = torch.zeros((N, 3), device=sett.device, dtype=torch.float64)

    cnt = 0
    for c in range(len(x)):
        for n in range(len(x[c])):
            all_mat[cnt, ...] = x[c][n].mat
            all_dim[cnt, ...] = torch.tensor(x[c][n].dim, 
                                         device=sett.device, dtype=torch.float64)
            all_vx[cnt, ...] = voxel_size(x[c][n].mat)
            cnt += 1

    return all_mat, all_dim, all_vx
Exemplo n.º 30
0
 def __repr__(self):
     vx = spatial.voxel_size(self.affine).tolist()
     return f'{type(self).__name__}(shape={self.shape}, vx={vx})'