Exemplo n.º 1
0
    def test_exp_crown(self):
        def exp_model(inp):
            return jnp.exp(inp)

        exp_inp_shape = (4, 7)
        lb, ub = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                          exp_inp_shape,
                                          minval=-10.,
                                          maxval=10.)

        input_bounds = jax_verify.IntervalBound(lb, ub)
        output_bounds = jax_verify.backward_crown_bound_propagation(
            exp_model, input_bounds)

        uniform_inps = test_utils.sample_bounded_points(
            jax.random.PRNGKey(1), (lb, ub), 100)
        uniform_outs = jax.vmap(exp_model)(uniform_inps)
        empirical_min = uniform_outs.min(axis=0)
        empirical_max = uniform_outs.max(axis=0)

        self.assertGreaterEqual((output_bounds.upper - empirical_max).min(),
                                0.,
                                'Invalid upper bound for Exponential. The gap '
                                'between upper bound and empirical max is < 0')
        self.assertGreaterEqual(
            (empirical_min - output_bounds.lower).min(), 0.,
            'Invalid lower bound for Exponential. The gap'
            'between emp. min and lower bound is negative.')
    def test_relu_random_fastlin(self):
        def relu_model(inp):
            return jax.nn.relu(inp)

        relu_inp_shape = (4, 7)
        lb, ub = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                          relu_inp_shape,
                                          minval=-10.,
                                          maxval=10.)

        input_bounds = jax_verify.IntervalBound(lb, ub)
        output_bounds = jax_verify.forward_fastlin_bound_propagation(
            relu_model, input_bounds)

        uniform_inps = test_utils.sample_bounded_points(
            jax.random.PRNGKey(1), (lb, ub), 100)
        uniform_outs = jax.vmap(relu_model)(uniform_inps)
        empirical_min = uniform_outs.min(axis=0)
        empirical_max = uniform_outs.max(axis=0)
        self.assertGreaterEqual((output_bounds.upper - empirical_max).min(),
                                0., 'Invalid upper bound for ReLU. The gap '
                                'between upper bound and empirical max is < 0')
        self.assertGreaterEqual(
            (empirical_min - output_bounds.lower).min(), 0.,
            'Invalid lower bound for ReLU. The gap'
            'between emp. min and lower bound is negative.')
Exemplo n.º 3
0
  def test_leaky_relu(self, negative_slope):
    batch_size = 5
    axis_dim = 8

    leaky_relu_inp_shape = (batch_size, axis_dim)

    def leaky_relu_model(inp):
      return jax.nn.leaky_relu(inp, negative_slope)

    bound_key = jax.random.PRNGKey(0)
    inp_lb, inp_ub = test_utils.sample_bounds(bound_key, leaky_relu_inp_shape,
                                              minval=-10., maxval=10.)

    lb_fun, ub_fun = activation_relaxation.leaky_relu_relaxation(
        IntervalBound(inp_lb, inp_ub), negative_slope=negative_slope)

    # Check that the bounds are valid
    uniform_check_key = jax.random.PRNGKey(1)
    self._check_bounds(uniform_check_key, leaky_relu_model, lb_fun, ub_fun,
                       [inp_lb], [inp_ub])

    # Sanity check the convexity of the relaxation
    cvx_check_key = jax.random.PRNGKey(2)
    self._check_convexity(cvx_check_key, lb_fun, [inp_lb], [inp_ub], True)
    ccv_check_key = jax.random.PRNGKey(3)
    self._check_convexity(ccv_check_key, ub_fun, [inp_lb], [inp_ub], False)
