def make_params(prng_key, dropout_rate=0.0, std=None): prng_key_seq = hk.PRNGSequence(prng_key) w1 = jax.random.normal(next(prng_key_seq), [8, 4]) b1 = jax.random.normal(next(prng_key_seq), [4]) w2 = jax.random.normal(next(prng_key_seq), [4, 2]) b2 = jax.random.normal(next(prng_key_seq), [2]) if std is not None: w1_std = std * jnp.ones([8, 4]) b1_std = std * jnp.ones([4]) w1_bound = jax_verify.IntervalBound(w1 - 3 * w1_std, w1 + 3 * w1_std) b1_bound = jax_verify.IntervalBound(b1 - 3 * b1_std, b1 + 3 * b1_std) else: w1_std, b1_std, w1_bound, b1_bound = None, None, None, None params = [ verify_utils.FCParams( w=w1, b=b1, w_std=w1_std, b_std=b1_std, w_bound=w1_bound, b_bound=b1_bound, ), verify_utils.FCParams( w=w2, b=b2, dropout_rate=dropout_rate, ) ] return params
def _bounds_from_cnn_layer(self, index): layer_index, is_preact = self._cnn_layer_indices[index] if is_preact: return jax_verify.IntervalBound( self._cnn_bounds[layer_index].lb_pre, self._cnn_bounds[layer_index].ub_pre) else: return jax_verify.IntervalBound(self._cnn_bounds[layer_index].lb, self._cnn_bounds[layer_index].ub)
def test_sdp_problem_equivalent_to_sdp_verify(self): # Set up a verification problem for test purposes. verif_instance = test_utils.make_toy_verif_instance(label=2, target_label=1) # Set up a spec function that replicates the test problem. inputs = jnp.zeros((1, 5)) input_bounds = jax_verify.IntervalBound(jnp.zeros_like(inputs), jnp.ones_like(inputs)) boundprop_transform = ibp.bound_transform def spec_fn(x): x = utils.predict_mlp(verif_instance.params, x) x = jax.nn.relu(x) return jnp.sum(jnp.reshape(x, (-1, )) * verif_instance.obj) + verif_instance.const # Build an SDP verification instance using the code under test. sdp_relu_problem = problem_from_graph.SdpReluProblem( boundprop_transform, spec_fn, input_bounds) sdp_problem_vi = sdp_relu_problem.build_sdp_verification_instance() # Build an SDP verification instance using existing `sdp_verify` code. sdp_verify_vi = problem.make_sdp_verif_instance(verif_instance) self._assert_verif_instances_equal(sdp_problem_vi, sdp_verify_vi)
def _get_reciprocal_bound(l: jnp.array, u: jnp.array, logits_params: LayerParams, label: int) -> jnp.array: """Helped for computing bound on label softmax given interval bounds on pre logits.""" def fwd(x, w, b): wdiff = jnp.reshape(w[:, label], [-1, 1]) - w bdiff = b[label] - b return x @ wdiff + bdiff x_bound = jax_verify.IntervalBound( lower_bound=jnp.reshape(l, [l.shape[0], -1]), upper_bound=jnp.reshape(u, [u.shape[0], -1])) params_bounds = [] if logits_params.w_bound is None: fwd = functools.partial(fwd, w=logits_params.w) else: params_bounds.append(logits_params.w_bound) if logits_params.b_bound is None: fwd = functools.partial(fwd, b=logits_params.b) else: params_bounds.append(logits_params.b_bound) fwd_bound = jax_verify.interval_bound_propagation(fwd, x_bound, *params_bounds) return fwd_bound
def test_exp_crown(self): def exp_model(inp): return jnp.exp(inp) exp_inp_shape = (4, 7) lb, ub = test_utils.sample_bounds(jax.random.PRNGKey(0), exp_inp_shape, minval=-10., maxval=10.) input_bounds = jax_verify.IntervalBound(lb, ub) output_bounds = jax_verify.backward_crown_bound_propagation( exp_model, input_bounds) uniform_inps = test_utils.sample_bounded_points( jax.random.PRNGKey(1), (lb, ub), 100) uniform_outs = jax.vmap(exp_model)(uniform_inps) empirical_min = uniform_outs.min(axis=0) empirical_max = uniform_outs.max(axis=0) self.assertGreaterEqual((output_bounds.upper - empirical_max).min(), 0., 'Invalid upper bound for Exponential. The gap ' 'between upper bound and empirical max is < 0') self.assertGreaterEqual( (empirical_min - output_bounds.lower).min(), 0., 'Invalid lower bound for Exponential. The gap' 'between emp. min and lower bound is negative.')
def test_nonconvex(self, model_cls): @hk.transform_with_state def model_pred(inputs, is_training, test_local_stats=False): model = model_cls() return model(inputs, is_training, test_local_stats) inps = jnp.zeros((4, 28, 28, 1), dtype=jnp.float32) params, state = model_pred.init(jax.random.PRNGKey(42), inps, is_training=True) def logits_fun(inputs): return model_pred.apply(params, state, None, inputs, False, test_local_stats=False)[0] input_bounds = jax_verify.IntervalBound(inps - 1.0, inps + 1.0) # Test with IBP for intermediate bounds jax_verify.nonconvex_ibp_bound_propagation(logits_fun, input_bounds) # Test with nonconvex bound evaluation for intermediate bounds jax_verify.nonconvex_constopt_bound_propagation( logits_fun, input_bounds)
def test_matching_output_structure(self, model): def _check_matching_structures(output_tree, bound_tree): """Replace all bounds/arrays with True, then compare pytrees.""" output_struct = tree.traverse( lambda x: True if isinstance(x, jnp.ndarray) else None, output_tree) bound_struct = tree.traverse( lambda x: True if isinstance(x, bound_propagation.Bound) else None, bound_tree) tree.assert_same_structure(output_struct, bound_struct) z = jnp.array([[1., 2., 3.]]) params = model.init(jax.random.PRNGKey(1), z) input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0) model_output = model.apply(params, z) fun_to_prop = functools.partial(model.apply, params) for boundprop_method in [ jax_verify.interval_bound_propagation, jax_verify.forward_crown_bound_propagation, jax_verify.backward_crown_bound_propagation, jax_verify.forward_fastlin_bound_propagation, jax_verify.backward_fastlin_bound_propagation, jax_verify.ibpforwardfastlin_bound_propagation, ]: output_bounds = boundprop_method(fun_to_prop, input_bounds) _check_matching_structures(model_output, output_bounds)
def test_relu_random_fastlin(self): def relu_model(inp): return jax.nn.relu(inp) relu_inp_shape = (4, 7) lb, ub = test_utils.sample_bounds(jax.random.PRNGKey(0), relu_inp_shape, minval=-10., maxval=10.) input_bounds = jax_verify.IntervalBound(lb, ub) output_bounds = jax_verify.forward_fastlin_bound_propagation( relu_model, input_bounds) uniform_inps = test_utils.sample_bounded_points( jax.random.PRNGKey(1), (lb, ub), 100) uniform_outs = jax.vmap(relu_model)(uniform_inps) empirical_min = uniform_outs.min(axis=0) empirical_max = uniform_outs.max(axis=0) self.assertGreaterEqual((output_bounds.upper - empirical_max).min(), 0., 'Invalid upper bound for ReLU. The gap ' 'between upper bound and empirical max is < 0') self.assertGreaterEqual( (empirical_min - output_bounds.lower).min(), 0., 'Invalid lower bound for ReLU. The gap' 'between emp. min and lower bound is negative.')
def test_conv1d_cvxpy_relaxation(self): def conv1d_model(inp): return hk.Conv1D(output_channels=1, kernel_shape=2, padding='VALID', stride=1, with_bias=True)(inp) z = jnp.array([3., 4.]) z = jnp.reshape(z, [1, 2, 1]) params = { 'conv1_d': { 'w': jnp.ones((2, 1, 1), dtype=jnp.float32), 'b': jnp.array([2.]) } } fun = functools.partial( hk.without_apply_rng(hk.transform(conv1d_model)).apply, params) input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) lower_bounds, upper_bounds = self.get_bounds(fun, input_bounds) self.assertAlmostEqual(7., lower_bounds, delta=1e-5) self.assertAlmostEqual(11., upper_bounds, delta=1e-5)
def test_conv2d_ibp(self): def conv2d_model(inp): return hk.Conv2D(output_channels=1, kernel_shape=(2, 2), padding='VALID', stride=1, with_bias=True)(inp) z = jnp.array([1., 2., 3., 4.]) z = jnp.reshape(z, [1, 2, 2, 1]) params = { 'conv2_d': { 'w': jnp.ones((2, 2, 1, 1), dtype=jnp.float32), 'b': jnp.array([2.]) } } fun = functools.partial( hk.without_apply_rng(hk.transform(conv2d_model)).apply, params) input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) output_bounds = jax_verify.interval_bound_propagation( fun, input_bounds) self.assertAlmostEqual(8., output_bounds.lower) self.assertAlmostEqual(16., output_bounds.upper)
def test_fc_fastlin(self): @hk.without_apply_rng @hk.transform def linear_model(inp): return hk.Linear(1)(inp) z = jnp.array([[1., 2., 3.]]) params = {'linear': {'w': jnp.ones((3, 1), dtype=jnp.float32), 'b': jnp.array([2.])}} input_bounds = jax_verify.IntervalBound(z-1., z+1.) fun = functools.partial(linear_model.apply, params) output_bounds = jax_verify.fastlin_bound_propagation(fun, input_bounds) self.assertTrue(jnp.all(output_bounds.lower_lin.lin_coeffs == 1.)) self.assertTrue(jnp.all(output_bounds.lower_lin.offset == 2.)) self.assertTrue(jnp.all(output_bounds.upper_lin.lin_coeffs == 1.)) self.assertTrue(jnp.all(output_bounds.upper_lin.offset == 2.)) self.assertArrayAlmostEqual(jnp.array([[0., 1., 2.]]), output_bounds.reference.lower) self.assertArrayAlmostEqual(jnp.array([[2., 3., 4.]]), output_bounds.reference.upper) self.assertAlmostEqual(5., output_bounds.lower) self.assertAlmostEqual(11., output_bounds.upper)
def test_fc_fastlin(self, name, elision): @hk.without_apply_rng @hk.transform def linear_model(inp): return hk.Linear(1)(inp) z = jnp.array([[1., 2., 3.]]) params = { 'linear': { 'w': jnp.ones((3, 1), dtype=jnp.float32), 'b': jnp.array([2.]) } } input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) fun = functools.partial(linear_model.apply, params) bound_prop = get_boundprop(name, elision) output_bounds = bound_prop(fun, input_bounds) all_linear_functions = list(output_bounds.linear_functions()) self.assertLen(all_linear_functions, 1) linear_fun = all_linear_functions[0] self.assertTrue(jnp.all(linear_fun.lower_lin.lin_coeffs == 1.)) self.assertTrue(jnp.all(linear_fun.lower_lin.offset == 2.)) self.assertTrue(jnp.all(linear_fun.upper_lin.lin_coeffs == 1.)) self.assertTrue(jnp.all(linear_fun.upper_lin.offset == 2.)) self.assertArrayAlmostEqual(jnp.array([[0., 1., 2.]]), linear_fun.reference_bound.bound.lower) self.assertArrayAlmostEqual(jnp.array([[2., 3., 4.]]), linear_fun.reference_bound.bound.upper) self.assertAlmostEqual(5., output_bounds.lower) self.assertAlmostEqual(11., output_bounds.upper)
def test_conv1d_fastlin(self): @hk.without_apply_rng @hk.transform def conv1d_model(inp): return hk.Conv1D(output_channels=1, kernel_shape=2, padding='VALID', stride=1, with_bias=True)(inp) z = jnp.array([3., 4.]) z = jnp.reshape(z, [1, 2, 1]) params = { 'conv1_d': { 'w': jnp.ones((2, 1, 1), dtype=jnp.float32), 'b': jnp.array([2.]) } } fun = functools.partial(conv1d_model.apply, params) input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) output_bounds = jax_verify.ibpforwardfastlin_bound_propagation( fun, input_bounds) self.assertAlmostEqual(7., output_bounds.lower, delta=1e-5) self.assertAlmostEqual(11., output_bounds.upper, delta=1e-5)
def _crown_ibp_boundprop(params, x, epsilon, input_bounds): """Runs CROWN-IBP for each layer separately.""" def get_layer_act(layer_idx, inputs): act = utils.predict_cnn(params[:layer_idx], inputs) return act initial_bound = jax_verify.IntervalBound( jnp.maximum(x - epsilon, input_bounds[0]), jnp.minimum(x + epsilon, input_bounds[1])) out_bounds = [ IntBound(lb_pre=None, ub_pre=None, lb=initial_bound.lower, ub=initial_bound.upper) ] for i in range(1, len(params) + 1): fwd = functools.partial(get_layer_act, i) bound = jax_verify.crownibp_bound_propagation(fwd, initial_bound) out_bounds.append( IntBound(lb_pre=bound.lower, ub_pre=bound.upper, lb=jnp.maximum(0, bound.lower), ub=jnp.maximum(0, bound.upper))) return out_bounds
def solve_with_jax_verify(self): lower_bound = jnp.minimum(jnp.maximum(self.inputs - self.eps, 0.0), 1.0) upper_bound = jnp.minimum(jnp.maximum(self.inputs + self.eps, 0.0), 1.0) init_bound = jax_verify.IntervalBound(lower_bound, upper_bound) logits_fn = make_model_fn(self.network_params) solver = cvxpy_relaxation_solver.CvxpySolver relaxation_transform = relaxation.RelaxationTransform( jax_verify.ibp_transform) var, env = bound_propagation.bound_propagation( bound_propagation.ForwardPropagationAlgorithm( relaxation_transform), logits_fn, init_bound) # This solver minimizes the objective -> get max with -min(-objective) neg_value_opt, _, _ = relaxation.solve_relaxation( solver, -self.objective, -self.objective_bias, var, env, index=0, time_limit_millis=None) value_opt = -neg_value_opt return value_opt
def test_chunking(self, relaxer): batch_size = 3 input_size = 2 hidden_size = 5 final_size = 4 input_shape = (batch_size, input_size) hidden_lay_weight_shape = (input_size, hidden_size) final_lay_weight_shape = (hidden_size, final_size) inp_lb, inp_ub = test_utils.sample_bounds(jax.random.PRNGKey(0), input_shape, minval=-1., maxval=1.) inp_bound = jax_verify.IntervalBound(inp_lb, inp_ub) hidden_lay_weight = jax.random.uniform(jax.random.PRNGKey(1), hidden_lay_weight_shape) final_lay_weight = jax.random.uniform(jax.random.PRNGKey(2), final_lay_weight_shape) def model_fun(inp): hidden = inp @ hidden_lay_weight act = jax.nn.relu(hidden) final = act @ final_lay_weight return final if isinstance(relaxer, linear_bound_utils.ParameterizedLinearBoundsRelaxer): concretizing_transform = ( backward_crown.OptimizingLinearBoundBackwardTransform( relaxer, backward_crown.CONCRETIZE_ARGS_PRIMITIVE, optax.adam(1.e-3), num_opt_steps=10)) else: concretizing_transform = backward_crown.LinearBoundBackwardTransform( relaxer, backward_crown.CONCRETIZE_ARGS_PRIMITIVE) chunked_concretizer = backward_crown.ChunkedBackwardConcretizer( concretizing_transform, max_chunk_size=16) unchunked_concretizer = backward_crown.ChunkedBackwardConcretizer( concretizing_transform, max_chunk_size=0) chunked_algorithm = bound_utils.BackwardConcretizingAlgorithm( chunked_concretizer) unchunked_algorithm = bound_utils.BackwardConcretizingAlgorithm( unchunked_concretizer) chunked_bound, _ = bound_propagation.bound_propagation( chunked_algorithm, model_fun, inp_bound) unchunked_bound, _ = bound_propagation.bound_propagation( unchunked_algorithm, model_fun, inp_bound) np.testing.assert_array_almost_equal(chunked_bound.lower, unchunked_bound.lower) np.testing.assert_array_almost_equal(chunked_bound.upper, unchunked_bound.upper)
def test_model_structure_nostate(self, model): z = jnp.array([[1., 2., 3.]]) params = model.init(jax.random.PRNGKey(1), z) input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0) fun_to_prop = functools.partial(model.apply, params) output_bounds = jax_verify.interval_bound_propagation( fun_to_prop, input_bounds) self.assertTrue(all(output_bounds.upper >= output_bounds.lower))
def test_sqrt(self, input_bounds, expected): input_bounds = jax_verify.IntervalBound( np.array([input_bounds[0], 0.0]), np.array([input_bounds[1], 0.0])) output_bounds = jax_verify.interval_bound_propagation( jnp.sqrt, input_bounds) np.testing.assert_array_equal(np.array([expected[0], 0.0]), output_bounds.lower) np.testing.assert_array_equal(np.array([expected[1], 0.0]), output_bounds.upper)
def test_passthrough_primitive(self, fn, inputs): z = jnp.array(inputs) input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) output_bounds = jax_verify.interval_bound_propagation(fn, input_bounds) self.assertArrayAlmostEqual(fn(input_bounds.lower), output_bounds.lower) self.assertArrayAlmostEqual(fn(input_bounds.upper), output_bounds.upper)
def test_multiinput_concatenate_fastlin(self, name, elision): def concatenate_and_sum_model(inp_1, inp_2): interm = jnp.concatenate((inp_1, inp_2), axis=1) return interm.sum(axis=1) z_1 = jnp.array([[-1., 1.]]) z_2 = jnp.array([[-1., 1.]]) bound_1 = jax_verify.IntervalBound(z_1 - 1., z_1 + 1.) bound_2 = jax_verify.IntervalBound(z_2 - 1., z_2 + 1.) out_lower = (z_1 + z_2 - 2.).sum() out_upper = (z_1 - z_2 + 2.).sum() bound_prop = get_boundprop(name, elision) output_bounds = bound_prop(concatenate_and_sum_model, bound_1, bound_2) self.assertArrayAlmostEqual(out_lower, output_bounds.lower) self.assertArrayAlmostEqual(out_upper, output_bounds.upper)
def primitive_transform(self, context: TransformContext, primitive: jax.core.Primitive, *args, **kwargs) -> bound_propagation.Bound: if (context.index not in self._fixed_bounds and isinstance(primitive, synthetic_primitives.FakePrimitive)): # Bound is missing at the synthetic primitive level. # Try and infer the bound from its sub-graph. subgraph = kwargs['jax_verify_subgraph'] return context.subgraph_handler(self, subgraph, *args) return jax_verify.IntervalBound(*self._fixed_bounds[context.index])
def test_relu_crownibp(self): def relu_model(inp): return jax.nn.relu(inp) z = jnp.array([[-2., 3.]]) input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) output_bounds = jax_verify.crownibp_bound_propagation( relu_model, input_bounds) self.assertArrayAlmostEqual(jnp.array([[0., 2.]]), output_bounds.lower) self.assertArrayAlmostEqual(jnp.array([[0., 4.]]), output_bounds.upper)
def test_relu_cvxpy_relaxation(self): def relu_model(inp): return jax.nn.relu(inp) z = jnp.array([[-2., 3.]]) input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) lower_bounds, upper_bounds = self.get_bounds(relu_model, input_bounds) self.assertArrayAlmostEqual(jnp.array([[0., 2.]]), lower_bounds) self.assertArrayAlmostEqual(jnp.array([[0., 4.]]), upper_bounds)
def test_multioutput_model(self): z = jnp.array([[1., 2., 3.]]) fun = hk.without_apply_rng(hk.transform(residual_model_all_act)) params = fun.init(jax.random.PRNGKey(1), z) input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0) fun_to_prop = functools.partial(fun.apply, params) output_bounds = jax_verify.interval_bound_propagation( fun_to_prop, input_bounds) self.assertLen(output_bounds, 7)
def main(unused_args): # Load the parameters of an existing model. model_pred, params = load_model(FLAGS.model) logits_fn = functools.partial(model_pred, params) # Load some test samples with jax_verify.open_file('mnist/x_test_first100.npy', 'rb') as f: inputs = np.load(f) # Compute boundprop bounds eps = 0.1 lower_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] - eps, 0.0), 1.0) upper_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] + eps, 0.0), 1.0) init_bound = jax_verify.IntervalBound(lower_bound, upper_bound) if FLAGS.boundprop_method == 'fastlin': final_bound = jax_verify.fastlin_bound_propagation( logits_fn, init_bound) boundprop_transform = jax_verify.fastlin_transform elif FLAGS.boundprop_method == 'ibp': final_bound = jax_verify.interval_bound_propagation( logits_fn, init_bound) boundprop_transform = jax_verify.ibp_transform else: raise NotImplementedError('Only ibp/fastlin boundprop are' 'currently supported') dummy_output = model_pred(params, inputs) # Run LP solver objective = jnp.where( jnp.arange(dummy_output[0, ...].size) == 0, jnp.ones_like(dummy_output[0, ...]), jnp.zeros_like(dummy_output[0, ...])) objective_bias = 0. value, status = jax_verify.solve_planet_relaxation(logits_fn, init_bound, boundprop_transform, objective, objective_bias, index=0) logging.info('Relaxation LB is : %f, Status is %s', value, status) value, status = jax_verify.solve_planet_relaxation(logits_fn, init_bound, boundprop_transform, -objective, objective_bias, index=0) logging.info('Relaxation UB is : %f, Status is %s', -value, status) logging.info('Boundprop LB is : %f', final_bound.lower[0, 0]) logging.info('Boundprop UB is : %f', final_bound.upper[0, 0])
def test_nobatch_batch_inputs(self): batch_shape = (3, 2) unbatch_shape = (2, 4) def bilinear_model(inp_1, inp_2): return jnp.einsum('bh,hH->bH', inp_1, inp_2) lb_1, ub_1 = test_utils.sample_bounds(jax.random.PRNGKey(0), batch_shape, minval=-10, maxval=10.) lb_2, ub_2 = test_utils.sample_bounds(jax.random.PRNGKey(1), unbatch_shape, minval=-10, maxval=10.) bound_1 = jax_verify.IntervalBound(lb_1, ub_1) bound_2 = jax_verify.IntervalBound(lb_2, ub_2) output_bounds = backward_crown.backward_crown_bound_propagation( bilinear_model, bound_1, bound_2) uniform_1 = test_utils.sample_bounded_points(jax.random.PRNGKey(2), (lb_1, ub_1), 100) uniform_2 = test_utils.sample_bounded_points(jax.random.PRNGKey(3), (lb_2, ub_2), 100) uniform_outs = jax.vmap(bilinear_model)(uniform_1, uniform_2) empirical_min = uniform_outs.min(axis=0) empirical_max = uniform_outs.max(axis=0) self.assertGreaterEqual( (output_bounds.upper - empirical_max).min(), 0., 'Invalid upper bound for mix of batched/unbatched' 'input bounds.') self.assertGreaterEqual( (empirical_min - output_bounds.lower).min(), 0., 'Invalid lower bound for mix of batched/unbatched' 'input bounds.')
def test_softplus_ibp(self): def softplus_model(inp): return jax.nn.softplus(inp) z = jnp.array([[-2., 3.]]) input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) output_bounds = jax_verify.interval_bound_propagation( softplus_model, input_bounds) self.assertArrayAlmostEqual(jnp.logaddexp(z - 1., 0), output_bounds.lower) self.assertArrayAlmostEqual(jnp.logaddexp(z + 1., 0), output_bounds.upper)
def test_keywords_argument(self): @hk.without_apply_rng @hk.transform def forward(inputs, use_2=False): model = StaticArgumentModel() return model(inputs, use_2) z = jnp.array([[1., 2., 3.]]) params = forward.init(jax.random.PRNGKey(1), z, use_2=True) input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0) fun_to_prop = functools.partial(forward.apply, params, use_2=True) output_bounds = jax_verify.interval_bound_propagation( fun_to_prop, input_bounds) self.assertTrue((output_bounds.upper >= output_bounds.lower).all())
def test_multiply_crown(self): def multiply_model(lhs, rhs): return lhs * rhs mul_inp_shape = (4, 7) lhs_lb, lhs_ub = test_utils.sample_bounds(jax.random.PRNGKey(0), mul_inp_shape, minval=-10., maxval=10.) rhs_lb, rhs_ub = test_utils.sample_bounds(jax.random.PRNGKey(1), mul_inp_shape, minval=-10., maxval=10.) lhs_bounds = jax_verify.IntervalBound(lhs_lb, lhs_ub) rhs_bounds = jax_verify.IntervalBound(rhs_lb, rhs_ub) output_bounds = jax_verify.backward_crown_bound_propagation( multiply_model, lhs_bounds, rhs_bounds) uniform_lhs_inps = test_utils.sample_bounded_points( jax.random.PRNGKey(2), (lhs_lb, lhs_ub), 100) uniform_rhs_inps = test_utils.sample_bounded_points( jax.random.PRNGKey(3), (rhs_lb, rhs_ub), 100) uniform_outs = jax.vmap(multiply_model)(uniform_lhs_inps, uniform_rhs_inps) empirical_min = uniform_outs.min(axis=0) empirical_max = uniform_outs.max(axis=0) self.assertGreaterEqual( (output_bounds.upper - empirical_max).min(), 0., 'Invalid upper bound for Multiply. The gap ' 'between upper bound and empirical max is negative') self.assertGreaterEqual( (empirical_min - output_bounds.lower).min(), 0., 'Invalid lower bound for Multiply. The gap' 'between emp. min and lower bound is negative.')
def test_staticargument_last(self): @hk.without_apply_rng @hk.transform def forward(inputs, use_2): model = StaticArgumentModel() return model(inputs, use_2) z = jnp.array([[1., 2., 3.]]) params = forward.init(jax.random.PRNGKey(1), z, True) input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0) def fun_to_prop(inputs): return forward.apply(params, inputs, True) output_bounds = jax_verify.interval_bound_propagation( fun_to_prop, input_bounds) self.assertTrue((output_bounds.upper >= output_bounds.lower).all())