예제 #1
0
def _world_reslice(dat, mat, interpolation=1, vx=None):
    """Reslice image data to world space.

    Parameters
    ----------
    dat : (X0, Y0, Z0) tensor_like, dtype=float32
        Image data.
    mat : (4, 4) tensor_like, dtype=float64
        Affine matrix.
    interpolation : int, default=1 (linear)
        Interpolation order.
    vx : float | [float,] *3, optional
        Output voxel size.

    Returns
    -------
    dat : (X1, Y1, Z1) tensor_like, dtype=float32
        New image data.
    mat : (4, 4) tensor_like, dtype=float64
        New affine matrix.

    """
    device = dat.device
    # Get voxel size
    if vx is None:
        vx = voxel_size(mat).type(torch.float64).to(device)
    else:
        if not isinstance(vx, (list, tuple)):
            vx = (vx, ) * 3
        vx = torch.as_tensor(vx).type(torch.float64).to(device)
    # Get corners
    c = _get_corners_3d(dat.shape).type(torch.float64).to(device)
    c = c.t()
    # Corners in world space
    c_world = mat[:3, :4].mm(c)
    c_world[0, :] = -c_world[0, :]
    # Get bounding box
    mx = c_world.max(dim=1)[0].round()
    mn = c_world.min(dim=1)[0].round()
    # Compute output affine
    mat_mn = affine_matrix_classic(mn).type(torch.float64).to(device)
    mat_vx = torch.diag(
        torch.cat((vx, torch.ones(1, dtype=torch.float64, device=device))))
    mat_1 = affine_matrix_classic(
        -1 * torch.ones(3, dtype=torch.float64, device=device))
    mat_out = mat_mn.mm(mat_vx.mm(mat_1))
    # Comput output image dimensions
    dim_out = mat_out.inverse().mm(
        torch.cat((mx, torch.ones(1, dtype=torch.float64,
                                  device=device)))[:, None])
    dim_out = dim_out[:3].ceil().flatten().int().tolist()
    I = torch.diag(torch.ones(4, dtype=torch.float64, device=device))
    I[0, 0] = -I[0, 0]
    mat_out = I.mm(mat_out)
    # Compute mapping from output to input
    mat = mat_out.solve(mat)[0]
    # Reslice image data
    dat = _reslice_dat_3d(dat, mat, dim_out, interpolation=interpolation)

    return dat, mat_out
예제 #2
0
def _bb_atlas(name, fov, dtype=torch.float64, device='cpu'):
    """Bounding-box NITorch atlas data to specific field-of-view.

    Parameters
    ----------
    name : str
        Name of nitorch data, available are:
        * atlas_t1: MRI T1w intensity atlas, 1 mm resolution.
        * atlas_t2: MRI T2w intensity atlas, 1 mm resolution.
        * atlas_pd: MRI PDw intensity atlas, 1 mm resolution.
        * atlas_t1_mni: MRI T1w intensity atlas, in MNI space, 1 mm resolution.
        * atlas_t2_mni: MRI T2w intensity atlas, in MNI space, 1 mm resolution.
        * atlas_pd_mni: MRI PDw intensity atlas, in MNI space, 1 mm resolution.
    fov : str
        Field-of-view, specific to 'name':
        * 'atlas_t1' | 'atlas_t2' | 'atlas_pd':
            * 'brain': Head FOV.
            * 'head': Brain FOV.

    Returns
    ----------
    mat_mu : (4, 4) tensor, dtype=float64
        Output affine matrix.
    dim_mu : (3, ) tensor, dtype=float64
        Output dimensions.

    """
    # Get atlas information
    file_mu = map(fetch_data(name))
    dim_mu = file_mu.shape
    mat_mu = file_mu.affine.type(torch.float64).to(device)
    # Get bounding box
    o = [[0, 0, 0], [0, 0, 0]]
    if name in ['atlas_t1', 'atlas_t2', 'atlas_pd']:
        if fov == 'brain' or fov == 'head':
            o[0][0] = 18
            o[0][1] = 52
            o[0][2] = 120
            o[1][0] = 18
            o[1][1] = 48
            o[1][2] = 58
            if fov == 'head':
                o[0][2] = 25
    # Get bounding box
    bb = torch.tensor(
        [[1 + o[0][0], 1 + o[0][1], 1 + o[0][2]],
         [dim_mu[0] - o[1][0], dim_mu[1] - o[1][1], dim_mu[2] - o[1][2]]])
    bb = bb.type(torch.float64).to(device)
    # Output dimensions
    dim_mu = bb[1, ...] - bb[0, ...] + 1
    # Bounding-box atlas affine
    mat_bb = affine_matrix_classic(bb[0, ...] - 1)
    # Modulate atlas affine with bb affine
    mat_mu = mat_mu.mm(mat_bb)

    return mat_mu, dim_mu
