Exemple #1
0
 def test_cnn_mlp_match_sliding_window(self):
     num_steps = 1000
     for seed in range(1):
         verif_instance = test_utils.make_toy_verif_instance(seed,
                                                             nn='cnn_slide')
         key = jax.random.PRNGKey(0)
         lb, ub, params_mlp = test_utils.make_mlp_verif_instance_from_cnn(
             verif_instance)
         bounds_mlp = sdp_verify.boundprop(
             params_mlp,
             sdp_verify.IntBound(lb=lb, ub=ub, lb_pre=None, ub_pre=None))
         verif_instance_mlp = utils.make_nn_verif_instance(
             params_mlp, bounds_mlp)
         dual_ub_cnn, _ = sdp_verify.solve_sdp_dual(
             problem.make_sdp_verif_instance(verif_instance),
             key,
             num_steps=num_steps,
             verbose=False,
             use_exact_eig_train=True)
         dual_ub_mlp, _ = sdp_verify.solve_sdp_dual(
             problem.make_sdp_verif_instance(verif_instance_mlp),
             key,
             num_steps=num_steps,
             verbose=False,
             use_exact_eig_train=True)
         assert abs(dual_ub_cnn - dual_ub_mlp) < 5e-3, (
             'Dual upper bound for MLP and CNN (sliding filter) should match.'
             f'Seed is {seed}. Vals are CNN: {dual_ub_cnn} MLP: {dual_ub_mlp}'
         )
Exemple #2
0
 def test_sdp_dual_simple_no_crash(self, model_type):
     verif_instance = test_utils.make_toy_verif_instance(seed=0,
                                                         target_label=1,
                                                         label=2,
                                                         nn=model_type)
     kwargs = {
         'key': jax.random.PRNGKey(0),
         'opt': optax.adam(1e-3),
         'num_steps': 10,
         'eval_every': 5,
         'verbose': False,
         'use_exact_eig_eval': False,
         'use_exact_eig_train': False,
         'n_iter_lanczos': 5,
         'kappa_reg_weight': 1e-5,
         'kappa_zero_after': 8,
         'device_type': None,
     }
     verif_instance = problem.make_sdp_verif_instance(verif_instance)
     # Check all kwargs work.
     dual_val, _ = sdp_verify.solve_sdp_dual_simple(verif_instance,
                                                    **kwargs)
     assert isinstance(dual_val, float)
     # Check code runs without kwargs.
     dual_val, _ = sdp_verify.solve_sdp_dual_simple(verif_instance,
                                                    num_steps=5)
     assert isinstance(dual_val, float)
Exemple #3
0
    def test_sdp_problem_equivalent_to_sdp_verify(self):
        # Set up a verification problem for test purposes.
        verif_instance = test_utils.make_toy_verif_instance(label=2,
                                                            target_label=1)

        # Set up a spec function that replicates the test problem.
        inputs = jnp.zeros((1, 5))
        input_bounds = jax_verify.IntervalBound(jnp.zeros_like(inputs),
                                                jnp.ones_like(inputs))
        boundprop_transform = ibp.bound_transform

        def spec_fn(x):
            x = utils.predict_mlp(verif_instance.params, x)
            x = jax.nn.relu(x)
            return jnp.sum(jnp.reshape(x, (-1, )) *
                           verif_instance.obj) + verif_instance.const

        # Build an SDP verification instance using the code under test.
        sdp_relu_problem = problem_from_graph.SdpReluProblem(
            boundprop_transform, spec_fn, input_bounds)
        sdp_problem_vi = sdp_relu_problem.build_sdp_verification_instance()

        # Build an SDP verification instance using existing `sdp_verify` code.
        sdp_verify_vi = problem.make_sdp_verif_instance(verif_instance)

        self._assert_verif_instances_equal(sdp_problem_vi, sdp_verify_vi)
Exemple #4
0
 def test_crossing_bounds(self):
     loss_margin = 1e-3
     seed = random.randint(1, 10000)
     verif_instance = test_utils.make_toy_verif_instance(seed)
     key = jax.random.PRNGKey(0)
     primal_opt, _ = cvxpy_verify.solve_mip_mlp_elided(verif_instance)
     dual_ub, _ = sdp_verify.solve_sdp_dual(
         problem.make_sdp_verif_instance(verif_instance),
         key,
         num_steps=1000)
     assert dual_ub > primal_opt - loss_margin, (
         'Dual upper bound should be greater than optimal primal objective.'
         f'Seed is {seed}. Vals are Dual: {dual_ub} Primal: {primal_opt}')
Exemple #5
0
 def test_cnn_mlp_match_fixed_window(self):
     num_steps = 1000
     for seed in range(1):
         verif_instance = test_utils.make_toy_verif_instance(
             seed, nn='cnn_simple')
         key = jax.random.PRNGKey(0)
         params_cnn = verif_instance.params_full
         in_shape = int(np.prod(np.array(params_cnn[0]['W'].shape[:-1])))
         # Input and filter size match -> filter is applied at just one location.
         # Number of layer 1 neurons = no. channels out of conv filter (last dim).
         out_shape = params_cnn[0]['W'].shape[-1]
         params_mlp = [(jnp.reshape(params_cnn[0]['W'],
                                    (in_shape, out_shape)),
                        params_cnn[0]['b']),
                       (params_cnn[1][0], params_cnn[1][1])]
         bounds_mlp = sdp_verify.boundprop(
             params_mlp,
             sdp_verify.IntBound(lb=np.zeros((1, in_shape)),
                                 ub=1 * np.ones((1, in_shape)),
                                 lb_pre=None,
                                 ub_pre=None))
         verif_instance_mlp = utils.make_nn_verif_instance(
             params_mlp, bounds_mlp)
         dual_ub_cnn, _ = sdp_verify.solve_sdp_dual(
             problem.make_sdp_verif_instance(verif_instance),
             key,
             num_steps=num_steps,
             verbose=False,
             use_exact_eig_train=True)
         dual_ub_mlp, _ = sdp_verify.solve_sdp_dual(
             problem.make_sdp_verif_instance(verif_instance_mlp),
             key,
             num_steps=num_steps,
             verbose=False,
             use_exact_eig_train=True)
         assert abs(dual_ub_cnn - dual_ub_mlp) < 1e-2, (
             'Dual upper bound for MLP and CNN (simple CNN) should match.'
             f'Seed is {seed}. Vals are CNN: {dual_ub_cnn} MLP: {dual_ub_mlp}'
         )
