Ejemplo n.º 1
0
    def initial_linear_bound(index, lower_bound, upper_bound):
        input_dim = np.prod(lower_bound.shape)
        lin_coeffs = jnp.reshape(jnp.eye(input_dim),
                                 (input_dim, *lower_bound.shape))
        offsets = jnp.zeros_like(lower_bound)
        identity_lin = LinearExpression(lin_coeffs, offsets)

        reference_bound = ibp.IntervalBound(lower_bound, upper_bound)
        lin_function = LinearFunction(identity_lin, identity_lin,
                                      RefBound(index, reference_bound))

        lin_bound = LinearBound([lin_function])
        lin_bound.set_concretized(ibp.IntervalBound(lower_bound, upper_bound))

        return lin_bound
Ejemplo n.º 2
0
def concretize_backward_bound(backward_bound, act_bound):
    """Compute the value of a backward bound.

  Args:
    backward_bound: a CrownBackwardBound, representing linear functions of
     activations lower and upper bounding the output of the network.
    act_bound: Bound on the activation that the backward_bound is a function of.
  Returns:
    bound: A concretized bound
  """
    act_lower = _broadcast_alltargets(act_bound.lower,
                                      backward_bound.lower_lin.lin_coeffs)
    act_upper = _broadcast_alltargets(act_bound.upper,
                                      backward_bound.lower_lin.lin_coeffs)

    nb_dims_to_reduce = act_lower.ndim - backward_bound.lower_lin.offset.ndim
    dims_to_reduce = tuple(range(-nb_dims_to_reduce, 0))

    lower_lin = backward_bound.lower_lin
    upper_lin = backward_bound.upper_lin
    lower_bound = (lower_lin.offset + jnp.sum(
        jnp.minimum(lower_lin.lin_coeffs, 0.) * act_upper +
        jnp.maximum(lower_lin.lin_coeffs, 0.) * act_lower, dims_to_reduce))
    upper_bound = (upper_lin.offset + jnp.sum(
        jnp.maximum(upper_lin.lin_coeffs, 0.) * act_upper +
        jnp.minimum(upper_lin.lin_coeffs, 0.) * act_lower, dims_to_reduce))

    return ibp.IntervalBound(lower_bound, upper_bound)
Ejemplo n.º 3
0
    def tight_bounds(self, variable: RelaxVariable) -> ibp.IntervalBound:
        """Compute tighter bounds based on the LP relaxation.

    Args:
      variable: Variable as created by the base boundprop transform. This is a
        RelaxVariable that has already been encoded into the solvers.
    Returns:
      tightened_base_bound: Bounds tightened by optimizing with the LP solver.
    """
        lbs = []
        ubs = []
        for solver in self.solvers:
            nb_targets = np.prod(variable.shape[1:])
            sample_lbs = []
            sample_ubs = []
            for target_idx in range(nb_targets):
                objective = (jnp.arange(nb_targets) == target_idx).astype(
                    jnp.float32)
                lb, _, optimal_lb = solver.minimize_objective(
                    variable.name, objective, 0., self._time_limit_millis)
                assert optimal_lb
                neg_ub, _, optimal_ub = solver.minimize_objective(
                    variable.name, -objective, 0., self._time_limit_millis)
                assert optimal_ub
                sample_lbs.append(lb)
                sample_ubs.append(-neg_ub)
            lbs.append(sample_lbs)
            ubs.append(sample_ubs)

        tightened_base_bound = ibp.IntervalBound(
            jnp.reshape(jnp.array(lbs), variable.shape),
            jnp.reshape(jnp.array(ubs), variable.shape))
        return tightened_base_bound
