Exemple #1
0
    def test_conv2d_ibp(self):
        def conv2d_model(inp):
            return hk.Conv2D(output_channels=1,
                             kernel_shape=(2, 2),
                             padding='VALID',
                             stride=1,
                             with_bias=True)(inp)

        z = jnp.array([1., 2., 3., 4.])
        z = jnp.reshape(z, [1, 2, 2, 1])

        params = {
            'conv2_d': {
                'w': jnp.ones((2, 2, 1, 1), dtype=jnp.float32),
                'b': jnp.array([2.])
            }
        }

        fun = functools.partial(
            hk.without_apply_rng(hk.transform(conv2d_model)).apply, params)
        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        output_bounds = jax_verify.interval_bound_propagation(
            fun, input_bounds)

        self.assertAlmostEqual(8., output_bounds.lower)
        self.assertAlmostEqual(16., output_bounds.upper)
Exemple #2
0
def _get_reciprocal_bound(l: jnp.array, u: jnp.array,
                          logits_params: LayerParams, label: int) -> jnp.array:
    """Helped for computing bound on label softmax given interval bounds on pre logits."""
    def fwd(x, w, b):
        wdiff = jnp.reshape(w[:, label], [-1, 1]) - w
        bdiff = b[label] - b
        return x @ wdiff + bdiff

    x_bound = jax_verify.IntervalBound(
        lower_bound=jnp.reshape(l, [l.shape[0], -1]),
        upper_bound=jnp.reshape(u, [u.shape[0], -1]))

    params_bounds = []
    if logits_params.w_bound is None:
        fwd = functools.partial(fwd, w=logits_params.w)
    else:
        params_bounds.append(logits_params.w_bound)

    if logits_params.b_bound is None:
        fwd = functools.partial(fwd, b=logits_params.b)
    else:
        params_bounds.append(logits_params.b_bound)

    fwd_bound = jax_verify.interval_bound_propagation(fwd, x_bound,
                                                      *params_bounds)

    return fwd_bound
Exemple #3
0
 def test_sqrt(self, input_bounds, expected):
     input_bounds = jax_verify.IntervalBound(
         np.array([input_bounds[0], 0.0]), np.array([input_bounds[1], 0.0]))
     output_bounds = jax_verify.interval_bound_propagation(
         jnp.sqrt, input_bounds)
     np.testing.assert_array_equal(np.array([expected[0], 0.0]),
                                   output_bounds.lower)
     np.testing.assert_array_equal(np.array([expected[1], 0.0]),
                                   output_bounds.upper)
Exemple #4
0
    def test_model_structure_nostate(self, model):
        z = jnp.array([[1., 2., 3.]])

        params = model.init(jax.random.PRNGKey(1), z)
        input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0)
        fun_to_prop = functools.partial(model.apply, params)
        output_bounds = jax_verify.interval_bound_propagation(
            fun_to_prop, input_bounds)
        self.assertTrue(all(output_bounds.upper >= output_bounds.lower))
Exemple #5
0
    def test_passthrough_primitive(self, fn, inputs):
        z = jnp.array(inputs)
        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        output_bounds = jax_verify.interval_bound_propagation(fn, input_bounds)

        self.assertArrayAlmostEqual(fn(input_bounds.lower),
                                    output_bounds.lower)
        self.assertArrayAlmostEqual(fn(input_bounds.upper),
                                    output_bounds.upper)
Exemple #6
0
    def test_multioutput_model(self):
        z = jnp.array([[1., 2., 3.]])

        fun = hk.without_apply_rng(hk.transform(residual_model_all_act))
        params = fun.init(jax.random.PRNGKey(1), z)
        input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0)
        fun_to_prop = functools.partial(fun.apply, params)
        output_bounds = jax_verify.interval_bound_propagation(
            fun_to_prop, input_bounds)

        self.assertLen(output_bounds, 7)
