def kernel_latent_time_graphical_lasso(
    emp_cov,
    alpha=0.01,
    tau=1.0,
    rho=1.0,
    kernel_psi=None,
    kernel_phi=None,
    max_iter=100,
    verbose=False,
    psi="laplacian",
    phi="laplacian",
    mode="admm",
    tol=1e-4,
    rtol=1e-4,
    assume_centered=False,
    n_samples=None,
    return_history=False,
    return_n_iter=True,
    update_rho_options=None,
    compute_objective=True,
    init="empirical",
):
    r"""Time-varying latent variable graphical lasso solver.

    Solves the following problem via ADMM:
        min sum_{i=1}^T -n_i log_likelihood(K_i-L_i) + alpha ||K_i||_{od,1}
            + tau ||L_i||_*
            + sum_{s>t}^T k_psi(s,t) Psi(K_s - K_t)
            + sum_{s>t}^T k_phi(s,t)(L_s - L_t)

    where S is the empirical covariance of the data
    matrix D (training observations by features).

    Parameters
    ----------
    emp_cov : ndarray, shape (n_features, n_features)
        Empirical covariance of data.
    alpha, tau, beta, eta : float, optional
        Regularisation parameters.
    rho : float, optional
        Augmented Lagrangian parameter.
    max_iter : int, optional
        Maximum number of iterations.
    tol : float, optional
        Absolute tolerance for convergence.
    rtol : float, optional
        Relative tolerance for convergence.
    return_history : bool, optional
        Return the history of computed values.

    Returns
    -------
    K, L : numpy.array, 3-dimensional (T x d x d)
        Solution to the problem for each time t=1...T .
    history : list
        If return_history, then also a structure that contains the
        objective value, the primal and dual residual norms, and tolerances
        for the primal and dual residual norms at each iteration.

    """
    psi, prox_psi, psi_node_penalty = check_norm_prox(psi)
    phi, prox_phi, phi_node_penalty = check_norm_prox(phi)
    n_times, _, n_features = emp_cov.shape

    if kernel_psi is None:
        kernel_psi = np.eye(n_times)
    if kernel_phi is None:
        kernel_phi = np.eye(n_times)

    Z_0 = init_precision(emp_cov, mode=init)
    W_0 = np.zeros_like(Z_0)
    X_0 = np.zeros_like(Z_0)
    R_old = np.zeros_like(Z_0)

    Z_M, Z_M_old = {}, {}
    Y_M = {}
    W_M, W_M_old = {}, {}
    U_M = {}
    for m in range(1, n_times):
        Z_L = Z_0.copy()[:-m]
        Z_R = Z_0.copy()[m:]
        Z_M[m] = (Z_L, Z_R)

        W_L = np.zeros_like(Z_L)
        W_R = np.zeros_like(Z_R)
        W_M[m] = (W_L, W_R)

        Y_L = np.zeros_like(Z_L)
        Y_R = np.zeros_like(Z_R)
        Y_M[m] = (Y_L, Y_R)

        U_L = np.zeros_like(W_L)
        U_R = np.zeros_like(W_R)
        U_M[m] = (U_L, U_R)

        Z_L_old = np.zeros_like(Z_L)
        Z_R_old = np.zeros_like(Z_R)
        Z_M_old[m] = (Z_L_old, Z_R_old)

        W_L_old = np.zeros_like(W_L)
        W_R_old = np.zeros_like(W_R)
        W_M_old[m] = (W_L_old, W_R_old)

    if n_samples is None:
        n_samples = np.ones(n_times)

    checks = []
    for iteration_ in range(max_iter):
        # update R
        A = Z_0 - W_0 - X_0
        A += A.transpose(0, 2, 1)
        A /= 2.0
        A *= -rho / n_samples[:, None, None]
        A += emp_cov
        # A = emp_cov / rho - A

        R = np.array(
            [prox_logdet(a, lamda=ni / rho) for a, ni in zip(A, n_samples)])

        # update Z_0
        A = R + W_0 + X_0
        for m in range(1, n_times):
            A[:-m] += Z_M[m][0] - Y_M[m][0]
            A[m:] += Z_M[m][1] - Y_M[m][1]

        A /= n_times
        Z_0 = soft_thresholding(A, lamda=alpha / (rho * n_times))

        # update W_0
        A = Z_0 - R - X_0
        for m in range(1, n_times):
            A[:-m] += W_M[m][0] - U_M[m][0]
            A[m:] += W_M[m][1] - U_M[m][1]

        A /= n_times
        A += A.transpose(0, 2, 1)
        A /= 2.0

        W_0 = np.array(
            [prox_trace_indicator(a, lamda=tau / (rho * n_times)) for a in A])

        # update residuals
        X_0 += R - Z_0 + W_0

        for m in range(1, n_times):
            # other Zs
            Y_L, Y_R = Y_M[m]
            A_L = Z_0[:-m] + Y_L
            A_R = Z_0[m:] + Y_R
            if not psi_node_penalty:
                prox_e = prox_psi(A_R - A_L,
                                  lamda=2.0 *
                                  np.diag(kernel_psi, m)[:, None, None] / rho)
                Z_L = 0.5 * (A_L + A_R - prox_e)
                Z_R = 0.5 * (A_L + A_R + prox_e)
            else:
                Z_L, Z_R = prox_psi(
                    np.concatenate((A_L, A_R), axis=1),
                    lamda=0.5 * np.diag(kernel_psi, m)[:, None, None] / rho,
                    rho=rho,
                    tol=tol,
                    rtol=rtol,
                    max_iter=max_iter,
                )
            Z_M[m] = (Z_L, Z_R)

            # update other residuals
            Y_L += Z_0[:-m] - Z_L
            Y_R += Z_0[m:] - Z_R

            # other Ws
            U_L, U_R = U_M[m]
            A_L = W_0[:-m] + U_L
            A_R = W_0[m:] + U_R
            if not phi_node_penalty:
                prox_e = prox_phi(A_R - A_L,
                                  lamda=2.0 *
                                  np.diag(kernel_phi, m)[:, None, None] / rho)
                W_L = 0.5 * (A_L + A_R - prox_e)
                W_R = 0.5 * (A_L + A_R + prox_e)
            else:
                W_L, W_R = prox_phi(
                    np.concatenate((A_L, A_R), axis=1),
                    lamda=0.5 * np.diag(kernel_phi, m)[:, None, None] / rho,
                    rho=rho,
                    tol=tol,
                    rtol=rtol,
                    max_iter=max_iter,
                )
            W_M[m] = (W_L, W_R)

            # update other residuals
            U_L += W_0[:-m] - W_L
            U_R += W_0[m:] - W_R

        # diagnostics, reporting, termination checks
        rnorm = np.sqrt(
            squared_norm(R - Z_0 + W_0) + sum(
                squared_norm(Z_0[:-m] - Z_M[m][0]) +
                squared_norm(Z_0[m:] - Z_M[m][1]) +
                squared_norm(W_0[:-m] - W_M[m][0]) +
                squared_norm(W_0[m:] - W_M[m][1]) for m in range(1, n_times)))

        snorm = rho * np.sqrt(
            squared_norm(R - R_old) + sum(
                squared_norm(Z_M[m][0] - Z_M_old[m][0]) +
                squared_norm(Z_M[m][1] - Z_M_old[m][1]) +
                squared_norm(W_M[m][0] - W_M_old[m][0]) +
                squared_norm(W_M[m][1] - W_M_old[m][1])
                for m in range(1, n_times)))

        obj = (objective(emp_cov, n_samples, R, Z_0, Z_M, W_0, W_M, alpha, tau,
                         kernel_psi, kernel_phi, psi, phi)
               if compute_objective else np.nan)

        check = convergence(
            obj=obj,
            rnorm=rnorm,
            snorm=snorm,
            e_pri=n_features * np.sqrt(n_times * (2 * n_times - 1)) * tol +
            rtol * max(
                np.sqrt(
                    squared_norm(R) + sum(
                        squared_norm(Z_M[m][0]) + squared_norm(Z_M[m][1]) +
                        squared_norm(W_M[m][0]) + squared_norm(W_M[m][1])
                        for m in range(1, n_times))),
                np.sqrt(
                    squared_norm(Z_0 - W_0) + sum(
                        squared_norm(Z_0[:-m]) + squared_norm(Z_0[m:]) +
                        squared_norm(W_0[:-m]) + squared_norm(W_0[m:])
                        for m in range(1, n_times))),
            ),
            e_dual=n_features * np.sqrt(n_times * (2 * n_times - 1)) * tol +
            rtol * rho * np.sqrt(
                squared_norm(X_0) + sum(
                    squared_norm(Y_M[m][0]) + squared_norm(Y_M[m][1]) +
                    squared_norm(U_M[m][0]) + squared_norm(U_M[m][1])
                    for m in range(1, n_times))),
        )

        R_old = R.copy()
        for m in range(1, n_times):
            Z_M_old[m] = (Z_M[m][0].copy(), Z_M[m][1].copy())
            W_M_old[m] = (W_M[m][0].copy(), W_M[m][1].copy())

        if verbose:
            print("obj: %.4f, rnorm: %.4f, snorm: %.4f,"
                  "eps_pri: %.4f, eps_dual: %.4f" % check[:5])

        checks.append(check)
        if check.rnorm <= check.e_pri and check.snorm <= check.e_dual:
            break

        rho_new = update_rho(rho,
                             rnorm,
                             snorm,
                             iteration=iteration_,
                             **(update_rho_options or {}))
        # scaled dual variables should be also rescaled
        X_0 *= rho / rho_new
        for m in range(1, n_times):
            Y_L, Y_R = Y_M[m]
            Y_L *= rho / rho_new
            Y_R *= rho / rho_new

            U_L, U_R = U_M[m]
            U_L *= rho / rho_new
            U_R *= rho / rho_new
        rho = rho_new
    else:
        warnings.warn("Objective did not converge.")

    covariance_ = np.array([linalg.pinvh(x) for x in Z_0])
    return_list = [Z_0, W_0, covariance_]
    if return_history:
        return_list.append(checks)
    if return_n_iter:
        return_list.append(iteration_)
    return return_list
