Exemplo n.º 1
0
def _smooth_for_reg(dat, mat, samp):
    """Smoothing for image registration. FWHM is computed from voxel size
       and sub-sampling amount.

    Parameters
    ----------
    dat : (X, Y, Z) tensor_like
        3D image volume.
    mat : (4, 4) tensor_like
        Affine matrix.
    samp : float
        Amount of sub-sampling (in mm).

    Returns
    -------
    dat : (Nx, Ny, Nz) tensor_like
        Smoothed 3D image volume.

    """
    if samp <= 0:
        return dat
    samp = torch.tensor((samp, ) * 3, dtype=dat.dtype, device=dat.device)
    # Make smoothing kernel
    vx = voxel_size(mat).to(dat.device).type(dat.dtype)
    fwhm = torch.sqrt(
        torch.max(samp**2 - vx**2,
                  torch.zeros(3, device=dat.device, dtype=dat.dtype))) / vx
    smo = smooth(('gauss', ) * 3,
                 fwhm=fwhm,
                 device=dat.device,
                 dtype=dat.dtype,
                 sep=True)
    # Padding amount for subsequent convolution
    size_pad = (smo[0].shape[2], smo[1].shape[3], smo[2].shape[4])
    size_pad = (torch.tensor(size_pad) - 1) // 2
    size_pad = tuple(size_pad.int().tolist())
    # Smooth deformation with Gaussian kernel (by separable convolution)
    dat = pad(dat, size_pad, side='both')
    dat = dat[None, None, ...]
    dat = F.conv3d(dat, smo[0])
    dat = F.conv3d(dat, smo[1])
    dat = F.conv3d(dat, smo[2])[0, 0, ...]

    return dat
Exemplo n.º 2
0
    def backward2(self, h, x, w=None, min=None, max=None):
        """

        Parameters
        ----------
        h : (..., *bins, [*bins]) tensor
        x : (..., n, 2) tensor
        w : (..., n) tensor, optional
        min : (...) tensor_like, optional
        max : (...) tensor_like, optional

        Returns
        -------
        h : (..., n, 2) tensor

        """
        backend = dict(dtype=x.dtype, device=x.device)
        n = x.shape[-2]
        xbatch = x.shape[:-2]
        if w is not None:
            _, w = torch.broadcast_tensors(x[..., 0], w)
            batch = w.shape[:-1]
            x = x.expand([*batch, *x.shape[-2:]])
            w = w.reshape([-1, n])
        else:
            batch = xbatch
        x = x.reshape([-1, n, 2])

        if h.shape[:-2] == batch:
            is_diag = True
        elif h.shape[:-4] == batch:
            is_diag = False
        else:
            raise ValueError('Don\'t know what to do with that shape')

        if min is None:
            min = x.min(-2, keepdim=True).values
        else:
            min = torch.as_tensor(min,
                                  **backend).expand([*xbatch,
                                                     2]).reshape([-1, 1, 2])
        if max is None:
            max = x.max(-2, keepdim=True).values
        else:
            max = torch.as_tensor(max,
                                  **backend).expand([*xbatch,
                                                     2]).reshape([-1, 1, 2])

        x = x.clone()
        bins = torch.as_tensor(self.bins, **backend)
        x = x.mul_(bins / (max - min)).add_(bins / (1 - max / min)).sub_(0.5)
        min = min.reshape([*xbatch, 2])
        max = max.reshape([*xbatch, 2])

        if is_diag:
            h = h.reshape([-1, *self.bins])
        else:
            h = h.reshape([-1, *self.bins, *self.bins])

        # smooth backward
        if any(self.fwhm):
            ker = kernels.smooth(fwhm=self.fwhm)
            if is_diag:
                ker = [k.square_() for k in ker]
                h = smooth(h, kernel=ker, bound=self.bound, dim=2)
            else:
                h = smooth(h, kernel=ker, bound=self.bound, dim=2)
                h = h.transpose(-4, -2).transpose(-3, -1)
                h = smooth(h, kernel=ker, bound=self.bound, dim=2)
                h = h.transpose(-4, -2).transpose(-3, -1)

        # push data into the histogram
        h = _jhistc_backward2(h, x, w, self.order, self.bound,
                              self.extrapolate)
        h = h.mul_((bins / (max - min)).square_())

        # reshape
        h = h.reshape([*batch, n, 2])
        return h
