Exemple #1
0
def make_params(prng_key, dropout_rate=0.0, std=None):
    prng_key_seq = hk.PRNGSequence(prng_key)

    w1 = jax.random.normal(next(prng_key_seq), [8, 4])
    b1 = jax.random.normal(next(prng_key_seq), [4])

    w2 = jax.random.normal(next(prng_key_seq), [4, 2])
    b2 = jax.random.normal(next(prng_key_seq), [2])

    if std is not None:
        w1_std = std * jnp.ones([8, 4])
        b1_std = std * jnp.ones([4])
        w1_bound = jax_verify.IntervalBound(w1 - 3 * w1_std, w1 + 3 * w1_std)
        b1_bound = jax_verify.IntervalBound(b1 - 3 * b1_std, b1 + 3 * b1_std)
    else:
        w1_std, b1_std, w1_bound, b1_bound = None, None, None, None

    params = [
        verify_utils.FCParams(
            w=w1,
            b=b1,
            w_std=w1_std,
            b_std=b1_std,
            w_bound=w1_bound,
            b_bound=b1_bound,
        ),
        verify_utils.FCParams(
            w=w2,
            b=b2,
            dropout_rate=dropout_rate,
        )
    ]
    return params
Exemple #2
0
 def _bounds_from_cnn_layer(self, index):
     layer_index, is_preact = self._cnn_layer_indices[index]
     if is_preact:
         return jax_verify.IntervalBound(
             self._cnn_bounds[layer_index].lb_pre,
             self._cnn_bounds[layer_index].ub_pre)
     else:
         return jax_verify.IntervalBound(self._cnn_bounds[layer_index].lb,
                                         self._cnn_bounds[layer_index].ub)
Exemple #3
0
    def test_sdp_problem_equivalent_to_sdp_verify(self):
        # Set up a verification problem for test purposes.
        verif_instance = test_utils.make_toy_verif_instance(label=2,
                                                            target_label=1)

        # Set up a spec function that replicates the test problem.
        inputs = jnp.zeros((1, 5))
        input_bounds = jax_verify.IntervalBound(jnp.zeros_like(inputs),
                                                jnp.ones_like(inputs))
        boundprop_transform = ibp.bound_transform

        def spec_fn(x):
            x = utils.predict_mlp(verif_instance.params, x)
            x = jax.nn.relu(x)
            return jnp.sum(jnp.reshape(x, (-1, )) *
                           verif_instance.obj) + verif_instance.const

        # Build an SDP verification instance using the code under test.
        sdp_relu_problem = problem_from_graph.SdpReluProblem(
            boundprop_transform, spec_fn, input_bounds)
        sdp_problem_vi = sdp_relu_problem.build_sdp_verification_instance()

        # Build an SDP verification instance using existing `sdp_verify` code.
        sdp_verify_vi = problem.make_sdp_verif_instance(verif_instance)

        self._assert_verif_instances_equal(sdp_problem_vi, sdp_verify_vi)
Exemple #4
0
def _get_reciprocal_bound(l: jnp.array, u: jnp.array,
                          logits_params: LayerParams, label: int) -> jnp.array:
    """Helped for computing bound on label softmax given interval bounds on pre logits."""
    def fwd(x, w, b):
        wdiff = jnp.reshape(w[:, label], [-1, 1]) - w
        bdiff = b[label] - b
        return x @ wdiff + bdiff

    x_bound = jax_verify.IntervalBound(
        lower_bound=jnp.reshape(l, [l.shape[0], -1]),
        upper_bound=jnp.reshape(u, [u.shape[0], -1]))

    params_bounds = []
    if logits_params.w_bound is None:
        fwd = functools.partial(fwd, w=logits_params.w)
    else:
        params_bounds.append(logits_params.w_bound)

    if logits_params.b_bound is None:
        fwd = functools.partial(fwd, b=logits_params.b)
    else:
        params_bounds.append(logits_params.b_bound)

    fwd_bound = jax_verify.interval_bound_propagation(fwd, x_bound,
                                                      *params_bounds)

    return fwd_bound
