class NaiveCONESTA(bases.ExplicitAlgorithm,
                   bases.IterativeAlgorithm,
                   bases.InformationAlgorithm):
    """A naïve implementation of COntinuation with NEsterov smoothing in a
    Soft-Thresholding Algorithm, or CONESTA for short.

    Parameters
    ----------
    mu_start : Non-negative float. An optional initial value of mu.

    mu_min : Non-negative float. A "very small" mu to use when computing
            the stopping criterion.

    tau : Float, 0 < tau < 1. The rate at which eps is decreasing. Default
            is 0.5.

    eps : Positive float. Tolerance for the stopping criterion.

    info : List or tuple of utils.consts.Info. What, if any, extra run
            information should be stored. Default is an empty list, which means
            that no run information is computed nor returned.

    max_iter : Non-negative integer. Maximum allowed number of iterations.

    min_iter : Non-negative integer less than or equal to max_iter. Minimum
            number of iterations that must be performed. Default is 1.
    """
    INTERFACES = [nesterov_properties.NesterovFunction,
                  properties.Gradient,
                  properties.StepSize,
                  properties.ProximalOperator,
                  properties.Continuation]

    INFO_PROVIDED = [Info.ok,
                     Info.num_iter,
                     Info.time,
                     Info.fvalue,
                     Info.mu,
                     Info.converged]

    def __init__(self, mu_start=None, mu_min=consts.TOLERANCE,
                 tau=0.5,

                 eps=consts.TOLERANCE,
                 info=[], max_iter=10000, min_iter=1):

        super(NaiveCONESTA, self).__init__(info=info,
                                           max_iter=max_iter,
                                           min_iter=min_iter)

        self.mu_start = mu_start
        self.mu_min = mu_min
        self.tau = tau

        self.eps = eps

        # Copy the allowed info keys for FISTA.
        fista_info = list()
        for nfo in self.info_copy():
            if nfo in FISTA.INFO_PROVIDED:
                fista_info.append(nfo)
        if Info.num_iter not in fista_info:
            fista_info.append(Info.num_iter)

        self.algorithm = FISTA(eps=eps, max_iter=max_iter, min_iter=min_iter,
                               info=fista_info)

        self.num_iter = 0

    @bases.force_reset
    @bases.check_compatibility
    def run(self, function, beta):

#        self.info.clear()

        if self.info_requested(Info.ok):
            self.info_set(Info.ok, False)

        if self.mu_start is None:
            mu = function.estimate_mu(beta)
        else:
            mu = self.mu_start

        # We use 2x as in Chen et al. (2012).
        eps = 2.0 * function.eps_max(mu)

        function.set_mu(self.mu_min)
        tmin = function.step(beta)
        function.set_mu(mu)

        if self.info_requested(Info.mu):
            mu = [mu]

        if self.info_requested(Info.time):
            t = []
        if self.info_requested(Info.fvalue):
            f = []
        if self.info_requested(Info.converged):
            self.info_set(Info.converged, False)

        i = 0
        while True:
            tnew = function.step(beta)
            self.algorithm.set_params(step=tnew, eps=eps,
                                      max_iter=self.max_iter - self.num_iter)
#            self.fista_info.clear()
            beta = self.algorithm.run(function, beta)

            self.num_iter += self.algorithm.num_iter

            if Info.time in self.algorithm.info:
                tval = self.algorithm.info_get(Info.time)
            if Info.fvalue in self.algorithm.info:
                fval = self.algorithm.info_get(Info.fvalue)

            if self.info_requested(Info.time):
                t = t + tval
            if self.info_requested(Info.fvalue):
                f = f + fval

            old_mu = function.set_mu(self.mu_min)
            # Take one ISTA step for use in the stopping criterion.
            beta_tilde = function.prox(beta - tmin * function.grad(beta),
                                       tmin)
            function.set_mu(old_mu)

            if (1.0 / tmin) * maths.norm(beta - beta_tilde) < self.eps:

                if self.info_requested(Info.converged):
                    self.info_set(Info.converged, True)

                break

            if self.num_iter >= self.max_iter:
                break

            eps = max(self.tau * eps, consts.TOLERANCE)

#            if eps <= consts.TOLERANCE:
#                break

            if self.info_requested(Info.mu):
                mu_new = max(self.mu_min, self.tau * mu[-1])
                mu = mu + [mu_new] * len(fval)

            else:
                mu_new = max(self.mu_min, self.tau * mu)
                mu = mu_new

            print "eps:", eps, ", mu:", mu_new
            function.set_mu(mu_new)

            i = i + 1

        if self.info_requested(Info.num_iter):
            self.info_set(Info.num_iter, i + 1)
        if self.info_requested(Info.time):
            self.info_set(Info.time, t)
        if self.info_requested(Info.fvalue):
            self.info_set(Info.fvalue, f)
        if self.info_requested(Info.mu):
            self.info_set(Info.mu, mu)
        if self.info_requested(Info.ok):
            self.info_set(Info.ok, True)

        return beta
    def run(self, function, beta):

        # Copy the allowed info keys for FISTA.
        fista_info = list()
        for nfo in self.info_copy():
            if nfo in FISTA.INFO_PROVIDED:
                fista_info.append(nfo)