Beispiel #2
0
def latent_time_matrix_decomposition(emp_cov,
                                     alpha=0.01,
                                     tau=1.,
                                     rho=1.,
                                     beta=1.,
                                     eta=1.,
                                     max_iter=100,
                                     verbose=False,
                                     psi='laplacian',
                                     phi='laplacian',
                                     mode='admm',
                                     tol=1e-4,
                                     rtol=1e-4,
                                     assume_centered=False,
                                     return_history=False,
                                     return_n_iter=True,
                                     update_rho_options=None,
                                     compute_objective=True):
    r"""Latent variable time-varying matrix decomposition solver.

    Solves the following problem via ADMM:
        min sum_{i=1}^T || S_i-(K_i-L_i)||^2 + alpha ||K_i||_{od,1}
            + tau ||L_i||_*
            + beta sum_{i=2}^T Psi(K_i - K_{i-1})
            + eta sum_{i=2}^T Phi(L_i - L_{i-1})

    where S is the matrix to decompose.

    Parameters
    ----------
    emp_cov : ndarray, shape (n_features, n_features)
        Matrix to decompose.
    alpha, tau, beta, eta : float, optional
        Regularisation parameters.
    rho : float, optional
        Augmented Lagrangian parameter.
    max_iter : int, optional
        Maximum number of iterations.
    tol : float, optional
        Absolute tolerance for convergence.
    rtol : float, optional
        Relative tolerance for convergence.
    return_history : bool, optional
        Return the history of computed values.

    Returns
    -------
    K, L : numpy.array, 3-dimensional (T x d x d)
        Solution to the problem for each time t=1...T .
    history : list
        If return_history, then also a structure that contains the
        objective value, the primal and dual residual norms, and tolerances
        for the primal and dual residual norms at each iteration.

    """
    psi, prox_psi, psi_node_penalty = check_norm_prox(psi)
    phi, prox_phi, phi_node_penalty = check_norm_prox(phi)

    Z_0 = np.zeros_like(emp_cov)
    Z_1 = np.zeros_like(Z_0)[:-1]
    Z_2 = np.zeros_like(Z_0)[1:]
    W_0 = np.zeros_like(Z_0)
    W_1 = np.zeros_like(Z_1)
    W_2 = np.zeros_like(Z_2)

    X_0 = np.zeros_like(Z_0)
    X_1 = np.zeros_like(Z_1)
    X_2 = np.zeros_like(Z_2)
    U_1 = np.zeros_like(W_1)
    U_2 = np.zeros_like(W_2)

    R_old = np.zeros_like(Z_0)
    Z_1_old = np.zeros_like(Z_1)
    Z_2_old = np.zeros_like(Z_2)
    W_1_old = np.zeros_like(W_1)
    W_2_old = np.zeros_like(W_2)

    # divisor for consensus variables, accounting for two less matrices
    divisor = np.full(emp_cov.shape[0], 3, dtype=float)
    divisor[0] -= 1
    divisor[-1] -= 1

    checks = []
    for iteration_ in range(max_iter):
        # update R
        A = Z_0 - W_0 - X_0
        R = (rho * A + 2 * emp_cov) / (2 + rho)

        # update Z_0
        A = R + W_0 + X_0
        A[:-1] += Z_1 - X_1
        A[1:] += Z_2 - X_2
        A /= divisor[:, None, None]
        # soft_thresholding_ = partial(soft_thresholding, lamda=alpha / rho)
        # Z_0 = np.array(map(soft_thresholding_, A))
        Z_0 = soft_thresholding(A,
                                lamda=alpha / (rho * divisor[:, None, None]))

        # update Z_1, Z_2
        A_1 = Z_0[:-1] + X_1
        A_2 = Z_0[1:] + X_2
        if not psi_node_penalty:
            prox_e = prox_psi(A_2 - A_1, lamda=2. * beta / rho)
            Z_1 = .5 * (A_1 + A_2 - prox_e)
            Z_2 = .5 * (A_1 + A_2 + prox_e)
        else:
            Z_1, Z_2 = prox_psi(np.concatenate((A_1, A_2), axis=1),
                                lamda=.5 * beta / rho,
                                rho=rho,
                                tol=tol,
                                rtol=rtol,
                                max_iter=max_iter)

        # update W_0
        A = Z_0 - R - X_0
        A[:-1] += W_1 - U_1
        A[1:] += W_2 - U_2
        A /= divisor[:, None, None]
        A += A.transpose(0, 2, 1)
        A /= 2.

        W_0 = np.array([
            prox_trace_indicator(a, lamda=tau / (rho * div))
            for a, div in zip(A, divisor)
        ])

        # update W_1, W_2
        A_1 = W_0[:-1] + U_1
        A_2 = W_0[1:] + U_2
        if not phi_node_penalty:
            prox_e = prox_phi(A_2 - A_1, lamda=2. * eta / rho)
            W_1 = .5 * (A_1 + A_2 - prox_e)
            W_2 = .5 * (A_1 + A_2 + prox_e)
        else:
            W_1, W_2 = prox_phi(np.concatenate((A_1, A_2), axis=1),
                                lamda=.5 * eta / rho,
                                rho=rho,
                                tol=tol,
                                rtol=rtol,
                                max_iter=max_iter)

        # update residuals
        X_0 += R - Z_0 + W_0
        X_1 += Z_0[:-1] - Z_1
        X_2 += Z_0[1:] - Z_2
        U_1 += W_0[:-1] - W_1
        U_2 += W_0[1:] - W_2

        # diagnostics, reporting, termination checks
        rnorm = np.sqrt(
            squared_norm(R - Z_0 + W_0) + squared_norm(Z_0[:-1] - Z_1) +
            squared_norm(Z_0[1:] - Z_2) + squared_norm(W_0[:-1] - W_1) +
            squared_norm(W_0[1:] - W_2))

        snorm = rho * np.sqrt(
            squared_norm(R - R_old) + squared_norm(Z_1 - Z_1_old) +
            squared_norm(Z_2 - Z_2_old) + squared_norm(W_1 - W_1_old) +
            squared_norm(W_2 - W_2_old))

        obj = objective(emp_cov, R, Z_0, Z_1, Z_2, W_0, W_1, W_2,
                        alpha, tau, beta, eta, psi, phi) \
            if compute_objective else np.nan

        check = convergence(
            obj=obj,
            rnorm=rnorm,
            snorm=snorm,
            e_pri=np.sqrt(R.size + 4 * Z_1.size) * tol + rtol * max(
                np.sqrt(
                    squared_norm(R) + squared_norm(Z_1) + squared_norm(Z_2) +
                    squared_norm(W_1) + squared_norm(W_2)),
                np.sqrt(
                    squared_norm(Z_0 - W_0) + squared_norm(Z_0[:-1]) +
                    squared_norm(Z_0[1:]) + squared_norm(W_0[:-1]) +
                    squared_norm(W_0[1:]))),
            e_dual=np.sqrt(R.size + 4 * Z_1.size) * tol + rtol * rho *
            (np.sqrt(
                squared_norm(X_0) + squared_norm(X_1) + squared_norm(X_2) +
                squared_norm(U_1) + squared_norm(U_2))))

        R_old = R.copy()
        Z_1_old = Z_1.copy()
        Z_2_old = Z_2.copy()
        W_1_old = W_1.copy()
        W_2_old = W_2.copy()

        if verbose:
            print("obj: %.4f, rnorm: %.4f, snorm: %.4f,"
                  "eps_pri: %.4f, eps_dual: %.4f" % check)

        checks.append(check)
        if check.rnorm <= check.e_pri and check.snorm <= check.e_dual:
            break

        rho_new = update_rho(rho,
                             rnorm,
                             snorm,
                             iteration=iteration_,
                             **(update_rho_options or {}))
        # scaled dual variables should be also rescaled
        X_0 *= rho / rho_new
        X_1 *= rho / rho_new
        X_2 *= rho / rho_new
        U_1 *= rho / rho_new
        U_2 *= rho / rho_new
        rho = rho_new
    else:
        warnings.warn("Objective did not converge.")

    return_list = [Z_0, W_0]
    if return_history:
        return_list.append(checks)
    if return_n_iter:
        return_list.append(iteration_)
    return return_list
