def _check_convexity(self, key, fun, lbs, ubs, is_convex, nb_samples=100): """Check that the function is convex or concave. We do this by sanity-checking that the function is below its chord. Args: key: PRNG key for random number generation fun: Function to be checked. lbs: List of lower bounds on the inputs ubs: List of upper bounds on the inputs is_convex: Boolean, if True: check that the function is convex. if False: check that the function is concave. nb_samples: How many random samples to draw for testing. """ assert len(lbs) == len(ubs) keys = jax.random.split(key, 2*len(lbs) + 1) a_inps = [] b_inps = [] interp_inps = [] interp_coeffs = jax.random.uniform(keys[-1], (nb_samples,)) for inp_idx, bounds in enumerate(zip(lbs, ubs)): interp_coeffs_shape = (-1,) + (1,)*bounds[0].ndim broad_interp_coeffs = jnp.reshape(interp_coeffs, interp_coeffs_shape) a_inp = test_utils.sample_bounded_points(keys[2*inp_idx], bounds, nb_samples) b_inp = test_utils.sample_bounded_points(keys[2*inp_idx + 1], bounds, nb_samples) interp_inp = (a_inp * broad_interp_coeffs + b_inp * (1. - broad_interp_coeffs)) a_inps.append(a_inp) b_inps.append(b_inp) interp_inps.append(interp_inp) vmap_fun = jax.vmap(fun) a_eval = vmap_fun(*a_inps) b_eval = vmap_fun(*b_inps) interp_eval = vmap_fun(*interp_inps) interp_coeffs_shape = (-1,) + (1,)*(interp_eval.ndim - 1) broad_interp_coeffs = jnp.reshape(interp_coeffs, interp_coeffs_shape) chord_eval = (a_eval * broad_interp_coeffs + b_eval * (1. - broad_interp_coeffs)) if is_convex: self.assertGreaterEqual( (chord_eval - interp_eval).min(), -TOL, msg='Function is not convex') else: self.assertGreaterEqual( (interp_eval - chord_eval).min(), -TOL, msg='Function is not concave')
def test_exp_crown(self): def exp_model(inp): return jnp.exp(inp) exp_inp_shape = (4, 7) lb, ub = test_utils.sample_bounds(jax.random.PRNGKey(0), exp_inp_shape, minval=-10., maxval=10.) input_bounds = jax_verify.IntervalBound(lb, ub) output_bounds = jax_verify.backward_crown_bound_propagation( exp_model, input_bounds) uniform_inps = test_utils.sample_bounded_points( jax.random.PRNGKey(1), (lb, ub), 100) uniform_outs = jax.vmap(exp_model)(uniform_inps) empirical_min = uniform_outs.min(axis=0) empirical_max = uniform_outs.max(axis=0) self.assertGreaterEqual((output_bounds.upper - empirical_max).min(), 0., 'Invalid upper bound for Exponential. The gap ' 'between upper bound and empirical max is < 0') self.assertGreaterEqual( (empirical_min - output_bounds.lower).min(), 0., 'Invalid lower bound for Exponential. The gap' 'between emp. min and lower bound is negative.')
def test_relu_random_fastlin(self): def relu_model(inp): return jax.nn.relu(inp) relu_inp_shape = (4, 7) lb, ub = test_utils.sample_bounds(jax.random.PRNGKey(0), relu_inp_shape, minval=-10., maxval=10.) input_bounds = jax_verify.IntervalBound(lb, ub) output_bounds = jax_verify.forward_fastlin_bound_propagation( relu_model, input_bounds) uniform_inps = test_utils.sample_bounded_points( jax.random.PRNGKey(1), (lb, ub), 100) uniform_outs = jax.vmap(relu_model)(uniform_inps) empirical_min = uniform_outs.min(axis=0) empirical_max = uniform_outs.max(axis=0) self.assertGreaterEqual((output_bounds.upper - empirical_max).min(), 0., 'Invalid upper bound for ReLU. The gap ' 'between upper bound and empirical max is < 0') self.assertGreaterEqual( (empirical_min - output_bounds.lower).min(), 0., 'Invalid lower bound for ReLU. The gap' 'between emp. min and lower bound is negative.')
def test_nobatch_batch_inputs(self): batch_shape = (3, 2) unbatch_shape = (2, 4) def bilinear_model(inp_1, inp_2): return jnp.einsum('bh,hH->bH', inp_1, inp_2) lb_1, ub_1 = test_utils.sample_bounds(jax.random.PRNGKey(0), batch_shape, minval=-10, maxval=10.) lb_2, ub_2 = test_utils.sample_bounds(jax.random.PRNGKey(1), unbatch_shape, minval=-10, maxval=10.) bound_1 = jax_verify.IntervalBound(lb_1, ub_1) bound_2 = jax_verify.IntervalBound(lb_2, ub_2) output_bounds = backward_crown.backward_crown_bound_propagation( bilinear_model, bound_1, bound_2) uniform_1 = test_utils.sample_bounded_points(jax.random.PRNGKey(2), (lb_1, ub_1), 100) uniform_2 = test_utils.sample_bounded_points(jax.random.PRNGKey(3), (lb_2, ub_2), 100) uniform_outs = jax.vmap(bilinear_model)(uniform_1, uniform_2) empirical_min = uniform_outs.min(axis=0) empirical_max = uniform_outs.max(axis=0) self.assertGreaterEqual( (output_bounds.upper - empirical_max).min(), 0., 'Invalid upper bound for mix of batched/unbatched' 'input bounds.') self.assertGreaterEqual( (empirical_min - output_bounds.lower).min(), 0., 'Invalid lower bound for mix of batched/unbatched' 'input bounds.')
def test_multiply_crown(self): def multiply_model(lhs, rhs): return lhs * rhs mul_inp_shape = (4, 7) lhs_lb, lhs_ub = test_utils.sample_bounds(jax.random.PRNGKey(0), mul_inp_shape, minval=-10., maxval=10.) rhs_lb, rhs_ub = test_utils.sample_bounds(jax.random.PRNGKey(1), mul_inp_shape, minval=-10., maxval=10.) lhs_bounds = jax_verify.IntervalBound(lhs_lb, lhs_ub) rhs_bounds = jax_verify.IntervalBound(rhs_lb, rhs_ub) output_bounds = jax_verify.backward_crown_bound_propagation( multiply_model, lhs_bounds, rhs_bounds) uniform_lhs_inps = test_utils.sample_bounded_points( jax.random.PRNGKey(2), (lhs_lb, lhs_ub), 100) uniform_rhs_inps = test_utils.sample_bounded_points( jax.random.PRNGKey(3), (rhs_lb, rhs_ub), 100) uniform_outs = jax.vmap(multiply_model)(uniform_lhs_inps, uniform_rhs_inps) empirical_min = uniform_outs.min(axis=0) empirical_max = uniform_outs.max(axis=0) self.assertGreaterEqual( (output_bounds.upper - empirical_max).min(), 0., 'Invalid upper bound for Multiply. The gap ' 'between upper bound and empirical max is negative') self.assertGreaterEqual( (empirical_min - output_bounds.lower).min(), 0., 'Invalid lower bound for Multiply. The gap' 'between emp. min and lower bound is negative.')
def _check_bounds(self, key, fun, lb_fun, ub_fun, lbs, ubs, nb_samples=1000): """Check that lb_fun and ub_fun actually bound the function fun. This is evaluated at a number of random samples. Args: key: PRNG key for random number generation fun: Function to be bounded. lb_fun: Lower bound function ub_fun: Upper bound function lbs: List of lower bounds on the inputs ubs: List of upper bounds on the inputs nb_samples: How many random samples to draw for testing. """ assert len(lbs) == len(ubs) keys = jax.random.split(key, len(lbs)) # Build the uniform samples. inps = [] for inp_idx, bounds in enumerate(zip(lbs, ubs)): inps.append(test_utils.sample_bounded_points( keys[inp_idx], bounds, nb_samples)) vmap_fun = jax.vmap(fun) vmap_lbfun = jax.vmap(lb_fun) vmap_ubfun = jax.vmap(ub_fun) samples_eval = vmap_fun(*inps) lb_eval = vmap_lbfun(*inps) ub_eval = vmap_ubfun(*inps) self.assertGreaterEqual( (samples_eval - lb_eval).min(), -TOL, msg='Lower Bound is invalid') self.assertGreaterEqual( (ub_eval - samples_eval).min(), -TOL, msg='Upper Bound is invalid')