def test_exp_fastlin(self):
        def exp_model(inp):
            return jnp.exp(inp)

        exp_inp_shape = (4, 7)
        lb, ub = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                          exp_inp_shape,
                                          minval=-10.,
                                          maxval=10.)

        input_bounds = jax_verify.IntervalBound(lb, ub)
        output_bounds = jax_verify.forward_fastlin_bound_propagation(
            exp_model, input_bounds)

        uniform_inps = test_utils.sample_bounded_points(
            jax.random.PRNGKey(1), (lb, ub), 100)
        uniform_outs = jax.vmap(exp_model)(uniform_inps)
        empirical_min = uniform_outs.min(axis=0)
        empirical_max = uniform_outs.max(axis=0)
        self.assertGreaterEqual((output_bounds.upper - empirical_max).min(),
                                0.,
                                'Invalid upper bound for Exponential. The gap '
                                'between upper bound and empirical max is < 0')
        self.assertGreaterEqual(
            (empirical_min - output_bounds.lower).min(), 0.,
            'Invalid lower bound for Exponential. The gap'
            'between emp. min and lower bound is negative.')
Example #2
0
    def test_fastlin(self, model_cls):
        @hk.transform_with_state
        def model_pred(inputs, is_training, test_local_stats=False):
            model = model_cls()
            return model(inputs, is_training, test_local_stats)

        inps = jnp.zeros((4, 28, 28, 1), dtype=jnp.float32)
        params, state = model_pred.init(jax.random.PRNGKey(42),
                                        inps,
                                        is_training=True)

        def logits_fun(inputs):
            return model_pred.apply(params,
                                    state,
                                    None,
                                    inputs,
                                    False,
                                    test_local_stats=False)[0]

        input_bounds = jax_verify.IntervalBound(inps - 1.0, inps + 1.0)
        jax_verify.forward_fastlin_bound_propagation(logits_fun, input_bounds)
    def test_relu_fixed_fastlin(self):
        def relu_model(inp):
            return jax.nn.relu(inp)

        z = jnp.array([[-2., 3.]])

        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        output_bounds = jax_verify.forward_fastlin_bound_propagation(
            relu_model, input_bounds)

        self.assertArrayAlmostEqual(jnp.array([[0., 2.]]), output_bounds.lower)
        self.assertArrayAlmostEqual(jnp.array([[0., 4.]]), output_bounds.upper)
Example #4
0
def main(unused_args):

  # Load the parameters of an existing model.
  model_pred, params = load_model(FLAGS.model)
  logits_fn = functools.partial(model_pred, params)

  # Load some test samples
  with jax_verify.open_file('mnist/x_test_first100.npy', 'rb') as f:
    inputs = np.load(f)

  # Compute boundprop bounds
  eps = 0.1
  lower_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] - eps, 0.0), 1.0)
  upper_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] + eps, 0.0), 1.0)
  init_bound = jax_verify.IntervalBound(lower_bound, upper_bound)

  if FLAGS.boundprop_method == 'forwardfastlin':
    final_bound = jax_verify.forward_fastlin_bound_propagation(logits_fn,
                                                               init_bound)
    boundprop_transform = forward_linear_bounds.forward_fastlin_transform
  elif FLAGS.boundprop_method == 'ibp':
    final_bound = jax_verify.interval_bound_propagation(logits_fn, init_bound)
    boundprop_transform = jax_verify.ibp_transform
  else:
    raise NotImplementedError('Only ibp/fastlin boundprop are'
                              'currently supported')

  dummy_output = model_pred(params, inputs)

  # Run LP solver
  objective = jnp.where(jnp.arange(dummy_output[0, ...].size) == 0,
                        jnp.ones_like(dummy_output[0, ...]),
                        jnp.zeros_like(dummy_output[0, ...]))
  objective_bias = 0.
  value, _, status = jax_verify.solve_planet_relaxation(
      logits_fn, init_bound, boundprop_transform, objective,
      objective_bias, index=0)
  logging.info('Relaxation LB is : %f, Status is %s', value, status)
  value, _, status = jax_verify.solve_planet_relaxation(
      logits_fn, init_bound, boundprop_transform, -objective,
      objective_bias, index=0)
  logging.info('Relaxation UB is : %f, Status is %s', -value, status)

  logging.info('Boundprop LB is : %f', final_bound.lower[0, 0])
  logging.info('Boundprop UB is : %f', final_bound.upper[0, 0])
