Ejemplo n.º 1
0
  def _check_convexity(self, key, fun, lbs, ubs, is_convex, nb_samples=100):
    """Check that the function is convex or concave.

    We do this by sanity-checking that the function is below its chord.

    Args:
      key: PRNG key for random number generation
      fun: Function to be checked.
      lbs: List of lower bounds on the inputs
      ubs: List of upper bounds on the inputs
      is_convex: Boolean, if True:  check that the function is convex.
                          if False: check that the function is concave.
      nb_samples: How many random samples to draw for testing.
    """
    assert len(lbs) == len(ubs)
    keys = jax.random.split(key, 2*len(lbs) + 1)

    a_inps = []
    b_inps = []
    interp_inps = []
    interp_coeffs = jax.random.uniform(keys[-1], (nb_samples,))
    for inp_idx, bounds in enumerate(zip(lbs, ubs)):
      interp_coeffs_shape = (-1,) + (1,)*bounds[0].ndim
      broad_interp_coeffs = jnp.reshape(interp_coeffs, interp_coeffs_shape)

      a_inp = test_utils.sample_bounded_points(keys[2*inp_idx], bounds,
                                               nb_samples)
      b_inp = test_utils.sample_bounded_points(keys[2*inp_idx + 1], bounds,
                                               nb_samples)
      interp_inp = (a_inp * broad_interp_coeffs +
                    b_inp * (1. - broad_interp_coeffs))
      a_inps.append(a_inp)
      b_inps.append(b_inp)
      interp_inps.append(interp_inp)

    vmap_fun = jax.vmap(fun)

    a_eval = vmap_fun(*a_inps)
    b_eval = vmap_fun(*b_inps)
    interp_eval = vmap_fun(*interp_inps)
    interp_coeffs_shape = (-1,) + (1,)*(interp_eval.ndim - 1)
    broad_interp_coeffs = jnp.reshape(interp_coeffs, interp_coeffs_shape)
    chord_eval = (a_eval * broad_interp_coeffs +
                  b_eval * (1. - broad_interp_coeffs))

    if is_convex:
      self.assertGreaterEqual(
          (chord_eval - interp_eval).min(), -TOL,
          msg='Function is not convex')
    else:
      self.assertGreaterEqual(
          (interp_eval - chord_eval).min(), -TOL,
          msg='Function is not concave')
Ejemplo n.º 2
0
    def test_exp_crown(self):
        def exp_model(inp):
            return jnp.exp(inp)

        exp_inp_shape = (4, 7)
        lb, ub = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                          exp_inp_shape,
                                          minval=-10.,
                                          maxval=10.)

        input_bounds = jax_verify.IntervalBound(lb, ub)
        output_bounds = jax_verify.backward_crown_bound_propagation(
            exp_model, input_bounds)

        uniform_inps = test_utils.sample_bounded_points(
            jax.random.PRNGKey(1), (lb, ub), 100)
        uniform_outs = jax.vmap(exp_model)(uniform_inps)
        empirical_min = uniform_outs.min(axis=0)
        empirical_max = uniform_outs.max(axis=0)

        self.assertGreaterEqual((output_bounds.upper - empirical_max).min(),
                                0.,
                                'Invalid upper bound for Exponential. The gap '
                                'between upper bound and empirical max is < 0')
        self.assertGreaterEqual(
            (empirical_min - output_bounds.lower).min(), 0.,
            'Invalid lower bound for Exponential. The gap'
            'between emp. min and lower bound is negative.')
    def test_relu_random_fastlin(self):
        def relu_model(inp):
            return jax.nn.relu(inp)

        relu_inp_shape = (4, 7)
        lb, ub = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                          relu_inp_shape,
                                          minval=-10.,
                                          maxval=10.)

        input_bounds = jax_verify.IntervalBound(lb, ub)
        output_bounds = jax_verify.forward_fastlin_bound_propagation(
            relu_model, input_bounds)

        uniform_inps = test_utils.sample_bounded_points(
            jax.random.PRNGKey(1), (lb, ub), 100)
        uniform_outs = jax.vmap(relu_model)(uniform_inps)
        empirical_min = uniform_outs.min(axis=0)
        empirical_max = uniform_outs.max(axis=0)
        self.assertGreaterEqual((output_bounds.upper - empirical_max).min(),
                                0., 'Invalid upper bound for ReLU. The gap '
                                'between upper bound and empirical max is < 0')
        self.assertGreaterEqual(
            (empirical_min - output_bounds.lower).min(), 0.,
            'Invalid lower bound for ReLU. The gap'
            'between emp. min and lower bound is negative.')
