def initial_linear_bound(index, lower_bound, upper_bound): input_dim = np.prod(lower_bound.shape) lin_coeffs = jnp.reshape(jnp.eye(input_dim), (input_dim, *lower_bound.shape)) offsets = jnp.zeros_like(lower_bound) identity_lin = LinearExpression(lin_coeffs, offsets) reference_bound = ibp.IntervalBound(lower_bound, upper_bound) lin_function = LinearFunction(identity_lin, identity_lin, RefBound(index, reference_bound)) lin_bound = LinearBound([lin_function]) lin_bound.set_concretized(ibp.IntervalBound(lower_bound, upper_bound)) return lin_bound
def concretize_backward_bound(backward_bound, act_bound): """Compute the value of a backward bound. Args: backward_bound: a CrownBackwardBound, representing linear functions of activations lower and upper bounding the output of the network. act_bound: Bound on the activation that the backward_bound is a function of. Returns: bound: A concretized bound """ act_lower = _broadcast_alltargets(act_bound.lower, backward_bound.lower_lin.lin_coeffs) act_upper = _broadcast_alltargets(act_bound.upper, backward_bound.lower_lin.lin_coeffs) nb_dims_to_reduce = act_lower.ndim - backward_bound.lower_lin.offset.ndim dims_to_reduce = tuple(range(-nb_dims_to_reduce, 0)) lower_lin = backward_bound.lower_lin upper_lin = backward_bound.upper_lin lower_bound = (lower_lin.offset + jnp.sum( jnp.minimum(lower_lin.lin_coeffs, 0.) * act_upper + jnp.maximum(lower_lin.lin_coeffs, 0.) * act_lower, dims_to_reduce)) upper_bound = (upper_lin.offset + jnp.sum( jnp.maximum(upper_lin.lin_coeffs, 0.) * act_upper + jnp.minimum(upper_lin.lin_coeffs, 0.) * act_lower, dims_to_reduce)) return ibp.IntervalBound(lower_bound, upper_bound)
def tight_bounds(self, variable: RelaxVariable) -> ibp.IntervalBound: """Compute tighter bounds based on the LP relaxation. Args: variable: Variable as created by the base boundprop transform. This is a RelaxVariable that has already been encoded into the solvers. Returns: tightened_base_bound: Bounds tightened by optimizing with the LP solver. """ lbs = [] ubs = [] for solver in self.solvers: nb_targets = np.prod(variable.shape[1:]) sample_lbs = [] sample_ubs = [] for target_idx in range(nb_targets): objective = (jnp.arange(nb_targets) == target_idx).astype( jnp.float32) lb, _, optimal_lb = solver.minimize_objective( variable.name, objective, 0., self._time_limit_millis) assert optimal_lb neg_ub, _, optimal_ub = solver.minimize_objective( variable.name, -objective, 0., self._time_limit_millis) assert optimal_ub sample_lbs.append(lb) sample_ubs.append(-neg_ub) lbs.append(sample_lbs) ubs.append(sample_ubs) tightened_base_bound = ibp.IntervalBound( jnp.reshape(jnp.array(lbs), variable.shape), jnp.reshape(jnp.array(ubs), variable.shape)) return tightened_base_bound
def concrete_bound( self, graph: bound_propagation.PropagationGraph, inputs: Nest[GraphInput], env: Dict[jax.core.Var, LayerInput], node_ref: jax.core.Var, ) -> ibp.IntervalBound: """Perform backward linear bound computation for the node `index`. Args: graph: Graph to perform Backward Propagation on. inputs: Bounds on the inputs. env: Environment containing intermediate bound and shape information. node_ref: Reference of the node to obtain a bound for. Returns: concrete_bound: IntervalBound on the activation at `node_ref`. """ node = graph_traversal.read_env(env, node_ref) def bound_fn(obj: Tensor) -> Tuple[Tensor, Tensor]: # Handle lower bounds and upper bounds independently in the same chunk. obj = jnp.concatenate([obj, -obj], axis=0) all_bounds = self._concretizing_transform.concrete_bound_chunk( graph, inputs, env, node_ref, obj) # Separate out the lower and upper bounds. lower_bound, neg_upper_bound = jnp.split(all_bounds, 2, axis=0) upper_bound = -neg_upper_bound return lower_bound, upper_bound return ibp.IntervalBound( *utils.chunked_bounds(node.shape, self._max_chunk_size, bound_fn))
def concretize(self): if self._concretized is not None: return self._concretized lb = jnp.zeros(()) ub = jnp.zeros(()) for lin_fun in self._refbound_to_linfun.values(): lin_fun_lb, lin_fun_ub = lin_fun.concretize() lb = lb + lin_fun_lb ub = ub + lin_fun_ub self._concretized = ibp.IntervalBound(lb, ub) return self._concretized
def initial_linear_bound(lower_bound, upper_bound): batch_size = lower_bound.shape[0] act_shape = lower_bound.shape[1:] input_dim = np.prod(lower_bound.shape[1:]) sp_lin = jnp.reshape(jnp.eye(input_dim), (input_dim, *act_shape)) batch_lin = jnp.repeat(jnp.expand_dims(sp_lin, 0), batch_size, axis=0) identity_lin = LinearExpression(batch_lin, jnp.zeros((batch_size, *act_shape))) lin_bound = LinearBound(identity_lin, identity_lin, None) lin_bound.set_concretized(ibp.IntervalBound(lower_bound, upper_bound)) return lin_bound
def _crown_max(out_bound, lhs, rhs): """Backward propagation of Linear Bounds through a ReLU. Args: out_bound: CrownBackwardBound, linear function of network outputs bounds with regards to the output of the ReLU lhs: left input to the max, inputs to the ReLU rhs: right input to the max, we assume this to be 0 Returns: lhs_backbound: CrownBackwardBound, linear function of network outputs bounds with regards to the inputs of the ReLU. rhs_backbound: None, because we assume the second argument to be 0. """ if not (isinstance(lhs, ibp.IntervalBound) and rhs == 0.): raise NotImplementedError('Only ReLU implemented for now.') relu_on = (lhs.lower >= 0.) relu_amb = jnp.logical_and(lhs.lower < 0., lhs.upper >= 0.) ub_slope = relu_on.astype(jnp.float32) ub_slope += jnp.where( relu_amb, lhs.upper / jnp.maximum(lhs.upper - lhs.lower, 1e-12), jnp.zeros_like(lhs.lower)) ub_bias = jnp.where(relu_amb, -ub_slope * lhs.lower, jnp.zeros_like(lhs.lower)) # Crown Relu propagation. lb_slope = (ub_slope >= 0.5).astype(jnp.float32) lb_bias = jnp.zeros_like(ub_bias) lower_lin_coeffs = out_bound.lower_lin.lin_coeffs upper_lin_coeffs = out_bound.upper_lin.lin_coeffs ub_slope = _broadcast_alltargets(ub_slope, lower_lin_coeffs) lb_slope = _broadcast_alltargets(lb_slope, lower_lin_coeffs) new_lower_lin_coeffs = (jnp.minimum(lower_lin_coeffs, 0.) * ub_slope + jnp.maximum(lower_lin_coeffs, 0.) * lb_slope) new_upper_lin_coeffs = (jnp.maximum(upper_lin_coeffs, 0.) * ub_slope + jnp.minimum(upper_lin_coeffs, 0.) * lb_slope) bias_to_conc = ibp.IntervalBound(lb_bias, ub_bias) new_offset = concretize_backward_bound(out_bound, bias_to_conc) lhs_backbound = CrownBackwardBound( LinearExpression(new_lower_lin_coeffs, new_offset.lower), LinearExpression(new_upper_lin_coeffs, new_offset.upper)) rhs_backbound = None return lhs_backbound, rhs_backbound
def tightened_variable_bounds(self, variable: RelaxVariable) -> RelaxVariable: """Compute tighter bounds based on the LP relaxation. Args: variable: Variable as created by the base boundprop transform. This is a RelaxVariable that has already been encoded into the solvers. Returns: tightened_variable: New variable, with the same name, referring to the same activation but whose bounds have been optimized by the LP solver. """ lbs = [] ubs = [] for solver in self.solvers: nb_targets = np.prod(variable.shape[1:]) sample_lbs = [] sample_ubs = [] for target_idx in range(nb_targets): objective = (jnp.arange(nb_targets) == target_idx).astype( jnp.float32) lb, optimal_lb = solver.minimize_objective( variable.name, objective, 0., 0) assert optimal_lb neg_ub, optimal_ub = solver.minimize_objective( variable.name, -objective, 0., 0) assert optimal_ub sample_lbs.append(lb) sample_ubs.append(-neg_ub) lbs.append(sample_lbs) ubs.append(sample_ubs) tightened_base_bound = ibp.IntervalBound( jnp.reshape(jnp.array(lbs), variable.shape), jnp.reshape(jnp.array(ubs), variable.shape)) tightened_variable = RelaxVariable(variable.name, tightened_base_bound) tightened_variable.set_constraints(variable.constraints) # TODO Now that we have obtained tighter bounds, we could make the # decision to encode them into the LP solver, which might make the problems # easier to solve. This howevere would not change the strength of the # relaxation. return tightened_variable
def concretize(self): if self._concretized is not None: return self._concretized batch_size = self.shape[0] nb_act = len(self.shape) - 1 broad_shape = (batch_size, -1) + (1, ) * nb_act flat_ref_lb = jnp.reshape(self.reference.lower, broad_shape) flat_ref_ub = jnp.reshape(self.reference.upper, broad_shape) concrete_lb = ((jnp.maximum(self.lower_lin.lin_coeffs, 0.) * flat_ref_lb).sum(axis=1) + (jnp.minimum(self.lower_lin.lin_coeffs, 0.) * flat_ref_ub).sum(axis=1) + self.lower_lin.offset) concrete_ub = ((jnp.maximum(self.upper_lin.lin_coeffs, 0.) * flat_ref_ub).sum(axis=1) + (jnp.minimum(self.upper_lin.lin_coeffs, 0.) * flat_ref_lb).sum(axis=1) + self.upper_lin.offset) self._concretized = ibp.IntervalBound(concrete_lb, concrete_ub) return self._concretized
def _chunked_optimization(bound_shape, max_parallel_nodes, optimize_chunk): """Perform optimization of the target in chunks. Args: bound_shape: Shape of the bound to compute max_parallel_nodes: How many activations to optimize at once. If =0, perform optimize all the nodes simultaneously. optimize_chunk: Function to optimize a chunk and return updated bounds. Returns: bounds: Optimized bounds. """ ini_lbs = jnp.zeros(bound_shape, jnp.float32) ini_ubs = jnp.zeros(bound_shape, jnp.float32) if max_parallel_nodes == 0: lbs, ubs = optimize_chunk(0, (ini_lbs, ini_ubs)) else: nb_opt_chunk = math.ceil(np.prod(bound_shape[1:]) / max_parallel_nodes) lbs, ubs = jax.lax.fori_loop(0, nb_opt_chunk, optimize_chunk, (ini_lbs, ini_ubs)) bounds = ibp.IntervalBound(lbs, ubs) return bounds
def get_bounds( self, to_opt_bound: nonconvex.NonConvexBound) -> bound_propagation.Bound: optimize_fun = self._optimizer.optimize_fun(to_opt_bound) def bound_fn(obj: Tensor) -> Tuple[Tensor, Tensor]: var_shapes, chunk_objectives = _create_opt_problems( to_opt_bound, obj) ini_var_set = { key: 0.5 * jnp.ones(shape) for key, shape in var_shapes.items() } def solve_problem(objectives: ParamSet) -> Tensor: # Optimize the bound for primal variables. opt_var_set = optimize_fun(objectives, ini_var_set) # Compute the resulting bound _, bound_vals = to_opt_bound.dual( jax.lax.stop_gradient(opt_var_set), objectives) return bound_vals if any(node_idx <= to_opt_bound.index for ((node_idx, *_), _) in self._branching_constraints): # There exists constraints that needs to be taken into account. # The dual vars per constraint are scalars, but we need to apply them # for each of the optimization objective. nb_targets = chunk_objectives[to_opt_bound.index].shape[0] # Create the dual variables for them. active_branching_constraints = [ (node_idx, neuron_idx, val, side) for (node_idx, neuron_idx, val), side in self._branching_constraints if node_idx <= to_opt_bound.index ] nb_constraints = len(active_branching_constraints) dual_vars = [jnp.zeros([nb_targets])] * nb_constraints # Define the objective function to optimize. The branching constraints # are lifted into the objective function. def unbranched_objective( dual_vars: ParamSet) -> Tuple[float, Tensor]: objectives = chunk_objectives.copy() base_term = jnp.zeros([nb_targets]) for ((node_idx, neuron_idx, val, side), branch_dvar) in zip(active_branching_constraints, dual_vars): # Adjust the objective function to incorporate the dual variables. if node_idx not in objectives: objectives[node_idx] = jnp.zeros( var_shapes[node_idx]) # The branching constraint is encoded as: # side * neuron >= side * val # (when side==1, this is neuron >= lb, # and when side==-1, this is -neuron >= -ub ) # To put in a canonical form \lambda_b() <= 0, this is: # \lambda_b() = side * val - side * neuron # Lifting the branching constraints takes us from the problem: # min_{z} f(z) # s.t. \mu_i() <= z_i <= \eta_i() \forall i # \lambda_b() <= 0 \forall b # # to # max_{\rho_b} min_{z} f(z) + \rho_b \lambda_b() # s.t \mu_i() <= z_i <= \eta_i() \forall i # s.t rho_b >= 0 # Add the term corresponding to the dual variables to the linear # objective function. coeff_to_add = -side * branch_dvar index_to_update = jnp.index_exp[:, neuron_idx] flat_node_obj = jnp.reshape(objectives[node_idx], (nb_targets, -1)) flat_updated_node_obj = flat_node_obj.at[ index_to_update].add(coeff_to_add) updated_node_obj = jnp.reshape(flat_updated_node_obj, var_shapes[node_idx]) objectives[node_idx] = updated_node_obj # Don't forget the terms based on the bound. base_term = base_term + (side * val * branch_dvar) network_term = solve_problem(objectives) bound = network_term + base_term return bound.sum(), bound def evaluate_bound(ini_dual_vars: List[Tensor]) -> Tensor: ini_state = self._branching_optimizer.init(ini_dual_vars) eval_and_grad_fun = jax.grad(unbranched_objective, argnums=0, has_aux=True) # The carry consists of: # - The best set of dual variables seen so far. # - The current set of dual variables. # - The best bound obtained so far. # - The state of the optimizer. # For each of the step, we will: # - Evaluate the bounds by the current set of dual variables. # - Update the best set of dual variables if progress was achieved. # - Do an optimization step on the current set of dual variables. # This way, we are guaranteed that we keep track of the dual variables # producing the best bound at the end. def opt_step( carry: Tuple[List[Tensor], List[Tensor], Tensor, optax.OptState], _ ) -> Tuple[Tuple[List[Tensor], List[Tensor], Tensor, optax.OptState], None]: best_lagdual, lagdual, best_bound, state = carry # Compute the bound and their gradients. lagdual_grads, new_bound = eval_and_grad_fun(lagdual) # Update the lagrangian dual variables for the best bound seen. improve_best = new_bound > best_bound new_best_lagdual = [] for best_dvar, new_dvar in zip(best_lagdual, lagdual): new_best_lagdual.append( jnp.where(improve_best, new_dvar, best_dvar)) # Update the best bound seen new_best_bound = jnp.maximum(best_bound, new_bound) # Perform optimization step updates, new_state = self._branching_optimizer.update( lagdual_grads, state, lagdual) unc_dual = optax.apply_updates(lagdual, updates) new_lagdual = jax.tree_map( lambda x: jnp.maximum(x, 0.), unc_dual) return ((new_best_lagdual, new_lagdual, new_best_bound, new_state), None) dummy_bound = -float('inf') * jnp.ones([nb_targets]) initial_carry = (ini_dual_vars, ini_dual_vars, dummy_bound, ini_state) (best_lagdual, *_), _ = jax.lax.scan( opt_step, initial_carry, None, length=self._branching_opt_number_steps) _, bound_vals = unbranched_objective( jax.lax.stop_gradient(best_lagdual)) return bound_vals bound_vals = evaluate_bound(dual_vars) else: bound_vals = solve_problem(chunk_objectives) chunk_lbs, chunk_ubs = _unpack_opt_problem(bound_vals) return chunk_lbs, chunk_ubs return ibp.IntervalBound(*utils.chunked_bounds( to_opt_bound.shape, self._max_parallel_nodes, bound_fn))