Exemple #5
0
    def test_exp_crown(self):
        def exp_model(inp):
            return jnp.exp(inp)

        exp_inp_shape = (4, 7)
        lb, ub = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                          exp_inp_shape,
                                          minval=-10.,
                                          maxval=10.)

        input_bounds = jax_verify.IntervalBound(lb, ub)
        output_bounds = jax_verify.backward_crown_bound_propagation(
            exp_model, input_bounds)

        uniform_inps = test_utils.sample_bounded_points(
            jax.random.PRNGKey(1), (lb, ub), 100)
        uniform_outs = jax.vmap(exp_model)(uniform_inps)
        empirical_min = uniform_outs.min(axis=0)
        empirical_max = uniform_outs.max(axis=0)

        self.assertGreaterEqual((output_bounds.upper - empirical_max).min(),
                                0.,
                                'Invalid upper bound for Exponential. The gap '
                                'between upper bound and empirical max is < 0')
        self.assertGreaterEqual(
            (empirical_min - output_bounds.lower).min(), 0.,
            'Invalid lower bound for Exponential. The gap'
            'between emp. min and lower bound is negative.')
    def test_nonconvex(self, model_cls):
        @hk.transform_with_state
        def model_pred(inputs, is_training, test_local_stats=False):
            model = model_cls()
            return model(inputs, is_training, test_local_stats)

        inps = jnp.zeros((4, 28, 28, 1), dtype=jnp.float32)
        params, state = model_pred.init(jax.random.PRNGKey(42),
                                        inps,
                                        is_training=True)

        def logits_fun(inputs):
            return model_pred.apply(params,
                                    state,
                                    None,
                                    inputs,
                                    False,
                                    test_local_stats=False)[0]

        input_bounds = jax_verify.IntervalBound(inps - 1.0, inps + 1.0)
        # Test with IBP for intermediate bounds
        jax_verify.nonconvex_ibp_bound_propagation(logits_fun, input_bounds)

        # Test with nonconvex bound evaluation for intermediate bounds
        jax_verify.nonconvex_constopt_bound_propagation(
            logits_fun, input_bounds)
Exemple #7
0
    def test_matching_output_structure(self, model):
        def _check_matching_structures(output_tree, bound_tree):
            """Replace all bounds/arrays with True, then compare pytrees."""
            output_struct = tree.traverse(
                lambda x: True
                if isinstance(x, jnp.ndarray) else None, output_tree)
            bound_struct = tree.traverse(
                lambda x: True
                if isinstance(x, bound_propagation.Bound) else None,
                bound_tree)
            tree.assert_same_structure(output_struct, bound_struct)

        z = jnp.array([[1., 2., 3.]])
        params = model.init(jax.random.PRNGKey(1), z)
        input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0)
        model_output = model.apply(params, z)
        fun_to_prop = functools.partial(model.apply, params)
        for boundprop_method in [
                jax_verify.interval_bound_propagation,
                jax_verify.forward_crown_bound_propagation,
                jax_verify.backward_crown_bound_propagation,
                jax_verify.forward_fastlin_bound_propagation,
                jax_verify.backward_fastlin_bound_propagation,
                jax_verify.ibpforwardfastlin_bound_propagation,
        ]:
            output_bounds = boundprop_method(fun_to_prop, input_bounds)
            _check_matching_structures(model_output, output_bounds)
    def test_relu_random_fastlin(self):
        def relu_model(inp):
            return jax.nn.relu(inp)

        relu_inp_shape = (4, 7)
        lb, ub = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                          relu_inp_shape,
                                          minval=-10.,
                                          maxval=10.)

        input_bounds = jax_verify.IntervalBound(lb, ub)
        output_bounds = jax_verify.forward_fastlin_bound_propagation(
            relu_model, input_bounds)

        uniform_inps = test_utils.sample_bounded_points(
            jax.random.PRNGKey(1), (lb, ub), 100)
        uniform_outs = jax.vmap(relu_model)(uniform_inps)
        empirical_min = uniform_outs.min(axis=0)
        empirical_max = uniform_outs.max(axis=0)
        self.assertGreaterEqual((output_bounds.upper - empirical_max).min(),
                                0., 'Invalid upper bound for ReLU. The gap '
                                'between upper bound and empirical max is < 0')
        self.assertGreaterEqual(
            (empirical_min - output_bounds.lower).min(), 0.,
            'Invalid lower bound for ReLU. The gap'
            'between emp. min and lower bound is negative.')