#        if not self.fista_info.allows(Info.num_iter):
#            self.fista_info.add_key(Info.num_iter)
        # Create the inner algorithm.
        algorithm = FISTA(eps=self.eps,
                          max_iter=self.max_iter, min_iter=self.min_iter,
                          info=fista_info)

        if self.info_requested(Info.ok):
            self.info_set(Info.ok, False)

        if self.mu_start is None:
            mu = [function.estimate_mu(beta)]
        else:
            mu = [self.mu_start]

        function.set_mu(self.mu_min)
        tmin = function.step(beta)
        function.set_mu(mu[0])

        max_eps = function.eps_max(mu[0])

        G = min(max_eps, function.eps_opt(mu[0]))

        if self.info_requested(Info.time):
            t = []
        if self.info_requested(Info.fvalue):
            f = []
        if self.info_requested(Info.gap):
            Gval = []
        if self.info_requested(Info.converged):
            self.info_set(Info.converged, False)

        i = 0
        while True:
            stop = False

            tnew = function.step(beta)
            eps_plus = min(max_eps, function.eps_opt(mu[-1]))
#            print "current iterations: ", self.num_iter, \
#                    ", iterations left: ", self.max_iter - self.num_iter
            algorithm.set_params(step=tnew, eps=eps_plus,
                                 max_iter=self.max_iter - self.num_iter,
                                 conesta_stop=None)
#                                      conesta_stop=[self.mu_min])
#            self.fista_info.clear()
            beta = algorithm.run(function, beta)
            #print "CONESTA loop", i, "FISTA=",self.fista_info[Info.num_iter], "TOT iter:", self.num_iter

            self.num_iter += algorithm.num_iter

            if Info.time in algorithm.info:
                tval = algorithm.info_get(Info.time)
            if Info.fvalue in algorithm.info:
                fval = algorithm.info_get(Info.fvalue)

            self.mu_min = min(self.mu_min, mu[-1])
            tmin = min(tmin, tnew)
            old_mu = function.set_mu(self.mu_min)
            # Take one ISTA step for use in the stopping criterion.
            beta_tilde = function.prox(beta - tmin * function.grad(beta),
                                       tmin)
            function.set_mu(old_mu)

            if (1.0 / tmin) * maths.norm(beta - beta_tilde) < self.eps:

                if self.info_requested(Info.converged):
                    self.info_set(Info.converged, True)

                stop = True

            if self.num_iter >= self.max_iter:
                stop = True

            if self.info_requested(Info.time):
                gap_time = utils.time_cpu()

            if self.dynamic:
                G_new = function.gap(beta, eps=eps_plus,
                                     max_iter=self.max_iter - self.num_iter)

                # TODO: Warn if G_new < 0.
                G_new = abs(G_new)  # Just in case ...

                if G_new < G:
                    G = G_new
                else:
                    G = self.tau * G

            else:  # Static

                G = self.tau * G

            if self.info_requested(Info.time):
                gap_time = utils.time_cpu() - gap_time
                tval[-1] += gap_time
                t = t + tval
            if self.info_requested(Info.fvalue):
                f = f + fval
            if self.info_requested(Info.gap):
                Gval.append(G)

            if (G <= consts.TOLERANCE and mu[-1] <= consts.TOLERANCE) or stop:
                break

            mu_new = min(mu[-1], function.mu_opt(G))
            self.mu_min = min(self.mu_min, mu_new)
            if self.info_requested(Info.mu):
                mu = mu + [max(self.mu_min, mu_new)] * len(fval)
            else:
                mu.append(max(self.mu_min, mu_new))
            function.set_mu(mu_new)

            i = i + 1

        if self.info_requested(Info.num_iter):
            self.info_set(Info.num_iter, i + 1)
        if self.info_requested(Info.time):
            self.info_set(Info.time, t)
        if self.info_requested(Info.fvalue):
            self.info_set(Info.fvalue, f)
        if self.info_requested(Info.gap):
            self.info_set(Info.gap, Gval)
        if self.info_requested(Info.mu):
            self.info_set(Info.mu, mu)
        if self.info_requested(Info.ok):
            self.info_set(Info.ok, True)

        return beta