def test_exp_crown(self): def exp_model(inp): return jnp.exp(inp) exp_inp_shape = (4, 7) lb, ub = test_utils.sample_bounds(jax.random.PRNGKey(0), exp_inp_shape, minval=-10., maxval=10.) input_bounds = jax_verify.IntervalBound(lb, ub) output_bounds = jax_verify.backward_crown_bound_propagation( exp_model, input_bounds) uniform_inps = test_utils.sample_bounded_points( jax.random.PRNGKey(1), (lb, ub), 100) uniform_outs = jax.vmap(exp_model)(uniform_inps) empirical_min = uniform_outs.min(axis=0) empirical_max = uniform_outs.max(axis=0) self.assertGreaterEqual((output_bounds.upper - empirical_max).min(), 0., 'Invalid upper bound for Exponential. The gap ' 'between upper bound and empirical max is < 0') self.assertGreaterEqual( (empirical_min - output_bounds.lower).min(), 0., 'Invalid lower bound for Exponential. The gap' 'between emp. min and lower bound is negative.')
def test_relu_random_fastlin(self): def relu_model(inp): return jax.nn.relu(inp) relu_inp_shape = (4, 7) lb, ub = test_utils.sample_bounds(jax.random.PRNGKey(0), relu_inp_shape, minval=-10., maxval=10.) input_bounds = jax_verify.IntervalBound(lb, ub) output_bounds = jax_verify.forward_fastlin_bound_propagation( relu_model, input_bounds) uniform_inps = test_utils.sample_bounded_points( jax.random.PRNGKey(1), (lb, ub), 100) uniform_outs = jax.vmap(relu_model)(uniform_inps) empirical_min = uniform_outs.min(axis=0) empirical_max = uniform_outs.max(axis=0) self.assertGreaterEqual((output_bounds.upper - empirical_max).min(), 0., 'Invalid upper bound for ReLU. The gap ' 'between upper bound and empirical max is < 0') self.assertGreaterEqual( (empirical_min - output_bounds.lower).min(), 0., 'Invalid lower bound for ReLU. The gap' 'between emp. min and lower bound is negative.')
def test_leaky_relu(self, negative_slope): batch_size = 5 axis_dim = 8 leaky_relu_inp_shape = (batch_size, axis_dim) def leaky_relu_model(inp): return jax.nn.leaky_relu(inp, negative_slope) bound_key = jax.random.PRNGKey(0) inp_lb, inp_ub = test_utils.sample_bounds(bound_key, leaky_relu_inp_shape, minval=-10., maxval=10.) lb_fun, ub_fun = activation_relaxation.leaky_relu_relaxation( IntervalBound(inp_lb, inp_ub), negative_slope=negative_slope) # Check that the bounds are valid uniform_check_key = jax.random.PRNGKey(1) self._check_bounds(uniform_check_key, leaky_relu_model, lb_fun, ub_fun, [inp_lb], [inp_ub]) # Sanity check the convexity of the relaxation cvx_check_key = jax.random.PRNGKey(2) self._check_convexity(cvx_check_key, lb_fun, [inp_lb], [inp_ub], True) ccv_check_key = jax.random.PRNGKey(3) self._check_convexity(ccv_check_key, ub_fun, [inp_lb], [inp_ub], False)
def test_abs(self): batch_size = 5 axis_dim = 8 abs_inp_shape = (batch_size, axis_dim) def abs_model(inp): return jnp.abs(inp) bound_key = jax.random.PRNGKey(0) inp_lb, inp_ub = test_utils.sample_bounds(bound_key, abs_inp_shape, minval=-10., maxval=10.) lb_fun, ub_fun = activation_relaxation.convex_fn_relaxation( lax.abs_p, IntervalBound(inp_lb, inp_ub)) # Check that the bounds are valid uniform_check_key = jax.random.PRNGKey(1) self._check_bounds(uniform_check_key, abs_model, lb_fun, ub_fun, [inp_lb], [inp_ub]) # Sanity check the convexity of the relaxation cvx_check_key = jax.random.PRNGKey(2) self._check_convexity(cvx_check_key, lb_fun, [inp_lb], [inp_ub], True) ccv_check_key = jax.random.PRNGKey(3) self._check_convexity(ccv_check_key, ub_fun, [inp_lb], [inp_ub], False)
def test_chunking(self, relaxer): batch_size = 3 input_size = 2 hidden_size = 5 final_size = 4 input_shape = (batch_size, input_size) hidden_lay_weight_shape = (input_size, hidden_size) final_lay_weight_shape = (hidden_size, final_size) inp_lb, inp_ub = test_utils.sample_bounds(jax.random.PRNGKey(0), input_shape, minval=-1., maxval=1.) inp_bound = jax_verify.IntervalBound(inp_lb, inp_ub) hidden_lay_weight = jax.random.uniform(jax.random.PRNGKey(1), hidden_lay_weight_shape) final_lay_weight = jax.random.uniform(jax.random.PRNGKey(2), final_lay_weight_shape) def model_fun(inp): hidden = inp @ hidden_lay_weight act = jax.nn.relu(hidden) final = act @ final_lay_weight return final if isinstance(relaxer, linear_bound_utils.ParameterizedLinearBoundsRelaxer): concretizing_transform = ( backward_crown.OptimizingLinearBoundBackwardTransform( relaxer, backward_crown.CONCRETIZE_ARGS_PRIMITIVE, optax.adam(1.e-3), num_opt_steps=10)) else: concretizing_transform = backward_crown.LinearBoundBackwardTransform( relaxer, backward_crown.CONCRETIZE_ARGS_PRIMITIVE) chunked_concretizer = backward_crown.ChunkedBackwardConcretizer( concretizing_transform, max_chunk_size=16) unchunked_concretizer = backward_crown.ChunkedBackwardConcretizer( concretizing_transform, max_chunk_size=0) chunked_algorithm = bound_utils.BackwardConcretizingAlgorithm( chunked_concretizer) unchunked_algorithm = bound_utils.BackwardConcretizingAlgorithm( unchunked_concretizer) chunked_bound, _ = bound_propagation.bound_propagation( chunked_algorithm, model_fun, inp_bound) unchunked_bound, _ = bound_propagation.bound_propagation( unchunked_algorithm, model_fun, inp_bound) np.testing.assert_array_almost_equal(chunked_bound.lower, unchunked_bound.lower) np.testing.assert_array_almost_equal(chunked_bound.upper, unchunked_bound.upper)
def test_nobatch_batch_inputs(self): batch_shape = (3, 2) unbatch_shape = (2, 4) def bilinear_model(inp_1, inp_2): return jnp.einsum('bh,hH->bH', inp_1, inp_2) lb_1, ub_1 = test_utils.sample_bounds(jax.random.PRNGKey(0), batch_shape, minval=-10, maxval=10.) lb_2, ub_2 = test_utils.sample_bounds(jax.random.PRNGKey(1), unbatch_shape, minval=-10, maxval=10.) bound_1 = jax_verify.IntervalBound(lb_1, ub_1) bound_2 = jax_verify.IntervalBound(lb_2, ub_2) output_bounds = backward_crown.backward_crown_bound_propagation( bilinear_model, bound_1, bound_2) uniform_1 = test_utils.sample_bounded_points(jax.random.PRNGKey(2), (lb_1, ub_1), 100) uniform_2 = test_utils.sample_bounded_points(jax.random.PRNGKey(3), (lb_2, ub_2), 100) uniform_outs = jax.vmap(bilinear_model)(uniform_1, uniform_2) empirical_min = uniform_outs.min(axis=0) empirical_max = uniform_outs.max(axis=0) self.assertGreaterEqual( (output_bounds.upper - empirical_max).min(), 0., 'Invalid upper bound for mix of batched/unbatched' 'input bounds.') self.assertGreaterEqual( (empirical_min - output_bounds.lower).min(), 0., 'Invalid lower bound for mix of batched/unbatched' 'input bounds.')
def test_multiply_crown(self): def multiply_model(lhs, rhs): return lhs * rhs mul_inp_shape = (4, 7) lhs_lb, lhs_ub = test_utils.sample_bounds(jax.random.PRNGKey(0), mul_inp_shape, minval=-10., maxval=10.) rhs_lb, rhs_ub = test_utils.sample_bounds(jax.random.PRNGKey(1), mul_inp_shape, minval=-10., maxval=10.) lhs_bounds = jax_verify.IntervalBound(lhs_lb, lhs_ub) rhs_bounds = jax_verify.IntervalBound(rhs_lb, rhs_ub) output_bounds = jax_verify.backward_crown_bound_propagation( multiply_model, lhs_bounds, rhs_bounds) uniform_lhs_inps = test_utils.sample_bounded_points( jax.random.PRNGKey(2), (lhs_lb, lhs_ub), 100) uniform_rhs_inps = test_utils.sample_bounded_points( jax.random.PRNGKey(3), (rhs_lb, rhs_ub), 100) uniform_outs = jax.vmap(multiply_model)(uniform_lhs_inps, uniform_rhs_inps) empirical_min = uniform_outs.min(axis=0) empirical_max = uniform_outs.max(axis=0) self.assertGreaterEqual( (output_bounds.upper - empirical_max).min(), 0., 'Invalid upper bound for Multiply. The gap ' 'between upper bound and empirical max is negative') self.assertGreaterEqual( (empirical_min - output_bounds.lower).min(), 0., 'Invalid lower bound for Multiply. The gap' 'between emp. min and lower bound is negative.')
def test_sigmoid(self, scale): batch_size = 5 axis_dim = 8 sigmoid_inp_shape = (batch_size, axis_dim) sigmoid = jax.nn.sigmoid bound_key = jax.random.PRNGKey(0) inp_lb, inp_ub = test_utils.sample_bounds(bound_key, sigmoid_inp_shape, minval=-scale, maxval=scale) lb_fun, ub_fun = activation_relaxation.sigmoid_relaxation( IntervalBound(inp_lb, inp_ub)) # Check that the bounds are valid uniform_check_key = jax.random.PRNGKey(1) self._check_bounds(uniform_check_key, sigmoid, lb_fun, ub_fun, [inp_lb], [inp_ub]) # Sanity check the convexity of the relaxation cvx_check_key = jax.random.PRNGKey(2) self._check_convexity(cvx_check_key, lb_fun, [inp_lb], [inp_ub], True) ccv_check_key = jax.random.PRNGKey(3) self._check_convexity(ccv_check_key, ub_fun, [inp_lb], [inp_ub], False)