Exemplo n.º 4
0
  def test_abs(self):

    batch_size = 5
    axis_dim = 8

    abs_inp_shape = (batch_size, axis_dim)

    def abs_model(inp):
      return jnp.abs(inp)

    bound_key = jax.random.PRNGKey(0)
    inp_lb, inp_ub = test_utils.sample_bounds(bound_key, abs_inp_shape,
                                              minval=-10., maxval=10.)

    lb_fun, ub_fun = activation_relaxation.convex_fn_relaxation(
        lax.abs_p, IntervalBound(inp_lb, inp_ub))

    # Check that the bounds are valid
    uniform_check_key = jax.random.PRNGKey(1)
    self._check_bounds(uniform_check_key, abs_model, lb_fun, ub_fun,
                       [inp_lb], [inp_ub])

    # Sanity check the convexity of the relaxation
    cvx_check_key = jax.random.PRNGKey(2)
    self._check_convexity(cvx_check_key, lb_fun, [inp_lb], [inp_ub], True)
    ccv_check_key = jax.random.PRNGKey(3)
    self._check_convexity(ccv_check_key, ub_fun, [inp_lb], [inp_ub], False)
Exemplo n.º 5
0
    def test_chunking(self, relaxer):
        batch_size = 3
        input_size = 2
        hidden_size = 5
        final_size = 4

        input_shape = (batch_size, input_size)
        hidden_lay_weight_shape = (input_size, hidden_size)
        final_lay_weight_shape = (hidden_size, final_size)

        inp_lb, inp_ub = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                                  input_shape,
                                                  minval=-1.,
                                                  maxval=1.)
        inp_bound = jax_verify.IntervalBound(inp_lb, inp_ub)

        hidden_lay_weight = jax.random.uniform(jax.random.PRNGKey(1),
                                               hidden_lay_weight_shape)
        final_lay_weight = jax.random.uniform(jax.random.PRNGKey(2),
                                              final_lay_weight_shape)

        def model_fun(inp):
            hidden = inp @ hidden_lay_weight
            act = jax.nn.relu(hidden)
            final = act @ final_lay_weight
            return final

        if isinstance(relaxer,
                      linear_bound_utils.ParameterizedLinearBoundsRelaxer):
            concretizing_transform = (
                backward_crown.OptimizingLinearBoundBackwardTransform(
                    relaxer,
                    backward_crown.CONCRETIZE_ARGS_PRIMITIVE,
                    optax.adam(1.e-3),
                    num_opt_steps=10))
        else:
            concretizing_transform = backward_crown.LinearBoundBackwardTransform(
                relaxer, backward_crown.CONCRETIZE_ARGS_PRIMITIVE)

        chunked_concretizer = backward_crown.ChunkedBackwardConcretizer(
            concretizing_transform, max_chunk_size=16)
        unchunked_concretizer = backward_crown.ChunkedBackwardConcretizer(
            concretizing_transform, max_chunk_size=0)

        chunked_algorithm = bound_utils.BackwardConcretizingAlgorithm(
            chunked_concretizer)
        unchunked_algorithm = bound_utils.BackwardConcretizingAlgorithm(
            unchunked_concretizer)

        chunked_bound, _ = bound_propagation.bound_propagation(
            chunked_algorithm, model_fun, inp_bound)
        unchunked_bound, _ = bound_propagation.bound_propagation(
            unchunked_algorithm, model_fun, inp_bound)

        np.testing.assert_array_almost_equal(chunked_bound.lower,
                                             unchunked_bound.lower)
        np.testing.assert_array_almost_equal(chunked_bound.upper,
                                             unchunked_bound.upper)