Ejemplo n.º 4
0
  def concrete_bound(
      self,
      graph: bound_propagation.PropagationGraph,
      inputs: Nest[GraphInput],
      env: Dict[jax.core.Var, LayerInput],
      node_ref: jax.core.Var,
  ) -> ibp.IntervalBound:
    """Perform backward linear bound computation for the node `index`.

    Args:
      graph: Graph to perform Backward Propagation on.
      inputs: Bounds on the inputs.
      env: Environment containing intermediate bound and shape information.
      node_ref: Reference of the node to obtain a bound for.
    Returns:
      concrete_bound: IntervalBound on the activation at `node_ref`.
    """
    node = graph_traversal.read_env(env, node_ref)

    def bound_fn(obj: Tensor) -> Tuple[Tensor, Tensor]:
      # Handle lower bounds and upper bounds independently in the same chunk.
      obj = jnp.concatenate([obj, -obj], axis=0)

      all_bounds = self._concretizing_transform.concrete_bound_chunk(
          graph, inputs, env, node_ref, obj)

      # Separate out the lower and upper bounds.
      lower_bound, neg_upper_bound = jnp.split(all_bounds, 2, axis=0)
      upper_bound = -neg_upper_bound
      return lower_bound, upper_bound

    return ibp.IntervalBound(
        *utils.chunked_bounds(node.shape, self._max_chunk_size, bound_fn))
Ejemplo n.º 5
0
 def concretize(self):
     if self._concretized is not None:
         return self._concretized
     lb = jnp.zeros(())
     ub = jnp.zeros(())
     for lin_fun in self._refbound_to_linfun.values():
         lin_fun_lb, lin_fun_ub = lin_fun.concretize()
         lb = lb + lin_fun_lb
         ub = ub + lin_fun_ub
     self._concretized = ibp.IntervalBound(lb, ub)
     return self._concretized
Ejemplo n.º 6
0
    def initial_linear_bound(lower_bound, upper_bound):
        batch_size = lower_bound.shape[0]
        act_shape = lower_bound.shape[1:]
        input_dim = np.prod(lower_bound.shape[1:])

        sp_lin = jnp.reshape(jnp.eye(input_dim), (input_dim, *act_shape))
        batch_lin = jnp.repeat(jnp.expand_dims(sp_lin, 0), batch_size, axis=0)

        identity_lin = LinearExpression(batch_lin,
                                        jnp.zeros((batch_size, *act_shape)))
        lin_bound = LinearBound(identity_lin, identity_lin, None)
        lin_bound.set_concretized(ibp.IntervalBound(lower_bound, upper_bound))
        return lin_bound
Ejemplo n.º 7
0
def _crown_max(out_bound, lhs, rhs):
    """Backward propagation of Linear Bounds through a ReLU.

  Args:
    out_bound: CrownBackwardBound, linear function of network outputs bounds
      with regards to the output of the ReLU
    lhs: left input to the max, inputs to the ReLU
    rhs: right input to the max, we assume this to be 0
  Returns:
    lhs_backbound: CrownBackwardBound, linear function of network outputs bounds
      with regards to the inputs of the ReLU.
    rhs_backbound: None, because we assume the second argument to be 0.
  """
    if not (isinstance(lhs, ibp.IntervalBound) and rhs == 0.):
        raise NotImplementedError('Only ReLU implemented for now.')

    relu_on = (lhs.lower >= 0.)
    relu_amb = jnp.logical_and(lhs.lower < 0., lhs.upper >= 0.)
    ub_slope = relu_on.astype(jnp.float32)
    ub_slope += jnp.where(
        relu_amb, lhs.upper / jnp.maximum(lhs.upper - lhs.lower, 1e-12),
        jnp.zeros_like(lhs.lower))
    ub_bias = jnp.where(relu_amb, -ub_slope * lhs.lower,
                        jnp.zeros_like(lhs.lower))
    # Crown Relu propagation.
    lb_slope = (ub_slope >= 0.5).astype(jnp.float32)
    lb_bias = jnp.zeros_like(ub_bias)

    lower_lin_coeffs = out_bound.lower_lin.lin_coeffs
    upper_lin_coeffs = out_bound.upper_lin.lin_coeffs

    ub_slope = _broadcast_alltargets(ub_slope, lower_lin_coeffs)
    lb_slope = _broadcast_alltargets(lb_slope, lower_lin_coeffs)

    new_lower_lin_coeffs = (jnp.minimum(lower_lin_coeffs, 0.) * ub_slope +
                            jnp.maximum(lower_lin_coeffs, 0.) * lb_slope)
    new_upper_lin_coeffs = (jnp.maximum(upper_lin_coeffs, 0.) * ub_slope +
                            jnp.minimum(upper_lin_coeffs, 0.) * lb_slope)

    bias_to_conc = ibp.IntervalBound(lb_bias, ub_bias)
    new_offset = concretize_backward_bound(out_bound, bias_to_conc)

    lhs_backbound = CrownBackwardBound(
        LinearExpression(new_lower_lin_coeffs, new_offset.lower),
        LinearExpression(new_upper_lin_coeffs, new_offset.upper))
    rhs_backbound = None

    return lhs_backbound, rhs_backbound