Exemple #6
0
 def test_correct_dual_var_types(self):
     for nn in ['cnn', 'mlp']:
         verif_instance = test_utils.make_toy_verif_instance(seed=0,
                                                             target_label=1,
                                                             label=2,
                                                             nn=nn)
         key = jax.random.PRNGKey(0)
         dual_vars = sdp_verify.init_duals(
             problem.make_sdp_verif_instance(verif_instance), key)
         assert len(dual_vars) == 3, 'Input, one hidden layer, kappa'
         assert isinstance(dual_vars[0], problem.DualVar)
         assert isinstance(dual_vars[1], problem.DualVarFin)
         assert isinstance(dual_vars[2], jax.interpreters.xla.DeviceArray)
Exemple #7
0
 def test_dual_sdp_no_crash(self):
     for nn in ['cnn', 'mlp']:
         verif_instance = test_utils.make_toy_verif_instance(seed=0,
                                                             target_label=1,
                                                             label=2,
                                                             nn=nn)
         key = jax.random.PRNGKey(0)
         dual_val, _ = sdp_verify.solve_sdp_dual(
             problem.make_sdp_verif_instance(verif_instance),
             key,
             num_steps=10,
             n_iter_lanczos=5)
     assert isinstance(dual_val, float)
Exemple #8
0
 def _test_tight_duality_gap(self, seed, loss_margin=0.003, num_steps=3000):
     verif_instance = test_utils.make_toy_verif_instance(seed,
                                                         label=1,
                                                         target_label=2)
     key = jax.random.PRNGKey(0)
     primal_opt, _ = cvxpy_verify.solve_sdp_mlp_elided(verif_instance)
     dual_ub, _ = sdp_verify.solve_sdp_dual(
         problem.make_sdp_verif_instance(verif_instance),
         key,
         num_steps=num_steps,
         verbose=False)
     assert dual_ub - primal_opt < loss_margin, (
         'Primal and dual vals should be close. '
         f'Seed: {seed}. Primal: {primal_opt}, Dual: {dual_ub}')
     assert dual_ub > primal_opt - 1e-3, 'crossing bounds'
Exemple #9
0
 def test_ibp_init_matches_ibp_bound(self):
     for nn in ['cnn', 'mlp']:
         for seed in range(20):
             orig_verif_instance = test_utils.make_toy_verif_instance(seed,
                                                                      nn=nn)
             key = jax.random.PRNGKey(0)
             verif_instance = problem.make_sdp_verif_instance(
                 orig_verif_instance)
             dual_vars = jax.tree_map(
                 lambda s: None
                 if s is None else jnp.zeros(s), verif_instance.dual_shapes)
             dual_vars = sdp_verify.init_duals_ibp(verif_instance,
                                                   dual_vars)
             dual_loss = sdp_verify.dual_fun(verif_instance,
                                             dual_vars,
                                             key,
                                             exact=True)
             ibp_bound = utils.ibp_bound_elided(orig_verif_instance)
             assert abs(dual_loss - ibp_bound) < 1e-4, (
                 f'Loss at initialization should match IBP: {dual_loss} {ibp_bound}'
             )
Exemple #10
0
def verify_cnn_single_dual(verif_instance):
    """Run verification for a CNN on a single MNIST/CIFAR problem."""
    verif_instance = problem.make_sdp_verif_instance(verif_instance)
    solver_params = dict(
        use_exact_eig_train=FLAGS.use_exact_eig_train,
        use_exact_eig_eval=FLAGS.use_exact_eig_eval,
        n_iter_lanczos=FLAGS.n_iter_lanczos,
        eval_every=FLAGS.eval_every,
        opt_name=FLAGS.opt_name,
        anneal_factor=FLAGS.anneal_factor,
        lr_init=FLAGS.lr_init,
        kappa_zero_after=FLAGS.kappa_zero_after,
        kappa_reg_weight=FLAGS.kappa_reg_weight,
    )
    # Set schedule
    steps_per_anneal = [int(x) for x in FLAGS.anneal_lengths.split(',')]
    num_steps = sum(steps_per_anneal)
    solver_params['steps_per_anneal'] = steps_per_anneal[:-1] + [int(1e9)]

    # Set learning rate multipliers
    kappa_shape = verif_instance.dual_shapes[-1]
    kappa_index = len(verif_instance.dual_shapes) - 1
    assert len(kappa_shape) == 2 and kappa_shape[0] == 1
    opt_multiplier_fn = functools.partial(_opt_multiplier_fn,
                                          kappa_index=kappa_index,
                                          kappa_dim=kappa_shape[1])

    # Call solver
    obj_value, info = sdp_verify.solve_sdp_dual(
        verif_instance,
        num_steps=num_steps,
        verbose=True,
        opt_multiplier_fn=opt_multiplier_fn,
        **solver_params)
    info['final_dual_vars'] = jax.tree_map(np.array, info['final_dual_vars'])
    return float(obj_value), info