Ejemplo n.º 4
0
    def test_nobatch_batch_inputs(self):

        batch_shape = (3, 2)
        unbatch_shape = (2, 4)

        def bilinear_model(inp_1, inp_2):
            return jnp.einsum('bh,hH->bH', inp_1, inp_2)

        lb_1, ub_1 = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                              batch_shape,
                                              minval=-10,
                                              maxval=10.)
        lb_2, ub_2 = test_utils.sample_bounds(jax.random.PRNGKey(1),
                                              unbatch_shape,
                                              minval=-10,
                                              maxval=10.)
        bound_1 = jax_verify.IntervalBound(lb_1, ub_1)
        bound_2 = jax_verify.IntervalBound(lb_2, ub_2)

        output_bounds = backward_crown.backward_crown_bound_propagation(
            bilinear_model, bound_1, bound_2)

        uniform_1 = test_utils.sample_bounded_points(jax.random.PRNGKey(2),
                                                     (lb_1, ub_1), 100)
        uniform_2 = test_utils.sample_bounded_points(jax.random.PRNGKey(3),
                                                     (lb_2, ub_2), 100)

        uniform_outs = jax.vmap(bilinear_model)(uniform_1, uniform_2)
        empirical_min = uniform_outs.min(axis=0)
        empirical_max = uniform_outs.max(axis=0)

        self.assertGreaterEqual(
            (output_bounds.upper - empirical_max).min(), 0.,
            'Invalid upper bound for mix of batched/unbatched'
            'input bounds.')
        self.assertGreaterEqual(
            (empirical_min - output_bounds.lower).min(), 0.,
            'Invalid lower bound for mix of batched/unbatched'
            'input bounds.')
Ejemplo n.º 5
0
    def test_multiply_crown(self):
        def multiply_model(lhs, rhs):
            return lhs * rhs

        mul_inp_shape = (4, 7)
        lhs_lb, lhs_ub = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                                  mul_inp_shape,
                                                  minval=-10.,
                                                  maxval=10.)
        rhs_lb, rhs_ub = test_utils.sample_bounds(jax.random.PRNGKey(1),
                                                  mul_inp_shape,
                                                  minval=-10.,
                                                  maxval=10.)

        lhs_bounds = jax_verify.IntervalBound(lhs_lb, lhs_ub)
        rhs_bounds = jax_verify.IntervalBound(rhs_lb, rhs_ub)
        output_bounds = jax_verify.backward_crown_bound_propagation(
            multiply_model, lhs_bounds, rhs_bounds)

        uniform_lhs_inps = test_utils.sample_bounded_points(
            jax.random.PRNGKey(2), (lhs_lb, lhs_ub), 100)
        uniform_rhs_inps = test_utils.sample_bounded_points(
            jax.random.PRNGKey(3), (rhs_lb, rhs_ub), 100)

        uniform_outs = jax.vmap(multiply_model)(uniform_lhs_inps,
                                                uniform_rhs_inps)
        empirical_min = uniform_outs.min(axis=0)
        empirical_max = uniform_outs.max(axis=0)

        self.assertGreaterEqual(
            (output_bounds.upper - empirical_max).min(), 0.,
            'Invalid upper bound for Multiply. The gap '
            'between upper bound and empirical max is negative')
        self.assertGreaterEqual(
            (empirical_min - output_bounds.lower).min(), 0.,
            'Invalid lower bound for Multiply. The gap'
            'between emp. min and lower bound is negative.')
Ejemplo n.º 6
0
  def _check_bounds(self, key, fun, lb_fun, ub_fun, lbs, ubs, nb_samples=1000):
    """Check that lb_fun and ub_fun actually bound the function fun.

    This is evaluated at a number of random samples.

    Args:
      key: PRNG key for random number generation
      fun: Function to be bounded.
      lb_fun: Lower bound function
      ub_fun: Upper bound function
      lbs: List of lower bounds on the inputs
      ubs: List of upper bounds on the inputs
      nb_samples: How many random samples to draw for testing.
    """
    assert len(lbs) == len(ubs)
    keys = jax.random.split(key, len(lbs))

    # Build the uniform samples.
    inps = []
    for inp_idx, bounds in enumerate(zip(lbs, ubs)):
      inps.append(test_utils.sample_bounded_points(
          keys[inp_idx], bounds, nb_samples))
    vmap_fun = jax.vmap(fun)
    vmap_lbfun = jax.vmap(lb_fun)
    vmap_ubfun = jax.vmap(ub_fun)

    samples_eval = vmap_fun(*inps)
    lb_eval = vmap_lbfun(*inps)
    ub_eval = vmap_ubfun(*inps)

    self.assertGreaterEqual(
        (samples_eval - lb_eval).min(), -TOL,
        msg='Lower Bound is invalid')
    self.assertGreaterEqual(
        (ub_eval - samples_eval).min(), -TOL,
        msg='Upper Bound is invalid')