Beispiel #3
0
def infimal_convolution(
    S,
    alpha=1.0,
    tau=1.0,
    rho=1.0,
    max_iter=100,
    verbose=False,
    tol=1e-4,
    rtol=1e-2,
    return_history=False,
    return_n_iter=True,
    update_rho_options=None,
    compute_objective=True,
):
    r"""Latent variable graphical lasso solver.

    Solves the following problem via ADMM:
        min - log_likelihood(S, K-L) + alpha ||K||_{od,1} + tau ||L_i||_*

    where S is the empirical covariance of the data
    matrix D (training observations by features).

    Parameters
    ----------
    emp_cov : array-like
        Empirical covariance matrix.
    alpha, tau : float, optional
        Regularisation parameters.
    rho : float, optional
        Augmented Lagrangian parameter.
    max_iter : int, optional
        Maximum number of iterations.
    tol : float, optional
        Absolute tolerance for convergence.
    rtol : float, optional
        Relative tolerance for convergence.
    return_history : bool, optional
        Return the history of computed values.
    return_n_iter : bool, optional
        Return the number of iteration before convergence.
    verbose : bool, default False
        Print info at each iteration.

    Returns
    -------
    K, L : np.array, 2-dimensional, size (d x d)
        Solution to the problem.
    S : np.array, 2 dimensional
        Empirical covariance matrix.
    n_iter : int
        If return_n_iter, returns the number of iterations before convergence.
    history : list
        If return_history, then also a structure that contains the
        objective value, the primal and dual residual norms, and tolerances
        for the primal and dual residual norms at each iteration.

    """
    K = np.zeros_like(S)
    L = np.zeros_like(S)
    U = np.zeros_like(S)
    R_old = np.zeros_like(S)

    checks = []
    for iteration_ in range(max_iter):
        # update R
        A = K - L - U
        A += A.T
        A /= 2.0
        R = prox_laplacian(S + rho * A, lamda=rho / 2.0)

        A = L + R + U
        K = soft_thresholding(A, lamda=alpha / rho)

        A = K - R - U
        A += A.T
        A /= 2.0
        L = prox_trace_indicator(A, lamda=tau / rho)

        # update residuals
        U += R - K + L

        # diagnostics, reporting, termination checks
        obj = objective(S, R, K, L, alpha,
                        tau) if compute_objective else np.nan
        rnorm = np.linalg.norm(R - K + L)
        snorm = rho * np.linalg.norm(R - R_old)
        check = convergence(
            obj=obj,
            rnorm=rnorm,
            snorm=snorm,
            e_pri=np.sqrt(R.size) * tol +
            rtol * max(np.linalg.norm(R), np.linalg.norm(K - L)),
            e_dual=np.sqrt(R.size) * tol + rtol * rho * np.linalg.norm(U),
        )
        R_old = R.copy()

        if verbose:
            print("obj: %.4f, rnorm: %.4f, snorm: %.4f,"
                  "eps_pri: %.4f, eps_dual: %.4f" % check[:5])

        checks.append(check)
        if check.rnorm <= check.e_pri and check.snorm <= check.e_dual:
            break
        if check.obj == np.inf:
            break
        rho_new = update_rho(rho,
                             rnorm,
                             snorm,
                             iteration=iteration_,
                             **(update_rho_options or {}))
        # scaled dual variables should be also rescaled
        U *= rho / rho_new
        rho = rho_new
    else:
        warnings.warn("Objective did not converge.")

    covariance_ = linalg.pinvh(K)
    return_list = [K, L, covariance_]
    if return_history:
        return_list.append(checks)
    if return_n_iter:
        return_list.append(iteration_)
    return return_list
