def proximal_optimal_step_size_subproblem(self, additional_coeffs, dual_vars, primal_vars, n_layers, eta):
        # Compute proximal_optimal_step_size knowing that only the conditional gradient of subproblem k was updated.

        k = self.k
        zahat = self.zahat_kp1
        zbhat = self.zbhat_k

        a_diff = zahat - primal_vars.zahats[k]

        if k == 0:
            upper = bdot(dual_vars.primal_grad[0], a_diff)
            lower = (1/eta) * a_diff.view(a_diff.shape[0], -1).pow(2).sum(dim=-1)
        else:
            b_diff = primal_vars.zbhats[k - 1] - zbhat
            upper = bdot(dual_vars.primal_grad[k-1], b_diff)
            lower = (1/eta) * b_diff.view(b_diff.shape[0], -1).pow(2).sum(dim=-1)

            if k != (n_layers-1):
                upper += bdot(dual_vars.primal_grad[k], a_diff)
                lower += (1/eta) * a_diff.view(a_diff.shape[0], -1).pow(2).sum(dim=-1)
            if (k+1) in additional_coeffs:
                upper += bdot(additional_coeffs[k+1], primal_vars.zahats[k] - zahat)

        opt_step_size = torch.where(lower > 0, upper / lower, torch.zeros_like(lower))
        # Set to 0 the 0/0 entries.
        up_mask = upper == 0
        low_mask = lower == 0
        sum_mask = up_mask + low_mask
        opt_step_size[sum_mask > 1] = 0
        opt_step_size = torch.clamp(opt_step_size, min=0, max=1)

        decrease = -0.5 * lower * opt_step_size.pow(2) + upper * opt_step_size

        return opt_step_size, decrease
    def proximal_optimal_step_size(additional_coeffs, diff_grad, primal_vars,
                                   cond_grad, eta):
        # If we write the objective function as a function of the step size, this gives:
        # \frac{a}/{2} \gamma^2 + b \gamma + c
        # The optimal step size is given by \gamma_opt = -\frac{b}{2*a}
        # The change in value is given by \frac{a}{2} \gamma_opt^2 + b * \gamma
        # a = \sum_k \frac{1}{eta_k} ||xahat - zahat - (xbhat - zbhat||^2
        # b = \sum_k rho_k (xbhat - zbhat - (xahat - zahat)) + (xahat,n - zahat,n)
        # c is unnecessary

        var_to_cond = primal_vars.as_dual_subgradient().add(
            cond_grad.as_dual_subgradient(), -1)
        upper = var_to_cond.bdot(diff_grad)
        for layer, add_coeff in additional_coeffs.items():
            # TODO: Check if this is the correct computation ON PAPER
            upper += bdot(
                add_coeff,
                primal_vars.zahats[layer - 1] - cond_grad.zahats[layer - 1])

        lower = var_to_cond.weighted_squared_norm(1 / eta)
        torch.clamp(lower, 1e-8, None, out=lower)

        opt_step_size = upper / lower

        opt_step_size = upper / lower
        # Set to 0 the 0/0 entries.
        up_mask = upper == 0
        low_mask = lower == 0
        sum_mask = up_mask + low_mask
        opt_step_size[sum_mask > 1] = 0

        decrease = -0.5 * lower * opt_step_size.pow(2) + upper * opt_step_size

        return opt_step_size, decrease
def compute_objective(dual_vars, primal_vars, additional_coeffs):
    '''
    We assume that all the constraints are satisfied.
    '''
    val = dual_vars.bdot(primal_vars.as_dual_subgradient())
    for layer, add_coeff in additional_coeffs.items():
        # zahats are going from 1 so we need to remove 1 to the index
        val += bdot(add_coeff, primal_vars.zahats[layer - 1])
    return val
def compute_objective(weights, additional_coeffs, primal_vars, dual_vars):
    """
    Given the network layers (LinearOp and ConvOp classes in proxlp_solver.utils), cost coefficients of the final layer,
    primal and dual variables (PrimalVarSet and DualVarSet, respectively), compute the objective function value for this
    derivation. It is equivalent to computing the bounds.
    :return: bound tensor, 2*opt_layer_width (first half is negative of upper bounds, second half is lower bounds)
    """
    add_coeff = next(iter(additional_coeffs.values()))
    obj = utils.bdot(weights[-1].backward(add_coeff), primal_vars.xs[-1]) + \
          utils.bdot(add_coeff, weights[-1].bias)

    for x_idx in range(len(weights) - 1):
        obj += utils.bdot(
            dual_vars.mus[x_idx], primal_vars.zs[x_idx] -
            weights[x_idx].forward(primal_vars.xs[x_idx]))
        obj += utils.bdot(
            dual_vars.lambdas[x_idx], primal_vars.xs[x_idx + 1] -
            torch.clamp(primal_vars.zs[x_idx], 0, None))
    return obj
def compute_proximal_objective(primal_vars, current_dual_vars,
                               anchor_dual_vars, additional_coeffs, eta):
    """
    Given primal variables as lists of tensors, and dual anchor variables
    (and functions thereof) as DualVars, compute the value of the objective of the proximal problem (Wolfe dual of
    proximal on dual variables).
    :return: a tensor of objectives, of size 2 x n_neurons of the layer to optimize.
    """

    val = current_dual_vars.bdot(primal_vars.as_dual_subgradient())
    for layer, add_coeff in additional_coeffs.items():
        # zahats are going from 1 so we need to remove 1 to the index
        val += bdot(add_coeff, primal_vars.zahats[layer - 1])

    val -= current_dual_vars.subtract(
        1, anchor_dual_vars).weighted_squared_norm(eta / 2)

    return val
 def bdot(self, other):
     val = 0
     for rho, orho in zip(self.rhos, other.rhos):
         val += bdot(rho, orho)
     return val