Exemple #9
0
    def test_conv1d_cvxpy_relaxation(self):
        def conv1d_model(inp):
            return hk.Conv1D(output_channels=1,
                             kernel_shape=2,
                             padding='VALID',
                             stride=1,
                             with_bias=True)(inp)

        z = jnp.array([3., 4.])
        z = jnp.reshape(z, [1, 2, 1])

        params = {
            'conv1_d': {
                'w': jnp.ones((2, 1, 1), dtype=jnp.float32),
                'b': jnp.array([2.])
            }
        }

        fun = functools.partial(
            hk.without_apply_rng(hk.transform(conv1d_model)).apply, params)
        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)

        lower_bounds, upper_bounds = self.get_bounds(fun, input_bounds)

        self.assertAlmostEqual(7., lower_bounds, delta=1e-5)
        self.assertAlmostEqual(11., upper_bounds, delta=1e-5)
Exemple #10
0
    def test_conv2d_ibp(self):
        def conv2d_model(inp):
            return hk.Conv2D(output_channels=1,
                             kernel_shape=(2, 2),
                             padding='VALID',
                             stride=1,
                             with_bias=True)(inp)

        z = jnp.array([1., 2., 3., 4.])
        z = jnp.reshape(z, [1, 2, 2, 1])

        params = {
            'conv2_d': {
                'w': jnp.ones((2, 2, 1, 1), dtype=jnp.float32),
                'b': jnp.array([2.])
            }
        }

        fun = functools.partial(
            hk.without_apply_rng(hk.transform(conv2d_model)).apply, params)
        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        output_bounds = jax_verify.interval_bound_propagation(
            fun, input_bounds)

        self.assertAlmostEqual(8., output_bounds.lower)
        self.assertAlmostEqual(16., output_bounds.upper)
Exemple #11
0
  def test_fc_fastlin(self):

    @hk.without_apply_rng
    @hk.transform
    def linear_model(inp):
      return hk.Linear(1)(inp)

    z = jnp.array([[1., 2., 3.]])
    params = {'linear':
              {'w': jnp.ones((3, 1), dtype=jnp.float32),
               'b': jnp.array([2.])}}
    input_bounds = jax_verify.IntervalBound(z-1., z+1.)
    fun = functools.partial(linear_model.apply, params)
    output_bounds = jax_verify.fastlin_bound_propagation(fun, input_bounds)

    self.assertTrue(jnp.all(output_bounds.lower_lin.lin_coeffs == 1.))
    self.assertTrue(jnp.all(output_bounds.lower_lin.offset == 2.))
    self.assertTrue(jnp.all(output_bounds.upper_lin.lin_coeffs == 1.))
    self.assertTrue(jnp.all(output_bounds.upper_lin.offset == 2.))
    self.assertArrayAlmostEqual(jnp.array([[0., 1., 2.]]),
                                output_bounds.reference.lower)
    self.assertArrayAlmostEqual(jnp.array([[2., 3., 4.]]),
                                output_bounds.reference.upper)

    self.assertAlmostEqual(5., output_bounds.lower)
    self.assertAlmostEqual(11., output_bounds.upper)
    def test_fc_fastlin(self, name, elision):
        @hk.without_apply_rng
        @hk.transform
        def linear_model(inp):
            return hk.Linear(1)(inp)

        z = jnp.array([[1., 2., 3.]])
        params = {
            'linear': {
                'w': jnp.ones((3, 1), dtype=jnp.float32),
                'b': jnp.array([2.])
            }
        }
        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        fun = functools.partial(linear_model.apply, params)
        bound_prop = get_boundprop(name, elision)
        output_bounds = bound_prop(fun, input_bounds)

        all_linear_functions = list(output_bounds.linear_functions())
        self.assertLen(all_linear_functions, 1)
        linear_fun = all_linear_functions[0]
        self.assertTrue(jnp.all(linear_fun.lower_lin.lin_coeffs == 1.))
        self.assertTrue(jnp.all(linear_fun.lower_lin.offset == 2.))
        self.assertTrue(jnp.all(linear_fun.upper_lin.lin_coeffs == 1.))
        self.assertTrue(jnp.all(linear_fun.upper_lin.offset == 2.))
        self.assertArrayAlmostEqual(jnp.array([[0., 1., 2.]]),
                                    linear_fun.reference_bound.bound.lower)
        self.assertArrayAlmostEqual(jnp.array([[2., 3., 4.]]),
                                    linear_fun.reference_bound.bound.upper)

        self.assertAlmostEqual(5., output_bounds.lower)
        self.assertAlmostEqual(11., output_bounds.upper)
    def test_conv1d_fastlin(self):
        @hk.without_apply_rng
        @hk.transform
        def conv1d_model(inp):
            return hk.Conv1D(output_channels=1,
                             kernel_shape=2,
                             padding='VALID',
                             stride=1,
                             with_bias=True)(inp)

        z = jnp.array([3., 4.])
        z = jnp.reshape(z, [1, 2, 1])

        params = {
            'conv1_d': {
                'w': jnp.ones((2, 1, 1), dtype=jnp.float32),
                'b': jnp.array([2.])
            }
        }

        fun = functools.partial(conv1d_model.apply, params)
        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        output_bounds = jax_verify.ibpforwardfastlin_bound_propagation(
            fun, input_bounds)

        self.assertAlmostEqual(7., output_bounds.lower, delta=1e-5)
        self.assertAlmostEqual(11., output_bounds.upper, delta=1e-5)