예제 #3
0
def _imatrix(M):
    """Return the parameters for creating an affine transformation matrix.

    Args:
        mat (torch.tensor): Affine transformation matrix (4, 4).

    Returns:
        P (torch.tensor): Affine parameters (<=12).

    Authors:
        John Ashburner & Stefan Kiebel, as part of the SPM12 software.

    """
    device = M.device
    dtype = M.dtype
    one = torch.tensor(1.0, device=device, dtype=dtype)
    # Translations and Zooms
    R = M[:-1, :-1]
    C = cholesky(R.t().mm(R))
    C = C.t()
    d = torch.diag(C)
    P = torch.tensor(
        [M[0, 3], M[1, 3], M[2, 3], 0, 0, 0, d[0], d[1], d[2], 0, 0, 0],
        device=device,
        dtype=dtype)
    if R.det() < 0:  # Fix for -ve determinants
        P[6] = -P[6]
    # Shears
    C = lmdiv(torch.diag(torch.diag(C)), C)
    P[9] = C[0, 1]
    P[10] = C[0, 2]
    P[11] = C[1, 2]
    R0 = affine_matrix_classic(
        torch.tensor([0, 0, 0, 0, 0, 0, P[6], P[7], P[8], P[9], P[10],
                      P[11]])).to(device)
    R0 = R0[:-1, :-1]
    R1 = R.mm(R0.inverse())  # This just leaves rotations in matrix R1
    # Correct rounding errors
    rang = lambda x: torch.min(torch.max(x, -one), one)
    P[4] = torch.asin(rang(R1[0, 2]))
    if (torch.abs(P[4]) - pi / 2)**2 < 1e-9:
        P[3] = 0
        P[5] = torch.atan2(-rang(R1[1, 0]), rang(-R1[2, 0] / R1[0, 2]))
    else:
        c = torch.cos(P[4])
        P[3] = torch.atan2(rang(R1[1, 2] / c), rang(R1[2, 2] / c))
        P[5] = torch.atan2(rang(R1[0, 1] / c), rang(R1[0, 0] / c))

    return P
예제 #4
0
def _subvol(dat, mat, bb=None):
    """Extract a sub-volume.

    Parameters
    ----------
    dat : (X0, Y0, Z0) tensor_like
        Image volume.
    mat : (4, 4) tensor_like, dtype=float64
        Image affine matrix.
    bb : (2, 3) sequence, optional
        Bounding box.

    Returns
    ----------
    dat : (X1, Y1, Z1) tensor_like
        Image sub-volume.
    mat : (4, 4) tensor_like, dtype=float64
        Sub-volume affine matrix.

    """
    device = dat.device
    dim_in = dat.shape
    if bb is None:
        bb = torch.tensor([[1, 1, 1], dim_in],
                          dtype=torch.float64,
                          device=device)
    # Process bounding-box
    bb = bb.round()
    bb = bb.sort(dim=0)[0]
    bb[0, ...] = torch.max(bb[0, ...],
                           torch.ones(3, device=device, dtype=torch.float64))
    bb[1, ...] = torch.min(
        bb[1, ...], torch.tensor(dim_in, device=device, dtype=torch.float64))
    # Output dimensions
    dim_bb = bb[1, ...] - bb[0, ...] + 1
    # Bounding-box affine
    mat_bb = affine_matrix_classic(bb[0, ...] - 1)
    # mat_bb = matrix(bb[0, ...] - 1)
    # Output data
    dat = _reslice_dat_3d(dat,
                          mat_bb,
                          dim_bb,
                          interpolation='nearest',
                          bound='zero',
                          extrapolate=False)
    # Output affine
    mat = mat.mm(mat_bb)

    return dat, mat