Exemplo n.º 3
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch size
        overload : dict

        Returns
        -------
        field : (batch, channel, *shape) tensor
            Generated random field

        """

        # get arguments
        shape = overload.get('shape', self.shape)
        mean = overload.get('mean', self.mean)
        amplitude = overload.get('amplitude', self.amplitude)
        fwhm = overload.get('fwhm', self.fwhm)
        channel = overload.get('channel', self.channel)
        basis = overload.get('basis', self.basis)
        dtype = overload.get('dtype', self.dtype)
        device = overload.get('device', self.device)

        # sample if parameters are callable
        mean = mean() if callable(mean) else mean
        amplitude = amplitude() if callable(amplitude) else amplitude
        fwhm = fwhm() if callable(fwhm) else fwhm

        # device/dtype
        mean = torch.as_tensor(mean, dtype=dtype, device=device)
        amplitude = torch.as_tensor(amplitude, dtype=dtype, device=device)
        fwhm = torch.as_tensor(fwhm, dtype=dtype, device=device)

        # reshape
        nb_dim = len(shape)
        full_shape = [batch, channel, *shape]
        mean = mean.expand(full_shape)
        amplitude = amplitude.expand(full_shape)
        fwhm = fwhm.expand([batch, channel, nb_dim])

        conv = torch.nn.functional.conv1d if nb_dim == 1 else \
               torch.nn.functional.conv2d if nb_dim == 2 else \
               torch.nn.functional.conv3d if nb_dim == 3 else None

        # convert SE parameters to noise/kernel parameters
        sigma_se = fwhm / math.sqrt(8 * math.log(2))
        sigma_se = unsqueeze(sigma_se.prod(dim=-1), dim=-1, ndim=nb_dim)
        amplitude = amplitude * (2 * pi)**(nb_dim / 4) * sigma_se.sqrt()
        fwhm = fwhm * math.sqrt(2)

        # smooth
        samples_b = []
        for b in range(batch):
            samples_c = []
            for c in range(channel):
                kernel = smooth('gauss',
                                fwhm[b, c],
                                basis=basis,
                                device=device,
                                dtype=dtype)

                # compute input shape
                pad_shape = [
                    shape[d] + kernel[d].shape[d + 2] - 1
                    for d in range(nb_dim)
                ]
                mean1 = ensure_shape(mean[b, c],
                                     pad_shape,
                                     mode='reflect2',
                                     side='both')
                amplitude1 = ensure_shape(amplitude[b, c],
                                          pad_shape,
                                          mode='reflect2',
                                          side='both')

                # generate sample
                sample = torch.distributions.Normal(mean1, amplitude1).sample()
                sample = sample[None, None, ...]

                # convolve
                for ker in kernel:
                    sample = conv(sample, ker)

                samples_c.append(sample)

            samples_b.append(torch.cat(samples_c, dim=1))

        sample = torch.cat(samples_b, dim=0)

        return sample
Exemplo n.º 4
0
def estimate_fwhm(dat, vx=None, verbose=0, mn=-inf, mx=inf):
    """Estimates full width at half maximum (FWHM) and noise standard
    deviation (sd) of a 2D or 3D image.

    It is assumed that the image has been generated as:
        dat = Ky + n,
    where K is Gaussian smoothing with some FWHM and n is
    additive Gaussian noise. FWHM and n are estimated.

    Parameters
    ----------
    dat : str or (*spatial) tensor
        Image data or path to nifti file
    vx : [sequence of] float, default=1
        Voxel size
    verbose : {0, 1, 2}, default=0
        Verbosity level:
            * 0: No verbosity
            * 1: Print FWHM and sd to screen
            * 2: 1 + show mask
    mn : float, optional
        Exclude values below
    mx : float, optional
        Exclude values above

    Returns
    -------
    fwhm : (dim,) tensor
        Estimated FWHM
    sd : scalar tensor
        Estimated noise standard deviation.

    References
    ----------
    ..[1] "Linked independent component analysis for multimodal data fusion."
          Appendix A
          Groves AR, Beckmann CF, Smith SM, Woolrich MW.
          Neuroimage. 2011 Feb 1;54(3):2198-217.

    """
    if isinstance(dat, str):
        dat = io.map(dat)
    if isinstance(dat, io.MappedArray):
        if vx is None:
            vx = get_voxel_size(dat.affine)
        dat = dat.fdata(rand=True, missing=0)
    dat = torch.as_tensor(dat)

    dim = dat.dim()
    if vx is None:
        vx = 1
    vx = utils.make_vector(vx, dim)
    backend = utils.backend(dat)
    # Make mask
    msk = (dat > mn).bitwise_and_(dat <= mx)
    dat = dat.masked_fill(~msk, 0)
    # TODO: we should erode the mask so that only voxels whose neighbours
    #       are in the mask are considered when computing gradients.
    if verbose >= 2:
        show_slices(msk)
    # Compute image gradient
    g = diff(dat, dim=range(dim), side='central', voxel_size=vx,
             bound='dft').abs_()
    slicer = (slice(1, -1), ) * dim
    g = g[(*slicer, None)]
    g[msk[slicer], :] = 0
    g = g.reshape([-1, dim]).sum(0, dtype=torch.double)
    # Make dat have zero mean
    dat = dat[slicer]
    dat = dat[msk[slicer]]
    x0 = dat - dat.mean()
    # Compute FWHM
    fwhm = pymath.sqrt(4 * pymath.log(2)) * x0.abs().sum(dtype=torch.double)
    fwhm = fwhm / g
    if verbose >= 1:
        print(f'FWHM={fwhm.tolist()}')
    # Compute noise standard deviation
    sx = smooth('gauss', fwhm[0], x=0, **backend)[0][0, 0, 0]
    sy = smooth('gauss', fwhm[1], x=0, **backend)[0][0, 0, 0]
    sz = 1.0
    if dim == 3:
        sz = smooth('gauss', fwhm[2], x=0, **backend)[0][0, 0, 0]
    sc = (sx * sy * sz) / dim
    sc.clamp_min_(1)
    sd = torch.sqrt(x0.square().sum(dtype=torch.double) / (x0.numel() * sc))
    if verbose >= 1:
        print(f'sd={sd.tolist()}')
    return fwhm, sd
Exemplo n.º 5
0
def _hist_2d(img0, img1, mx_int, fwhm):
    """Make 2D histogram, requires:
        * Images same size.
        * Images same min and max intensities (non-negative).

    Parameters
    ----------
    img0 : (X, Y, Z) tensor_like
        First image volume.
    img1 : (X, Y, Z) tensor_like
        Second image volume.
    mx_int : int
        This parameter sets the max intensity in the images, which decides
        how many bins to use in the joint image histograms
        (e.g, mx_int=511 -> H.shape = (512, 512)).
    fwhm : float
        Full-width at half max of Gaussian kernel, for smoothing
        histogram.

    Returns
    ----------
    H : (mx_int + 1, mx_int + 1) tensor_like
        Joint intensity histogram.

    Notes
    ----------
    Naive method for computing a 2D histogram:
    h = torch.zeros((mx_int + 1, mx_int + 1))
    for n in range(num_vox):
        h[img0[n], mg1[n]] += 1

    """
    fwhm = (fwhm, ) * 2
    # Convert each 'coordinate' of intensities to an index
    # (replicates the sub2ind function of MATLAB)
    img0 = img0.flatten().floor()
    img1 = img1.flatten().floor()
    sub = torch.stack((img0, img1), dim=1)  # (num_vox, 2)
    to_ind = torch.tensor((1, mx_int + 1), dtype=sub.dtype,
                          device=img0.device)[..., None]  # (2, 1)
    ind = torch.tensordot(sub, to_ind, dims=([1], [0]))  # (nvox, 1)
    # Build histogram H by adding up counts according to the indicies in ind
    H = torch.zeros(mx_int + 1,
                    mx_int + 1,
                    device=img0.device,
                    dtype=ind.dtype)
    H.put_(ind.long(),
           torch.ones(1, device=img0.device, dtype=ind.dtype).expand_as(ind),
           accumulate=True)
    # Smoothing kernel
    smo = smooth(('gauss', ) * 2,
                 fwhm=fwhm,
                 device=img0.device,
                 dtype=torch.float32,
                 sep=True)
    # Pad
    p = (smo[0].shape[2], smo[1].shape[3])
    p = (torch.tensor(p) - 1) // 2
    p = tuple(p.int().tolist())
    H = pad(H, p, side='both')
    # Smooth
    H = H[None, None, ...]
    H = F.conv2d(H, smo[0])
    H = F.conv2d(H, smo[1])
    H = H[0, 0, ...]
    # Clamp
    H = H.clamp_min(0.0)
    # Add eps
    H = H + 1e-7
    # # Visualise histogram
    # import matplotlib.pyplot as plt
    # plt.figure(num=1)
    # plt.imshow(H.detach().cpu(),
    #     cmap='coolwarm', interpolation='nearest',
    #     aspect='equal', vmax=0.05*H.max())
    # plt.axis('off')
    # plt.show()

    return H
Exemplo n.º 6
0
def _proj_info(dim_y,
               mat_y,
               dim_x,
               mat_x,
               rigid=None,
               prof_ip=0,
               prof_tp=0,
               gap=0.0,
               device='cpu',
               scl=0.0,
               samp=0):
    """ Define projection operator object, to be used with _proj_apply.

    Args:
        dim_y ((int, int, int))): High-res image dimensions (3,).
        mat_y (torch.tensor): High-res affine matrix (4, 4).
        dim_x ((int, int, int))): Low-res image dimensions (3,).
        mat_x (torch.tensor): Low-res affine matrix (4, 4).
        rigid (torch.tensor): Rigid transformation aligning x to y (4, 4), defaults to eye(4).
        prof_ip (int, optional): In-plane slice profile (0=rect|1=tri|2=gauss), defaults to 0.
        prof_tp (int, optional): Through-plane slice profile (0=rect|1=tri|2=gauss), defaults to 0.
        gap (float, optional): Slice-gap between 0 and 1, defaults to 0.
        device (torch.device, optional): Device. Defaults to 'cpu'.
        scl (float, optional): Odd/even slice scaling, defaults to 0.

    Returns:
        po (_proj_op()): Projection operator object.

    """
    # Get projection operator object
    po = _proj_op()
    # Data types
    dtype = torch.float64
    dtype_smo_ker = torch.float32
    # Output properties
    if not isinstance(dim_y, torch.Tensor):
        dim_y = torch.tensor(dim_y, device=device, dtype=dtype)
    po.dim_y = dim_y
    po.mat_y = mat_y
    po.vx_y = voxel_size(mat_y)
    # Input properties
    if not isinstance(dim_x, torch.Tensor):
        dim_x = torch.tensor(dim_x, device=device, dtype=dtype)
    po.dim_x = dim_x
    po.mat_x = mat_x
    po.vx_x = voxel_size(mat_x)
    # Number of dimensions
    ndim = len(dim_y)
    one = torch.tensor((1, ) * ndim, device=device, dtype=torch.float64)
    if rigid is None:
        po.rigid = torch.eye(ndim + 1, device=device, dtype=dtype)
    else:
        po.rigid = rigid.type(dtype).to(device)
    # Slice-profile
    gap_cn = torch.zeros(ndim, device=device, dtype=dtype)
    profile = torch.tensor((prof_ip, ) * ndim, device=device, dtype=dtype)
    dim_thick = torch.max(po.vx_x, dim=0)[1]
    gap_cn[dim_thick] = gap
    profile[dim_thick] = prof_tp
    po.dim_thick = dim_thick
    if samp > 0:
        # Sub-sampling
        samp = torch.tensor((samp, ) * ndim,
                            device=device,
                            dtype=torch.float64)
        # Intermediate to lowres
        sk = torch.max(one, torch.floor(samp * one / po.vx_x + 0.5))
        D_x = torch.diag(torch.cat((sk, one[0, None])))
        po.D_x = D_x
        # Modulate lowres
        po.mat_x = po.mat_x.mm(D_x)
        po.dim_x = D_x.inverse()[:ndim, :ndim].mm(
            po.dim_x[..., None]).floor().squeeze()
        if torch.sum(torch.abs(po.vx_x - po.vx_x)) > 1e-4:
            # Intermediate to highres (only for superres)
            sk = torch.max(one, torch.floor(samp * one / po.vx_y + 0.5))
            D_y = torch.diag(torch.cat((sk, one[0, None])))
            po.D_y = D_y
            # Modulate highres
            po.mat_y = po.mat_y.mm(D_y)
            po.vx_y = voxel_size(po.mat_y)
            po.dim_y = D_y.inverse()[:ndim, :ndim].mm(
                po.dim_y[..., None]).floor().squeeze()
        po.vx_x = voxel_size(po.mat_x)
    # Make intermediate
    ratio = torch.solve(po.mat_x, po.mat_y)[0]  # mat_y\mat_x
    ratio = (ratio[:ndim, :ndim]**2).sum(0).sqrt()
    ratio = ratio.ceil().clamp(1)  # ratio low/high >= 1
    mat_yx = torch.cat((ratio, torch.ones(1, device=device,
                                          dtype=dtype))).diag()
    po.mat_yx = po.mat_x.matmul(mat_yx.inverse())  # mat_x/mat_yx
    po.dim_yx = (po.dim_x - 1) * ratio + 1
    # Make elements with ratio <= 1 use dirac profile
    profile[ratio == 1] = -1
    profile = profile.int().tolist()
    # Make smoothing kernel (slice-profile)
    fwhm = (1. - gap_cn) * ratio
    smo_ker = smooth(profile,
                     fwhm,
                     sep=False,
                     dtype=dtype_smo_ker,
                     device=device)
    po.smo_ker = smo_ker
    # Add offset to intermediate space
    off = torch.tensor(smo_ker.shape[-ndim:], dtype=dtype, device=device)
    off = -(off - 1) // 2  # set offset
    mat_off = torch.eye(ndim + 1, dtype=torch.float64, device=device)
    mat_off[:ndim, -1] = off
    po.dim_yx = po.dim_yx + 2 * torch.abs(off)
    po.mat_yx = torch.matmul(po.mat_yx, mat_off)
    # Odd/even slice scaling
    if isinstance(scl, torch.Tensor):
        po.scl = scl
    else:
        po.scl = torch.tensor(scl, dtype=torch.float32, device=device)
    # To tuple of ints
    po.dim_y = tuple(po.dim_y.int().tolist())
    po.dim_yx = tuple(po.dim_yx.int().tolist())
    po.dim_x = tuple(po.dim_x.int().tolist())
    po.ratio = tuple(ratio.int().tolist())

    return po