コード例 #1
0
ファイル: conversions.py プロジェクト: liamchalcroft/nitorch
    def __init__(self, *args):
        """

        Parameters
        ----------
        Either
            quat : (..., 4) tensor
        Or
            orientation : (..., 3) tensor
            attitude : (...) tensor
        Or
            i, j, k, r : tensors

        """
        if len(args) == 1:
            ijkr = torch.as_tensor(args[0])
            i = ijkr[..., 0]
            j = ijkr[..., 1]
            k = ijkr[..., 2]
            r = ijkr[..., 3]
        elif len(args) == 2:
            ijk, r = utils.to_max_backend(*args)
            i = ijk[..., 0]
            j = ijk[..., 1]
            k = ijk[..., 2]
        elif len(args) == 4:
            i, j, k, r = utils.to_max_backend(*args)
        else:
            raise ValueError('Expected 1, 2 or 4 arguments')
        self.i = i
        self.j = j
        self.k = k
        self.r = r
コード例 #2
0
def build_se(sqdist, sigma, lam, **backend):
    """Build squared-exponential covariance matrix

    Parameters
    ----------
    sqdist : sequence[int] or (vox, vox) tensor
        If a tensor -> it is the pre-computed squared distance map
        If a tuple -> it is the shape and we build the distance map
    sigma : (*batch) tensor_like
        Amplitude
    lam : (*batch) tensor_like
        Length-scale

    Returns
    -------
    cov : (*batch, vox, vox) tensor
        Covariance matrix

    """
    lam, sigma = utils.to_max_backend(lam, sigma, **backend, force_float=True)
    backend = utils.backend(lam)

    # Build SE covariance matrix
    if not torch.is_tensor(sqdist):
        shape = sqdist
        e = dist_map(shape, **backend)
    else:
        e = sqdist.to(**backend)
    del sqdist
    lam = lam[..., None, None]
    sigma = sigma[..., None, None]
    e = e.mul(-0.5 / (lam**2)).exp_().mul_(sigma**2)
    return e
コード例 #3
0
def irls_tukey_reweight(moving,
                        fixed,
                        lam=1,
                        c=4.685,
                        joint=False,
                        dim=None,
                        mask=None):
    """Update iteratively reweighted least-squares weights for Tukey's biweight

    Parameters
    ----------
    moving : ([B], K, *spatial) tensor
        Moving image
    fixed : ([B], K, *spatial) tensor
        Fixed image
    lam : float or ([B], K|1, [*spatial]) tensor_like
        Equivalent to Gaussian noise precision
        (used to standardize the residuals)
    c  : float, default=4.685
        Tukey's threshold.
        Approximately equal to a number of standard deviations above
        which the loss is capped.
    dim : int, default=`fixed.dim() - 1`
        Number of spatial dimensions

    Returns
    -------
    weights : (..., K|1, *spatial) tensor
        IRLS weights

    """
    if lam is None:
        lam = 1
    c = c * c
    fixed, moving, lam = utils.to_max_backend(fixed, moving, lam)
    if mask is not None:
        mask = mask.to(fixed.device)
    dim = dim or (fixed.dim() - 1)
    if lam.dim() <= 2:
        if lam.dim() == 0:
            lam = lam.flatten()
        lam = utils.unsqueeze(lam, -1, dim)  # pad spatial dimensions
    weights = (moving - fixed).square_().mul_(lam)
    if mask is not None:
        weights = weights.mul_(mask)
    if joint:
        weights = weights.sum(dim=-dim - 1, keepdims=True)
    zeromsk = weights > c
    weights = weights.div_(-c).add_(1).square()
    weights[zeromsk].zero_()
    return weights
コード例 #4
0
def irls_laplace_reweight(moving,
                          fixed,
                          lam=1,
                          joint=False,
                          eps=1e-5,
                          dim=None,
                          mask=None):
    """Update iteratively reweighted least-squares weights for l1

    Parameters
    ----------
    moving : ([B], K, *spatial) tensor
        Moving image
    fixed : ([B], K, *spatial) tensor
        Fixed image
    lam : float or ([B], K|1, [*spatial]) tensor_like
        Inverse-squared scale parameter of the Laplace distribution.
        (equivalent to Gaussian noise precision)
    dim : int, default=`fixed.dim() - 1`
        Number of spatial dimensions

    Returns
    -------
    weights : (..., K|1, *spatial) tensor
        IRLS weights

    """
    if lam is None:
        lam = 1
    fixed, moving, lam = utils.to_max_backend(fixed, moving, lam)
    if mask is not None:
        mask = mask.to(fixed.device)
    dim = dim or (fixed.dim() - 1)
    if lam.dim() <= 2:
        if lam.dim() == 0:
            lam = lam.flatten()
        lam = utils.unsqueeze(lam, -1, dim)  # pad spatial dimensions
    weights = (moving - fixed).square_().mul_(lam)
    if mask is not None:
        weights = weights.mul_(mask)
    if joint:
        weights = weights.sum(dim=-dim - 1, keepdims=True)
    weights = weights.sqrt_().clamp_min_(eps).reciprocal_()
    if mask is not None:
        weights = weights.masked_fill_(mask == 0, 0)
    return weights
コード例 #5
0
ファイル: _mrfield.py プロジェクト: balbasty/nitorch
def mrfield_greens_apply(mom, greens):
    """Apply the Greens function to a momentum field.

    Parameters
    ----------
    mom : (..., *spatial) tensor
        Momentum
    greens : (*spatial) tensor
        Greens function

    Returns
    -------
    field : (..., *spatial) tensor
        Field

    """
    mom, greens = utils.to_max_backend(mom, greens)
    dim = greens.dim()

    # fourier transform
    if utils.torch_version('>=', (1, 8)):
        mom = torch.fft.fftn(mom, dim=dim)
    else:
        if torch.backends.mkl.is_available:
            # use rfft
            mom = torch.rfft(mom, dim, onesided=False)
        else:
            zero = mom.new_zeros([]).expand(mom.shape)
            mom = torch.stack([mom, zero], dim=-1)
            mom = torch.fft(mom, dim)

    # voxel wise multiplication
    mom = mom * greens[..., None]

    # inverse fourier transform
    if utils.torch_version('>=', (1, 8)):
        mom = torch.fft.ifftn(mom, dim=dim).real()
    else:
        mom = torch.ifft(mom, dim)[..., 0]

    return mom