예제 #5
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch size

        Other Parameters
        ----------------
        dim : int, optional
        device : torch.device, optional
        dtype : torch.dtype, optional

        Returns
        -------
        affine : (batch, dim+1, dim+1) tensor
            Affine matrix

        """
        dim = overload.get('dim', self.dim)
        dtype = overload.get('dtype', self.dtype)
        device = overload.get('device', self.device)
        backend = dict(dtype=dtype, device=device)

        # prepare sampler
        translation = self._make_sampler('translation', dim, **backend)
        rotation = self._make_sampler('rotation', dim, **backend)
        zoom = self._make_sampler('zoom', dim, **backend)
        shear = self._make_sampler('shear', dim, **backend)

        # sample parameters
        prm = torch.cat([
            translation.sample([batch]),
            rotation.sample([batch]).mul_(math.pi/180),
            zoom.sample([batch]),
            shear.sample([batch]),
        ], dim=-1)

        # generate affine matrix
        mat = affine_matrix_classic(prm, dim=dim)
        return mat
예제 #6
0
def _format_y(x, sett):
    """ Construct algorithm output struct. See _output() dataclass.

    Returns:
        y (_output()): Algorithm output struct(s).

    """
    one = torch.tensor(1.0, device=sett.device, dtype=torch.float64)
    vx_y = sett.vx
    if vx_y == 0:
        vx_y = None
    if vx_y is not None:
        if isinstance(vx_y, int):
            vx_y = float(vx_y)
        if isinstance(vx_y, float):
            vx_y = (vx_y,) * 3
        vx_y = torch.tensor(vx_y, dtype=torch.float64, device=sett.device)

    # Get all orientation matrices and dimensions
    all_mat, all_dim, all_vx = _all_mat_dim_vx(x, sett)
    N = all_mat.shape[0]  # Total number of observations

    if N == 1:
        # Disable unified rigid registration
        sett.unified_rigid = False
        sett.clean_fov = True

    # Check if all input images have the same fov/vx
    mat_same = True
    dim_same = True
    vx_same = True
    for n in range(1, N):
        mat_same = mat_same & \
            torch.equal(round(all_mat[n - 1, ...], 3), round(all_mat[n, ...], 3))
        dim_same = dim_same & \
            torch.equal(round(all_dim[n - 1, ...], 3), round(all_dim[n, ...], 3))
        vx_same = vx_same & \
            torch.equal(round(all_vx[n - 1, ...], 3), round(all_vx[n, ...], 3))

    # Decide if super-resolving and/or projection is necessary
    do_sr = True
    sett.do_proj = True
    if vx_y is None and ((N == 1) or vx_same):  # One image, voxel size not given
        vx_y = all_vx[0, ...]

    if vx_same and (torch.abs(all_vx[0, ...] - vx_y) < 1e-3).all():
        # All input images have same voxel size, and output voxel size is the also the same
        do_sr = False
        if mat_same and dim_same and not sett.unified_rigid:
            # All input images have the same FOV
            mat = all_mat[0, ...]
            dim = all_dim[0, ...]
            sett.do_proj = False

    if do_sr or sett.do_proj:
        # Get FOV of mean space
        mat, dim, vx_y = _mean_space(all_mat, all_dim, vx_y)

        if sett.crop:
            # Crop output to atlas field-of-view
            vx_y = voxel_size(mat)
            mat_mu, dim = _bb_atlas('atlas_t1',
                fov=sett.fov, dtype=torch.float64, device=sett.device)
            # Modulate atlas with voxel size
            mat_vx = torch.diag(torch.cat((
                vx_y, torch.ones(1, dtype=torch.float64, device=sett.device))))
            mat = mat_mu.mm(mat_vx)
            dim = mat_vx[:3, :3].inverse().mm(dim[:, None]).floor().squeeze()

        if sett.pow:
            # Ensure output image dimensions are compatible with encode/decode
            # architecture
            dim2 = ceil_pow(dim, p=2.0, l=2.0, mx=256)
            dim3 = ceil_pow(dim, p=2.0, l=3.0, mx=256)
            ndim = dim2
            ndim[dim3 < ndim] = dim3[dim3 < ndim]
            # Modulate output affine
            mat_bb = affine_matrix_classic(-((ndim - dim)/2).round())\
                .type(torch.float64).to(sett.device)
            mat = mat.mm(mat_bb)
            dim = ndim

    # Set method
    if do_sr:
        sett.method = 'super-resolution'
    else:
        sett.method = 'denoising'

    # Optimise even/odd scaling parameter?
    if sett.method == 'denoising' or (N == 1 and x[0][0].ct):
        sett.scaling = False

    dim = tuple(dim.int().tolist())
    _ = _print_info('mean-space', sett, dim, mat)

    # Assign output
    y = []
    for c in range(len(x)):
        y.append(_output())
        # Regularisation (lambda) for channel c
        mu_c = torch.zeros(len(x[c]), dtype=torch.float32, device=sett.device)
        for n in range(len(x[c])):
            mu_c[n] = x[c][n].mu
            if x[c][n].ct and sett.method == 'super-resolution':
                mu_c[n] /= 4
        y[c].lam0 = math.sqrt(1/len(x)) / torch.mean(mu_c)
        y[c].lam = math.sqrt(1/len(x))  / torch.mean(mu_c)  # To facilitate rescaling
        # Output image(s) dimension and orientation matrix
        y[c].dim = dim
        y[c].mat = mat.double().to(sett.device)

    return y, sett
예제 #7
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch size
        overload : dict
            All parameters defined at build time can be overridden at call time

        Returns
        -------
        affine : (batch, dim[+1], dim+1) tensor
            Velocity field

        """
        dim = overload.get('dim', self.dim)
        translation = make_list(overload.get('translation', self.translation))
        rotation = make_list(overload.get('rotation', self.rotation))
        zoom = make_list(overload.get('zoom', self.zoom))
        shear = make_list(overload.get('shear', self.shear))
        dtype = make_list(overload.get('dtype', self.dtype))
        device = make_list(overload.get('device', self.device))

        # compute dimension
        dim = dim or max(len(translation), len(rotation), len(zoom),
                         len(shear))
        translation = make_list(translation, dim)
        rotation = make_list(rotation, dim * (dim - 1) // 2)
        zoom = make_list(zoom, dim)
        shear = make_list(shear, dim * (dim - 1) // 2)

        # sample values if needed
        translation = [
            x([batch]) if callable(x) else self.default_translation([batch])
            if x is True else 0. if x is None or x is False else x
            for x in translation
        ]
        rotation = [
            x([batch]) if callable(x) else self.default_rotation([batch])
            if x is True else 0. if x is None or x is False else x
            for x in rotation
        ]
        zoom = [
            x([batch]) if callable(x) else self.default_zoom([batch])
            if x is True else 1. if x is None or x is False else x
            for x in zoom
        ]
        shear = [
            x([batch]) if callable(x) else self.default_shear([batch])
            if x is True else 0. if x is None or x is False else x
            for x in shear
        ]
        rotation = [x * math.pi / 180 for x in rotation]  # degree -> radian
        prm = [*translation, *rotation, *zoom, *shear]
        prm = [
            p.expand(batch) if torch.is_tensor(p) and p.shape[0] != batch else
            make_list(p, batch) if not torch.is_tensor(p) else p for p in prm
        ]

        prm = utils.as_tensor(prm)
        prm = prm.transpose(0, 1)

        # generate affine matrix
        mat = affine_matrix_classic(prm, dim=dim).\
            type(self.dtype).to(self.device)

        return mat
예제 #8
0
    def forward(self, prm, **overload):
        """

        Parameters
        ----------
        prm : (batch, nb_prm) tensor or list[tensor]
            Affine parameters, ordered as
            (*translations, *rotations, *zooms, *shears).
        overload : dict
            All parameters of the module can be overridden at call time.

        Returns
        -------
        affine : (batch, dim+1, dim+1) tensor
            Affine matrix

        """
        dim = overload.get('dim', self.dim)
        basis = overload.get('basis', self.basis)
        logzooms = overload.get('logzooms', self.logzooms)

        def checkdim(expected, got):
            if got != expected:
                raise ValueError('Expected {} parameters for group {}({}) but '
                                 'got {}.'.format(expected, basis, dim, got))

        nb_prm = prm.shape[-1]
        eps = core.constants.eps(prm.dtype)

        if basis == 'T':
            checkdim(dim, nb_prm)
        elif basis == 'SO':
            checkdim(dim*(dim-1)//2, nb_prm)
        elif basis == 'SE':
            checkdim(dim + dim*(dim-1)//2, nb_prm)
        elif basis == 'D':
            checkdim(dim + 1, nb_prm)
            translations = prm[..., :dim]
            zooms = prm[..., -1]
            zooms = zooms.expand([*zooms.shape, dim])
            zooms = zooms.exp() if logzooms else zooms.clamp_min(eps)
            prm = torch.cat((translations, zooms), dim=-1)
        elif basis == 'CSO':
            checkdim(dim + dim*(dim-1)//2 + 1, nb_prm)
            rigid = prm[..., :-1]
            zooms = prm[..., -1]
            zooms = zooms.expand([*zooms.shape, dim])
            zooms = zooms.exp() if logzooms else zooms.clamp_min(eps)
            prm = torch.cat((rigid, zooms), dim=-1)
        elif basis == 'GL+':
            checkdim((dim-1)*(dim+1), nb_prm)
            rigid = prm[..., :dim*(dim-1)//2]
            zooms = prm[..., dim*(dim-1)//2:(dim + dim*(dim-1)//2)]
            zooms = zooms.exp() if logzooms else zooms.clamp_min(eps)
            strides = prm[..., (dim + dim*(dim-1)//2):]
            prm = torch.cat((rigid, zooms, strides), dim=-1)
        elif basis == 'Aff+':
            checkdim(dim*(dim+1), nb_prm)
            rigid = prm[..., :(dim + dim*(dim-1)//2)]
            zooms = prm[..., (dim + dim*(dim-1)//2):(2*dim + dim*(dim-1)//2)]
            zooms = zooms.exp() if logzooms else zooms.clamp_min(eps)
            strides = prm[..., (2*dim + dim*(dim-1)//2):]
            prm = torch.cat((rigid, zooms, strides), dim=-1)
        else:
            raise ValueError(f'Unknown basis {basis}')

        return spatial.affine_matrix_classic(prm, dim=dim)
예제 #9
0
파일: spatial.py 프로젝트: balbasty/nitorch
    def forward(self, prm):
        """

        Parameters
        ----------
        prm : (batch, nb_prm) tensor or list[tensor]
            Affine parameters, ordered as
            (*translations, *rotations, *zooms, *shears).

        Returns
        -------
        affine : (batch, dim+1, dim+1) tensor
            Affine matrix

        """
        def checkdim(expected, got):
            if got != expected:
                raise ValueError(f'Expected {expected} parameters for '
                                 f'group {self.basis}({self.dim}) but '
                                 f'got {got}.')

        nb_prm = prm.shape[-1]
        eps = core.constants.eps(prm.dtype)

        if self.basis == 'T':
            checkdim(self.dim, nb_prm)
        elif self.basis == 'SO':
            checkdim(self.dim * (self.dim - 1) // 2, nb_prm)
        elif self.basis == 'SE':
            checkdim(self.dim + self.dim * (self.dim - 1) // 2, nb_prm)
        elif self.basis == 'D':
            checkdim(self.dim + 1, nb_prm)
            translations = prm[..., :self.dim]
            zooms = prm[..., -1]
            zooms = zooms.expand([*zooms.shape, self.dim])
            zooms = zooms.exp() if self.logzooms else zooms.clamp_min(eps)
            prm = torch.cat((translations, zooms), dim=-1)
        elif self.basis == 'CSO':
            checkdim(self.dim + self.dim * (self.dim - 1) // 2 + 1, nb_prm)
            rigid = prm[..., :-1]
            zooms = prm[..., -1]
            zooms = zooms.expand([*zooms.shape, self.dim])
            zooms = zooms.exp() if self.logzooms else zooms.clamp_min(eps)
            prm = torch.cat((rigid, zooms), dim=-1)
        elif self.basis == 'GL+':
            checkdim((self.dim - 1) * (self.dim + 1), nb_prm)
            rigid = prm[..., :self.dim * (self.dim - 1) // 2]
            zooms = prm[..., self.dim * (self.dim - 1) //
                        2:(self.dim + self.dim * (self.dim - 1) // 2)]
            zooms = zooms.exp() if self.logzooms else zooms.clamp_min(eps)
            strides = prm[..., (self.dim + self.dim * (self.dim - 1) // 2):]
            prm = torch.cat((rigid, zooms, strides), dim=-1)
        elif self.basis == 'Aff+':
            checkdim(self.dim * (self.dim + 1), nb_prm)
            rigid = prm[..., :(self.dim + self.dim * (self.dim - 1) // 2)]
            zooms = prm[..., (self.dim + self.dim *
                              (self.dim - 1) // 2):(2 * self.dim + self.dim *
                                                    (self.dim - 1) // 2)]
            zooms = zooms.exp() if self.logzooms else zooms.clamp_min(eps)
            strides = prm[...,
                          (2 * self.dim + self.dim * (self.dim - 1) // 2):]
            prm = torch.cat((rigid, zooms, strides), dim=-1)
        else:
            raise ValueError(f'Unknown basis {self.basis}')

        return spatial.affine_matrix_classic(prm, dim=self.dim)