def _crown_ibp_boundprop(params, x, epsilon, input_bounds):
    """Runs CROWN-IBP for each layer separately."""
    def get_layer_act(layer_idx, inputs):
        act = utils.predict_cnn(params[:layer_idx], inputs)
        return act

    initial_bound = jax_verify.IntervalBound(
        jnp.maximum(x - epsilon, input_bounds[0]),
        jnp.minimum(x + epsilon, input_bounds[1]))

    out_bounds = [
        IntBound(lb_pre=None,
                 ub_pre=None,
                 lb=initial_bound.lower,
                 ub=initial_bound.upper)
    ]
    for i in range(1, len(params) + 1):
        fwd = functools.partial(get_layer_act, i)
        bound = jax_verify.crownibp_bound_propagation(fwd, initial_bound)
        out_bounds.append(
            IntBound(lb_pre=bound.lower,
                     ub_pre=bound.upper,
                     lb=jnp.maximum(0, bound.lower),
                     ub=jnp.maximum(0, bound.upper)))
    return out_bounds
Exemple #15
0
    def solve_with_jax_verify(self):
        lower_bound = jnp.minimum(jnp.maximum(self.inputs - self.eps, 0.0),
                                  1.0)
        upper_bound = jnp.minimum(jnp.maximum(self.inputs + self.eps, 0.0),
                                  1.0)
        init_bound = jax_verify.IntervalBound(lower_bound, upper_bound)

        logits_fn = make_model_fn(self.network_params)

        solver = cvxpy_relaxation_solver.CvxpySolver
        relaxation_transform = relaxation.RelaxationTransform(
            jax_verify.ibp_transform)

        var, env = bound_propagation.bound_propagation(
            bound_propagation.ForwardPropagationAlgorithm(
                relaxation_transform), logits_fn, init_bound)

        # This solver minimizes the objective -> get max with -min(-objective)
        neg_value_opt, _, _ = relaxation.solve_relaxation(
            solver,
            -self.objective,
            -self.objective_bias,
            var,
            env,
            index=0,
            time_limit_millis=None)
        value_opt = -neg_value_opt

        return value_opt
