def propagate( self, graph: bound_propagation.PropagationGraph, bounds: Nest[graph_traversal.GraphInput], ) -> Tuple[Nest[Bound], Dict[jax.core.Var, Union[Bound, Tensor]]]: if self._base_boundprop is not None: # Propagate the 'base' bounds in advance, for subsequent use by # the concretiser. _, base_env = bound_propagation.ForwardPropagationAlgorithm( self._base_boundprop).propagate(graph, bounds) else: # No 'base' boundprop method specified. # This is fine as long as the concretiser does not rely on base bounds. base_env = None nonconvex_transform = self._nonconvex_transform_ctor( self._concretizer, graph, base_env) output_bounds, env = bound_propagation.ForwardPropagationAlgorithm( nonconvex_transform).propagate(graph, bounds) # Always concretize the returned bounds so that `lower` and `upper` are # accessible for out_bound in jax.tree_util.tree_leaves(output_bounds): if out_bound.requires_concretizing(None): out_bound.concretize(self._concretizer, graph, base_env) return output_bounds, env
def get_bounds(self, fun, input_bounds): output = fun(input_bounds.lower) boundprop_transform = jax_verify.ibp_transform relaxation_transform = relaxation.RelaxationTransform( boundprop_transform) var, env = bound_propagation.bound_propagation( bound_propagation.ForwardPropagationAlgorithm( relaxation_transform), fun, input_bounds) objective_bias = 0. index = 0 lower_bounds = [] upper_bounds = [] for output_idx in range(output.size): objective = (jnp.arange(output.size) == output_idx).astype( jnp.float32) lower_bound, _, _ = relaxation.solve_relaxation( cvxpy_relaxation_solver.CvxpySolver, objective, objective_bias, var, env, index) neg_upper_bound, _, _ = relaxation.solve_relaxation( cvxpy_relaxation_solver.CvxpySolver, -objective, objective_bias, var, env, index) lower_bounds.append(lower_bound) upper_bounds.append(-neg_upper_bound) return jnp.array(lower_bounds), jnp.array(upper_bounds)
def solve_with_jax_verify(self): lower_bound = jnp.minimum(jnp.maximum(self.inputs - self.eps, 0.0), 1.0) upper_bound = jnp.minimum(jnp.maximum(self.inputs + self.eps, 0.0), 1.0) init_bound = jax_verify.IntervalBound(lower_bound, upper_bound) logits_fn = make_model_fn(self.network_params) solver = cvxpy_relaxation_solver.CvxpySolver relaxation_transform = relaxation.RelaxationTransform( jax_verify.ibp_transform) var, env = bound_propagation.bound_propagation( bound_propagation.ForwardPropagationAlgorithm( relaxation_transform), logits_fn, init_bound) # This solver minimizes the objective -> get max with -min(-objective) neg_value_opt, _, _ = relaxation.solve_relaxation( solver, -self.objective, -self.objective_bias, var, env, index=0, time_limit_millis=None) value_opt = -neg_value_opt return value_opt
def interval_bound_propagation(function, *bounds): """Performs IBP as described in https://arxiv.org/abs/1810.12715. Args: function: Function performing computation to obtain bounds for. Takes as only argument the network inputs. *bounds: jax_verify.IntervalBounds, bounds on the inputs of the function. Returns: output_bound: Bounds on the output of the function obtained by IBP """ output_bound, _ = bound_propagation.bound_propagation( bound_propagation.ForwardPropagationAlgorithm(bound_transform), function, *bounds) return output_bound
def propagate(self, graph: bound_propagation.PropagationGraph, inputs: Nest[GraphInput]): # Inspect the graph to figure out what are the nodes needing concretization. graph_inspector = bound_utils.GraphInspector() inspector_algorithm = bound_propagation.ForwardPropagationAlgorithm( graph_inspector) gn_outvals, env = inspector_algorithm.propagate(graph, inputs) flat_inputs, _ = jax.tree_util.tree_flatten(inputs) flat_bounds = [inp for inp in flat_inputs if isinstance(inp, bound_propagation.Bound)] input_nodes_indices = [(i,) for i in range(len(flat_bounds))] # For every node that requires relaxations, we will use a RelaxationScanner # to collect the node that it requires. relaxations = {} def collect_relaxations(graph_node): if graph_node.index not in relaxations: index_to_concretize = graph_node.index jaxpr_node = graph.jaxpr_node(index_to_concretize) scanner = _RelaxationScanner(self._relaxer) graph.backward_propagation( scanner, env, {jaxpr_node: graph_node}, input_nodes_indices) relaxations[index_to_concretize] = scanner.node_relaxations for node in graph_inspector.nodes.values(): node_primitive = node.primitive if node_primitive and node_primitive in CONCRETIZE_ARGS_PRIMITIVE: for node_arg in node.args: collect_relaxations(node_arg) # Iterate over the outputs, making notes of their index so that we can use # them to specify the objective function, and collecting the relaxations we # need to define to use them. objective_nodes = [] for gn in gn_outvals: collect_relaxations(gn) jaxpr_node = graph.jaxpr_node(gn.index) objective_nodes.append(jaxpr_node) env_with_final_bounds = self.jointly_optimize_relaxations( relaxations, graph, inputs, env, objective_nodes) outvals = [env_with_final_bounds[jaxpr_node_opted] for jaxpr_node_opted in objective_nodes] return outvals, env_with_final_bounds
def solve_planet_relaxation(logits_fn, initial_bounds, boundprop_transform, objective, objective_bias, index, solver=cvxpy_relaxation_solver.CvxpySolver): """Solves the "Planet" (Ehlers 17) or "triangle" relaxation. The general approach is to use jax_verify to generate constraints, which can then be passed to generic solvers. Note that using CVXPY will incur a large overhead when defining the LP, because we define all constraints element-wise, to avoid representing convolutional layers as a single matrix multiplication, which would be inefficient. In CVXPY, defining large numbers of constraints is slow. Args: logits_fn: Mapping from inputs (batch_size x input_size) -> (batch_size, num_classes) initial_bounds: `IntervalBound` with initial bounds on inputs, with lower and upper bounds of dimension (batch_size x input_size). boundprop_transform: bound_propagation.BoundTransform instance, such as `jax_verify.ibp_transform`. Used to pre-compute interval bounds for intermediate activations used in defining the Planet relaxation. objective: Objective to optimize, given as an array of coefficients to be applied to the output of logits_fn defining the objective to minimize objective_bias: Bias to add to objective index: Index in the batch for which to solve the relaxation solver: A relaxation.RelaxationSolver, which specifies the backend to solve the resulting LP. Returns: val: The optimal value from the relaxation solution: The optimal solution found by the solver status: The status of the relaxation solver """ relaxation_transform = relaxation.RelaxationTransform(boundprop_transform) variable, env = bound_propagation.bound_propagation( bound_propagation.ForwardPropagationAlgorithm(relaxation_transform), logits_fn, initial_bounds) value, solution, status = relaxation.solve_relaxation( solver, objective, objective_bias, variable, env, index=index, time_limit_millis=None) return value, solution, status
def ibpforwardfastlin_bound_propagation(function, *bounds): """Obtains the best of IBP and ForwardFastlin bounds. Args: function: Function performing computation to obtain bounds for. Takes as only argument the network inputs. *bounds: jax_verify.IntervalBound, bounds on the inputs of the function. Returns: output_bound: Bounds on the output of the function obtained by FastLin """ output_bound, _ = bound_propagation.bound_propagation( bound_propagation.ForwardPropagationAlgorithm( intersection.IntersectionBoundTransform( ibp.bound_transform, forward_fastlin_transform)), function, *bounds) return output_bound
def __init__( self, boundprop_transform: bound_propagation.BoundTransform, spec_fn: Callable[..., Tensor], *input_bounds: Bound, ): """Initialises a ReLU-based network SDP problem. Args: boundprop_transform: Transform to supply concrete bounds. spec_fn: Network to verify. *input_bounds: Concrete bounds on the network inputs. """ self._output_node, self._env = bound_propagation.bound_propagation( bound_propagation.ForwardPropagationAlgorithm( _SdpTransform(boundprop_transform)), spec_fn, *input_bounds)
def forward_fastlin_bound_propagation(function, *bounds): """Performs forward linear bound propagation. This is using the relu relaxation of Fastlin. (https://arxiv.org/abs/1804.09699) Args: function: Function performing computation to obtain bounds for. Takes as only argument the network inputs. *bounds: jax_verify.IntervalBound, bounds on the inputs of the function. Returns: output_bound: Bounds on the output of the function obtained by FastLin """ output_bound, _ = bound_propagation.bound_propagation( bound_propagation.ForwardPropagationAlgorithm( forward_fastlin_transform), function, *bounds) return output_bound
def propagate( self, graph: bound_propagation.PropagationGraph, inputs: Nest[GraphInput] ) -> Tuple[Nest[LayerInput], Dict[jax.core.Var, LayerInput]]: subgraph_decider = self._backward_concretizer.should_handle_as_subgraph graph_inspector = GraphInspector(subgraph_decider) inspector_algorithm = bound_propagation.ForwardPropagationAlgorithm( graph_inspector) gn_outvals, env = inspector_algorithm.propagate(graph, inputs) for node in graph_inspector.nodes.values(): # Iterate over the nodes in order so that we get intermediate bounds in # the order where we need them. node_primitive = node.primitive if (node_primitive and self._backward_concretizer.concretize_args( node.primitive)): for node_arg in node.args: if isinstance(node_arg, GraphNode): node_index_to_concretize = node_arg.index jaxpr_node = graph.jaxpr_node(node_index_to_concretize) node_to_concretize = env[jaxpr_node] if isinstance(node_to_concretize, GraphNode): # This node has not yet been concretized. Perform concretization. concrete_bound = self._backward_concretizer.concrete_bound( graph, inputs, env, jaxpr_node) env[jaxpr_node] = concrete_bound # Fill in the bounds for the inputs. for in_jaxpr_node, in_bound in zip(graph.inputs, inputs): env[in_jaxpr_node] = in_bound # Iterate over the outputs, making sure to concretize all of them. outvals = [] for gn in gn_outvals: jaxpr_node = graph.jaxpr_node(gn.index) env_node = env[jaxpr_node] # This node has not been concretized. Perform concretization. if isinstance(env_node, GraphNode): concrete_bound = self._backward_concretizer.concrete_bound( graph, inputs, env, jaxpr_node) env[jaxpr_node] = concrete_bound else: concrete_bound = env_node outvals.append(concrete_bound) return outvals, env
def get_boundprop( name: str, elision: bool) -> Callable[..., forward_linear_bounds.LinearBound]: if name == 'fastlin': relaxer = linear_bound_utils.fastlin_rvt_relaxer elif name == 'crown': relaxer = linear_bound_utils.crown_rvt_relaxer transform = forward_linear_bounds.ForwardLinearBoundTransform( relaxer, elision) algorithm = bound_propagation.ForwardPropagationAlgorithm(transform) def bound_prop(function, *bounds) -> forward_linear_bounds.LinearBound: output_bound, _ = bound_propagation.bound_propagation( algorithm, function, *bounds) return output_bound return bound_prop
def forward_crown_bound_propagation(function, *bounds): """Performs forward linear bound propagation. This is using the relu relaxation of CROWN. (https://arxiv.org/abs/1811.00866) Args: function: Function performing computation to obtain bounds for. Takes as only argument the network inputs. *bounds: jax_verify.IntervalBound, bounds on the inputs of the function. Returns: output_bound: Bounds on the output of the function obtained by FastLin """ forward_crown_transform = ForwardLinearBoundTransform( linear_bound_utils.crown_rvt_relaxer) output_bound, _ = bound_propagation.bound_propagation( bound_propagation.ForwardPropagationAlgorithm(forward_crown_transform), function, *bounds) return output_bound
def test_cvxpy_relaxation(self, model_cls): @hk.transform_with_state def model_pred(inputs, is_training, test_local_stats=False): model = model_cls() return model(inputs, is_training, test_local_stats) inps = jnp.zeros((4, 28, 28, 1), dtype=jnp.float32) params, state = model_pred.init(jax.random.PRNGKey(42), inps, is_training=True) def logits_fun(inputs): return model_pred.apply(params, state, None, inputs, False, test_local_stats=False)[0] output = logits_fun(inps) input_bounds = jax_verify.IntervalBound(inps - 1.0, inps + 1.0) boundprop_transform = jax_verify.ibp_transform relaxation_transform = relaxation.RelaxationTransform( boundprop_transform) var, env = bound_propagation.bound_propagation( bound_propagation.ForwardPropagationAlgorithm( relaxation_transform), logits_fun, input_bounds) objective_bias = 0. objective = jnp.zeros(output.shape[1:]).at[0].set(1) index = 0 lower_bound, _, _ = relaxation.solve_relaxation( cvxpy_relaxation_solver.CvxpySolver, objective, objective_bias, var, env, index) self.assertLessEqual(lower_bound, output[index, 0])
def __init__(self, forward_transform: bound_propagation.BoundTransform, backward_concretizer: BackwardConcretizer): self._forward_algorithm = bound_propagation.ForwardPropagationAlgorithm( forward_transform) self._backward_concretizer = backward_concretizer
def concrete_bound_chunk( self, graph: bound_propagation.PropagationGraph, inputs: Nest[GraphInput], env: Dict[jax.core.Var, LayerInput], node_ref: jax.core.Var, obj: Tensor, ) -> Tensor: # Analyse the relevant parts of the graph. flat_inputs, _ = jax.tree_util.tree_flatten(inputs) bound_inputs = [inp for inp in flat_inputs if isinstance(inp, bound_propagation.Bound)] input_nodes_indices = [(i,) for i in range(len(bound_inputs))] scanner = _RelaxationScanner(self._relaxer) graph.backward_propagation( scanner, env, {node_ref: env[node_ref]}, input_nodes_indices) # Allow lookup of any node's input bounds, for parameter initialisation. graph_inspector = bound_utils.GraphInspector() bound_propagation.ForwardPropagationAlgorithm( graph_inspector).propagate(graph, inputs) def input_bounds(index: Index) -> Sequence[LayerInput]: graph_node = graph_inspector.nodes[index] return [env[graph.jaxpr_node(arg.index)] if isinstance(arg, bound_utils.GraphNode) else arg for arg in graph_node.args] # Define optimisation for a single neuron's bound. (We'll vmap afterwards.) # This ensures that each neuron uses independent relaxation parameters. def optimized_concrete_bound(one_obj): def concrete_bound(relax_params): return self._bind( scanner.node_relaxations, relax_params).concrete_bound_chunk( graph, inputs, env, node_ref, jnp.expand_dims(one_obj, 0)) # Define function to optimise: summary tightness of guaranteed bounds. def objective(relax_params): lb_min = concrete_bound(relax_params) return jnp.sum(-lb_min) val_and_grad_fn = jax.value_and_grad(objective) # Optimise the relaxation parameters. initial_params = self._initial_params(scanner, input_bounds) initial_state = (initial_params, self._opt.init(initial_params), initial_params, jnp.inf) def update_state(_, state): params, opt_state, best_params, best_val = state params_val, params_grad = val_and_grad_fn(params) # Compute the next step in the optimization process. updates, next_opt_state = self._opt.update(params_grad, opt_state) next_params = optax.apply_updates(params, updates) next_params = self._project_params(scanner, next_params) # Update the best params seen. params_improved = params_val < best_val update_best_params = lambda p, best: jnp.where(params_improved, p, best) next_best_params = jax.tree_multimap(update_best_params, params, best_params) next_best_val = jnp.minimum(best_val, params_val) return next_params, next_opt_state, next_best_params, next_best_val _, _, relax_params, _ = jax.lax.fori_loop( 0, self._num_opt_steps, update_state, initial_state) # Evaluate the relaxation at these parameters. return concrete_bound(jax.lax.stop_gradient(relax_params)) return jax.vmap(optimized_concrete_bound)(obj)