def test_chunking(self, relaxer): batch_size = 3 input_size = 2 hidden_size = 5 final_size = 4 input_shape = (batch_size, input_size) hidden_lay_weight_shape = (input_size, hidden_size) final_lay_weight_shape = (hidden_size, final_size) inp_lb, inp_ub = test_utils.sample_bounds(jax.random.PRNGKey(0), input_shape, minval=-1., maxval=1.) inp_bound = jax_verify.IntervalBound(inp_lb, inp_ub) hidden_lay_weight = jax.random.uniform(jax.random.PRNGKey(1), hidden_lay_weight_shape) final_lay_weight = jax.random.uniform(jax.random.PRNGKey(2), final_lay_weight_shape) def model_fun(inp): hidden = inp @ hidden_lay_weight act = jax.nn.relu(hidden) final = act @ final_lay_weight return final if isinstance(relaxer, linear_bound_utils.ParameterizedLinearBoundsRelaxer): concretizing_transform = ( backward_crown.OptimizingLinearBoundBackwardTransform( relaxer, backward_crown.CONCRETIZE_ARGS_PRIMITIVE, optax.adam(1.e-3), num_opt_steps=10)) else: concretizing_transform = backward_crown.LinearBoundBackwardTransform( relaxer, backward_crown.CONCRETIZE_ARGS_PRIMITIVE) chunked_concretizer = backward_crown.ChunkedBackwardConcretizer( concretizing_transform, max_chunk_size=16) unchunked_concretizer = backward_crown.ChunkedBackwardConcretizer( concretizing_transform, max_chunk_size=0) chunked_algorithm = bound_utils.BackwardConcretizingAlgorithm( chunked_concretizer) unchunked_algorithm = bound_utils.BackwardConcretizingAlgorithm( unchunked_concretizer) chunked_bound, _ = bound_propagation.bound_propagation( chunked_algorithm, model_fun, inp_bound) unchunked_bound, _ = bound_propagation.bound_propagation( unchunked_algorithm, model_fun, inp_bound) np.testing.assert_array_almost_equal(chunked_bound.lower, unchunked_bound.lower) np.testing.assert_array_almost_equal(chunked_bound.upper, unchunked_bound.upper)
def solve_with_jax_verify(self): lower_bound = jnp.minimum(jnp.maximum(self.inputs - self.eps, 0.0), 1.0) upper_bound = jnp.minimum(jnp.maximum(self.inputs + self.eps, 0.0), 1.0) init_bound = jax_verify.IntervalBound(lower_bound, upper_bound) logits_fn = make_model_fn(self.network_params) solver = cvxpy_relaxation_solver.CvxpySolver relaxation_transform = relaxation.RelaxationTransform( jax_verify.ibp_transform) var, env = bound_propagation.bound_propagation( bound_propagation.ForwardPropagationAlgorithm( relaxation_transform), logits_fn, init_bound) # This solver minimizes the objective -> get max with -min(-objective) neg_value_opt, _, _ = relaxation.solve_relaxation( solver, -self.objective, -self.objective_bias, var, env, index=0, time_limit_millis=None) value_opt = -neg_value_opt return value_opt
def get_bounds(self, fun, input_bounds): output = fun(input_bounds.lower) boundprop_transform = jax_verify.ibp_transform relaxation_transform = relaxation.RelaxationTransform( boundprop_transform) var, env = bound_propagation.bound_propagation( bound_propagation.ForwardPropagationAlgorithm( relaxation_transform), fun, input_bounds) objective_bias = 0. index = 0 lower_bounds = [] upper_bounds = [] for output_idx in range(output.size): objective = (jnp.arange(output.size) == output_idx).astype( jnp.float32) lower_bound, _, _ = relaxation.solve_relaxation( cvxpy_relaxation_solver.CvxpySolver, objective, objective_bias, var, env, index) neg_upper_bound, _, _ = relaxation.solve_relaxation( cvxpy_relaxation_solver.CvxpySolver, -objective, objective_bias, var, env, index) lower_bounds.append(lower_bound) upper_bounds.append(-neg_upper_bound) return jnp.array(lower_bounds), jnp.array(upper_bounds)
def build_nonconvex_formulation(bound_cls, concretizer_ctor, function, *bounds): """Builds the optimizable objective. Args: bound_cls: Bound class to use. This determines what dual formulation will be computed. concretizer_ctor: Constructor for the concretizer to use to obtain intermediate bounds. function: Function performing computation to obtain bounds for. Takes as only arguments the network inputs. *bounds: jax_verify.NonConvexBound, bounds on the inputs of the function. These can be created using jax_verify.NonConvexBound.initial_nonconvex_bound. Returns: output_bound: NonConvex bound that can be optimized with a solver. """ input_transform = functools.partial(bound_cls.initial_nonconvex_bound, concretizer_ctor) primitive_transform = { primitive: functools.partial(transform, bound_cls) for primitive, transform in _nonconvex_primitive_transform.items() } bound_transform = bound_propagation.OpwiseBoundTransform( input_transform, primitive_transform) output_bound, _ = bound_propagation.bound_propagation( bound_transform, function, *bounds) return output_bound
def crown_bound_propagation(function, *bounds): """Performs CROWN as described in https://arxiv.org/abs/1811.00866. Args: function: Function performing computation to obtain bounds for. Takes as only argument the network inputs. *bounds: jax_verify.IntervalBound, bounds on the inputs of the function. Returns: output_bound: Bounds on the output of the function obtained by FastLin """ output_bound, _ = bound_propagation.bound_propagation( _crown_transform, function, *bounds) return output_bound
def interval_bound_propagation(function, *bounds): """Performs IBP as described in https://arxiv.org/abs/1810.12715. Args: function: Function performing computation to obtain bounds for. Takes as only argument the network inputs. *bounds: jax_verify.IntervalBounds, bounds on the inputs of the function. Returns: output_bound: Bounds on the output of the function obtained by IBP """ output_bound, _ = bound_propagation.bound_propagation( bound_transform, function, *bounds) return output_bound
def ibpfastlin_bound_propagation(function, *bounds): """Obtains the best of IBP and Fastlin bounds. Args: function: Function performing computation to obtain bounds for. Takes as only argument the network inputs. *bounds: jax_verify.IntervalBound, bounds on the inputs of the function. Returns: output_bound: Bounds on the output of the function obtained by FastLin """ output_bound, _ = bound_propagation.bound_propagation( intersection.IntersectionBoundTransform(ibp.bound_transform, fastlin_transform), function, *bounds) return output_bound
def solve_planet_relaxation(logits_fn, initial_bounds, boundprop_transform, objective, objective_bias, index, solver=cvxpy_relaxation_solver.CvxpySolver): """Solves the "Planet" (Ehlers 17) or "triangle" relaxation. The general approach is to use jax_verify to generate constraints, which can then be passed to generic solvers. Note that using CVXPY will incur a large overhead when defining the LP, because we define all constraints element-wise, to avoid representing convolutional layers as a single matrix multiplication, which would be inefficient. In CVXPY, defining large numbers of constraints is slow. Args: logits_fn: Mapping from inputs (batch_size x input_size) -> (batch_size, num_classes) initial_bounds: `IntervalBound` with initial bounds on inputs, with lower and upper bounds of dimension (batch_size x input_size). boundprop_transform: bound_propagation.BoundTransform instance, such as `jax_verify.ibp_transform`. Used to pre-compute interval bounds for intermediate activations used in defining the Planet relaxation. objective: Objective to optimize, given as an array of coefficients to be applied to the output of logits_fn defining the objective to minimize objective_bias: Bias to add to objective index: Index in the batch for which to solve the relaxation solver: A relaxation.RelaxationSolver, which specifies the backend to solve the resulting LP. Returns: val: The optimal value from the relaxation solution: The optimal solution found by the solver status: The status of the relaxation solver """ relaxation_transform = relaxation.RelaxationTransform(boundprop_transform) variable, env = bound_propagation.bound_propagation( bound_propagation.ForwardPropagationAlgorithm(relaxation_transform), logits_fn, initial_bounds) value, solution, status = relaxation.solve_relaxation( solver, objective, objective_bias, variable, env, index=index, time_limit_millis=None) return value, solution, status
def __init__( self, boundprop_transform: bound_propagation.BoundTransform, spec_fn: Callable[..., Tensor], *input_bounds: Bound, ): """Initialises a ReLU-based network SDP problem. Args: boundprop_transform: Transform to supply concrete bounds. spec_fn: Network to verify. *input_bounds: Concrete bounds on the network inputs. """ self._output_node, self._env = bound_propagation.bound_propagation( bound_propagation.ForwardPropagationAlgorithm( _SdpTransform(boundprop_transform)), spec_fn, *input_bounds)
def backward_fastlin_bound_propagation( function: Callable[..., Nest[Tensor]], *bounds: Nest[GraphInput]) -> Nest[LayerInput]: """Performs FastLin as described in https://arxiv.org/abs/1804.09699. Args: function: Function performing computation to obtain bounds for. Takes as only argument the network inputs. *bounds: jax_verify.IntervalBound, bounds on the inputs of the function. Returns: output_bound: Bounds on the output of the function obtained by FastLin """ backward_fastlin_algorithm = bound_utils.BackwardConcretizingAlgorithm( backward_fastlin_concretizer) output_bound, _ = bound_propagation.bound_propagation( backward_fastlin_algorithm, function, *bounds) return output_bound
def forward_fastlin_bound_propagation(function, *bounds): """Performs forward linear bound propagation. This is using the relu relaxation of Fastlin. (https://arxiv.org/abs/1804.09699) Args: function: Function performing computation to obtain bounds for. Takes as only argument the network inputs. *bounds: jax_verify.IntervalBound, bounds on the inputs of the function. Returns: output_bound: Bounds on the output of the function obtained by FastLin """ output_bound, _ = bound_propagation.bound_propagation( bound_propagation.ForwardPropagationAlgorithm( forward_fastlin_transform), function, *bounds) return output_bound
def test_equal_bounds_parameterized(self): model = jax.nn.relu sample_value = jnp.array([-1., 1.]) inp_bound = jax_verify.IntervalBound(sample_value, sample_value) concretizer = backward_crown.ChunkedBackwardConcretizer( backward_crown.OptimizingLinearBoundBackwardTransform( linear_bound_utils.parameterized_relaxer, backward_crown.CONCRETIZE_ARGS_PRIMITIVE, optax.adam(1.e-3), num_opt_steps=10)) algorithm = bound_utils.BackwardConcretizingAlgorithm(concretizer) bound, _ = bound_propagation.bound_propagation(algorithm, model, inp_bound) np.testing.assert_array_almost_equal(bound.lower, bound.upper)
def forward_crown_bound_propagation(function, *bounds): """Performs forward linear bound propagation. This is using the relu relaxation of CROWN. (https://arxiv.org/abs/1811.00866) Args: function: Function performing computation to obtain bounds for. Takes as only argument the network inputs. *bounds: jax_verify.IntervalBound, bounds on the inputs of the function. Returns: output_bound: Bounds on the output of the function obtained by FastLin """ forward_crown_transform = ForwardLinearBoundTransform( linear_bound_utils.crown_rvt_relaxer) output_bound, _ = bound_propagation.bound_propagation( bound_propagation.ForwardPropagationAlgorithm(forward_crown_transform), function, *bounds) return output_bound
def crownibp_bound_propagation( function: Callable[..., Nest[Tensor]], *bounds: Nest[GraphInput]) -> Nest[LayerInput]: """Performs Crown-IBP as described in https://arxiv.org/abs/1906.06316. We first perform IBP to obtain intermediate bounds and then propagate linear bounds backwards. Args: function: Function performing computation to obtain bounds for. Takes as only argument the network inputs. *bounds: jax_verify.IntervalBounds, bounds on the inputs of the function. Returns: output_bounds: Bounds on the outputs of the function obtained by Crown-IBP """ crown_ibp_algorithm = bound_utils.BackwardAlgorithmForwardConcretization( ibp.bound_transform, backward_crown_concretizer) output_bounds, _ = bound_propagation.bound_propagation( crown_ibp_algorithm, function, *bounds) return output_bounds
def nonconvex_ibp_bound_propagation( function: Callable[..., Nest[Tensor]], *bounds: Nest[graph_traversal.GraphInput], graph_simplifier=synthetic_primitives.default_simplifier, ) -> Nest[nonconvex.NonConvexBound]: """Builds the non-convex objective using IBP. Args: function: Function performing computation to obtain bounds for. Takes as only arguments the network inputs. *bounds: Bounds on the inputs of the function. graph_simplifier: What graph simplifier to use. Returns: output_bounds: NonConvex bounds that can be optimized with a solver. """ algorithm = nonconvex.nonconvex_algorithm( duals.WolfeNonConvexBound, nonconvex.BaseBoundConcretizer(), base_boundprop=ibp.bound_transform) output_bounds, _ = bound_propagation.bound_propagation( algorithm, function, *bounds, graph_simplifier=graph_simplifier) return output_bounds
def _nonconvex_boundprop(params, x, epsilon, input_bounds, nonconvex_boundprop_steps=100, nonconvex_boundprop_nodes=128): """Wrapper for nonconvex bound propagation.""" # Get initial bounds for boundprop init_bounds = utils.init_bound(x, epsilon, input_bounds=input_bounds, add_batch_dim=False) # Build fn to boundprop through all_act_fun = functools.partial(utils.predict_cnn, params, include_preactivations=True) # Collect the intermediate bounds. input_bound = jax_verify.IntervalBound(init_bounds.lb, init_bounds.ub) optimizer = optimizers.OptimizingConcretizer( FistaOptimizer(num_steps=nonconvex_boundprop_steps), max_parallel_nodes=nonconvex_boundprop_nodes) nonconvex_algorithm = nonconvex.nonconvex_algorithm( duals.WolfeNonConvexBound, optimizer) all_outputs, _ = bound_propagation.bound_propagation( nonconvex_algorithm, all_act_fun, input_bound) _, intermediate_nonconvex_bounds = all_outputs bounds = [init_bounds] for nncvx_bound in intermediate_nonconvex_bounds: bounds.append( utils.IntBound(lb_pre=nncvx_bound.lower, ub_pre=nncvx_bound.upper, lb=jnp.maximum(nncvx_bound.lower, 0), ub=jnp.maximum(nncvx_bound.upper, 0))) return bounds
def crownibp_bound_propagation(function, bounds): """Performs Crown-IBP as described in https://arxiv.org/abs/1906.06316. We first perform IBP to obtain intermediate bounds and then propagate linear bounds backwards. Args: function: Function performing computation to obtain bounds for. Takes as only argument the network inputs. bounds: jax_verify.IntervalBounds, bounds on the inputs of the function. Returns: output_bound: Bounds on the output of the function obtained by Crown-IBP """ ibp_bound, graph = bound_propagation.bound_propagation( ibp.bound_transform, function, bounds) # Define the initial bound to propagate backward assert hasattr(ibp_bound, 'shape'), ( 'crownibp_bound_propagation requires `function` to output a single array ' 'as opposed to an arbitrary pytree') batch_size = ibp_bound.shape[0] act_shape = ibp_bound.shape[1:] nb_act = np.prod(act_shape) identity_lin_coeffs = jnp.reshape(jnp.eye(nb_act), act_shape + act_shape) initial_lin_coeffs = jnp.repeat(jnp.expand_dims(identity_lin_coeffs, 0), batch_size, axis=0) initial_offsets = jnp.zeros_like(ibp_bound.lower) initial_backward_bound = CrownBackwardBound( LinearExpression(initial_lin_coeffs, initial_offsets), LinearExpression(initial_lin_coeffs, initial_offsets)) input_fun, = graph.backward_propagation(_primitive_transform, sum, initial_backward_bound) return concretize_backward_bound(input_fun, bounds)
def test_cvxpy_relaxation(self, model_cls): @hk.transform_with_state def model_pred(inputs, is_training, test_local_stats=False): model = model_cls() return model(inputs, is_training, test_local_stats) inps = jnp.zeros((4, 28, 28, 1), dtype=jnp.float32) params, state = model_pred.init(jax.random.PRNGKey(42), inps, is_training=True) def logits_fun(inputs): return model_pred.apply(params, state, None, inputs, False, test_local_stats=False)[0] output = logits_fun(inps) input_bounds = jax_verify.IntervalBound(inps - 1.0, inps + 1.0) boundprop_transform = jax_verify.ibp_transform relaxation_transform = relaxation.RelaxationTransform( boundprop_transform) var, env = bound_propagation.bound_propagation( bound_propagation.ForwardPropagationAlgorithm( relaxation_transform), logits_fun, input_bounds) objective_bias = 0. objective = jnp.zeros(output.shape[1:]).at[0].set(1) index = 0 lower_bound, _, _ = relaxation.solve_relaxation( cvxpy_relaxation_solver.CvxpySolver, objective, objective_bias, var, env, index) self.assertLessEqual(lower_bound, output[index, 0])
def nonconvex_constopt_bound_propagation( function: Callable[..., Nest[Tensor]], *bounds: Nest[graph_traversal.GraphInput], graph_simplifier=synthetic_primitives.default_simplifier, ) -> Nest[nonconvex.NonConvexBound]: """Builds the optimizable objective. Args: function: Function performing computation to obtain bounds for. Takes as only arguments the network inputs. *bounds: Bounds on the inputs of the function. graph_simplifier: What graph simplifier to use. Returns: output_bounds: NonConvex bounds that can be optimized with a solver. """ nostep_optimizer = optimizers.OptimizingConcretizer( optimizers.PGDOptimizer(0, 0., optimize_dual=False), max_parallel_nodes=512) algorithm = nonconvex.nonconvex_algorithm(duals.WolfeNonConvexBound, nostep_optimizer) output_bounds, _ = bound_propagation.bound_propagation( algorithm, function, *bounds, graph_simplifier=graph_simplifier) return output_bounds
def backward_rvt_bound_propagation( function: Callable[..., Nest[Tensor]], *bounds: Nest[GraphInput]) -> Nest[LayerInput]: """Performs CROWN as described in https://arxiv.org/abs/1811.00866. Args: function: Function performing computation to obtain bounds for. Takes as only argument the network inputs. *bounds: jax_verify.IntervalBound, bounds on the inputs of the function. Returns: output_bound: Bounds on the output of the function obtained by FastLin """ backward_crown_algorithm = bound_utils.BackwardConcretizingAlgorithm( backward_crown_concretizer) expand_softmax_simplifier_chain = synthetic_primitives.simplifier_composition( synthetic_primitives.activation_simplifier, synthetic_primitives.hoist_constant_computations, synthetic_primitives.expand_softmax_simplifier, synthetic_primitives.group_linear_sequence, synthetic_primitives.group_posbilinear) output_bound, _ = bound_propagation.bound_propagation( backward_crown_algorithm, function, *bounds, graph_simplifier=expand_softmax_simplifier_chain) return output_bound
def bound_prop(function, *bounds) -> forward_linear_bounds.LinearBound: output_bound, _ = bound_propagation.bound_propagation( algorithm, function, *bounds) return output_bound