def test_conv2d_ibp(self): def conv2d_model(inp): return hk.Conv2D(output_channels=1, kernel_shape=(2, 2), padding='VALID', stride=1, with_bias=True)(inp) z = jnp.array([1., 2., 3., 4.]) z = jnp.reshape(z, [1, 2, 2, 1]) params = { 'conv2_d': { 'w': jnp.ones((2, 2, 1, 1), dtype=jnp.float32), 'b': jnp.array([2.]) } } fun = functools.partial( hk.without_apply_rng(hk.transform(conv2d_model)).apply, params) input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) output_bounds = jax_verify.interval_bound_propagation( fun, input_bounds) self.assertAlmostEqual(8., output_bounds.lower) self.assertAlmostEqual(16., output_bounds.upper)
def _get_reciprocal_bound(l: jnp.array, u: jnp.array, logits_params: LayerParams, label: int) -> jnp.array: """Helped for computing bound on label softmax given interval bounds on pre logits.""" def fwd(x, w, b): wdiff = jnp.reshape(w[:, label], [-1, 1]) - w bdiff = b[label] - b return x @ wdiff + bdiff x_bound = jax_verify.IntervalBound( lower_bound=jnp.reshape(l, [l.shape[0], -1]), upper_bound=jnp.reshape(u, [u.shape[0], -1])) params_bounds = [] if logits_params.w_bound is None: fwd = functools.partial(fwd, w=logits_params.w) else: params_bounds.append(logits_params.w_bound) if logits_params.b_bound is None: fwd = functools.partial(fwd, b=logits_params.b) else: params_bounds.append(logits_params.b_bound) fwd_bound = jax_verify.interval_bound_propagation(fwd, x_bound, *params_bounds) return fwd_bound
def test_sqrt(self, input_bounds, expected): input_bounds = jax_verify.IntervalBound( np.array([input_bounds[0], 0.0]), np.array([input_bounds[1], 0.0])) output_bounds = jax_verify.interval_bound_propagation( jnp.sqrt, input_bounds) np.testing.assert_array_equal(np.array([expected[0], 0.0]), output_bounds.lower) np.testing.assert_array_equal(np.array([expected[1], 0.0]), output_bounds.upper)
def test_model_structure_nostate(self, model): z = jnp.array([[1., 2., 3.]]) params = model.init(jax.random.PRNGKey(1), z) input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0) fun_to_prop = functools.partial(model.apply, params) output_bounds = jax_verify.interval_bound_propagation( fun_to_prop, input_bounds) self.assertTrue(all(output_bounds.upper >= output_bounds.lower))
def test_passthrough_primitive(self, fn, inputs): z = jnp.array(inputs) input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) output_bounds = jax_verify.interval_bound_propagation(fn, input_bounds) self.assertArrayAlmostEqual(fn(input_bounds.lower), output_bounds.lower) self.assertArrayAlmostEqual(fn(input_bounds.upper), output_bounds.upper)
def test_multioutput_model(self): z = jnp.array([[1., 2., 3.]]) fun = hk.without_apply_rng(hk.transform(residual_model_all_act)) params = fun.init(jax.random.PRNGKey(1), z) input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0) fun_to_prop = functools.partial(fun.apply, params) output_bounds = jax_verify.interval_bound_propagation( fun_to_prop, input_bounds) self.assertLen(output_bounds, 7)
def main(unused_args): # Load the parameters of an existing model. model_pred, params = load_model(FLAGS.model) logits_fn = functools.partial(model_pred, params) # Load some test samples with jax_verify.open_file('mnist/x_test_first100.npy', 'rb') as f: inputs = np.load(f) # Compute boundprop bounds eps = 0.1 lower_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] - eps, 0.0), 1.0) upper_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] + eps, 0.0), 1.0) init_bound = jax_verify.IntervalBound(lower_bound, upper_bound) if FLAGS.boundprop_method == 'fastlin': final_bound = jax_verify.fastlin_bound_propagation( logits_fn, init_bound) boundprop_transform = jax_verify.fastlin_transform elif FLAGS.boundprop_method == 'ibp': final_bound = jax_verify.interval_bound_propagation( logits_fn, init_bound) boundprop_transform = jax_verify.ibp_transform else: raise NotImplementedError('Only ibp/fastlin boundprop are' 'currently supported') dummy_output = model_pred(params, inputs) # Run LP solver objective = jnp.where( jnp.arange(dummy_output[0, ...].size) == 0, jnp.ones_like(dummy_output[0, ...]), jnp.zeros_like(dummy_output[0, ...])) objective_bias = 0. value, status = jax_verify.solve_planet_relaxation(logits_fn, init_bound, boundprop_transform, objective, objective_bias, index=0) logging.info('Relaxation LB is : %f, Status is %s', value, status) value, status = jax_verify.solve_planet_relaxation(logits_fn, init_bound, boundprop_transform, -objective, objective_bias, index=0) logging.info('Relaxation UB is : %f, Status is %s', -value, status) logging.info('Boundprop LB is : %f', final_bound.lower[0, 0]) logging.info('Boundprop UB is : %f', final_bound.upper[0, 0])
def test_ibp(self, model_cls): @hk.transform_with_state def model_pred(inputs, is_training, test_local_stats=False): model = model_cls() return model(inputs, is_training, test_local_stats) inps = jnp.zeros((4, 28, 28, 1), dtype=jnp.float32) params, state = model_pred.init(jax.random.PRNGKey(42), inps, is_training=True) def logits_fun(inputs): return model_pred.apply(params, state, None, inputs, False, test_local_stats=False)[0] input_bounds = jax_verify.IntervalBound(inps - 1.0, inps + 1.0) jax_verify.interval_bound_propagation(logits_fun, input_bounds)
def test_relu_ibp(self): def relu_model(inp): return jax.nn.relu(inp) z = jnp.array([[-2., 3.]]) input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) output_bounds = jax_verify.interval_bound_propagation( relu_model, input_bounds) self.assertArrayAlmostEqual(jnp.array([[0., 2.]]), output_bounds.lower) self.assertArrayAlmostEqual(jnp.array([[0., 4.]]), output_bounds.upper)
def test_keywords_argument(self): @hk.without_apply_rng @hk.transform def forward(inputs, use_2=False): model = StaticArgumentModel() return model(inputs, use_2) z = jnp.array([[1., 2., 3.]]) params = forward.init(jax.random.PRNGKey(1), z, use_2=True) input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0) fun_to_prop = functools.partial(forward.apply, params, use_2=True) output_bounds = jax_verify.interval_bound_propagation( fun_to_prop, input_bounds) self.assertTrue((output_bounds.upper >= output_bounds.lower).all())
def test_softplus_ibp(self): def softplus_model(inp): return jax.nn.softplus(inp) z = jnp.array([[-2., 3.]]) input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) output_bounds = jax_verify.interval_bound_propagation( softplus_model, input_bounds) self.assertArrayAlmostEqual(jnp.logaddexp(z - 1., 0), output_bounds.lower) self.assertArrayAlmostEqual(jnp.logaddexp(z + 1., 0), output_bounds.upper)
def test_staticargument_last(self): @hk.without_apply_rng @hk.transform def forward(inputs, use_2): model = StaticArgumentModel() return model(inputs, use_2) z = jnp.array([[1., 2., 3.]]) params = forward.init(jax.random.PRNGKey(1), z, True) input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0) def fun_to_prop(inputs): return forward.apply(params, inputs, True) output_bounds = jax_verify.interval_bound_propagation( fun_to_prop, input_bounds) self.assertTrue((output_bounds.upper >= output_bounds.lower).all())
def _compute_jv_bounds( input_bound: sdp_utils.IntBound, params: ModelParams, method: str, ) -> List[sdp_utils.IntBound]: """Compute bounds with jax_verify.""" jv_input_bound = jax_verify.IntervalBound(input_bound.lb, input_bound.ub) # create a function that takes as arguments the input and all parameters # that have bounds (as specified in param_bounds) and returns all # activations all_act_fun = _make_all_act_fn(params) # use jax_verify to perform (bilinear) interval bound propagation jv_param_bounds = [(p.w_bound, p.b_bound) for p in params if p.has_bounds] if method == 'ibp': _, jv_bounds = jax_verify.interval_bound_propagation( all_act_fun, jv_input_bound, *jv_param_bounds) elif method == 'fastlin': _, jv_bounds = jax_verify.forward_fastlin_bound_propagation( all_act_fun, jv_input_bound, *jv_param_bounds) elif method == 'ibpfastlin': _, jv_bounds = jax_verify.ibpforwardfastlin_bound_propagation( all_act_fun, jv_input_bound, *jv_param_bounds) elif method == 'crown': _, jv_bounds = jax_verify.backward_crown_bound_propagation( all_act_fun, jv_input_bound, *jv_param_bounds) elif method == 'nonconvex': _, jv_bounds = jax_verify.nonconvex_constopt_bound_propagation( all_act_fun, jv_input_bound, *jv_param_bounds) else: raise ValueError('Unsupported method.') # re-format bounds with internal convention bounds = [] for intermediate_bound in jv_bounds: bounds.append( sdp_utils.IntBound(lb_pre=intermediate_bound.lower, ub_pre=intermediate_bound.upper, lb=jnp.maximum(intermediate_bound.lower, 0), ub=jnp.maximum(intermediate_bound.upper, 0))) return bounds
def test_stateful_model(self): @hk.transform_with_state def forward(inputs, is_training, test_local_stats=False): model = ModelWithState() return model(inputs, is_training, test_local_stats) z = jnp.array([[1., 2., 3.]]) params, state = forward.init(jax.random.PRNGKey(1), z, True, False) def fun_to_prop(inputs): outs = forward.apply(params, state, jax.random.PRNGKey(1), inputs, False, False) # Ignore the outputs that are not the network outputs. return outs[0] input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0) # Consider as static the state, the random generator, and the flags output_bounds = jax_verify.interval_bound_propagation( fun_to_prop, input_bounds) self.assertTrue((output_bounds.upper >= output_bounds.lower).all())
def test_linear_ibp(self): def linear_model(inp): return hk.Linear(1)(inp) z = jnp.array([[1., 2., 3.]]) params = { 'linear': { 'w': jnp.ones((3, 1), dtype=jnp.float32), 'b': jnp.array([2.]) } } fun = functools.partial( hk.without_apply_rng(hk.transform(linear_model)).apply, params) input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) output_bounds = jax_verify.interval_bound_propagation( fun, input_bounds) self.assertAlmostEqual(5., output_bounds.lower) self.assertAlmostEqual(11., output_bounds.upper)
def test_tight_bounds_nostate(self, model): z = jnp.array([[1., 2., 3.]]) params = model.init(jax.random.PRNGKey(1), z) tight_input_bounds = jax_verify.IntervalBound(z, z) fun_to_prop = functools.partial(model.apply, params) tight_output_bounds = jax_verify.interval_bound_propagation( fun_to_prop, tight_input_bounds) model_eval = model.apply(params, z) # Because the input lower bound is equal to the input upper bound, the value # of the output bounds should be the same and correspond to the value of the # forward pass. self.assertAlmostEqual(tight_output_bounds.lower.tolist(), tight_output_bounds.upper.tolist()) self.assertAlmostEqual(tight_output_bounds.lower.tolist(), model_eval.tolist())
def elide_adversarial_spec( params: ModelParams, data_spec: DataSpec, ) -> ModelParamsElided: """Elide params to have last layer merged with the adversarial objective. Args: params: parameters of the model under verification. data_spec: data specification. Returns: params_elided: elided parameters with the adversarial objective folded in the last layer (and bounds adapted accordingly). """ def elide_fn(w_fin, b_fin): label_onehot = jnp.eye(w_fin.shape[-1])[data_spec.true_label] target_onehot = jnp.eye(w_fin.shape[-1])[data_spec.target_label] obj_orig = target_onehot - label_onehot obj_bp = jnp.matmul(w_fin, obj_orig) const = jnp.expand_dims(jnp.vdot(obj_orig, b_fin), axis=-1) obj = jnp.reshape(obj_bp, (obj_bp.size, 1)) return obj, const last_params = params[-1] w_elided, b_elided = elide_fn(last_params.w, last_params.b) last_params_elided = verify_utils.FCParams(w_elided, b_elided) if last_params.has_bounds: w_bound_elided, b_bound_elided = jax_verify.interval_bound_propagation( elide_fn, last_params.w_bound, last_params.b_bound) last_params_elided = dataclasses.replace(last_params_elided, w_bound=w_bound_elided, b_bound=b_bound_elided) params_elided = params[:-1] + [last_params_elided] return params_elided
def test_jittable_input_bounds(self): model = sequential_model z = jnp.array([[1., 2., 3.]]) params = model.init(jax.random.PRNGKey(1), z) fun_to_prop = functools.partial(model.apply, params) non_jittable_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0) jittable_input_bounds = non_jittable_bounds.to_jittable() @jax.jit def bound_prop_fun(inp_bound): bounds = jax_verify.interval_bound_propagation( fun_to_prop, inp_bound) return bounds.lower, bounds.upper # check that we can jit the bound prop and pass in jittable bounds. out_lb, out_ub = bound_prop_fun(jittable_input_bounds) self.assertTrue(all(out_ub >= out_lb)) # Check that this gives the same result as without the jit bounds = jax_verify.interval_bound_propagation(fun_to_prop, non_jittable_bounds) chex.assert_trees_all_close(out_lb, bounds.lower) chex.assert_trees_all_close(out_ub, bounds.upper)
def _compute_bounds(lower, upper): input_bounds = jax_verify.IntervalBound(lower, upper) output_bounds = jax_verify.interval_bound_propagation( lambda x: x**exponent, input_bounds) return output_bounds.lower, output_bounds.upper
def bound_prop_fun(inp_bound): bounds = jax_verify.interval_bound_propagation( fun_to_prop, inp_bound) return bounds.lower, bounds.upper