コード例 #6
0
def mp2rage(pd,
            r1,
            r2s=None,
            transmit=None,
            receive=None,
            gfactor=None,
            tr=6.25,
            ti1=0.8,
            ti2=2.2,
            tx=None,
            te=None,
            fa=(4, 5),
            n=160,
            eff=0.96,
            sigma=None,
            device=None,
            return_combined=True):
    """Simulate data generated by a (simplified) MP2RAGE sequence.

    The defaults are parameters used at 3T in the original MP2RAGE paper.
    However, I don't get a nice image with these parameters when applied
    to maps obtained at 3T with the hmri toolbox.
    Here are (unrealistic) parameters that seem to give a decent contrast:
    tr=6.25, ti1=1.4, ti2=4.5, tx=5.8e-3, fa=(4, 5), n=160, eff=0.96

    Tissue parameters
    -----------------
    pd : tensor_like
        Proton density
    r1 : tensor_like
        Longitudinal relaxation rate, in 1/sec
    r2s : tensor_like, optional
        Transverse relaxation rate, in 1/sec.
        If not provided, T2*-bias is not included.

    Fields
    ------
    transmit : tensor_like, optional
        Transmit B1 field
    receive : tensor_like, optional
        Receive B1 field
    gfactor : tensor_like, optional
        G-factor map.
        If provided and `sigma` is not `None`, the g-factor map is used
        to sample non-stationary noise.

    Sequence parameters
    -------------------
    tr : float default=6.25
        Full Repetition time, in sec.
        (Time between two inversion pulses)
    ti1 : float, default=0.8
        First inversion time, in sec.
        (Time between inversion pulse and middle of the first echo train)
    ti2 : float, default=2.2
        Second inversion time, in sec.
        (Time between inversion pulse and middle of the second echo train)
    tx : float, default=te*2 or 5.8e-3
        Excitation repetition time, in sec.
        (Time between two excitation pulses within the echo train)
    te : float, default=tx/2
        Echo time, in sec.
    fa : float or (float, float), default=(4, 5)
        Flip angle of the first and second acquisition block, in deg
        If only one value, it is shared between the blocks.
    n : int, default=160
        Number of excitation pulses (= phase encoding steps) per train.
    eff : float, default=0.96
        Efficiency of the inversion pulse.

    Noise
    -----
    sigma : float, optional
        Standard-deviation of the sampled Rician noise (no sampling if `None`)

    Returns
    -------
    mp2rage : tensor, if return_combined is True
        Simulated MP2RAGE image

    image1 : tensor, if return_combined is False
        Image at first inversion time
    image2 : tensor, if return_combined is False
        Image at second inversion time

    References
    ----------
    ..[1] "MP2RAGE, a self bias-field corrected sequence for improved
        segmentation and T1-mapping at high field."
        Marques JP, Kober T, Krueger G, van der Zwaag W, Van de Moortele PF, Gruetter R.
        Neuroimage. 2010 Jan 15;49(2):1271-81.
        doi: 10.1016/j.neuroimage.2009.10.002

    """

    pd, r1, r2s, transmit, receive, gfactor \
        = utils.to_max_backend(pd, r1, r2s, transmit, receive, gfactor)
    pd, r1, r2s, transmit, receive, gfactor \
        = utils.to(pd, r1, r2s, transmit, receive, gfactor, device=device)

    if tx is None and te is None:
        tx = 5.8e-3
    tx = tx or 2 * te  # Time between excitation pulses
    te = te or tx / 2  # Echo time
    fa1, fa2 = py.make_list(fa, 2)
    fa1 = fa1 * constants.pi / 180  # Flip angle of first GRE block
    fa2 = fa2 * constants.pi / 180  # Flip angle of second GRE block
    n = n or min(pd.shape)  # Number of readouts (PE steps) per loop
    tr1 = n * tx  # First GRE block
    tr2 = n * tx  # Second GRE block
    tp = ti1 - tr1 / 2  # Preparation time
    tw = ti2 - tr2 / 2 - ti1 - tr1 / 2  # Wait time between GRE blocks
    td = tr - ti2 - tr2 / 2  # Recovery time

    if return_combined and not sigma:
        m = mp2rage_nonoise(pd, r1, tx, tp, tw, td, tr, fa1, fa2, n, eff,
                            transmit)

        m = torch.where(~torch.isfinite(m), m.new_zeros([1]), m)
        return m

    mi1, mi2 = mp2rage_uncombined(pd, r1, r2s, tx, tp, tw, td, tr, te, fa1,
                                  fa2, n, eff, transmit, receive)

    # noise
    mi1 = add_noise(mi1, std=sigma, gfactor=gfactor)
    mi2 = add_noise(mi2, std=sigma, gfactor=gfactor)

    if return_combined:
        m = mp2rage_from_ir(mi1, mi2)
        m = torch.where(~torch.isfinite(m), m.new_zeros([1]), m)
        return m
    else:
        mi1 = torch.where(~torch.isfinite(mi1), mi1.new_zeros([]), mi1)
        mi2 = torch.where(~torch.isfinite(mi2), mi2.new_zeros([]), mi2)
        return mi1, mi2
コード例 #7
0
def fs_to_affine(shape,
                 voxel_size=1.,
                 x=None,
                 y=None,
                 z=None,
                 c=0.,
                 source='voxel',
                 dest='ras'):
    """Transform FreeSurfer orientation parameters into an affine matrix.

    The returned matrix is effectively a "<source> to <dest>" transform.

    Parameters
    ----------
    shape : sequence of int
    voxel_size : [sequence of] float, default=1
    x : [sequence of] float, default=[1, 0, 0]
    y: [sequence of] float, default=[0, 1, 0]
    z: [sequence of] float, default=[0, 0, 1]
    c: [sequence of] float, default=0
    source : {'voxel', 'physical', 'ras'}, default='voxel'
    dest : {'voxel', 'physical', 'ras'}, default='ras'

    Returns
    -------
    affine : (4, 4) tensor

    """
    dim = len(shape)
    shape, voxel_size, x, y, z, c \
        = utils.to_max_backend(shape, voxel_size, x, y, z, c)
    backend = dict(dtype=shape.dtype, device=shape.device)
    voxel_size = utils.make_vector(voxel_size, dim)
    if x is None:
        x = [1, 0, 0]
    if y is None:
        y = [0, 1, 0]
    if z is None:
        z = [0, 0, 1]
    x = utils.make_vector(x, dim)
    y = utils.make_vector(y, dim)
    z = utils.make_vector(z, dim)
    c = utils.make_vector(c, dim)

    shift = shape / 2.
    shift = -shift * voxel_size
    vox2phys = Orientation(shift, voxel_size).affine()
    phys2ras = XYZC(x, y, z, c).affine()

    affines = []
    if source.lower().startswith('vox'):
        affines.append(vox2phys)
        middle_space = 'phys'
    elif source.lower().startswith('phys'):
        if dest.lower().startswith('vox'):
            affines.append(affine_inv(vox2phys))
            middle_space = 'vox'
        else:
            affines.append(phys2ras)
            middle_space = 'ras'
    elif source.lower() == 'ras':
        affines.append(affine_inv(phys2ras))
        middle_space = 'phys'
    else:
        # We need a matrix to switch orientations
        affines.append(layout_matrix(source, **backend))
        middle_space = 'ras'

    if dest.lower().startswith('phys'):
        if middle_space == 'vox':
            affines.append(vox2phys)
        elif middle_space == 'ras':
            affines.append(affine_inv(phys2ras))
    elif dest.lower().startswith('vox'):
        if middle_space == 'phys':
            affines.append(affine_inv(vox2phys))
        elif middle_space == 'ras':
            affines.append(affine_inv(phys2ras))
            affines.append(affine_inv(vox2phys))
    elif dest.lower().startswith('ras'):
        if middle_space == 'phys':
            affines.append(phys2ras)
        elif middle_space.lower().startswith('vox'):
            affines.append(vox2phys)
            affines.append(phys2ras)
    else:
        if middle_space == 'phys':
            affines.append(affine_inv(phys2ras))
        elif middle_space == 'vox':
            affines.append(vox2phys)
            affines.append(phys2ras)
        layout = layout_matrix(dest, **backend)
        affines.append(affine_inv(layout))

    affine, *affines = affines
    for aff in affines:
        affine = affine_matmul(aff, affine)
    return affine
コード例 #8
0
ファイル: _prototype_affine.py プロジェクト: balbasty/nitorch
def register(fixed=None,
             moving=None,
             dim=None,
             loss='mse',
             basis='CSO',
             optim='ogm',
             max_iter=500,
             lr=1,
             ls=6,
             plot=False,
             klosure=RegisterStep,
             logaff=None,
             verbose=True):
    """Affine registration between two images using Lie groups.

    Parameters
    ----------
    fixed : (..., K, *spatial) tensor
        Fixed image
    moving : (..., K, *spatial) tensor
        Moving image
    dim : int, default=`fixed.dim() - 1`
        Number of spatial dimensions
    loss : {'mse', 'cat'} or OptimizationLoss, default='mse'
        'mse': Mean-squared error
        'cat': Categorical cross-entropy
    optim : {'relax', 'cg', 'gd', 'momentum', 'nesterov'}, default='ogm'
        'gn'        : Gauss-Newton
        'gd'        : Gradient descent
        'momentum'  : Gradient descent with momentum
        'nesterov'  : Nesterov-accelerated gradient descent
        'ogm'       : Optimized gradient descent (Kim & Fessler)
        'lbfgs'     : Limited-memory BFGS
    max_iter : int, default=100
        Maximum number of Gauss-Newton or Gradient descent iterations
    lr : float, default=1
        Learning rate.
    ls : int, default=6
        Number of line search iterations.
    plot : bool, default=False
        Plot progress

    Returns
    -------
    logaff : (...) tensor
        Displacement field.

    """

    # If no inputs provided: demo "circle to square"
    if fixed is None or moving is None:
        fixed, moving = phantoms.demo_register(cat=(loss == 'cat'))

    # init tensors
    fixed, moving = utils.to_max_backend(fixed, moving)
    dim = dim or (fixed.dim() - 1)
    basis = spatial.affine_basis(basis, dim, **utils.backend(fixed))
    if logaff is None:
        logaff = torch.zeros(len(basis), **utils.backend(fixed))
        # logaff = torch.zeros(12, **utils.backend(fixed))

    # init optimizer
    optim = regutils.make_iteroptim_affine(optim, lr, ls, max_iter)

    # init loss
    loss = losses.make_loss(loss, dim)

    # optimize
    if verbose:
        print(
            f'{"it":3s} | {"fit":^12s} + {"reg":^12s} = {"obj":^12s} | {"gain":^12s}'
        )
        print('-' * 63)
    closure = klosure(moving,
                      fixed,
                      loss,
                      basis=basis,
                      verbose=verbose,
                      plot=plot,
                      max_iter=optim.max_iter)
    logaff = optim.iter(logaff, closure)
    if verbose:
        print('')
    return logaff