Ejemplo n.º 8
0
    def tightened_variable_bounds(self,
                                  variable: RelaxVariable) -> RelaxVariable:
        """Compute tighter bounds based on the LP relaxation.

    Args:
      variable: Variable as created by the base boundprop transform. This is a
        RelaxVariable that has already been encoded into the solvers.
    Returns:
      tightened_variable: New variable, with the same name, referring to the
        same activation but whose bounds have been optimized by the LP solver.
    """
        lbs = []
        ubs = []
        for solver in self.solvers:
            nb_targets = np.prod(variable.shape[1:])
            sample_lbs = []
            sample_ubs = []
            for target_idx in range(nb_targets):
                objective = (jnp.arange(nb_targets) == target_idx).astype(
                    jnp.float32)
                lb, optimal_lb = solver.minimize_objective(
                    variable.name, objective, 0., 0)
                assert optimal_lb
                neg_ub, optimal_ub = solver.minimize_objective(
                    variable.name, -objective, 0., 0)
                assert optimal_ub
                sample_lbs.append(lb)
                sample_ubs.append(-neg_ub)
            lbs.append(sample_lbs)
            ubs.append(sample_ubs)

        tightened_base_bound = ibp.IntervalBound(
            jnp.reshape(jnp.array(lbs), variable.shape),
            jnp.reshape(jnp.array(ubs), variable.shape))
        tightened_variable = RelaxVariable(variable.name, tightened_base_bound)
        tightened_variable.set_constraints(variable.constraints)
        # TODO Now that we have obtained tighter bounds, we could make the
        # decision to encode them into the LP solver, which might make the problems
        # easier to solve. This howevere would not change the strength of the
        # relaxation.
        return tightened_variable
Ejemplo n.º 9
0
    def concretize(self):
        if self._concretized is not None:
            return self._concretized

        batch_size = self.shape[0]
        nb_act = len(self.shape) - 1
        broad_shape = (batch_size, -1) + (1, ) * nb_act
        flat_ref_lb = jnp.reshape(self.reference.lower, broad_shape)
        flat_ref_ub = jnp.reshape(self.reference.upper, broad_shape)

        concrete_lb = ((jnp.maximum(self.lower_lin.lin_coeffs, 0.) *
                        flat_ref_lb).sum(axis=1) +
                       (jnp.minimum(self.lower_lin.lin_coeffs, 0.) *
                        flat_ref_ub).sum(axis=1) + self.lower_lin.offset)
        concrete_ub = ((jnp.maximum(self.upper_lin.lin_coeffs, 0.) *
                        flat_ref_ub).sum(axis=1) +
                       (jnp.minimum(self.upper_lin.lin_coeffs, 0.) *
                        flat_ref_lb).sum(axis=1) + self.upper_lin.offset)

        self._concretized = ibp.IntervalBound(concrete_lb, concrete_ub)
        return self._concretized