Exemple #7
0
def main(unused_args):

    # Load the parameters of an existing model.
    model_pred, params = load_model(FLAGS.model)
    logits_fn = functools.partial(model_pred, params)

    # Load some test samples
    with jax_verify.open_file('mnist/x_test_first100.npy', 'rb') as f:
        inputs = np.load(f)

    # Compute boundprop bounds
    eps = 0.1
    lower_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] - eps, 0.0), 1.0)
    upper_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] + eps, 0.0), 1.0)
    init_bound = jax_verify.IntervalBound(lower_bound, upper_bound)

    if FLAGS.boundprop_method == 'fastlin':
        final_bound = jax_verify.fastlin_bound_propagation(
            logits_fn, init_bound)
        boundprop_transform = jax_verify.fastlin_transform
    elif FLAGS.boundprop_method == 'ibp':
        final_bound = jax_verify.interval_bound_propagation(
            logits_fn, init_bound)
        boundprop_transform = jax_verify.ibp_transform
    else:
        raise NotImplementedError('Only ibp/fastlin boundprop are'
                                  'currently supported')

    dummy_output = model_pred(params, inputs)

    # Run LP solver
    objective = jnp.where(
        jnp.arange(dummy_output[0, ...].size) == 0,
        jnp.ones_like(dummy_output[0, ...]),
        jnp.zeros_like(dummy_output[0, ...]))
    objective_bias = 0.
    value, status = jax_verify.solve_planet_relaxation(logits_fn,
                                                       init_bound,
                                                       boundprop_transform,
                                                       objective,
                                                       objective_bias,
                                                       index=0)
    logging.info('Relaxation LB is : %f, Status is %s', value, status)
    value, status = jax_verify.solve_planet_relaxation(logits_fn,
                                                       init_bound,
                                                       boundprop_transform,
                                                       -objective,
                                                       objective_bias,
                                                       index=0)
    logging.info('Relaxation UB is : %f, Status is %s', -value, status)

    logging.info('Boundprop LB is : %f', final_bound.lower[0, 0])
    logging.info('Boundprop UB is : %f', final_bound.upper[0, 0])
    def test_ibp(self, model_cls):
        @hk.transform_with_state
        def model_pred(inputs, is_training, test_local_stats=False):
            model = model_cls()
            return model(inputs, is_training, test_local_stats)

        inps = jnp.zeros((4, 28, 28, 1), dtype=jnp.float32)
        params, state = model_pred.init(jax.random.PRNGKey(42),
                                        inps,
                                        is_training=True)

        def logits_fun(inputs):
            return model_pred.apply(params,
                                    state,
                                    None,
                                    inputs,
                                    False,
                                    test_local_stats=False)[0]

        input_bounds = jax_verify.IntervalBound(inps - 1.0, inps + 1.0)
        jax_verify.interval_bound_propagation(logits_fun, input_bounds)
Exemple #9
0
    def test_relu_ibp(self):
        def relu_model(inp):
            return jax.nn.relu(inp)

        z = jnp.array([[-2., 3.]])

        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        output_bounds = jax_verify.interval_bound_propagation(
            relu_model, input_bounds)

        self.assertArrayAlmostEqual(jnp.array([[0., 2.]]), output_bounds.lower)
        self.assertArrayAlmostEqual(jnp.array([[0., 4.]]), output_bounds.upper)
Exemple #10
0
    def test_keywords_argument(self):
        @hk.without_apply_rng
        @hk.transform
        def forward(inputs, use_2=False):
            model = StaticArgumentModel()
            return model(inputs, use_2)

        z = jnp.array([[1., 2., 3.]])
        params = forward.init(jax.random.PRNGKey(1), z, use_2=True)
        input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0)
        fun_to_prop = functools.partial(forward.apply, params, use_2=True)
        output_bounds = jax_verify.interval_bound_propagation(
            fun_to_prop, input_bounds)
        self.assertTrue((output_bounds.upper >= output_bounds.lower).all())