def latent_time_graphical_lasso(emp_cov,
                                alpha=0.01,
                                tau=1.,
                                rho=1.,
                                beta=1.,
                                eta=1.,
                                max_iter=100,
                                n_samples=None,
                                verbose=False,
                                psi='laplacian',
                                phi='laplacian',
                                mode='admm',
                                tol=1e-4,
                                rtol=1e-4,
                                return_history=False,
                                return_n_iter=True,
                                update_rho_options=None,
                                compute_objective=True,
                                init='empirical'):
    r"""Latent variable time-varying graphical lasso solver.

    Solves the following problem via ADMM:
      min sum_{i=1}^T -n_i log_likelihood(S_i, K_i-L_i) + alpha ||K_i||_{od,1}
          + tau ||L_i||_*
          + beta sum_{i=2}^T Psi(K_i - K_{i-1})
          + eta sum_{i=2}^T Phi(L_i - L_{i-1})

    where S_i = (1/n_i) X_i^T \times X_i is the empirical covariance of data
    matrix X (training observations by features).

    Parameters
    ----------
    emp_cov : ndarray, shape (n_features, n_features)
        Empirical covariance of data.
    alpha, tau, beta, eta : float, optional
        Regularisation parameters.
    rho : float, optional
        Augmented Lagrangian parameter.
    max_iter : int, optional
        Maximum number of iterations.
    n_samples : ndarray
        Number of samples available for each time point.
    tol : float, optional
        Absolute tolerance for convergence.
    rtol : float, optional
        Relative tolerance for convergence.
    return_history : bool, optional
        Return the history of computed values.
    return_n_iter : bool, optional
        Return the number of iteration before convergence.
    verbose : bool, default False
        Print info at each iteration.
    update_rho_options : dict, optional
        Arguments for the rho update.
        See regain.update_rules.update_rho function for more information.
    compute_objective : bool, default True
        Choose to compute the objective value.
    init : {'empirical', 'zeros', ndarray}, default 'empirical'
        How to initialise the inverse covariance matrix. Default is take
        the empirical covariance and inverting it.

    Returns
    -------
    K, L : numpy.array, 3-dimensional (T x d x d)
        Solution to the problem for each time t=1...T .
    history : list
        If return_history, then also a structure that contains the
        objective value, the primal and dual residual norms, and tolerances
        for the primal and dual residual norms at each iteration.

    """
    psi, prox_psi, psi_node_penalty = check_norm_prox(psi)
    phi, prox_phi, phi_node_penalty = check_norm_prox(phi)

    Z_0 = init_precision(emp_cov, mode=init)
    Z_1 = Z_0.copy()[:-1]
    Z_2 = Z_0.copy()[1:]
    W_0 = np.zeros_like(Z_0)
    W_1 = np.zeros_like(Z_1)
    W_2 = np.zeros_like(Z_2)

    X_0 = np.zeros_like(Z_0)
    X_1 = np.zeros_like(Z_1)
    X_2 = np.zeros_like(Z_2)
    U_1 = np.zeros_like(W_1)
    U_2 = np.zeros_like(W_2)

    R_old = np.zeros_like(Z_0)
    Z_1_old = np.zeros_like(Z_1)
    Z_2_old = np.zeros_like(Z_2)
    W_1_old = np.zeros_like(W_1)
    W_2_old = np.zeros_like(W_2)

    # divisor for consensus variables, accounting for two less matrices
    divisor = np.full(emp_cov.shape[0], 3, dtype=float)
    divisor[0] -= 1
    divisor[-1] -= 1

    if n_samples is None:
        n_samples = np.ones(emp_cov.shape[0])

    checks = []
    for iteration_ in range(max_iter):
        # update R
        A = Z_0 - W_0 - X_0
        A += A.transpose(0, 2, 1)
        A /= 2.
        A *= -rho / n_samples[:, None, None]
        A += emp_cov
        # A = emp_cov / rho - A

        R = np.array(
            [prox_logdet(a, lamda=ni / rho) for a, ni in zip(A, n_samples)])

        # update Z_0
        A = R + W_0 + X_0
        A[:-1] += Z_1 - X_1
        A[1:] += Z_2 - X_2
        A /= divisor[:, None, None]
        # soft_thresholding_ = partial(soft_thresholding, lamda=alpha / rho)
        # Z_0 = np.array(map(soft_thresholding_, A))
        Z_0 = soft_thresholding(A,
                                lamda=alpha / (rho * divisor[:, None, None]))

        # update Z_1, Z_2
        A_1 = Z_0[:-1] + X_1
        A_2 = Z_0[1:] + X_2
        if not psi_node_penalty:
            prox_e = prox_psi(A_2 - A_1, lamda=2. * beta / rho)
            Z_1 = .5 * (A_1 + A_2 - prox_e)
            Z_2 = .5 * (A_1 + A_2 + prox_e)
        else:
            Z_1, Z_2 = prox_psi(np.concatenate((A_1, A_2), axis=1),
                                lamda=.5 * beta / rho,
                                rho=rho,
                                tol=tol,
                                rtol=rtol,
                                max_iter=max_iter)

        # update W_0
        A = Z_0 - R - X_0
        A[:-1] += W_1 - U_1
        A[1:] += W_2 - U_2
        A /= divisor[:, None, None]
        A += A.transpose(0, 2, 1)
        A /= 2.

        W_0 = np.array([
            prox_trace_indicator(a, lamda=tau / (rho * div))
            for a, div in zip(A, divisor)
        ])

        # update W_1, W_2
        A_1 = W_0[:-1] + U_1
        A_2 = W_0[1:] + U_2
        if not phi_node_penalty:
            prox_e = prox_phi(A_2 - A_1, lamda=2. * eta / rho)
            W_1 = .5 * (A_1 + A_2 - prox_e)
            W_2 = .5 * (A_1 + A_2 + prox_e)
        else:
            W_1, W_2 = prox_phi(np.concatenate((A_1, A_2), axis=1),
                                lamda=.5 * eta / rho,
                                rho=rho,
                                tol=tol,
                                rtol=rtol,
                                max_iter=max_iter)

        # update residuals
        X_0 += R - Z_0 + W_0
        X_1 += Z_0[:-1] - Z_1
        X_2 += Z_0[1:] - Z_2
        U_1 += W_0[:-1] - W_1
        U_2 += W_0[1:] - W_2

        # diagnostics, reporting, termination checks
        rnorm = np.sqrt(
            squared_norm(R - Z_0 + W_0) + squared_norm(Z_0[:-1] - Z_1) +
            squared_norm(Z_0[1:] - Z_2) + squared_norm(W_0[:-1] - W_1) +
            squared_norm(W_0[1:] - W_2))

        snorm = rho * np.sqrt(
            squared_norm(R - R_old) + squared_norm(Z_1 - Z_1_old) +
            squared_norm(Z_2 - Z_2_old) + squared_norm(W_1 - W_1_old) +
            squared_norm(W_2 - W_2_old))

        obj = objective(emp_cov, n_samples, R, Z_0, Z_1, Z_2, W_0, W_1, W_2,
                        alpha, tau, beta, eta, psi, phi) \
            if compute_objective else np.nan

        check = convergence(
            obj=obj,
            rnorm=rnorm,
            snorm=snorm,
            e_pri=np.sqrt(R.size + 4 * Z_1.size) * tol + rtol * max(
                np.sqrt(
                    squared_norm(R) + squared_norm(Z_1) + squared_norm(Z_2) +
                    squared_norm(W_1) + squared_norm(W_2)),
                np.sqrt(
                    squared_norm(Z_0 - W_0) + squared_norm(Z_0[:-1]) +
                    squared_norm(Z_0[1:]) + squared_norm(W_0[:-1]) +
                    squared_norm(W_0[1:]))),
            e_dual=np.sqrt(R.size + 4 * Z_1.size) * tol + rtol * rho *
            (np.sqrt(
                squared_norm(X_0) + squared_norm(X_1) + squared_norm(X_2) +
                squared_norm(U_1) + squared_norm(U_2))))

        R_old = R.copy()
        Z_1_old = Z_1.copy()
        Z_2_old = Z_2.copy()
        W_1_old = W_1.copy()
        W_2_old = W_2.copy()

        if verbose:
            print("obj: %.4f, rnorm: %.4f, snorm: %.4f,"
                  "eps_pri: %.4f, eps_dual: %.4f" % check[:5])

        checks.append(check)
        if check.rnorm <= check.e_pri and check.snorm <= check.e_dual:
            break

        rho_new = update_rho(rho,
                             rnorm,
                             snorm,
                             iteration=iteration_,
                             **(update_rho_options or {}))
        # scaled dual variables should be also rescaled
        X_0 *= rho / rho_new
        X_1 *= rho / rho_new
        X_2 *= rho / rho_new
        U_1 *= rho / rho_new
        U_2 *= rho / rho_new
        rho = rho_new
    else:
        warnings.warn("Objective did not converge.")

    covariance_ = np.array([linalg.pinvh(x) for x in Z_0])
    return_list = [Z_0, W_0, covariance_]
    if return_history:
        return_list.append(checks)
    if return_n_iter:
        return_list.append(iteration_)
    return return_list