Exemplo n.º 6
0
    def test_nobatch_batch_inputs(self):

        batch_shape = (3, 2)
        unbatch_shape = (2, 4)

        def bilinear_model(inp_1, inp_2):
            return jnp.einsum('bh,hH->bH', inp_1, inp_2)

        lb_1, ub_1 = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                              batch_shape,
                                              minval=-10,
                                              maxval=10.)
        lb_2, ub_2 = test_utils.sample_bounds(jax.random.PRNGKey(1),
                                              unbatch_shape,
                                              minval=-10,
                                              maxval=10.)
        bound_1 = jax_verify.IntervalBound(lb_1, ub_1)
        bound_2 = jax_verify.IntervalBound(lb_2, ub_2)

        output_bounds = backward_crown.backward_crown_bound_propagation(
            bilinear_model, bound_1, bound_2)

        uniform_1 = test_utils.sample_bounded_points(jax.random.PRNGKey(2),
                                                     (lb_1, ub_1), 100)
        uniform_2 = test_utils.sample_bounded_points(jax.random.PRNGKey(3),
                                                     (lb_2, ub_2), 100)

        uniform_outs = jax.vmap(bilinear_model)(uniform_1, uniform_2)
        empirical_min = uniform_outs.min(axis=0)
        empirical_max = uniform_outs.max(axis=0)

        self.assertGreaterEqual(
            (output_bounds.upper - empirical_max).min(), 0.,
            'Invalid upper bound for mix of batched/unbatched'
            'input bounds.')
        self.assertGreaterEqual(
            (empirical_min - output_bounds.lower).min(), 0.,
            'Invalid lower bound for mix of batched/unbatched'
            'input bounds.')
Exemplo n.º 7
0
    def test_multiply_crown(self):
        def multiply_model(lhs, rhs):
            return lhs * rhs

        mul_inp_shape = (4, 7)
        lhs_lb, lhs_ub = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                                  mul_inp_shape,
                                                  minval=-10.,
                                                  maxval=10.)
        rhs_lb, rhs_ub = test_utils.sample_bounds(jax.random.PRNGKey(1),
                                                  mul_inp_shape,
                                                  minval=-10.,
                                                  maxval=10.)

        lhs_bounds = jax_verify.IntervalBound(lhs_lb, lhs_ub)
        rhs_bounds = jax_verify.IntervalBound(rhs_lb, rhs_ub)
        output_bounds = jax_verify.backward_crown_bound_propagation(
            multiply_model, lhs_bounds, rhs_bounds)

        uniform_lhs_inps = test_utils.sample_bounded_points(
            jax.random.PRNGKey(2), (lhs_lb, lhs_ub), 100)
        uniform_rhs_inps = test_utils.sample_bounded_points(
            jax.random.PRNGKey(3), (rhs_lb, rhs_ub), 100)

        uniform_outs = jax.vmap(multiply_model)(uniform_lhs_inps,
                                                uniform_rhs_inps)
        empirical_min = uniform_outs.min(axis=0)
        empirical_max = uniform_outs.max(axis=0)

        self.assertGreaterEqual(
            (output_bounds.upper - empirical_max).min(), 0.,
            'Invalid upper bound for Multiply. The gap '
            'between upper bound and empirical max is negative')
        self.assertGreaterEqual(
            (empirical_min - output_bounds.lower).min(), 0.,
            'Invalid lower bound for Multiply. The gap'
            'between emp. min and lower bound is negative.')
Exemplo n.º 8
0
  def test_sigmoid(self, scale):
    batch_size = 5
    axis_dim = 8

    sigmoid_inp_shape = (batch_size, axis_dim)
    sigmoid = jax.nn.sigmoid

    bound_key = jax.random.PRNGKey(0)
    inp_lb, inp_ub = test_utils.sample_bounds(bound_key, sigmoid_inp_shape,
                                              minval=-scale, maxval=scale)

    lb_fun, ub_fun = activation_relaxation.sigmoid_relaxation(
        IntervalBound(inp_lb, inp_ub))

    # Check that the bounds are valid
    uniform_check_key = jax.random.PRNGKey(1)
    self._check_bounds(uniform_check_key, sigmoid, lb_fun, ub_fun,
                       [inp_lb], [inp_ub])

    # Sanity check the convexity of the relaxation
    cvx_check_key = jax.random.PRNGKey(2)
    self._check_convexity(cvx_check_key, lb_fun, [inp_lb], [inp_ub], True)
    ccv_check_key = jax.random.PRNGKey(3)
    self._check_convexity(ccv_check_key, ub_fun, [inp_lb], [inp_ub], False)