コード例 #9
0
ファイル: _conv.py プロジェクト: balbasty/nitorch
def conv(dim, tensor, kernel, bias=None, stride=1, padding=0, bound='zero',
         dilation=1, groups=1):
    """Perform a convolution

    Parameters
    ----------
    dim : {1, 2, 3}
        Number of spatial dimensions
    tensor : (*batch, [channel_in,] *spatial_in) tensor
        Input tensor
    kernel : ([channel_in, channel_out,] *kernel_size) tensor
        Convolution kernel
    bias : ([channel_out,]) tensor, optional
        Bias tensor
    stride : int or sequence[int], default=1
        Strides between output elements,
    padding : 'same' or int or sequence[int], default=0
        Padding performed before the convolution.
        If 'same', the padding is chosen such that the shape of the
        output tensor is `spatial_in // stride`.
    bound : str, default='zero'
        Boundary conditions used in the padding.
    dilation : int or sequence[int], default=1
        Dilation of the kernel.
    groups : int, default=1

    Returns
    -------
    convolved : (*batch, [channel_out], *spatial_out)

    """
    # move everything to the same dtype/device
    tensor, kernel, bias = utils.to_max_backend(tensor, kernel, bias)

    # sanity checks + reshape for torch's conv
    if kernel.dim() not in (dim, dim + 2):
        raise ValueError('Kernel shape should be (*kernel_size) or '
                         '(channel_in, channel_out, *kernel_size) but '
                         'got {}'.format(kernel.shape))
    has_channels = kernel.dim() == dim + 2
    channels_in = kernel.shape[0] if has_channels else 1
    channels_out = kernel.shape[1] if has_channels else 1
    kernel_size = kernel.shape[(2*has_channels):]
    kernel = kernel.reshape([channels_in, channels_out, *kernel_size])
    batch = tensor.shape[:-(dim+has_channels)]
    spatial_in = tensor.shape[(-dim):]
    if has_channels and tensor.shape[-(dim+has_channels)] != channels_in:
        raise ValueError('Number of input channels not consistent: '
                         'Got {} (kernel) and {} (tensor).' .format(
                         channels_in, tensor.shape[-(dim+has_channels)]))
    tensor = tensor.reshape([-1, channels_in, *spatial_in])
    if bias:
        bias = bias.flatten()
        if bias.numel() == 1:
            bias = bias.expand(channels_out)
        elif bias.numel() != channels_out:
            raise ValueError('Number of output channels not consistent: '
                             'Got {} (kernel) and {} (bias).' .format(
                             channels_out, bias.numel()))

    # Perform padding
    dilation = make_list(dilation, dim)
    padding = make_list(padding, dim)
    padding = [0 if p == 'valid' else 'same' if p == 'auto' else p
               for p in padding]
    for i in range(dim):
        if isinstance(padding[i], str):
            assert padding[i].lower() == 'same'
            if kernel_size[i] % 2 == 0:
                raise ValueError('Cannot compute "same" padding '
                                 'for even-sized kernels.')
            padding[i] = dilation[i] * (kernel_size[i] // 2)
    if bound != 'zero' and sum(padding) > 0:
        tensor = core.utils.pad(tensor, padding, bound, side='both')
        padding = 0

    conv_fn = (F.conv1d if dim == 1 else
               F.conv2d if dim == 2 else
               F.conv3d if dim == 3 else None)
    if not conv_fn:
        raise NotImplementedError('Convolution is only implemented in '
                                  'dimension 1, 2 or 3.')
    tensor = conv_fn(tensor, kernel, bias, stride=stride, padding=padding,
                     dilation=dilation, groups=groups)
    spatial_out = tensor.shape[(-dim):]
    channels_out = [channels_out] if has_channels else []
    tensor = tensor.reshape([*batch, *channels_out, *spatial_out])
    return tensor
コード例 #10
0
ファイル: cc.py プロジェクト: balbasty/nitorch
def lcc(moving,
        fixed,
        dim=None,
        patch=20,
        stride=1,
        lam=1,
        mode='g',
        grad=True,
        hess=True,
        mask=None):
    """Local correlation coefficient (squared)

    This function implements a squared version of Cachier and
    Pennec's local correlation coefficient, so that anti-correlations
    are not penalized.

    Parameters
    ----------
    moving : (..., K, *spatial) tensor
        Moving image with K channels.
    fixed : (..., K, *spatial) tensor
        Fixed image with K channels.
    dim : int, default=`fixed.dim() - 1`
        Number of spatial dimensions.
    patch : int, default=5
        Patch size
    lam : float or ([B], K|1, [*spatial]) tensor_like, default=1
        Precision of the NCC distribution
    grad : bool, default=True
        Compute and return gradient
    hess : bool, default=True
        Compute and return approximate Hessian

    Returns
    -------
    ll : () tensor

    References
    ----------
    ..[1] "3D Non-Rigid Registration by Gradient Descent on a Gaussian-
           Windowed Similarity Measure using Convolutions"
          Pascal Cachier, Xavier Pennec
          MMBIA (2000)

    """
    if moving.requires_grad:
        sqrt_ = torch.sqrt
        div_ = torch.div
    else:
        sqrt_ = torch.sqrt_
        div_ = lambda x, y: x.div_(y)

    fixed, moving, lam = utils.to_max_backend(fixed, moving, lam)
    dim = dim or (fixed.dim() - 1)
    shape = fixed.shape[-dim:]
    if mask is not None:
        mask = mask.to(**utils.backend(fixed))
    else:
        mask = fixed.new_ones(fixed.shape[-dim:])

    if lam.dim() <= 2:
        if lam.dim() == 0:
            lam = lam.flatten()
        lam = utils.unsqueeze(lam, -1, dim)

    patch = list(map(float, py.ensure_list(patch)))
    stride = py.ensure_list(stride)
    stride = [s or 0 for s in stride]
    fwd = lambda x: local_mean(
        x, patch, stride, dim=dim, mode=mode, mask=mask, cache=local_cache)
    bwd = lambda x: local_mean(x,
                               patch,
                               stride,
                               dim=dim,
                               mode=mode,
                               mask=mask,
                               backward=True,
                               shape=shape,
                               cache=local_cache)
    sumall = lambda x: x.sum(list(range(-dim, 0)), keepdim=True)

    # compute ncc within each patch
    mom0, moving_mean, fixed_mean, moving_std, fixed_std, corr = \
        _suffstat(fwd, moving, fixed)
    mom0 = mom0.div_(sumall(mom0).clamp_min_(1e-5)).mul_(lam)
    moving_std = sqrt_(moving_std.addcmul_(moving_mean, moving_mean, value=-1))
    fixed_std = sqrt_(fixed_std.addcmul_(fixed_mean, fixed_mean, value=-1))
    moving_std.clamp_min_(1e-5)
    fixed_std.clamp_min_(1e-5)
    corr = div_(
        div_(corr.addcmul_(moving_mean, fixed_mean, value=-1), moving_std),
        fixed_std)
    corr2 = corr.square().neg_().add_(1).clamp_min_(1e-8)

    out = []
    if grad or hess:
        h = (corr / moving_std).square_().mul_(mom0).div_(corr2)
        h = bwd(h)

        if grad:
            # g = G' * (corr.*(corr.*xmean./xstd - ymean./ystd)./xstd)
            #   - x .* (G' * (corr./ xstd).^2)
            #   + y .* (G' * (corr ./ (xstd.*ystd)))
            # g = -2 * g
            fixed_mean = fixed_mean.div_(fixed_std)
            moving_mean = moving_mean.div_(moving_std)
            g = fixed_mean.addcmul_(corr, moving_mean, value=-1)
            fixed_mean = moving_mean = None
            g = g.mul_(corr).div_(moving_std).mul_(mom0).div_(corr2)
            g = bwd(g)
            g = g.addcmul_(h, moving)
            g = g.addcmul_(bwd(
                corr.div_(moving_std).div_(fixed_std).mul_(mom0).div_(corr2)),
                           fixed,
                           value=-1)
            g = g.mul_(2)
            out.append(g)

        if hess:
            # h = 2 * (G' * (corr./ xstd).^2)
            h = h.mul_(2)
            out.append(h)

    # return stuff
    corr = corr2.log_().mul_(mom0)
    corr = corr.sum()
    out = [corr, *out]
    return tuple(out) if len(out) > 1 else out[0]
コード例 #11
0
ファイル: fse.py プロジェクト: balbasty/nitorch
def fse(pd,
        r1,
        r2=None,
        receive=None,
        gfactor=None,
        te=0.02,
        tr=5,
        sigma=None,
        device=None):
    """Simulate data generated by a (simplified) Fast Spin-Echo (FSE) sequence.

    Tissue parameters
    -----------------
    pd : tensor_like
        Proton density
    r1 : tensor_like
        Longitudinal relaxation rate, in 1/sec
    r2 : tensor_like, optional
        Transverse relaxation rate, in 1/sec.

    Fields
    ------
    receive : tensor_like, optional
        Receive B1 field
    gfactor : tensor_like, optional
        G-factor map.
        If provided and `sigma` is not `None`, the g-factor map is used
        to sample non-stationary noise.

    Sequence parameters
    -------------------
    te : float, default=3e-3
        Echo time, in sec
    tr : float default=2.3
        Repetition time, in sec.

    Noise
    -----
    sigma : float, optional
        Standard-deviation of the sampled noise (no sampling if `None`)

    Returns
    -------
    sim : tensor
        Simulated FSE image

    """

    pd, r1, r2, receive, gfactor \
        = utils.to_max_backend(pd, r1, r2, receive, gfactor)
    pd, r1, r2, receive, gfactor \
        = utils.to(pd, r1, r2, receive, gfactor, device=device)

    if receive is not None:
        pd = pd * receive
    del receive

    e1 = r1.mul(tr).neg_().exp_()
    e2 = r2.mul(te).neg_().exp_()

    signal = pd * (1 - e1) * e2

    # noise
    signal = add_noise(signal, std=sigma)
    return signal
コード例 #12
0
def physio_sample(shape=None,
                  sigma_p=0.008,
                  lam_p=0.4,
                  sigma_0=2.,
                  sigma_r=1.,
                  lam_r=0.2,
                  signal=100.,
                  repeats=100,
                  sampler='svd',
                  **backend):
    """Sample from the fMRI physiological model.

    Parameters
    ----------
    shape : list[int], default=[32, 32]
        Shape of the field of view
    sigma_p : float, default=0.008
        Amplitude of the physiological noise.
    lam_p : float, default=0.4
        Length-scale of the physiological noise (i.e., smoothness).
    sigma_0 : float, default=2.0
        Amplitude of the thermal noise.
    sigma_r : float, default=1.
        Amplitude of the reconstruction filter.
    lam_r : float, default=0.4
        Length-scale of the reconstruction filter.
    signal : float,default=100.
        Mean signal.
    repeats : int, default=100
        Number of repeats in the time series

    Returns
    -------
    time_series : (repeats, *shape) tensor[dtype]
        fMRI time series. Forward model: ReconFilter(Signal * Physio + Thermal)
    replicate_series : (repeats, *shape) tensor[dtype]
        Replicate series. Forward model: ReconFilter(Signal + Thermal)

    """
    shape = [32, 32] if shape is None else shape
    dim = len(shape)
    sigma_p_recon = sigma_p * sigma_r * (1 + 2 * (lam_r**2) /
                                         (lam_p**2))**(-dim / 4)
    sigma_0_recon = sigma_0 * sigma_r * (4. * constants.pi * lam_r**2)**(-dim /
                                                                         4)
    lam_p_recon = (lam_p**2 + 2. * lam_r**2)**0.5
    lam_0_recon = (2.**0.5) * lam_r

    param = sigma_p_recon, sigma_0_recon, lam_p_recon, lam_0_recon
    param = utils.to_max_backend(*param, **backend)
    sigma_p_recon, sigma_0_recon, lam_p_recon, lam_0_recon = param
    backend = utils.backend(sigma_p_recon)

    # thermal noise (*) recon
    tr = lambda: se_sample(shape,
                           sigma_0_recon,
                           lam_0_recon,
                           repeats=repeats,
                           sampler=sampler,
                           **backend)
    # physio noise (*) recon
    pr = lambda: se_sample(shape,
                           sigma_p_recon,
                           lam_p_recon,
                           repeats=repeats,
                           sampler=sampler,
                           **backend)

    time_series = signal * (1. + pr()) + tr()
    replicate_series = signal + tr()

    return time_series, replicate_series
コード例 #13
0
ファイル: segmentation.py プロジェクト: liamchalcroft/nitorch
    def compose(self,
                orient_in,
                deformation,
                orient_mean,
                affine=None,
                orient_out=None,
                shape_out=None):
        """Composes a deformation defined in a mean space to an image space.

        Parameters
        ----------
        orient_in : (4, 4) tensor
            Orientation of the input image
        deformation : (*shape_mean, 3) tensor
            Random deformation
        orient_mean : (4, 4) tensor
            Orientation of the mean space (where the deformation is)
        affine : (4, 4) tensor, default=identity
            Random affine
        orient_out : (4, 4) tensor, default=orient_in
            Orientation of the output image
        shape_out : sequence[int], default=shape_mean
            Shape of the output image

        Returns
        -------
        grid : (*shape_out, 3)
            Voxel-to-voxel transform

        """
        if orient_out is None:
            orient_out = orient_in
        if shape_out is None:
            shape_out = deformation.shape[:-1]
        if affine is None:
            affine = torch.eye(4,
                               4,
                               device=orient_in.device,
                               dtype=orient_in.dtype)
        shape_mean = deformation.shape[:-1]

        orient_in, affine, deformation, orient_mean, orient_out \
            = utils.to_max_backend(orient_in, affine, deformation, orient_mean, orient_out)
        backend = utils.backend(deformation)
        eye = torch.eye(4, **backend)

        # Compose deformation on the right
        right_affine = spatial.affine_lmdiv(orient_mean, orient_out)
        if not (shape_mean == shape_out and right_affine.all_close(eye)):
            # the mean space and native space are not the same
            # we must compose the diffeo with a dense affine transform
            # we write the diffeo as an identity plus a displacement
            # (id + disp)(aff) = aff + disp(aff)
            # -------
            # to displacement
            deformation = deformation - spatial.identity_grid(
                deformation.shape[:-1], **backend)
            trf = spatial.affine_grid(right_affine, shape_out)
            deformation = spatial.grid_pull(utils.movedim(deformation, -1,
                                                          0)[None],
                                            trf[None],
                                            bound='dft',
                                            extrapolate=True)
            deformation = utils.movedim(deformation[0], 0, -1)
            trf = trf + deformation  # add displacement

        # Compose deformation on the left
        #   the output of the diffeo(right) are mean_space voxels
        #   we must compose on the left with `in\(aff(mean))`
        # -------
        left_affine = spatial.affine_matmul(spatial.affine_inv(orient_in),
                                            affine)
        left_affine = spatial.affine_matmul(left_affine, orient_mean)
        trf = spatial.affine_matvec(left_affine, trf)

        return trf
コード例 #14
0
def register(fixed=None,
             moving=None,
             dim=None,
             lam=1.,
             loss='mse',
             optim='nesterov',
             hilbert=None,
             max_iter=500,
             sub_iter=16,
             lr=None,
             ls=0,
             plot=False,
             klosure=RegisterStep,
             kernel=None,
             **prm):
    """Nonlinear registration between two images using smooth displacements.

    Parameters
    ----------
    fixed : (..., K, *spatial) tensor
        Fixed image
    moving : (..., K, *spatial) tensor
        Moving image
    dim : int, default=`fixed.dim() - 1`
        Number of spatial dimensions
    lam : float, default=1
        Modulate regularisation
    loss : {'mse', 'cat'} or OptimizationLoss, default='mse'
        'mse': Mean-squared error
        'cat': Categorical cross-entropy
    optim : {'relax', 'cg', 'gd', 'momentum', 'nesterov'}, default='relax'
        'relax'     : Gauss-Newton (linear system solved by relaxation)
        'cg'        : Gauss-Newton (linear system solved by conjugate gradient)
        'gd'        : Gradient descent
        'momentum'  : Gradient descent with momentum
        'nesterov'  : Nesterov-accelerated gradient descent
        'lbfgs'     : Limited-memory BFGS
    hilbert : bool, default=True
        Use hilbert preconditioning (not used if optim is second order)
    max_iter : int, default=100
        Maximum number of Gauss-Newton or Gradient descent iterations
    sub_iter : int, default=16
        Number of relax/cg iterations per GN step
    lr : float, default=1
        Learning rate.
    ls : int, default=0
        Number of line search iterations.
    absolute : float, default=1e-4
        Penalty on absolute displacements
    membrane : float, default=1e-3
        Penalty on first derivatives
    bending : float, default=0.2
        Penalty on second derivatives
    lame : (float, float), default=(0.05, 0.2)
        Penalty on zooms and shears

    Returns
    -------
    disp : (..., *spatial, dim) tensor
        Displacement field.

    """
    defaults_velocity(prm)

    # If no inputs provided: demo "circle to square"
    if fixed is None or moving is None:
        fixed, moving = phantoms.demo_register(cat=(loss == 'cat'))

    # init tensors
    fixed, moving = utils.to_max_backend(fixed, moving)
    dim = dim or (fixed.dim() - 1)
    shape = fixed.shape[-dim:]
    lam = lam / py.prod(shape)
    prm['factor'] = lam
    velshape = [*fixed.shape[:-dim - 1], *shape, dim]
    vel = torch.zeros(velshape, **utils.backend(fixed))

    # init optimizer
    optim = regutils.make_iteroptim_grid(optim, lr, ls, max_iter, sub_iter,
                                         **prm)
    if hilbert is None:
        hilbert = not optim.requires_hess
    if hilbert and kernel is None:
        kernel = spatial.greens(shape, **prm, **utils.backend(fixed))
    if kernel is not None:
        optim.preconditioner = lambda x: spatial.greens_apply(x, kernel)

    # init loss
    loss = losses.make_loss(loss, dim)

    print(
        f'{"it":3s} | {"fit":^12s} + {"reg":^12s} = {"obj":^12s} | {"gain":^12s}'
    )
    print('-' * 63)
    closure = klosure(moving,
                      fixed,
                      loss,
                      plot=plot,
                      max_iter=optim.max_iter,
                      **prm)
    vel = optim.iter(vel, closure)
    print('')
    return vel
コード例 #15
0
ファイル: mprage.py プロジェクト: balbasty/nitorch
def mprage(pd,
           r1,
           r2s=None,
           transmit=None,
           receive=None,
           gfactor=None,
           tr=2.3,
           ti=0.9,
           tx=None,
           te=None,
           fa=9,
           n=160,
           eff=0.96,
           sigma=None,
           device=None):
    """Simulate data generated by a (simplified) MP-RAGE sequence.

    Default parameters mimic the ADNI-3 protocol on 3T Siemens scanners.
    Our Implementation is based on the MP2RAGE paper, where the sequence
    is stripped from the second GRE readout block.

    Tissue parameters
    -----------------
    pd : tensor_like
        Proton density
    r1 : tensor_like
        Longitudinal relaxation rate, in 1/sec
    r2s : tensor_like, optional
        Transverse relaxation rate, in 1/sec.
        If not provided, T2*-bias is not included.

    Fields
    ------
    transmit : tensor_like, optional
        Transmit B1 field
    receive : tensor_like, optional
        Receive B1 field
    gfactor : tensor_like, optional
        G-factor map.
        If provided and `sigma` is not `None`, the g-factor map is used
        to sample non-stationary noise.

    Sequence parameters
    -------------------
    tr : float default=2.3
        Repetition time, in sec.
        (Time between two inversion pulses)
    ti : float, default=0.9
        Inversion time, in sec.
        (Time between inversion pulse and middle of the echo train)
    tx : float, default=2*te or 6e-3
        Excitation repetition time, in sec
        (Time between two excitation pulses within the echo train)
    te : float, default=tx/2
        Echo time, in sec
    fa : float, default=9
        Flip angle, in deg
    n : int, default=160
        Number of excitation pulses (= phase encoding steps) per train.
    eff : float, default=0.96
        Efficiency of the inversion pulse.

    Noise
    -----
    sigma : float, optional
        Standard-deviation of the sampled Rician noise (no sampling if `None`)

    Returns
    -------
    sim : tensor
        Simulated MPRAGE image

    References
    ----------
    ..[1] "MP2RAGE, a self bias-field corrected sequence for improved
        segmentation and T1-mapping at high field."
        Marques JP, Kober T, Krueger G, van der Zwaag W, Van de Moortele PF, Gruetter R.
        Neuroimage. 2010 Jan 15;49(2):1271-81.
        doi: 10.1016/j.neuroimage.2009.10.002

    """

    pd, r1, r2s, transmit, receive, gfactor \
        = utils.to_max_backend(pd, r1, r2s, transmit, receive, gfactor)
    pd, r1, r2s, transmit, receive, gfactor \
        = utils.to(pd, r1, r2s, transmit, receive, gfactor, device=device)
    backend = utils.backend(pd)

    if tx is None and te is None:
        tx = 6e-3
    tx = tx or 2 * te  # Time between excitation pulses
    te = te or tx / 2  # Echo time
    fa = fa * constants.pi / 180  # Flip angle of GRE block
    n = n or min(pd.shape)  # Number of readouts (PE steps) per loop
    tr1 = n * tx  # GRE block
    tp = ti - tr1 / 2  # Preparation time
    td = tr - ti - tr1 / 2  # Recovery time
    m = n // 2  # Middle of echo train

    if transmit is not None:
        fa = transmit * fa
    del transmit
    fa = torch.as_tensor(fa, **backend)

    # precompute exponential terms
    ex = r1.mul(-tx).exp()
    ep = r1.mul(-tp).exp()
    ed = r1.mul(-td).exp()
    e1 = r1.mul(-tr).exp()
    c = fa.cos()

    # steady state
    s = (1 - ep) * (c * ex).pow(n)
    s = s + (1 - ex) * (1 - (c * ex).pow(n)) / (1 - c * ex)
    s = s * ed + (1 - ed)
    s = s * pd / (1 + eff * c.pow(n) * e1)

    # IR component
    s = -eff * s * ep / pd + (1 - ep)
    s = s * (c * ex).pow(m - 1)
    s = s + (1 - ex) * (1 - (c * ex).pow(m - 1)) / (1 - c * ex)
    s = s * fa.sin()
    s = s.abs()

    # Modulation (PD, B1-, R2*)
    if receive is not None:
        pd = pd * receive
    del receive

    s = s * pd

    if r2s is not None:
        e2 = r2s.mul(-te).exp_()
        s = s * e2
    del r2s

    # noise
    s = add_noise(s, std=sigma, gfactor=gfactor)
    return s
コード例 #16
0
def mp2rage_old(pd,
                r1,
                r2s=None,
                transmit=None,
                receive=None,
                gfactor=None,
                tr=6.25,
                ti1=0.8,
                ti2=2.2,
                tx=None,
                te=None,
                fa=(4, 5),
                n=160,
                eff=0.96,
                sigma=None,
                device=None,
                return_combined=True):
    """Simulate data generated by a (simplified) MP2RAGE sequence.

    The defaults are parameters used at 3T in the original MP2RAGE paper.
    However, I don't get a nice image with these parameters when applied
    to maps obtained at 3T with the hmri toolbox.
    Here are (unrealistic) parameters that seem to give a decent contrast:
    tr=6.25, ti1=1.4, ti2=4.5, tx=5.8e-3, fa=(4, 5), n=160, eff=0.96

    Tissue parameters
    -----------------
    pd : tensor_like
        Proton density
    r1 : tensor_like
        Longitudinal relaxation rate, in 1/sec
    r2s : tensor_like, optional
        Transverse relaxation rate, in 1/sec.
        If not provided, T2*-bias is not included.

    Fields
    ------
    transmit : tensor_like, optional
        Transmit B1 field
    receive : tensor_like, optional
        Receive B1 field
    gfactor : tensor_like, optional
        G-factor map.
        If provided and `sigma` is not `None`, the g-factor map is used
        to sample non-stationary noise.

    Sequence parameters
    -------------------
    tr : float default=6.25
        Full Repetition time, in sec.
        (Time between two inversion pulses)
    ti1 : float, default=0.8
        First inversion time, in sec.
        (Time between inversion pulse and middle of the first echo train)
    ti2 : float, default=2.2
        Second inversion time, in sec.
        (Time between inversion pulse and middle of the second echo train)
    tx : float, default=te*2 or 5.8e-3
        Excitation repetition time, in sec.
        (Time between two excitation pulses within the echo train)
    te : float, default=minitr/2
        Echo time, in sec.
    fa : float or (float, float), default=(4, 5)
        Flip angle of the first and second acquisition block, in deg
        If only one value, it is shared between the blocks.
    n : int, default=160
        Number of excitation pulses (= phase encoding steps) per train.
    eff : float, default=0.96
        Efficiency of the inversion pulse.

    Noise
    -----
    sigma : float, optional
        Standard-deviation of the sampled Rician noise (no sampling if `None`)

    Returns
    -------
    mp2rage : tensor, if return_combined is True
        Simulated MP2RAGE image

    image1 : tensor, if return_combined is False
        Image at first inversion time
    image2 : tensor, if return_combined is False
        Image at second inversion time

    References
    ----------
    ..[1] "MP2RAGE, a self bias-field corrected sequence for improved
        segmentation and T1-mapping at high field."
        Marques JP, Kober T, Krueger G, van der Zwaag W, Van de Moortele PF, Gruetter R.
        Neuroimage. 2010 Jan 15;49(2):1271-81.
        doi: 10.1016/j.neuroimage.2009.10.002

    """

    pd, r1, r2s, transmit, receive, gfactor \
        = utils.to_max_backend(pd, r1, r2s, transmit, receive, gfactor)
    pd, r1, r2s, transmit, receive, gfactor \
        = utils.to(pd, r1, r2s, transmit, receive, gfactor, device=device)
    backend = utils.backend(pd)

    if tx is None and te is None:
        tx = 5.8e-3
    tx = tx or 2 * te  # Time between excitation pulses
    te = te or tx / 2  # Echo time
    fa1, fa2 = py.make_list(fa, 2)
    fa1 = fa1 * constants.pi / 180  # Flip angle of first GRE block
    fa2 = fa2 * constants.pi / 180  # Flip angle of second GRE block
    n = n or min(pd.shape)  # Number of readouts (PE steps) per loop
    tr1 = n * tx  # First GRE block
    tr2 = n * tx  # Second GRE block
    tp = ti1 - tr1 / 2  # Preparation time
    tw = ti2 - tr2 / 2 - ti1 - tr1 / 2  # Wait time between GRE blocks
    td = tr - ti2 - tr2 / 2  # Recovery time
    m = n // 2  # Middle of echo train

    if transmit is not None:
        fa1 = transmit * fa1
        fa2 = transmit * fa2
    del transmit
    fa1 = torch.as_tensor(fa1, **backend)
    fa2 = torch.as_tensor(fa2, **backend)

    # precompute exponential terms
    ex = r1.mul(-tx).exp()
    ep = r1.mul(-tp).exp()
    ew = r1.mul(-tw).exp()
    ed = r1.mul(-td).exp()
    e1 = r1.mul(-tr).exp()
    c1 = fa1.cos()
    c2 = fa2.cos()

    # steady state
    mss = (1 - ep) * (c1 * ex).pow(n)
    mss = mss + (1 - ex) * (1 - (c1 * ex).pow(n)) / (1 - c1 * ex)
    mss = mss * ew + (1 - ew)
    mss = mss * (c2 * ex).pow(n)
    mss = mss + (1 - ex) * (1 - (c2 * ex).pow(n)) / (1 - c2 * ex)
    mss = mss * ed + (1 - ed)
    mss = mss * pd / (1 + eff * (c1 * c2).pow(n) * e1)

    # IR components
    mi1 = -eff * mss * ep / pd + (1 - ep)
    mi1 = mi1 * (c1 * ex).pow(m - 1)
    mi1 = mi1 + (1 - ex) * (1 - (c1 * ex).pow(m - 1)) / (1 - c1 * ex)
    mi1 = mi1 * fa1.sin()
    mi1 = mi1.abs()

    mi2 = (mss / pd - (1 - ed)) / (ed * (c2 * ex).pow(m))
    mi2 = mi2 + (1 - ex) * (1 - (c2 * ex).pow(-m)) / (1 - c2 * ex)
    mi2 = mi2 * fa2.sin()
    mi2 = mi2.abs()

    if return_combined and not sigma:
        m = (mi1 * mi2) / (mi1.square() + mi2.square())
        m = torch.where(~torch.isfinite(m), m.new_zeros([]), m)
        return m

    # Common component (pd, B1-, R2*)
    if receive is not None:
        pd = pd * receive
    del receive

    mi1 = mi1 * pd
    mi2 = mi2 * pd

    if r2s is not None:
        e2 = r2s.mul(-te).exp_()
        mi1 = mi1 * e2
        mi2 = mi2 * e2
    del r2s

    # noise
    mi1 = add_noise(mi1, std=sigma, gfactor=gfactor)
    mi2 = add_noise(mi2, std=sigma, gfactor=gfactor)

    if return_combined:
        m = (mi1 * mi2) / (mi1.square() + mi2.square())
        m = torch.where(~torch.isfinite(m), m.new_zeros([]), m)
        return m
    else:
        mi1 = torch.where(~torch.isfinite(mi1), mi1.new_zeros([]), mi1)
        mi2 = torch.where(~torch.isfinite(mi2), mi2.new_zeros([]), mi2)
        return mi1, mi2
コード例 #17
0
ファイル: mse.py プロジェクト: balbasty/nitorch
def mse(moving, fixed, lam=1, dim=None, grad=True, hess=True, mask=None):
    """Mean-squared error loss for optimisation-based registration.

    (A factor 1/2 is included, and the loss is averaged across voxels,
    but not across channels or batches)

    Parameters
    ----------
    moving : ([B], K, *spatial) tensor
        Moving image
    fixed : ([B], K, *spatial) tensor
        Fixed image
    lam : float or ([B], K|1, [*spatial]) tensor_like
        Gaussian noise precision (or IRLS weights)
    dim : int, default=`fixed.dim() - 1`
        Number of spatial dimensions
    grad : bool, default=True
        Compute and return gradient
    hess : bool, default=True
        Compute and return Hessian

    Returns
    -------
    ll : () tensor
        Negative log-likelihood
    g : (..., K, *spatial) tensor, optional
        Gradient with respect to the moving imaged
    h : (..., K, *spatial) tensor, optional
        (Diagonal) Hessian with respect to the moving image

    """
    fixed, moving, lam = utils.to_max_backend(fixed, moving, lam)
    if mask is not None:
        mask = mask.to(fixed.device)
    dim = dim or (fixed.dim() - 1)
    if lam.dim() <= 2:
        if lam.dim() == 0:
            lam = lam.flatten()
        lam = utils.unsqueeze(lam, -1, dim)  # pad spatial dimensions
    nvox = py.prod(fixed.shape[-dim:])

    if moving.requires_grad:
        ll = moving - fixed
        if mask is not None:
            ll = ll.mul_(mask)
        ll = ll.square().mul_(lam).sum() / (2 * nvox)
    else:
        ll = moving - fixed
        if mask is not None:
            ll = ll.mul_(mask)
        ll = ll.square_().mul_(lam).sum() / (2 * nvox)

    out = [ll]
    if grad:
        g = moving - fixed
        if mask is not None:
            g = g.mul_(mask)
        g = g.mul_(lam).div_(nvox)
        out.append(g)
    if hess:
        h = lam / nvox
        if mask is not None:
            h = mask * h
        out.append(h)

    return tuple(out) if len(out) > 1 else out[0]
コード例 #18
0
ファイル: conversions.py プロジェクト: liamchalcroft/nitorch
 def __init__(self, x, y, z, c):
     x, y, z, c = utils.to_max_backend(x, y, z, c)
     self.x = x
     self.y = y
     self.z = z
     self.c = c
コード例 #19
0
def intensity_preproc(*images, min=None, max=None, eq=None):
    """(Joint) rescaling and intensity equalizing.

    Parameters
    ----------
    *images : (*batch, H, W) tensor
        Input (batch of) 2d images.
        All batch shapes should be broadcastable together.
    min : tensor_like, optional
        Minimum value. Should be broadcastable to batch.
        Default: 5th percentile of each batch element.
    max : tensor_like, optional
        Maximum value. Should be broadcastable to batch.
        Default: 95th percentile of each batch element.
    eq : {'linear', 'quadratic', 'log', None} or float, default=None
        Apply histogram equalization.
        If 'quadratic' or 'log', the histogram of the transformed signal
        is equalized.
        If float, the signal is taken to that power before being equalized.

    Returns
    -------
    *images : (*batch, H, W) tensor
        Preprocessed images.
        Intensities are scaled within [0, 1].

    """

    if len(images) == 1:
        images = [utils.to_max_backend(*images)]
    else:
        images = utils.to_max_backend(*images)
    backend = utils.backend(images[0])
    eps = constants.eps(images[0].dtype)

    # rescale min/max
    min = py.make_list(min, len(images))
    max = py.make_list(max, len(images))
    min = [
        utils.quantile(image, 0.05, bins=2048, dim=[-1, -2], keepdim=True)
        if mn is None else torch.as_tensor(mn, **backend)[None, None]
        for image, mn in zip(images, min)
    ]
    min, *othermin = min
    for mn in othermin:
        min = torch.min(min, mn)
    del othermin
    max = [
        utils.quantile(image, 0.95, bins=2048, dim=[-1, -2], keepdim=True)
        if mx is None else torch.as_tensor(mx, **backend)[None, None]
        for image, mx in zip(images, max)
    ]
    max, *othermax = max
    for mx in othermax:
        max = torch.max(max, mx)
    del othermax
    images = [torch.max(torch.min(image, max), min) for image in images]
    images = [
        image.mul_(1 / (max - min + eps)).add_(1 / (1 - max / min))
        for image in images
    ]

    if not eq:
        return tuple(images) if len(images) > 1 else images[0]

    # reshape and concatenate
    batch = utils.expanded_shape(*[image.shape[:-2] for image in images])
    images = [image.expand([*batch, *image.shape[-2:]]) for image in images]
    shapes = [image.shape[-2:] for image in images]
    chunks = [py.prod(s) for s in shapes]
    images = [image.reshape([*batch, c]) for image, c in zip(images, chunks)]
    images = torch.cat(images, dim=-1)

    if eq is True:
        eq = 'linear'
    if not isinstance(eq, str):
        if eq >= 0:
            images = images.pow(eq)
        else:
            images = images.clamp_min_(constants.eps(images.dtype)).pow(eq)
    elif eq.startswith('q'):
        images = images.square()
    elif eq.startswith('log'):
        images = images.clamp_min_(constants.eps(images.dtype)).log()

    images = histeq(images, dim=-1)

    if not (isinstance(eq, str) and eq.startswith('lin')):
        # rescale min/max
        images -= math.min(images, dim=-1, keepdim=True)
        images /= math.max(images, dim=-1, keepdim=True)

    images = images.split(chunks, dim=-1)
    images = [image.reshape(*batch, *s) for image, s in zip(images, shapes)]

    return tuple(images) if len(images) > 1 else images[0]
コード例 #20
0
ファイル: conversions.py プロジェクト: liamchalcroft/nitorch
 def __init__(self, axis, angle):
     axis, angle = utils.to_max_backend(axis, angle)
     self.ax = axis
     self.theta = angle
コード例 #21
0
ファイル: cc.py プロジェクト: balbasty/nitorch
def cc(moving, fixed, dim=None, grad=True, hess=True, mask=None):
    """Squared Pearson's correlation coefficient loss

        1 - (E[(x - mu_x)'(y - mu_y)]/(s_x * s_y)) ** 2

    Parameters
    ----------
    moving : (..., K, *spatial) tensor
        Moving image with K channels.
    fixed : (..., K, *spatial) tensor
        Fixed image with K channels.
    dim : int, default=`fixed.dim() - 1`
        Number of spatial dimensions.
    grad : bool, default=True
        Compute an return gradient
    hess : bool, default=True
        Compute and return approximate Hessian

    Returns
    -------
    ll : () tensor

    """
    moving, fixed = utils.to_max_backend(moving, fixed)
    moving = moving.clone()
    fixed = fixed.clone()
    dim = dim or (fixed.dim() - 1)
    dims = list(range(-dim, 0))

    if mask is not None:
        mask = mask.to(fixed.device)
        mean = lambda x: (x * mask).sum(dim=dims, keepdim=True).div_(
            mask.sum(dim=dims, keepdim=True))
    else:
        mean = lambda x: x.mean(dim=dims, keepdim=True)

    n = py.prod(fixed.shape[-dim:])
    moving -= mean(moving)
    fixed -= mean(fixed)
    sigm = mean(moving.square()).sqrt_()
    sigf = mean(fixed.square()).sqrt_()
    moving = moving.div_(sigm)
    fixed = fixed.div_(sigf)

    corr = mean(moving * fixed)
    corr2 = 1 - corr.square()
    corr2.clamp_min_(1e-8)

    out = []
    if grad:
        g = 2 * corr * (moving * corr - fixed) / (n * sigm)
        g /= corr2  # chain rule for log
        if mask is not None:
            g = g.mul_(mask)
        out.append(g)

    if hess:
        # approximate hessian
        h = 2 * (corr / sigm).square() / n
        h /= corr2  # chain rule for log
        if mask is not None:
            h = h * mask
        out.append(h)

    # return stuff
    corr = corr2.log_().sum()
    out = [corr, *out]
    return tuple(out) if len(out) > 1 else out[0]
コード例 #22
0
ファイル: conversions.py プロジェクト: liamchalcroft/nitorch
    def __init__(self, shift, scale, orientation='RAS'):

        shift, scale = utils.to_max_backend(shift, scale)
        self.shift = shift
        self.scale = scale
        self.orientation = orientation
コード例 #23
0
def greens_apply(mom, greens, factor=1, voxel_size=1):
    """Apply the Greens function to a momentum field.

    Parameters
    ----------
    mom : (..., *spatial, dim) tensor
        Momentum
    greens : (*spatial, [dim, dim]) tensor
        Greens function
    voxel_size : [sequence of] float, default=1
        Voxel size. Only needed when no penalty is put on linear-elasticity.

    Returns
    -------
    vel : (..., *spatial, dim) tensor
        Velocity

    """
    # Authors
    # -------
    # .. John Ashburner <*****@*****.**> : original Matlab code
    # .. Yael Balbastre <*****@*****.**> : Python port
    #
    # License
    # -------
    # The original Matlab code is (C) 2012-2019 WCHN / John Ashburner
    # and was distributed as part of [SPM](https://www.fil.ion.ucl.ac.uk/spm)
    # under the GNU General Public Licence (version >= 2).

    mom, greens = utils.to_max_backend(mom, greens)
    dim = mom.shape[-1]

    # fourier transform
    mom = fft.fftn(mom, dim=list(range(-dim - 1, -1)), real=True)

    # mom = utils.movedim(mom, -1, 0)
    # if utils.torch_version('>=', (1, 8)):
    #     mom = torch.fft.fftn(mom, dim=list(range(-dim, 0)))
    # else:
    #     if torch.backends.mkl.is_available:
    #         # use rfft
    #         mom = torch.rfft(mom, dim, onesided=False)
    #     else:
    #         zero = mom.new_zeros([]).expand(mom.shape)
    #         mom = torch.stack([mom, zero], dim=-1)
    #         mom = torch.fft(mom, dim)
    # mom = utils.movedim(mom, 0, -1)

    # voxel-wise matrix multiplication
    # if greens.dim() == dim:
    #     voxel_size = utils.make_vector(voxel_size, dim, **utils.backend(mom))
    #     voxel_size = voxel_size.square()
    #     if utils.torch_version('<', (1, 8)):
    #         greens = greens[..., None, None]
    #     mom = mom * greens
    #     mom = mom / voxel_size
    # else:
    #     if utils.torch_version('<', (1, 8)):
    #         mom[..., 0, :] = linalg.matvec(greens, mom[..., 0, :])
    #         mom[..., 1, :] = linalg.matvec(greens, mom[..., 1, :])
    #     else:
    #         mom = torch.complex(linalg.matvec(greens, mom.real),
    #                             linalg.matvec(greens, mom.imag))

    if greens.dim() == dim:
        voxel_size = utils.make_vector(voxel_size, dim, **utils.backend(mom))
        voxel_size = voxel_size.square().reciprocal()
        greens = greens.unsqueeze(-1)
        mom = fft.mul(mom, greens, real=(False, True))
        mom = fft.mul(mom, voxel_size, real=(False, True))
    else:
        mom = fft.mul(mom, greens, real=(False, True))

    # inverse fourier transform
    # mom = utils.movedim(mom, -1, 0)
    # if utils.torch_version('>=', (1, 8)):
    #     mom = torch.fft.ifftn(mom, dim=list(range(-dim, 0))).real
    #     if callable(mom):
    #         mom = mom()
    # else:
    #     mom = torch.ifft(mom, dim)[..., 0]
    # mom = utils.movedim(mom, 0, -1)

    mom = fft.real(fft.ifftn(mom, dim=list(range(-dim - 1, -1))))
    mom /= factor

    return mom
コード例 #24
0
ファイル: spgr.py プロジェクト: balbasty/nitorch
def spgr(pd,
         r1,
         r2s=None,
         mt=None,
         transmit=None,
         receive=None,
         gfactor=None,
         te=0,
         tr=25e-3,
         fa=20,
         sigma=None,
         device=None):
    """Simulate data generated by a Spoiled Gradient-Echo (SPGR/FLASH) sequence.

    Tissue parameters
    -----------------
    pd : tensor_like
        Proton density
    r1 : tensor_like
        Longitudinal relaxation rate, in 1/sec
    r2s : tensor_like, optional
        Transverse relaxation rate, in 1/sec. Mandatory if any `te > 0`.
    mt : tensor_like, optional
        MTsat. Mandatory if any `mtpulse == True`.

    Fields
    ------
    transmit : tensor_like, optional
        Transmit B1 field
    receive : tensor_like, optional
        Receive B1 field
    gfactor : tensor_like, optional
        G-factor map.
        If provided and `sigma` is not `None`, the g-factor map is used
        to sample non-stationary noise.

    Sequence parameters
    -------------------
    te : float, default=0
        Echo time, in sec
    tr : float default=2.5e-3
        Repetition time, in sec
    fa : float, default=20
        Flip angle, in deg

    Noise
    -----
    sigma : float, optional
        Standard-deviation of the sampled Rician noise (no sampling if `None`)
    Returns
    -------
    sim : tensor
        Simulated SPGR image

    """
    pd, r1, r2s, mt, transmit, receive, gfactor \
        = utils.to_max_backend(pd, r1, r2s, mt, transmit, receive, gfactor)
    pd, r1, r2s, mt, transmit, receive, gfactor \
        = utils.to(pd, r1, r2s, mt, transmit, receive, gfactor, device=device)
    backend = utils.backend(pd)

    fa = fa * constants.pi / 180.
    if transmit is not None:
        fa = fa * transmit
    del transmit
    fa = torch.as_tensor(fa, **backend)

    if receive is not None:
        pd = pd * receive
    del receive
    pd = pd * fa.sin()
    fa = fa.cos()

    e1, r1 = r1.mul(tr).neg_().exp(), None
    signal = pd * (1 - e1)

    if mt is not None:
        omt = mt.neg().add_(1)
        signal *= omt
        signal /= (1 - fa * omt * e1)
        del omt
    else:
        signal /= (1 - fa * e1)

    if r2s is not None:
        e2, r2s = r2s.mul(te).neg_().exp(), None
        signal *= e2
        del e2

    # noise
    signal = add_noise(signal, std=sigma)
    return signal
コード例 #25
0
ファイル: gmm.py プロジェクト: balbasty/nitorch
def lgmmh(moving, fixed, dim=None, bins=3, patch=7, stride=1,
          grad=True, hess=True, mode='g', max_iter=128,
          theta=None, return_theta=False):

    fixed, moving = utils.to_max_backend(fixed, moving)
    dim = dim or (fixed.dim() - 1)
    shape = fixed.shape[-dim:]

    if not isinstance(patch, (list, tuple)):
        patch = [patch]
    patch = list(patch)
    if not isinstance(stride, (list, tuple)):
        stride = [stride]
    stride = [s or 0 for s in stride]

    fwd = Fwd(patch, stride, dim, mode)
    bwd = Bwd(patch, stride, dim, mode, shape)

    gmmfit = fit_lgmm2(moving, fixed, bins, max_iter, dim,
                       patch=patch, stride=stride, mode=mode, theta=theta)

    # drop unused variables
    get = gmmfit.get if return_theta else gmmfit.pop
    pop = gmmfit.pop

    z = pop('resp')
    moving_mean = get('xmean')
    fixed_mean = get('ymean')
    moving_var = get('xvar')
    fixed_var = get('yvar')
    corr = get('corr')
    prior = get('prior')
    out = [pop('nll')]
    nvox = py.prod(z.shape[-dim:])

    moving = moving.unsqueeze(-dim-1)
    fixed = fixed.unsqueeze(-dim-1)

    if grad:
        z0 = fwd(z, None).clamp_min_(1e-10)

        # gradient of the GMM entropy
        # L = 0.5 * \sum_k pi_k log|\Sigma_k| + cte
        @torch.jit.script
        def make_grad(bwd: Bwd, z, z0, moving, fixed, moving_mean, fixed_mean,
                      moving_var, fixed_var, corr, prior) -> Tensor:
            cov = corr * (moving_var * fixed_var).sqrt()
            idet = moving_var * fixed_var * (1 - corr * corr)
            idet = prior / idet
            # gradient of determinant + chain rule of log
            g = moving * bwd(fixed_var * idet, z, z0) - fixed * bwd(cov * idet, z, z0)
            g -= bwd((moving_mean * fixed_var - fixed_mean * cov) * idet, z, z0)
            g = g.sum(-bwd.dim-1)
            return g

        g = make_grad(bwd, z, z0, moving, fixed, moving_mean, fixed_mean,
                      moving_var, fixed_var, corr, prior)
        g.div_(nvox)
        out.append(g)

        if hess:
            # # hessian of (1 - corr^2)
            # imoving_var = moving_var.reciprocal()
            # corr2 = corr * corr
            # h = corr2 * imoving_var
            # # chain rule (with Fisher's scoring)
            # h /= 1 - corr2
            # # hessian of log(moving_var)
            # h += imoving_var
            # # weight by proportion and sum
            # h = h * (z * prior)
            # h = h.sum(-1)

            @torch.jit.script
            def make_hess(bwd: Bwd, z, z0, moving_var, corr, prior) -> Tensor:
                h = (1 - z0) * prior / (moving_var * (1 - corr * corr))
                h = bwd(h, z, z0)
                h = h.sum(-bwd.dim-1)
                return h

            h = make_hess(bwd, z, z0, moving_var, corr, prior)
            h.div_(nvox)
            out.append(h)

    if return_theta:
        out.append(gmmfit)

    return out[0] if len(out) == 1 else tuple(out)
コード例 #26
0
def dice_nolog(moving, fixed, dim=None, grad=True, hess=True, mask=None,
               add_background=False, weighted=False):
    """Dice loss for optimisation-based registration.

    Parameters
    ----------
    moving : (..., K, *spatial) tensor
        Moving image of probabilities (post-softmax).
        The background class should be omitted.
    fixed : (..., K, *spatial) tensor
        Fixed image of probabilities.
    dim : int, default=`fixed.dim() - 1`
        Number of spatial dimensions.
    grad : bool, default=True
        Compute and return gradient
    hess : bool, default=True
        Compute and return Hessian
    mask : (..., *spatial) tensor, optional
        Mask of voxels to include in the loss (all by default)
    add_background : bool, default=False
        Include the Dice of the (implicit) background class in the loss.
    weighted : bool or tensor, default=False
        Weights for each class. If True, weight by positive rate.

    Returns
    -------
    ll : () tensor
        Negative log-likelihood
    g : (..., K, *spatial) tensor, optional
        Gradient with respect to the moving image.
    h : (..., K, *spatial) tensor, optional
        Hessian with respect to the moving image.

    """
    fixed, moving = utils.to_max_backend(fixed, moving)
    dim = dim or (fixed.dim() - 1)
    nc = moving.shape[-dim-1]                               # nb classes - bck
    fixed = utils.slice_tensor(fixed, slice(nc), -dim-1)    # remove bkg class
    if mask is not None:
        mask = mask.to(moving.device)
        nvox = mask.sum(range(-dim-1), keepdim=True)
    else:
        nvox = py.prod(fixed.shape[-dim:])

    @torch.jit.script
    def rescale(x, dim_channel: int, add_background: bool = False):
        """Ensure that a tensor is in [0, 1]"""
        x = x.clamp_min(0)
        x = x / x.sum(dim_channel, keepdim=True).clamp_min_(1)
        if add_background:
            x = torch.stack([x, 1 - x.sum(dim_channel, keepdim=True)], dim_channel)
        return x

    moving = rescale(moving, -dim-1, add_background)
    fixed = rescale(fixed, -dim-1, add_background)
    if mask is not None:
        moving *= mask
        fixed *= mask

    if weighted is True:
            weighted = fixed.sum(list(range(-dim, 0)), keepdim=True).div_(nvox)
    elif weighted is not False:
        weighted = torch.as_tensor(weighted, **utils.backend(moving))
        for _ in range(dim):
            weighted = weighted.unsqueeze(-1)
    else:
        weighted = None

    @torch.jit.script
    def loss_components(moving, fixed, dim: int, weighted: Optional[Tensor] = None):
        """Compute the (negative) DiceLoss, (positive) Dice and union"""
        dims = [d for d in range(-dim, 0)]
        overlap = (moving * fixed).sum(dims, keepdim=True)
        union = (moving + fixed).sum(dims, keepdim=True)
        union += 1e-5
        dice = 2 * overlap / union
        if weighted is not None:
            ll = 1 - weighted * dice
        else:
            ll = 1 - dice
        ll = ll.sum()
        return ll, dice, union

    ll, dice, union = loss_components(moving, fixed, dim, weighted)
    out = [ll]

    # gradient
    if grad:
        @torch.jit.script
        def do_grad(dice, fixed, union):
            return (dice - 2 * fixed) / union
        g = do_grad(dice, fixed, union)
        if weighted is not None:
            g *= weighted
        if add_background:
            g_last = utils.slice_tensor(g, slice(-1, None), -dim-1)
            g = utils.slice_tensor(g, slice(-1), -dim-1)
            g -= g_last
        if mask is not None:
            g *= mask
        out.append(g)

    # hessian
    if hess:
        @torch.jit.script
        def do_hess(dice, fixed, union, nvox, dim: int):
            dims = [d for d in range(-dim, 0)]
            positive_rate = fixed.sum(dims, keepdim=True) / nvox
            h = (dice - fixed - positive_rate).abs()
            h = 2 * nvox * h / union.square()
            return h
        nvox = torch.as_tensor(nvox, device=moving.device)
        h = do_hess(dice, fixed, union, nvox, dim)
        if weighted is not None:
            h *= weighted
        if add_background:
            h_foreground = utils.slice_tensor(h, slice(-1), -dim-1)
            h = utils.slice_tensor(h, slice(-1, None), -dim-1)  # h background
            hshape = list(h.shape)
            hshape[-dim-1] = nc*(nc+1)//2
            h = h.expand(hshape).clone()
            diag = utils.slice_tensor(h, range(nc), -dim-1)
            diag += h_foreground
        if mask is not None:
            h *= mask
        out.append(h)

    return tuple(out) if len(out) > 1 else out[0]