Ejemplo n.º 10
0
def _chunked_optimization(bound_shape, max_parallel_nodes, optimize_chunk):
  """Perform optimization of the target in chunks.

  Args:
    bound_shape: Shape of the bound to compute
    max_parallel_nodes: How many activations to optimize at once. If =0, perform
      optimize all the nodes simultaneously.
    optimize_chunk: Function to optimize a chunk and return updated bounds.
  Returns:
    bounds: Optimized bounds.
  """
  ini_lbs = jnp.zeros(bound_shape, jnp.float32)
  ini_ubs = jnp.zeros(bound_shape, jnp.float32)
  if max_parallel_nodes == 0:
    lbs, ubs = optimize_chunk(0, (ini_lbs, ini_ubs))
  else:
    nb_opt_chunk = math.ceil(np.prod(bound_shape[1:]) / max_parallel_nodes)
    lbs, ubs = jax.lax.fori_loop(0, nb_opt_chunk, optimize_chunk,
                                 (ini_lbs, ini_ubs))
  bounds = ibp.IntervalBound(lbs, ubs)
  return bounds
Ejemplo n.º 11
0
    def get_bounds(
            self,
            to_opt_bound: nonconvex.NonConvexBound) -> bound_propagation.Bound:
        optimize_fun = self._optimizer.optimize_fun(to_opt_bound)

        def bound_fn(obj: Tensor) -> Tuple[Tensor, Tensor]:
            var_shapes, chunk_objectives = _create_opt_problems(
                to_opt_bound, obj)
            ini_var_set = {
                key: 0.5 * jnp.ones(shape)
                for key, shape in var_shapes.items()
            }

            def solve_problem(objectives: ParamSet) -> Tensor:
                # Optimize the bound for primal variables.
                opt_var_set = optimize_fun(objectives, ini_var_set)
                # Compute the resulting bound
                _, bound_vals = to_opt_bound.dual(
                    jax.lax.stop_gradient(opt_var_set), objectives)
                return bound_vals

            if any(node_idx <= to_opt_bound.index
                   for ((node_idx, *_), _) in self._branching_constraints):
                # There exists constraints that needs to be taken into account.

                # The dual vars per constraint are scalars, but we need to apply them
                # for each of the optimization objective.
                nb_targets = chunk_objectives[to_opt_bound.index].shape[0]
                # Create the dual variables for them.
                active_branching_constraints = [
                    (node_idx, neuron_idx, val, side)
                    for (node_idx, neuron_idx,
                         val), side in self._branching_constraints
                    if node_idx <= to_opt_bound.index
                ]
                nb_constraints = len(active_branching_constraints)
                dual_vars = [jnp.zeros([nb_targets])] * nb_constraints

                # Define the objective function to optimize. The branching constraints
                # are lifted into the objective function.
                def unbranched_objective(
                        dual_vars: ParamSet) -> Tuple[float, Tensor]:
                    objectives = chunk_objectives.copy()
                    base_term = jnp.zeros([nb_targets])
                    for ((node_idx, neuron_idx, val, side),
                         branch_dvar) in zip(active_branching_constraints,
                                             dual_vars):
                        # Adjust the objective function to incorporate the dual variables.
                        if node_idx not in objectives:
                            objectives[node_idx] = jnp.zeros(
                                var_shapes[node_idx])

                        # The branching constraint is encoded as:
                        #   side * neuron >= side * val
                        # (when side==1, this is neuron >= lb,
                        #  and when side==-1, this is -neuron >= -ub )
                        # To put in a canonical form \lambda_b() <= 0, this is:
                        # \lambda_b() = side * val - side * neuron

                        # Lifting the branching constraints takes us from the problem:
                        #   min_{z} f(z)
                        #   s.t.    \mu_i() <= z_i <= \eta_i()  \forall i
                        #           \lambda_b() <= 0            \forall b
                        #
                        # to
                        #   max_{\rho_b} min_{z} f(z) + \rho_b \lambda_b()
                        #                s.t \mu_i() <= z_i <= \eta_i()  \forall i
                        #   s.t  rho_b >= 0

                        # Add the term corresponding to the dual variables to the linear
                        # objective function.
                        coeff_to_add = -side * branch_dvar
                        index_to_update = jnp.index_exp[:, neuron_idx]
                        flat_node_obj = jnp.reshape(objectives[node_idx],
                                                    (nb_targets, -1))
                        flat_updated_node_obj = flat_node_obj.at[
                            index_to_update].add(coeff_to_add)
                        updated_node_obj = jnp.reshape(flat_updated_node_obj,
                                                       var_shapes[node_idx])
                        objectives[node_idx] = updated_node_obj

                        # Don't forget the terms based on the bound.
                        base_term = base_term + (side * val * branch_dvar)

                    network_term = solve_problem(objectives)
                    bound = network_term + base_term

                    return bound.sum(), bound

                def evaluate_bound(ini_dual_vars: List[Tensor]) -> Tensor:
                    ini_state = self._branching_optimizer.init(ini_dual_vars)
                    eval_and_grad_fun = jax.grad(unbranched_objective,
                                                 argnums=0,
                                                 has_aux=True)

                    # The carry consists of:
                    # - The best set of dual variables seen so far.
                    # - The current set of dual variables.
                    # - The best bound obtained so far.
                    # - The state of the optimizer.
                    # For each of the step, we will:
                    # - Evaluate the bounds by the current set of dual variables.
                    # - Update the best set of dual variables if progress was achieved.
                    # - Do an optimization step on the current set of dual variables.
                    # This way, we are guaranteed that we keep track of the dual variables
                    # producing the best bound at the end.
                    def opt_step(
                        carry: Tuple[List[Tensor], List[Tensor], Tensor,
                                     optax.OptState], _
                    ) -> Tuple[Tuple[List[Tensor], List[Tensor], Tensor,
                                     optax.OptState], None]:
                        best_lagdual, lagdual, best_bound, state = carry
                        # Compute the bound and their gradients.
                        lagdual_grads, new_bound = eval_and_grad_fun(lagdual)

                        # Update the lagrangian dual variables for the best bound seen.
                        improve_best = new_bound > best_bound
                        new_best_lagdual = []
                        for best_dvar, new_dvar in zip(best_lagdual, lagdual):
                            new_best_lagdual.append(
                                jnp.where(improve_best, new_dvar, best_dvar))
                        # Update the best bound seen
                        new_best_bound = jnp.maximum(best_bound, new_bound)

                        # Perform optimization step
                        updates, new_state = self._branching_optimizer.update(
                            lagdual_grads, state, lagdual)
                        unc_dual = optax.apply_updates(lagdual, updates)
                        new_lagdual = jax.tree_map(
                            lambda x: jnp.maximum(x, 0.), unc_dual)
                        return ((new_best_lagdual, new_lagdual, new_best_bound,
                                 new_state), None)

                    dummy_bound = -float('inf') * jnp.ones([nb_targets])
                    initial_carry = (ini_dual_vars, ini_dual_vars, dummy_bound,
                                     ini_state)

                    (best_lagdual, *_), _ = jax.lax.scan(
                        opt_step,
                        initial_carry,
                        None,
                        length=self._branching_opt_number_steps)

                    _, bound_vals = unbranched_objective(
                        jax.lax.stop_gradient(best_lagdual))

                    return bound_vals

                bound_vals = evaluate_bound(dual_vars)
            else:
                bound_vals = solve_problem(chunk_objectives)

            chunk_lbs, chunk_ubs = _unpack_opt_problem(bound_vals)
            return chunk_lbs, chunk_ubs

        return ibp.IntervalBound(*utils.chunked_bounds(
            to_opt_bound.shape, self._max_parallel_nodes, bound_fn))