----------
    [1] Efficient Projections onto the .1-Ball for Learning in High Dimensions
        John Duchi, Shai Shalev-Shwartz, Yoram Singer, and Tushar Chandra.
        International Conference on Machine Learning (ICML 2008)
        http://www.cs.berkeley.edu/~jduchi/projects/DuchiSiShCh08.pdf
    """
    assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s
    n, = v.shape  # will raise ValueError if v is not 1-D
    # check if we are already on the simplex
    if v.sum() == s and np.alltrue(v >= 0):
        # best projection: itself!
        return v
    # get the array of cumulative sums of a sorted (decreasing) copy of v
    u = np.sort(v)[::-1]
    cssv = np.cumsum(u)
    # get the number of > 0 components of the optimal solution
    rho = np.nonzero(u * np.arange(1, n + 1) > (cssv - s))[0][-1]
    # compute the Lagrange multiplier associated to the simplex constraint
    theta = (cssv[rho] - s) / (rho + 1.0)
    # compute the projection by thresholding v using theta
    w = (v - theta).clip(min=0)
    return w


def proj_simplex_torch(v):
    return torch.from_numpy(euclidean_proj_simplex(v.numpy()))


l1_project = projector_helpers.l1_from_simplex(
    projector_helpers.list_project(proj_simplex_torch))
import torch
from Utils.projector_helpers import l1_from_simplex, list_project


def _michelot_list(y, a=1):
    N = y.shape[0]
    v = y
    p = (y.sum() - a) / N

    while (v > p).sum() != v.shape[0]:
        v = v[v > p]
        p = (v.sum() - a) / v.shape[0]

    tau = p
    K = v.shape[0]

    return (y - tau).clamp(min=0)


michelot = l1_from_simplex(list_project(_michelot_list))
    tau = (y.max(dim=1)[0])
    tau = tau.unsqueeze(-1)
    tau = torch.where(tau - s > 0, tau - s, tau)
    step_size = torch.ones_like(tau)

    i = 0
    norms = y.norm(dim=1, p=1)
    norms_diff = norms - s
    while ((norms_diff).abs() > 1e-7).any():
        y_ = (y - tau).clamp(min=0)
        norms = y_.norm(dim=1, p=1)
        norms_diff_ = norms - s

        slower = torch.sign(norms_diff_) != torch.sign(norms_diff)
        step_size = torch.where(slower[:, None], step_size * 0.5, step_size)
        step = norms_diff_.unsqueeze(-1)

        step *= step_size
        tau += step

        norms_diff = norms_diff_

    return y_


def descent_simplex_single(y, s=1):
    return descent_simplex_batch(y[None])[0]

descent_l1 = \
projector_helpers.l1_from_simplex(projector_helpers.list_project(descent_simplex_single))
    rho = rho.split(dim=0, split_size=1)
    rho = list(map(lambda x: x[0].nonzero()[-1][0], rho))

    offset = list(map(lambda p: cssv[p[0], p[1]], enumerate(rho)))
    offset = torch.stack(offset)

    rho = torch.stack(rho)

    theta = (offset - s) / (rho + 1)

    w = x - theta[:, None]
    w = w.clamp_(min=0)

    return w


def project_simplex_1(x):
    return project_simplex(x[None])[0]


def project_l1_ball(x: torch.Tensor, s=1):
    u = torch.abs(x)
    u_proj = project_simplex(u)

    x_proj = torch.sign(x) * u_proj

    return x_proj

project_l1_ball_serial = \
projector_helpers.l1_from_simplex(projector_helpers.list_project(project_simplex_1))
    if len(v_hat) > 0:
        for y_ in v_hat:
            #3.1
            if y_ > p:
                v.append(y_)
                p = p + (y_ - p) / (len(v))

    first = False
    v_len = len(v)
    # 4
    while v_len != len(v) or not first:
        v_len = len(v)
        to_remove = []
        for i, y_ in enumerate(v):
            if y_ < p:
                to_remove.append(i)
                p = p + (p - y_) / (len(v) - len(to_remove))

        for index in reversed(to_remove):
            del v[index]

        first = True

    # 5
    tau = p
    K = len(v)
    return (y - tau).clamp(min=0)

condat_l1 = \
projector_helpers.l1_from_simplex(projector_helpers.list_project(condat_simplex))