Exemple #11
0
    def test_softplus_ibp(self):
        def softplus_model(inp):
            return jax.nn.softplus(inp)

        z = jnp.array([[-2., 3.]])

        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        output_bounds = jax_verify.interval_bound_propagation(
            softplus_model, input_bounds)

        self.assertArrayAlmostEqual(jnp.logaddexp(z - 1., 0),
                                    output_bounds.lower)
        self.assertArrayAlmostEqual(jnp.logaddexp(z + 1., 0),
                                    output_bounds.upper)
Exemple #12
0
    def test_staticargument_last(self):
        @hk.without_apply_rng
        @hk.transform
        def forward(inputs, use_2):
            model = StaticArgumentModel()
            return model(inputs, use_2)

        z = jnp.array([[1., 2., 3.]])
        params = forward.init(jax.random.PRNGKey(1), z, True)
        input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0)

        def fun_to_prop(inputs):
            return forward.apply(params, inputs, True)

        output_bounds = jax_verify.interval_bound_propagation(
            fun_to_prop, input_bounds)
        self.assertTrue((output_bounds.upper >= output_bounds.lower).all())
Exemple #13
0
def _compute_jv_bounds(
    input_bound: sdp_utils.IntBound,
    params: ModelParams,
    method: str,
) -> List[sdp_utils.IntBound]:
    """Compute bounds with jax_verify."""

    jv_input_bound = jax_verify.IntervalBound(input_bound.lb, input_bound.ub)

    # create a function that takes as arguments the input and all parameters
    # that have bounds (as specified in param_bounds) and returns all
    # activations
    all_act_fun = _make_all_act_fn(params)

    # use jax_verify to perform (bilinear) interval bound propagation
    jv_param_bounds = [(p.w_bound, p.b_bound) for p in params if p.has_bounds]

    if method == 'ibp':
        _, jv_bounds = jax_verify.interval_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'fastlin':
        _, jv_bounds = jax_verify.forward_fastlin_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'ibpfastlin':
        _, jv_bounds = jax_verify.ibpforwardfastlin_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'crown':
        _, jv_bounds = jax_verify.backward_crown_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'nonconvex':
        _, jv_bounds = jax_verify.nonconvex_constopt_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    else:
        raise ValueError('Unsupported method.')

    # re-format bounds with internal convention
    bounds = []
    for intermediate_bound in jv_bounds:
        bounds.append(
            sdp_utils.IntBound(lb_pre=intermediate_bound.lower,
                               ub_pre=intermediate_bound.upper,
                               lb=jnp.maximum(intermediate_bound.lower, 0),
                               ub=jnp.maximum(intermediate_bound.upper, 0)))

    return bounds
Exemple #14
0
    def test_stateful_model(self):
        @hk.transform_with_state
        def forward(inputs, is_training, test_local_stats=False):
            model = ModelWithState()
            return model(inputs, is_training, test_local_stats)

        z = jnp.array([[1., 2., 3.]])
        params, state = forward.init(jax.random.PRNGKey(1), z, True, False)

        def fun_to_prop(inputs):
            outs = forward.apply(params, state, jax.random.PRNGKey(1), inputs,
                                 False, False)
            # Ignore the outputs that are not the network outputs.
            return outs[0]

        input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0)
        # Consider as static the state, the random generator, and the flags
        output_bounds = jax_verify.interval_bound_propagation(
            fun_to_prop, input_bounds)
        self.assertTrue((output_bounds.upper >= output_bounds.lower).all())
Exemple #15
0
    def test_linear_ibp(self):
        def linear_model(inp):
            return hk.Linear(1)(inp)

        z = jnp.array([[1., 2., 3.]])
        params = {
            'linear': {
                'w': jnp.ones((3, 1), dtype=jnp.float32),
                'b': jnp.array([2.])
            }
        }

        fun = functools.partial(
            hk.without_apply_rng(hk.transform(linear_model)).apply, params)
        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        output_bounds = jax_verify.interval_bound_propagation(
            fun, input_bounds)

        self.assertAlmostEqual(5., output_bounds.lower)
        self.assertAlmostEqual(11., output_bounds.upper)
