def test_conv1d_fastlin(self):
        @hk.without_apply_rng
        @hk.transform
        def conv1d_model(inp):
            return hk.Conv1D(output_channels=1,
                             kernel_shape=2,
                             padding='VALID',
                             stride=1,
                             with_bias=True)(inp)

        z = jnp.array([3., 4.])
        z = jnp.reshape(z, [1, 2, 1])

        params = {
            'conv1_d': {
                'w': jnp.ones((2, 1, 1), dtype=jnp.float32),
                'b': jnp.array([2.])
            }
        }

        fun = functools.partial(conv1d_model.apply, params)
        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        output_bounds = jax_verify.ibpforwardfastlin_bound_propagation(
            fun, input_bounds)

        self.assertAlmostEqual(7., output_bounds.lower, delta=1e-5)
        self.assertAlmostEqual(11., output_bounds.upper, delta=1e-5)
    def test_relu_fastlin(self):
        def relu_model(inp):
            return jax.nn.relu(inp)

        z = jnp.array([[-2., 3.]])

        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        output_bounds = jax_verify.ibpforwardfastlin_bound_propagation(
            relu_model, input_bounds)

        self.assertArrayAlmostEqual(jnp.array([[0., 2.]]), output_bounds.lower)
        self.assertArrayAlmostEqual(jnp.array([[0., 4.]]), output_bounds.upper)
Пример #3
0
    def test_ibpfastlin(self, model_cls):
        @hk.transform_with_state
        def model_pred(inputs, is_training, test_local_stats=False):
            model = model_cls()
            return model(inputs, is_training, test_local_stats)

        inps = jnp.zeros((4, 28, 28, 1), dtype=jnp.float32)
        params, state = model_pred.init(jax.random.PRNGKey(42),
                                        inps,
                                        is_training=True)

        def logits_fun(inputs):
            return model_pred.apply(params,
                                    state,
                                    None,
                                    inputs,
                                    False,
                                    test_local_stats=False)[0]

        input_bounds = jax_verify.IntervalBound(inps - 1.0, inps + 1.0)
        jax_verify.ibpforwardfastlin_bound_propagation(logits_fun,
                                                       input_bounds)
Пример #4
0
def _compute_jv_bounds(
    input_bound: sdp_utils.IntBound,
    params: ModelParams,
    method: str,
) -> List[sdp_utils.IntBound]:
    """Compute bounds with jax_verify."""

    jv_input_bound = jax_verify.IntervalBound(input_bound.lb, input_bound.ub)

    # create a function that takes as arguments the input and all parameters
    # that have bounds (as specified in param_bounds) and returns all
    # activations
    all_act_fun = _make_all_act_fn(params)

    # use jax_verify to perform (bilinear) interval bound propagation
    jv_param_bounds = [(p.w_bound, p.b_bound) for p in params if p.has_bounds]

    if method == 'ibp':
        _, jv_bounds = jax_verify.interval_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'fastlin':
        _, jv_bounds = jax_verify.forward_fastlin_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'ibpfastlin':
        _, jv_bounds = jax_verify.ibpforwardfastlin_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'crown':
        _, jv_bounds = jax_verify.backward_crown_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'nonconvex':
        _, jv_bounds = jax_verify.nonconvex_constopt_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    else:
        raise ValueError('Unsupported method.')

    # re-format bounds with internal convention
    bounds = []
    for intermediate_bound in jv_bounds:
        bounds.append(
            sdp_utils.IntBound(lb_pre=intermediate_bound.lower,
                               ub_pre=intermediate_bound.upper,
                               lb=jnp.maximum(intermediate_bound.lower, 0),
                               ub=jnp.maximum(intermediate_bound.upper, 0)))

    return bounds
    def test_fc_fastlin(self):
        @hk.without_apply_rng
        @hk.transform
        def linear_model(inp):
            return hk.Linear(1)(inp)

        z = jnp.array([[1., 2., 3.]])
        params = {
            'linear': {
                'w': jnp.ones((3, 1), dtype=jnp.float32),
                'b': jnp.array([2.])
            }
        }
        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        fun = functools.partial(linear_model.apply, params)
        output_bounds = jax_verify.ibpforwardfastlin_bound_propagation(
            fun, input_bounds)

        self.assertAlmostEqual(5., output_bounds.lower)
        self.assertAlmostEqual(11., output_bounds.upper)