Exemple #16
0
    def test_chunking(self, relaxer):
        batch_size = 3
        input_size = 2
        hidden_size = 5
        final_size = 4

        input_shape = (batch_size, input_size)
        hidden_lay_weight_shape = (input_size, hidden_size)
        final_lay_weight_shape = (hidden_size, final_size)

        inp_lb, inp_ub = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                                  input_shape,
                                                  minval=-1.,
                                                  maxval=1.)
        inp_bound = jax_verify.IntervalBound(inp_lb, inp_ub)

        hidden_lay_weight = jax.random.uniform(jax.random.PRNGKey(1),
                                               hidden_lay_weight_shape)
        final_lay_weight = jax.random.uniform(jax.random.PRNGKey(2),
                                              final_lay_weight_shape)

        def model_fun(inp):
            hidden = inp @ hidden_lay_weight
            act = jax.nn.relu(hidden)
            final = act @ final_lay_weight
            return final

        if isinstance(relaxer,
                      linear_bound_utils.ParameterizedLinearBoundsRelaxer):
            concretizing_transform = (
                backward_crown.OptimizingLinearBoundBackwardTransform(
                    relaxer,
                    backward_crown.CONCRETIZE_ARGS_PRIMITIVE,
                    optax.adam(1.e-3),
                    num_opt_steps=10))
        else:
            concretizing_transform = backward_crown.LinearBoundBackwardTransform(
                relaxer, backward_crown.CONCRETIZE_ARGS_PRIMITIVE)

        chunked_concretizer = backward_crown.ChunkedBackwardConcretizer(
            concretizing_transform, max_chunk_size=16)
        unchunked_concretizer = backward_crown.ChunkedBackwardConcretizer(
            concretizing_transform, max_chunk_size=0)

        chunked_algorithm = bound_utils.BackwardConcretizingAlgorithm(
            chunked_concretizer)
        unchunked_algorithm = bound_utils.BackwardConcretizingAlgorithm(
            unchunked_concretizer)

        chunked_bound, _ = bound_propagation.bound_propagation(
            chunked_algorithm, model_fun, inp_bound)
        unchunked_bound, _ = bound_propagation.bound_propagation(
            unchunked_algorithm, model_fun, inp_bound)

        np.testing.assert_array_almost_equal(chunked_bound.lower,
                                             unchunked_bound.lower)
        np.testing.assert_array_almost_equal(chunked_bound.upper,
                                             unchunked_bound.upper)
Exemple #17
0
    def test_model_structure_nostate(self, model):
        z = jnp.array([[1., 2., 3.]])

        params = model.init(jax.random.PRNGKey(1), z)
        input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0)
        fun_to_prop = functools.partial(model.apply, params)
        output_bounds = jax_verify.interval_bound_propagation(
            fun_to_prop, input_bounds)
        self.assertTrue(all(output_bounds.upper >= output_bounds.lower))
Exemple #18
0
 def test_sqrt(self, input_bounds, expected):
     input_bounds = jax_verify.IntervalBound(
         np.array([input_bounds[0], 0.0]), np.array([input_bounds[1], 0.0]))
     output_bounds = jax_verify.interval_bound_propagation(
         jnp.sqrt, input_bounds)
     np.testing.assert_array_equal(np.array([expected[0], 0.0]),
                                   output_bounds.lower)
     np.testing.assert_array_equal(np.array([expected[1], 0.0]),
                                   output_bounds.upper)
Exemple #19
0
    def test_passthrough_primitive(self, fn, inputs):
        z = jnp.array(inputs)
        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        output_bounds = jax_verify.interval_bound_propagation(fn, input_bounds)

        self.assertArrayAlmostEqual(fn(input_bounds.lower),
                                    output_bounds.lower)
        self.assertArrayAlmostEqual(fn(input_bounds.upper),
                                    output_bounds.upper)
    def test_multiinput_concatenate_fastlin(self, name, elision):
        def concatenate_and_sum_model(inp_1, inp_2):
            interm = jnp.concatenate((inp_1, inp_2), axis=1)
            return interm.sum(axis=1)

        z_1 = jnp.array([[-1., 1.]])
        z_2 = jnp.array([[-1., 1.]])

        bound_1 = jax_verify.IntervalBound(z_1 - 1., z_1 + 1.)
        bound_2 = jax_verify.IntervalBound(z_2 - 1., z_2 + 1.)

        out_lower = (z_1 + z_2 - 2.).sum()
        out_upper = (z_1 - z_2 + 2.).sum()

        bound_prop = get_boundprop(name, elision)
        output_bounds = bound_prop(concatenate_and_sum_model, bound_1, bound_2)

        self.assertArrayAlmostEqual(out_lower, output_bounds.lower)
        self.assertArrayAlmostEqual(out_upper, output_bounds.upper)
Exemple #21
0
    def primitive_transform(self, context: TransformContext,
                            primitive: jax.core.Primitive, *args,
                            **kwargs) -> bound_propagation.Bound:
        if (context.index not in self._fixed_bounds
                and isinstance(primitive, synthetic_primitives.FakePrimitive)):
            # Bound is missing at the synthetic primitive level.
            # Try and infer the bound from its sub-graph.
            subgraph = kwargs['jax_verify_subgraph']
            return context.subgraph_handler(self, subgraph, *args)

        return jax_verify.IntervalBound(*self._fixed_bounds[context.index])