Exemple #16
0
    def test_tight_bounds_nostate(self, model):
        z = jnp.array([[1., 2., 3.]])

        params = model.init(jax.random.PRNGKey(1), z)

        tight_input_bounds = jax_verify.IntervalBound(z, z)
        fun_to_prop = functools.partial(model.apply, params)
        tight_output_bounds = jax_verify.interval_bound_propagation(
            fun_to_prop, tight_input_bounds)

        model_eval = model.apply(params, z)

        # Because the input lower bound is equal to the input upper bound, the value
        # of the output bounds should be the same and correspond to the value of the
        # forward pass.
        self.assertAlmostEqual(tight_output_bounds.lower.tolist(),
                               tight_output_bounds.upper.tolist())

        self.assertAlmostEqual(tight_output_bounds.lower.tolist(),
                               model_eval.tolist())
Exemple #17
0
def elide_adversarial_spec(
    params: ModelParams,
    data_spec: DataSpec,
) -> ModelParamsElided:
    """Elide params to have last layer merged with the adversarial objective.

  Args:
    params: parameters of the model under verification.
    data_spec: data specification.

  Returns:
    params_elided: elided parameters with the adversarial objective folded in
      the last layer (and bounds adapted accordingly).
  """
    def elide_fn(w_fin, b_fin):
        label_onehot = jnp.eye(w_fin.shape[-1])[data_spec.true_label]
        target_onehot = jnp.eye(w_fin.shape[-1])[data_spec.target_label]
        obj_orig = target_onehot - label_onehot
        obj_bp = jnp.matmul(w_fin, obj_orig)
        const = jnp.expand_dims(jnp.vdot(obj_orig, b_fin), axis=-1)
        obj = jnp.reshape(obj_bp, (obj_bp.size, 1))
        return obj, const

    last_params = params[-1]
    w_elided, b_elided = elide_fn(last_params.w, last_params.b)
    last_params_elided = verify_utils.FCParams(w_elided, b_elided)

    if last_params.has_bounds:
        w_bound_elided, b_bound_elided = jax_verify.interval_bound_propagation(
            elide_fn, last_params.w_bound, last_params.b_bound)
        last_params_elided = dataclasses.replace(last_params_elided,
                                                 w_bound=w_bound_elided,
                                                 b_bound=b_bound_elided)

    params_elided = params[:-1] + [last_params_elided]
    return params_elided
Exemple #18
0
    def test_jittable_input_bounds(self):
        model = sequential_model
        z = jnp.array([[1., 2., 3.]])
        params = model.init(jax.random.PRNGKey(1), z)
        fun_to_prop = functools.partial(model.apply, params)

        non_jittable_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0)
        jittable_input_bounds = non_jittable_bounds.to_jittable()

        @jax.jit
        def bound_prop_fun(inp_bound):
            bounds = jax_verify.interval_bound_propagation(
                fun_to_prop, inp_bound)
            return bounds.lower, bounds.upper

        # check that we can jit the bound prop and pass in jittable bounds.
        out_lb, out_ub = bound_prop_fun(jittable_input_bounds)
        self.assertTrue(all(out_ub >= out_lb))

        # Check that this gives the same result as without the jit
        bounds = jax_verify.interval_bound_propagation(fun_to_prop,
                                                       non_jittable_bounds)
        chex.assert_trees_all_close(out_lb, bounds.lower)
        chex.assert_trees_all_close(out_ub, bounds.upper)
Exemple #19
0
 def _compute_bounds(lower, upper):
     input_bounds = jax_verify.IntervalBound(lower, upper)
     output_bounds = jax_verify.interval_bound_propagation(
         lambda x: x**exponent, input_bounds)
     return output_bounds.lower, output_bounds.upper
Exemple #20
0
 def bound_prop_fun(inp_bound):
     bounds = jax_verify.interval_bound_propagation(
         fun_to_prop, inp_bound)
     return bounds.lower, bounds.upper