def test_fc_fastlin(self): @hk.without_apply_rng @hk.transform def linear_model(inp): return hk.Linear(1)(inp) z = jnp.array([[1., 2., 3.]]) params = {'linear': {'w': jnp.ones((3, 1), dtype=jnp.float32), 'b': jnp.array([2.])}} input_bounds = jax_verify.IntervalBound(z-1., z+1.) fun = functools.partial(linear_model.apply, params) output_bounds = jax_verify.fastlin_bound_propagation(fun, input_bounds) self.assertTrue(jnp.all(output_bounds.lower_lin.lin_coeffs == 1.)) self.assertTrue(jnp.all(output_bounds.lower_lin.offset == 2.)) self.assertTrue(jnp.all(output_bounds.upper_lin.lin_coeffs == 1.)) self.assertTrue(jnp.all(output_bounds.upper_lin.offset == 2.)) self.assertArrayAlmostEqual(jnp.array([[0., 1., 2.]]), output_bounds.reference.lower) self.assertArrayAlmostEqual(jnp.array([[2., 3., 4.]]), output_bounds.reference.upper) self.assertAlmostEqual(5., output_bounds.lower) self.assertAlmostEqual(11., output_bounds.upper)
def test_relu_fastlin(self): def relu_model(inp): return jax.nn.relu(inp) z = jnp.array([[-2., 3.]]) input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) output_bounds = jax_verify.fastlin_bound_propagation(relu_model, input_bounds) self.assertArrayAlmostEqual(jnp.array([[0., 2.]]), output_bounds.lower) self.assertArrayAlmostEqual(jnp.array([[0., 4.]]), output_bounds.upper)
def main(unused_args): # Load the parameters of an existing model. model_pred, params = load_model(FLAGS.model) logits_fn = functools.partial(model_pred, params) # Load some test samples with jax_verify.open_file('mnist/x_test_first100.npy', 'rb') as f: inputs = np.load(f) # Compute boundprop bounds eps = 0.1 lower_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] - eps, 0.0), 1.0) upper_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] + eps, 0.0), 1.0) init_bound = jax_verify.IntervalBound(lower_bound, upper_bound) if FLAGS.boundprop_method == 'fastlin': final_bound = jax_verify.fastlin_bound_propagation( logits_fn, init_bound) boundprop_transform = jax_verify.fastlin_transform elif FLAGS.boundprop_method == 'ibp': final_bound = jax_verify.interval_bound_propagation( logits_fn, init_bound) boundprop_transform = jax_verify.ibp_transform else: raise NotImplementedError('Only ibp/fastlin boundprop are' 'currently supported') dummy_output = model_pred(params, inputs) # Run LP solver objective = jnp.where( jnp.arange(dummy_output[0, ...].size) == 0, jnp.ones_like(dummy_output[0, ...]), jnp.zeros_like(dummy_output[0, ...])) objective_bias = 0. value, status = jax_verify.solve_planet_relaxation(logits_fn, init_bound, boundprop_transform, objective, objective_bias, index=0) logging.info('Relaxation LB is : %f, Status is %s', value, status) value, status = jax_verify.solve_planet_relaxation(logits_fn, init_bound, boundprop_transform, -objective, objective_bias, index=0) logging.info('Relaxation UB is : %f, Status is %s', -value, status) logging.info('Boundprop LB is : %f', final_bound.lower[0, 0]) logging.info('Boundprop UB is : %f', final_bound.upper[0, 0])
def test_fastlin(self, model_cls): @hk.transform_with_state def model_pred(inputs, is_training, test_local_stats=False): model = model_cls() return model(inputs, is_training, test_local_stats) inps = jnp.zeros((4, 28, 28, 1), dtype=jnp.float32) params, state = model_pred.init(jax.random.PRNGKey(42), inps, is_training=True) def logits_fun(inputs): return model_pred.apply(params, state, None, inputs, False, test_local_stats=False)[0] input_bounds = jax_verify.IntervalBound(inps - 1.0, inps + 1.0) jax_verify.fastlin_bound_propagation(logits_fun, input_bounds)
def test_conv1d_fastlin(self): @hk.without_apply_rng @hk.transform def conv1d_model(inp): return hk.Conv1D(output_channels=1, kernel_shape=2, padding='VALID', stride=1, with_bias=True)(inp) z = jnp.array([3., 4.]) z = jnp.reshape(z, [1, 2, 1]) params = {'conv1_d': {'w': jnp.ones((2, 1, 1), dtype=jnp.float32), 'b': jnp.array([2.])}} fun = functools.partial(conv1d_model.apply, params) input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) output_bounds = jax_verify.fastlin_bound_propagation(fun, input_bounds) self.assertAlmostEqual(7., output_bounds.lower) self.assertAlmostEqual(11., output_bounds.upper)