Esempio n. 1
0
def _to_gradient_magnitudes(dat, mat, scl):
    """ Compute squared gradient magnitudes (modulated with scaling and voxel size).

    OBS: Replaces the image data in dat.

    Parameters
    ----------
    dat : (X, Y, Z) tensor_like
        Image data.
    mat : (4, 4) tensor_like
        Affine matrix.
    scl : (N, ) tensor_like
        Gradient scaling parameter.

    Returns
    ----------
    dat : (X, Y, Z) tensor_like
        Squared gradient magnitudes.

    """
    # Get voxel size
    vx = voxel_size(mat)
    gr = scl*im_gradient(dat, vx=vx, which='forward', bound='zero')
    # Square gradients
    gr = torch.sum(gr**2, dim=0)
    dat = gr

    return dat
Esempio n. 2
0
def _DtD(dat, vx_y, bound='zero', diff='forward'):
    """ Computes the divergence of the gradient.

    Args:
        dat (torch.tensor()): A tensor (dim_y).
        vx_y (tuple(float)): Output voxel size.
        bound (str, optional): Bound for gradient/divergence calculation, defaults to
            constant zero.
        diff (str, optional): Gradient difference operator, defaults to 'forward'.

    Returns:
          div (torch.tensor()): Dt(D(dat)) (dim_y).

    """
    dat = im_gradient(dat, vx=vx_y, bound=bound, which=diff)
    dat = im_divergence(dat, vx=vx_y, bound=bound, which=diff)

    return dat
Esempio n. 3
0
def _compute_nll(x, y, sett, rho, sum_dtype=torch.float64):
    """ Compute negative model log-likelihood.

    Args:
        rho (torch.Tensor): ADMM step size.
        sum_dtype (torch.dtype): Defaults to torch.float64.

    Returns:
        nll_yx (torch.tensor()): Negative log-posterior
        nll_xy (torch.tensor()): Negative log-likelihood.
        nll_y (torch.tensor()): Negative log-prior.

    """
    vx_y = voxel_size(y[0].mat).float()
    nll_xy = torch.tensor(0, device=sett.device, dtype=torch.float64)
    for c in range(len(x)):
        # Neg. log-likelihood term
        for n in range(len(x[c])):
            msk = x[c][n].dat != 0
            Ay = _proj('A',
                       y[c].dat,
                       x[c],
                       y[c],
                       n=n,
                       method=sett.method,
                       do=sett.do_proj,
                       bound=sett.bound,
                       interpolation=sett.interpolation)
            nll_xy += 0.5 * x[c][n].tau * torch.sum(
                (x[c][n].dat[msk] - Ay[msk])**2, dtype=sum_dtype)
        # Neg. log-prior term
        Dy = y[c].lam * im_gradient(
            y[c].dat, vx=vx_y, bound=sett.bound, which=sett.diff)
        if c > 0:
            nll_y += torch.sum(Dy**2, dim=0)
        else:
            nll_y = torch.sum(Dy**2, dim=0)

    nll_y = torch.sum(torch.sqrt(nll_y), dtype=sum_dtype)

    return nll_xy + nll_y, nll_xy, nll_y
Esempio n. 4
0
def _update_admm(x, y, z, w, rho, tmp, obj, n_iter, sett):
    """


    """
    # Parameters
    vx_y = voxel_size(y[0].mat).float()  # Output voxel size
    # Constants
    tiny = torch.tensor(1e-7, dtype=torch.float32, device=sett.device)
    one = torch.tensor(1, dtype=torch.float32, device=sett.device)
    # Over/under-relaxation parameter
    alpha = torch.tensor(sett.alpha, device=sett.device, dtype=torch.float32)

    # ----------
    # UPDATE: y
    # ----------
    t0 = _print_info('fit-update', sett, 'y', n_iter)  # PRINT
    for c in range(len(x)):  # Loop over channels
        # RHS
        tmp[:] = 0
        for n in range(len(x[c])):  # Loop over observations of channel 'c'
            # _ = _print_info('int', sett, n)  # PRINT
            tmp += x[c][n].tau * _proj('At',
                                       x[c][n].dat,
                                       x[c],
                                       y[c],
                                       method=sett.method,
                                       do=sett.do_proj,
                                       n=n,
                                       bound=sett.bound,
                                       interpolation=sett.interpolation)

        # Divergence
        div = w[c, ...] - rho * z[c, ...]
        div = im_divergence(div, vx=vx_y, bound=sett.bound, which=sett.diff)
        tmp -= y[c].lam * div

        # Get CG preconditioner
        # precond = _precond(x[c], y[c], rho, sett)
        precond = lambda x: x

        # Invert y = lhs\tmp by conjugate gradients
        lhs = lambda dat: _proj('AtA',
                                dat,
                                x[c],
                                y[c],
                                method=sett.method,
                                do=sett.do_proj,
                                rho=rho,
                                vx_y=vx_y,
                                bound=sett.bound,
                                interpolation=sett.interpolation,
                                diff=sett.diff)
        cg(A=lhs,
           b=tmp,
           x=y[c].dat,
           verbose=sett.cgs_verbose,
           max_iter=sett.cgs_max_iter,
           stop='residuals',
           inplace=True,
           precond=precond,
           tolerance=sett.cgs_tol)  # OBS: y[c].dat is here updated in-place

        _ = _print_info('int', sett, c)  # PRINT

    _ = _print_info('fit-done', sett, t0)  # PRINT

    # ----------
    # Compute model objective function
    # ----------
    if sett.tolerance > 0:
        obj[n_iter,
            0], obj[n_iter,
                    1], obj[n_iter,
                            2] = _compute_nll(x, y, sett,
                                              rho)  # nl_pyx, nl_pxy, nl_py

    # ----------
    # UPDATE: z
    # ----------
    if alpha != 1:  # Use over/under-relaxation
        z_old = z.clone()
    t0 = _print_info('fit-update', sett, 'z', n_iter)  # PRINT
    tmp[:] = 0
    for c in range(len(x)):
        Dy = y[c].lam * im_gradient(
            y[c].dat, vx=vx_y, bound=sett.bound, which=sett.diff)
        if alpha != 1:  # Use over/under-relaxation
            Dy = alpha * Dy + (one - alpha) * z_old[c, ...]
        tmp += torch.sum((w[c, ...] / rho + Dy)**2, dim=0)
    tmp.sqrt_()  # in-place
    tmp = ((tmp - one / rho).clamp_min(0)) / (tmp + tiny)

    for c in range(len(x)):
        Dy = y[c].lam * im_gradient(
            y[c].dat, vx=vx_y, bound=sett.bound, which=sett.diff)
        if alpha != 1:  # Use over/under-relaxation
            Dy = alpha * Dy + (one - alpha) * z_old[c, ...]
        for d in range(Dy.shape[0]):
            z[c, d, ...] = tmp * (w[c, d, ...] / rho + Dy[d, ...])
    _ = _print_info('fit-done', sett, t0)  # PRINT

    # ----------
    # UPDATE: w
    # ----------
    t0 = _print_info('fit-update', sett, 'w', n_iter)  # PRINT
    for c in range(len(x)):  # Loop over channels
        Dy = y[c].lam * im_gradient(
            y[c].dat, vx=vx_y, bound=sett.bound, which=sett.diff)
        if alpha != 1:  # Use over/under-relaxation
            Dy = alpha * Dy + (one - alpha) * z_old[c, ...]
        w[c, ...] += rho * (Dy - z[c, ...])
        _ = _print_info('int', sett, c)  # PRINT
    _ = _print_info('fit-done', sett, t0)  # PRINT

    return y, z, w, tmp, obj