Example #5
0
def _compute_jv_bounds(
    input_bound: sdp_utils.IntBound,
    params: ModelParams,
    method: str,
) -> List[sdp_utils.IntBound]:
    """Compute bounds with jax_verify."""

    jv_input_bound = jax_verify.IntervalBound(input_bound.lb, input_bound.ub)

    # create a function that takes as arguments the input and all parameters
    # that have bounds (as specified in param_bounds) and returns all
    # activations
    all_act_fun = _make_all_act_fn(params)

    # use jax_verify to perform (bilinear) interval bound propagation
    jv_param_bounds = [(p.w_bound, p.b_bound) for p in params if p.has_bounds]

    if method == 'ibp':
        _, jv_bounds = jax_verify.interval_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'fastlin':
        _, jv_bounds = jax_verify.forward_fastlin_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'ibpfastlin':
        _, jv_bounds = jax_verify.ibpforwardfastlin_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'crown':
        _, jv_bounds = jax_verify.backward_crown_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'nonconvex':
        _, jv_bounds = jax_verify.nonconvex_constopt_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    else:
        raise ValueError('Unsupported method.')

    # re-format bounds with internal convention
    bounds = []
    for intermediate_bound in jv_bounds:
        bounds.append(
            sdp_utils.IntBound(lb_pre=intermediate_bound.lower,
                               ub_pre=intermediate_bound.upper,
                               lb=jnp.maximum(intermediate_bound.lower, 0),
                               ub=jnp.maximum(intermediate_bound.upper, 0)))

    return bounds
    def test_nobatch_batch_inputs(self):
        batch_shape = (3, 2)
        unbatch_shape = (2, 4)

        def bilinear_model(inp_1, inp_2):
            return jnp.einsum('bh,hH->bH', inp_1, inp_2)

        lb_1, ub_1 = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                              batch_shape,
                                              minval=-10,
                                              maxval=10.)
        lb_2, ub_2 = test_utils.sample_bounds(jax.random.PRNGKey(1),
                                              unbatch_shape,
                                              minval=-10,
                                              maxval=10.)
        bound_1 = jax_verify.IntervalBound(lb_1, ub_1)
        bound_2 = jax_verify.IntervalBound(lb_2, ub_2)

        output_bounds = jax_verify.forward_fastlin_bound_propagation(
            bilinear_model, bound_1, bound_2)

        uniform_1 = test_utils.sample_bounded_points(jax.random.PRNGKey(2),
                                                     (lb_1, ub_1), 100)
        uniform_2 = test_utils.sample_bounded_points(jax.random.PRNGKey(3),
                                                     (lb_2, ub_2), 100)

        uniform_outs = jax.vmap(bilinear_model)(uniform_1, uniform_2)
        empirical_min = uniform_outs.min(axis=0)
        empirical_max = uniform_outs.max(axis=0)

        self.assertGreaterEqual(
            (output_bounds.upper - empirical_max).min(), 0.,
            'Invalid upper bound for mix of batched/unbatched'
            'input bounds.')
        self.assertGreaterEqual(
            (empirical_min - output_bounds.lower).min(), 0.,
            'Invalid lower bound for mix of batched/unbatched'
            'input bounds.')
    def test_multiply_fastlin(self):
        def multiply_model(lhs, rhs):
            return lhs * rhs

        mul_inp_shape = (4, 7)
        lhs_lb, lhs_ub = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                                  mul_inp_shape,
                                                  minval=-10.,
                                                  maxval=10.)
        rhs_lb, rhs_ub = test_utils.sample_bounds(jax.random.PRNGKey(1),
                                                  mul_inp_shape,
                                                  minval=-10.,
                                                  maxval=10.)

        lhs_bounds = jax_verify.IntervalBound(lhs_lb, lhs_ub)
        rhs_bounds = jax_verify.IntervalBound(rhs_lb, rhs_ub)
        output_bounds = jax_verify.forward_fastlin_bound_propagation(
            multiply_model, lhs_bounds, rhs_bounds)

        uniform_lhs_inps = test_utils.sample_bounded_points(
            jax.random.PRNGKey(2), (lhs_lb, lhs_ub), 100)
        uniform_rhs_inps = test_utils.sample_bounded_points(
            jax.random.PRNGKey(3), (rhs_lb, rhs_ub), 100)

        uniform_outs = jax.vmap(multiply_model)(uniform_lhs_inps,
                                                uniform_rhs_inps)
        empirical_min = uniform_outs.min(axis=0)
        empirical_max = uniform_outs.max(axis=0)

        self.assertGreaterEqual(
            (output_bounds.upper - empirical_max).min(), 0.,
            'Invalid upper bound for Multiply. The gap '
            'between upper bound and empirical max is negative')
        self.assertGreaterEqual(
            (empirical_min - output_bounds.lower).min(), 0.,
            'Invalid lower bound for Multiply. The gap'
            'between emp. min and lower bound is negative.')