Exemple #22
0
  def test_relu_crownibp(self):
    def relu_model(inp):
      return jax.nn.relu(inp)
    z = jnp.array([[-2., 3.]])

    input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
    output_bounds = jax_verify.crownibp_bound_propagation(
        relu_model, input_bounds)

    self.assertArrayAlmostEqual(jnp.array([[0., 2.]]), output_bounds.lower)
    self.assertArrayAlmostEqual(jnp.array([[0., 4.]]), output_bounds.upper)
Exemple #23
0
    def test_relu_cvxpy_relaxation(self):
        def relu_model(inp):
            return jax.nn.relu(inp)

        z = jnp.array([[-2., 3.]])

        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        lower_bounds, upper_bounds = self.get_bounds(relu_model, input_bounds)

        self.assertArrayAlmostEqual(jnp.array([[0., 2.]]), lower_bounds)
        self.assertArrayAlmostEqual(jnp.array([[0., 4.]]), upper_bounds)
Exemple #24
0
    def test_multioutput_model(self):
        z = jnp.array([[1., 2., 3.]])

        fun = hk.without_apply_rng(hk.transform(residual_model_all_act))
        params = fun.init(jax.random.PRNGKey(1), z)
        input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0)
        fun_to_prop = functools.partial(fun.apply, params)
        output_bounds = jax_verify.interval_bound_propagation(
            fun_to_prop, input_bounds)

        self.assertLen(output_bounds, 7)
Exemple #25
0
def main(unused_args):

    # Load the parameters of an existing model.
    model_pred, params = load_model(FLAGS.model)
    logits_fn = functools.partial(model_pred, params)

    # Load some test samples
    with jax_verify.open_file('mnist/x_test_first100.npy', 'rb') as f:
        inputs = np.load(f)

    # Compute boundprop bounds
    eps = 0.1
    lower_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] - eps, 0.0), 1.0)
    upper_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] + eps, 0.0), 1.0)
    init_bound = jax_verify.IntervalBound(lower_bound, upper_bound)

    if FLAGS.boundprop_method == 'fastlin':
        final_bound = jax_verify.fastlin_bound_propagation(
            logits_fn, init_bound)
        boundprop_transform = jax_verify.fastlin_transform
    elif FLAGS.boundprop_method == 'ibp':
        final_bound = jax_verify.interval_bound_propagation(
            logits_fn, init_bound)
        boundprop_transform = jax_verify.ibp_transform
    else:
        raise NotImplementedError('Only ibp/fastlin boundprop are'
                                  'currently supported')

    dummy_output = model_pred(params, inputs)

    # Run LP solver
    objective = jnp.where(
        jnp.arange(dummy_output[0, ...].size) == 0,
        jnp.ones_like(dummy_output[0, ...]),
        jnp.zeros_like(dummy_output[0, ...]))
    objective_bias = 0.
    value, status = jax_verify.solve_planet_relaxation(logits_fn,
                                                       init_bound,
                                                       boundprop_transform,
                                                       objective,
                                                       objective_bias,
                                                       index=0)
    logging.info('Relaxation LB is : %f, Status is %s', value, status)
    value, status = jax_verify.solve_planet_relaxation(logits_fn,
                                                       init_bound,
                                                       boundprop_transform,
                                                       -objective,
                                                       objective_bias,
                                                       index=0)
    logging.info('Relaxation UB is : %f, Status is %s', -value, status)

    logging.info('Boundprop LB is : %f', final_bound.lower[0, 0])
    logging.info('Boundprop UB is : %f', final_bound.upper[0, 0])
Exemple #26
0
    def test_nobatch_batch_inputs(self):

        batch_shape = (3, 2)
        unbatch_shape = (2, 4)

        def bilinear_model(inp_1, inp_2):
            return jnp.einsum('bh,hH->bH', inp_1, inp_2)

        lb_1, ub_1 = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                              batch_shape,
                                              minval=-10,
                                              maxval=10.)
        lb_2, ub_2 = test_utils.sample_bounds(jax.random.PRNGKey(1),
                                              unbatch_shape,
                                              minval=-10,
                                              maxval=10.)
        bound_1 = jax_verify.IntervalBound(lb_1, ub_1)
        bound_2 = jax_verify.IntervalBound(lb_2, ub_2)

        output_bounds = backward_crown.backward_crown_bound_propagation(
            bilinear_model, bound_1, bound_2)

        uniform_1 = test_utils.sample_bounded_points(jax.random.PRNGKey(2),
                                                     (lb_1, ub_1), 100)
        uniform_2 = test_utils.sample_bounded_points(jax.random.PRNGKey(3),
                                                     (lb_2, ub_2), 100)

        uniform_outs = jax.vmap(bilinear_model)(uniform_1, uniform_2)
        empirical_min = uniform_outs.min(axis=0)
        empirical_max = uniform_outs.max(axis=0)

        self.assertGreaterEqual(
            (output_bounds.upper - empirical_max).min(), 0.,
            'Invalid upper bound for mix of batched/unbatched'
            'input bounds.')
        self.assertGreaterEqual(
            (empirical_min - output_bounds.lower).min(), 0.,
            'Invalid lower bound for mix of batched/unbatched'
            'input bounds.')
