コード例 #1
0
    def test_conv2d_crown(self):
        @hk.without_apply_rng
        @hk.transform
        def conv2d_model(inp):
            return hk.Conv2D(output_channels=1,
                             kernel_shape=(2, 2),
                             padding='VALID',
                             stride=1,
                             with_bias=True)(inp)

        z = jnp.array([1., 2., 3., 4.])
        z = jnp.reshape(z, [1, 2, 2, 1])

        params = {
            'conv2_d': {
                'w': jnp.ones((2, 2, 1, 1), dtype=jnp.float32),
                'b': jnp.array([2.])
            }
        }

        fun = functools.partial(conv2d_model.apply, params)

        # Test with standard interval bounds
        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        output_bounds = jax_verify.backward_crown_bound_propagation(
            fun, input_bounds)

        self.assertArrayAlmostEqual(8., output_bounds.lower)
        self.assertArrayAlmostEqual(16., output_bounds.upper)
コード例 #2
0
    def test_exp_crown(self):
        def exp_model(inp):
            return jnp.exp(inp)

        exp_inp_shape = (4, 7)
        lb, ub = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                          exp_inp_shape,
                                          minval=-10.,
                                          maxval=10.)

        input_bounds = jax_verify.IntervalBound(lb, ub)
        output_bounds = jax_verify.backward_crown_bound_propagation(
            exp_model, input_bounds)

        uniform_inps = test_utils.sample_bounded_points(
            jax.random.PRNGKey(1), (lb, ub), 100)
        uniform_outs = jax.vmap(exp_model)(uniform_inps)
        empirical_min = uniform_outs.min(axis=0)
        empirical_max = uniform_outs.max(axis=0)

        self.assertGreaterEqual((output_bounds.upper - empirical_max).min(),
                                0.,
                                'Invalid upper bound for Exponential. The gap '
                                'between upper bound and empirical max is < 0')
        self.assertGreaterEqual(
            (empirical_min - output_bounds.lower).min(), 0.,
            'Invalid lower bound for Exponential. The gap'
            'between emp. min and lower bound is negative.')
コード例 #3
0
ファイル: model_zoo_test.py プロジェクト: deepmind/jax_verify
    def test_backward_crown(self, model_cls):
        @hk.transform_with_state
        def model_pred(inputs, is_training, test_local_stats=False):
            model = model_cls()
            return model(inputs, is_training, test_local_stats)

        inps = jnp.zeros((4, 28, 28, 1), dtype=jnp.float32)
        params, state = model_pred.init(jax.random.PRNGKey(42),
                                        inps,
                                        is_training=True)

        def logits_fun(inputs):
            return model_pred.apply(params,
                                    state,
                                    None,
                                    inputs,
                                    False,
                                    test_local_stats=False)[0]

        input_bounds = jax_verify.IntervalBound(inps - 1.0, inps + 1.0)
        jax_verify.backward_crown_bound_propagation(logits_fun, input_bounds)
コード例 #4
0
    def test_relu_crown(self):
        def relu_model(inp):
            return jax.nn.relu(inp)

        z = jnp.array([[-2., 3.]])

        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        output_bounds = jax_verify.backward_crown_bound_propagation(
            relu_model, input_bounds)

        self.assertArrayAlmostEqual(jnp.array([[0., 2.]]), output_bounds.lower)
        self.assertArrayAlmostEqual(jnp.array([[0., 4.]]), output_bounds.upper)
コード例 #5
0
ファイル: bounding.py プロジェクト: deepmind/jax_verify
def _compute_jv_bounds(
    input_bound: sdp_utils.IntBound,
    params: ModelParams,
    method: str,
) -> List[sdp_utils.IntBound]:
    """Compute bounds with jax_verify."""

    jv_input_bound = jax_verify.IntervalBound(input_bound.lb, input_bound.ub)

    # create a function that takes as arguments the input and all parameters
    # that have bounds (as specified in param_bounds) and returns all
    # activations
    all_act_fun = _make_all_act_fn(params)

    # use jax_verify to perform (bilinear) interval bound propagation
    jv_param_bounds = [(p.w_bound, p.b_bound) for p in params if p.has_bounds]

    if method == 'ibp':
        _, jv_bounds = jax_verify.interval_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'fastlin':
        _, jv_bounds = jax_verify.forward_fastlin_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'ibpfastlin':
        _, jv_bounds = jax_verify.ibpforwardfastlin_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'crown':
        _, jv_bounds = jax_verify.backward_crown_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'nonconvex':
        _, jv_bounds = jax_verify.nonconvex_constopt_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    else:
        raise ValueError('Unsupported method.')

    # re-format bounds with internal convention
    bounds = []
    for intermediate_bound in jv_bounds:
        bounds.append(
            sdp_utils.IntBound(lb_pre=intermediate_bound.lower,
                               ub_pre=intermediate_bound.upper,
                               lb=jnp.maximum(intermediate_bound.lower, 0),
                               ub=jnp.maximum(intermediate_bound.upper, 0)))

    return bounds
コード例 #6
0
    def test_fc_crown(self):
        @hk.without_apply_rng
        @hk.transform
        def linear_model(inp):
            return hk.Linear(1)(inp)

        z = jnp.array([[1., 2., 3.]])
        params = {
            'linear': {
                'w': jnp.ones((3, 1), dtype=jnp.float32),
                'b': jnp.array([2.])
            }
        }
        fun = functools.partial(linear_model.apply, params)

        # Test with standard interval bounds.
        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        output_bounds = jax_verify.backward_crown_bound_propagation(
            fun, input_bounds)

        self.assertArrayAlmostEqual(5., output_bounds.lower)
        self.assertArrayAlmostEqual(11., output_bounds.upper)
コード例 #7
0
    def test_multiply_crown(self):
        def multiply_model(lhs, rhs):
            return lhs * rhs

        mul_inp_shape = (4, 7)
        lhs_lb, lhs_ub = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                                  mul_inp_shape,
                                                  minval=-10.,
                                                  maxval=10.)
        rhs_lb, rhs_ub = test_utils.sample_bounds(jax.random.PRNGKey(1),
                                                  mul_inp_shape,
                                                  minval=-10.,
                                                  maxval=10.)

        lhs_bounds = jax_verify.IntervalBound(lhs_lb, lhs_ub)
        rhs_bounds = jax_verify.IntervalBound(rhs_lb, rhs_ub)
        output_bounds = jax_verify.backward_crown_bound_propagation(
            multiply_model, lhs_bounds, rhs_bounds)

        uniform_lhs_inps = test_utils.sample_bounded_points(
            jax.random.PRNGKey(2), (lhs_lb, lhs_ub), 100)
        uniform_rhs_inps = test_utils.sample_bounded_points(
            jax.random.PRNGKey(3), (rhs_lb, rhs_ub), 100)

        uniform_outs = jax.vmap(multiply_model)(uniform_lhs_inps,
                                                uniform_rhs_inps)
        empirical_min = uniform_outs.min(axis=0)
        empirical_max = uniform_outs.max(axis=0)

        self.assertGreaterEqual(
            (output_bounds.upper - empirical_max).min(), 0.,
            'Invalid upper bound for Multiply. The gap '
            'between upper bound and empirical max is negative')
        self.assertGreaterEqual(
            (empirical_min - output_bounds.lower).min(), 0.,
            'Invalid lower bound for Multiply. The gap'
            'between emp. min and lower bound is negative.')