def test_conv2d_crown(self): @hk.without_apply_rng @hk.transform def conv2d_model(inp): return hk.Conv2D(output_channels=1, kernel_shape=(2, 2), padding='VALID', stride=1, with_bias=True)(inp) z = jnp.array([1., 2., 3., 4.]) z = jnp.reshape(z, [1, 2, 2, 1]) params = { 'conv2_d': { 'w': jnp.ones((2, 2, 1, 1), dtype=jnp.float32), 'b': jnp.array([2.]) } } fun = functools.partial(conv2d_model.apply, params) # Test with standard interval bounds input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) output_bounds = jax_verify.backward_crown_bound_propagation( fun, input_bounds) self.assertArrayAlmostEqual(8., output_bounds.lower) self.assertArrayAlmostEqual(16., output_bounds.upper)
def test_exp_crown(self): def exp_model(inp): return jnp.exp(inp) exp_inp_shape = (4, 7) lb, ub = test_utils.sample_bounds(jax.random.PRNGKey(0), exp_inp_shape, minval=-10., maxval=10.) input_bounds = jax_verify.IntervalBound(lb, ub) output_bounds = jax_verify.backward_crown_bound_propagation( exp_model, input_bounds) uniform_inps = test_utils.sample_bounded_points( jax.random.PRNGKey(1), (lb, ub), 100) uniform_outs = jax.vmap(exp_model)(uniform_inps) empirical_min = uniform_outs.min(axis=0) empirical_max = uniform_outs.max(axis=0) self.assertGreaterEqual((output_bounds.upper - empirical_max).min(), 0., 'Invalid upper bound for Exponential. The gap ' 'between upper bound and empirical max is < 0') self.assertGreaterEqual( (empirical_min - output_bounds.lower).min(), 0., 'Invalid lower bound for Exponential. The gap' 'between emp. min and lower bound is negative.')
def test_backward_crown(self, model_cls): @hk.transform_with_state def model_pred(inputs, is_training, test_local_stats=False): model = model_cls() return model(inputs, is_training, test_local_stats) inps = jnp.zeros((4, 28, 28, 1), dtype=jnp.float32) params, state = model_pred.init(jax.random.PRNGKey(42), inps, is_training=True) def logits_fun(inputs): return model_pred.apply(params, state, None, inputs, False, test_local_stats=False)[0] input_bounds = jax_verify.IntervalBound(inps - 1.0, inps + 1.0) jax_verify.backward_crown_bound_propagation(logits_fun, input_bounds)
def test_relu_crown(self): def relu_model(inp): return jax.nn.relu(inp) z = jnp.array([[-2., 3.]]) input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) output_bounds = jax_verify.backward_crown_bound_propagation( relu_model, input_bounds) self.assertArrayAlmostEqual(jnp.array([[0., 2.]]), output_bounds.lower) self.assertArrayAlmostEqual(jnp.array([[0., 4.]]), output_bounds.upper)
def _compute_jv_bounds( input_bound: sdp_utils.IntBound, params: ModelParams, method: str, ) -> List[sdp_utils.IntBound]: """Compute bounds with jax_verify.""" jv_input_bound = jax_verify.IntervalBound(input_bound.lb, input_bound.ub) # create a function that takes as arguments the input and all parameters # that have bounds (as specified in param_bounds) and returns all # activations all_act_fun = _make_all_act_fn(params) # use jax_verify to perform (bilinear) interval bound propagation jv_param_bounds = [(p.w_bound, p.b_bound) for p in params if p.has_bounds] if method == 'ibp': _, jv_bounds = jax_verify.interval_bound_propagation( all_act_fun, jv_input_bound, *jv_param_bounds) elif method == 'fastlin': _, jv_bounds = jax_verify.forward_fastlin_bound_propagation( all_act_fun, jv_input_bound, *jv_param_bounds) elif method == 'ibpfastlin': _, jv_bounds = jax_verify.ibpforwardfastlin_bound_propagation( all_act_fun, jv_input_bound, *jv_param_bounds) elif method == 'crown': _, jv_bounds = jax_verify.backward_crown_bound_propagation( all_act_fun, jv_input_bound, *jv_param_bounds) elif method == 'nonconvex': _, jv_bounds = jax_verify.nonconvex_constopt_bound_propagation( all_act_fun, jv_input_bound, *jv_param_bounds) else: raise ValueError('Unsupported method.') # re-format bounds with internal convention bounds = [] for intermediate_bound in jv_bounds: bounds.append( sdp_utils.IntBound(lb_pre=intermediate_bound.lower, ub_pre=intermediate_bound.upper, lb=jnp.maximum(intermediate_bound.lower, 0), ub=jnp.maximum(intermediate_bound.upper, 0))) return bounds
def test_fc_crown(self): @hk.without_apply_rng @hk.transform def linear_model(inp): return hk.Linear(1)(inp) z = jnp.array([[1., 2., 3.]]) params = { 'linear': { 'w': jnp.ones((3, 1), dtype=jnp.float32), 'b': jnp.array([2.]) } } fun = functools.partial(linear_model.apply, params) # Test with standard interval bounds. input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) output_bounds = jax_verify.backward_crown_bound_propagation( fun, input_bounds) self.assertArrayAlmostEqual(5., output_bounds.lower) self.assertArrayAlmostEqual(11., output_bounds.upper)
def test_multiply_crown(self): def multiply_model(lhs, rhs): return lhs * rhs mul_inp_shape = (4, 7) lhs_lb, lhs_ub = test_utils.sample_bounds(jax.random.PRNGKey(0), mul_inp_shape, minval=-10., maxval=10.) rhs_lb, rhs_ub = test_utils.sample_bounds(jax.random.PRNGKey(1), mul_inp_shape, minval=-10., maxval=10.) lhs_bounds = jax_verify.IntervalBound(lhs_lb, lhs_ub) rhs_bounds = jax_verify.IntervalBound(rhs_lb, rhs_ub) output_bounds = jax_verify.backward_crown_bound_propagation( multiply_model, lhs_bounds, rhs_bounds) uniform_lhs_inps = test_utils.sample_bounded_points( jax.random.PRNGKey(2), (lhs_lb, lhs_ub), 100) uniform_rhs_inps = test_utils.sample_bounded_points( jax.random.PRNGKey(3), (rhs_lb, rhs_ub), 100) uniform_outs = jax.vmap(multiply_model)(uniform_lhs_inps, uniform_rhs_inps) empirical_min = uniform_outs.min(axis=0) empirical_max = uniform_outs.max(axis=0) self.assertGreaterEqual( (output_bounds.upper - empirical_max).min(), 0., 'Invalid upper bound for Multiply. The gap ' 'between upper bound and empirical max is negative') self.assertGreaterEqual( (empirical_min - output_bounds.lower).min(), 0., 'Invalid lower bound for Multiply. The gap' 'between emp. min and lower bound is negative.')