Exemple #27
0
    def test_softplus_ibp(self):
        def softplus_model(inp):
            return jax.nn.softplus(inp)

        z = jnp.array([[-2., 3.]])

        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        output_bounds = jax_verify.interval_bound_propagation(
            softplus_model, input_bounds)

        self.assertArrayAlmostEqual(jnp.logaddexp(z - 1., 0),
                                    output_bounds.lower)
        self.assertArrayAlmostEqual(jnp.logaddexp(z + 1., 0),
                                    output_bounds.upper)
Exemple #28
0
    def test_keywords_argument(self):
        @hk.without_apply_rng
        @hk.transform
        def forward(inputs, use_2=False):
            model = StaticArgumentModel()
            return model(inputs, use_2)

        z = jnp.array([[1., 2., 3.]])
        params = forward.init(jax.random.PRNGKey(1), z, use_2=True)
        input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0)
        fun_to_prop = functools.partial(forward.apply, params, use_2=True)
        output_bounds = jax_verify.interval_bound_propagation(
            fun_to_prop, input_bounds)
        self.assertTrue((output_bounds.upper >= output_bounds.lower).all())
Exemple #29
0
    def test_multiply_crown(self):
        def multiply_model(lhs, rhs):
            return lhs * rhs

        mul_inp_shape = (4, 7)
        lhs_lb, lhs_ub = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                                  mul_inp_shape,
                                                  minval=-10.,
                                                  maxval=10.)
        rhs_lb, rhs_ub = test_utils.sample_bounds(jax.random.PRNGKey(1),
                                                  mul_inp_shape,
                                                  minval=-10.,
                                                  maxval=10.)

        lhs_bounds = jax_verify.IntervalBound(lhs_lb, lhs_ub)
        rhs_bounds = jax_verify.IntervalBound(rhs_lb, rhs_ub)
        output_bounds = jax_verify.backward_crown_bound_propagation(
            multiply_model, lhs_bounds, rhs_bounds)

        uniform_lhs_inps = test_utils.sample_bounded_points(
            jax.random.PRNGKey(2), (lhs_lb, lhs_ub), 100)
        uniform_rhs_inps = test_utils.sample_bounded_points(
            jax.random.PRNGKey(3), (rhs_lb, rhs_ub), 100)

        uniform_outs = jax.vmap(multiply_model)(uniform_lhs_inps,
                                                uniform_rhs_inps)
        empirical_min = uniform_outs.min(axis=0)
        empirical_max = uniform_outs.max(axis=0)

        self.assertGreaterEqual(
            (output_bounds.upper - empirical_max).min(), 0.,
            'Invalid upper bound for Multiply. The gap '
            'between upper bound and empirical max is negative')
        self.assertGreaterEqual(
            (empirical_min - output_bounds.lower).min(), 0.,
            'Invalid lower bound for Multiply. The gap'
            'between emp. min and lower bound is negative.')
Exemple #30
0
    def test_staticargument_last(self):
        @hk.without_apply_rng
        @hk.transform
        def forward(inputs, use_2):
            model = StaticArgumentModel()
            return model(inputs, use_2)

        z = jnp.array([[1., 2., 3.]])
        params = forward.init(jax.random.PRNGKey(1), z, True)
        input_bounds = jax_verify.IntervalBound(z - 1.0, z + 1.0)

        def fun_to_prop(inputs):
            return forward.apply(params, inputs, True)

        output_bounds = jax_verify.interval_bound_propagation(
            fun_to_prop, input_bounds)
        self.assertTrue((output_bounds.upper >= output_bounds.lower).all())