Ejemplo n.º 1
0
 def test_cnn_mlp_match_sliding_window(self):
     num_steps = 1000
     for seed in range(1):
         verif_instance = test_utils.make_toy_verif_instance(seed,
                                                             nn='cnn_slide')
         key = jax.random.PRNGKey(0)
         lb, ub, params_mlp = test_utils.make_mlp_verif_instance_from_cnn(
             verif_instance)
         bounds_mlp = sdp_verify.boundprop(
             params_mlp,
             sdp_verify.IntBound(lb=lb, ub=ub, lb_pre=None, ub_pre=None))
         verif_instance_mlp = utils.make_nn_verif_instance(
             params_mlp, bounds_mlp)
         dual_ub_cnn, _ = sdp_verify.solve_sdp_dual(
             utils.make_sdp_verif_instance(verif_instance),
             key,
             num_steps=num_steps,
             verbose=False,
             use_exact_eig_train=True)
         dual_ub_mlp, _ = sdp_verify.solve_sdp_dual(
             utils.make_sdp_verif_instance(verif_instance_mlp),
             key,
             num_steps=num_steps,
             verbose=False,
             use_exact_eig_train=True)
         assert abs(dual_ub_cnn - dual_ub_mlp) < 5e-3, (
             'Dual upper bound for MLP and CNN (sliding filter) should match.'
             f'Seed is {seed}. Vals are CNN: {dual_ub_cnn} MLP: {dual_ub_mlp}'
         )
Ejemplo n.º 2
0
 def test_cnn_mlp_match_fixed_window(self):
   num_steps = 1000
   for seed in range(1):
     verif_instance = test_utils.make_toy_verif_instance(seed, nn='cnn_simple')
     key = jax.random.PRNGKey(0)
     params_cnn = verif_instance.params_full
     in_shape = int(np.prod(np.array(params_cnn[0]['W'].shape[:-1])))
     # Input and filter size match -> filter is applied at just one location.
     # Number of layer 1 neurons = no. channels out of conv filter (last dim).
     out_shape = params_cnn[0]['W'].shape[-1]
     params_mlp = [(jnp.reshape(params_cnn[0]['W'],
                                (in_shape, out_shape)), params_cnn[0]['b']),
                   (params_cnn[1][0], params_cnn[1][1])]
     bounds_mlp = sdp_verify.boundprop(
         params_mlp,
         sdp_verify.IntBound(
             lb=np.zeros((1, in_shape)),
             ub=1 * np.ones((1, in_shape)),
             lb_pre=None,
             ub_pre=None))
     verif_instance_mlp = utils.make_nn_verif_instance(params_mlp, bounds_mlp)
     dual_ub_cnn, _ = sdp_verify.solve_sdp_dual(
         utils.make_sdp_verif_instance(verif_instance), key,
         num_steps=num_steps, verbose=False, use_exact_eig_train=True)
     dual_ub_mlp, _ = sdp_verify.solve_sdp_dual(
         utils.make_sdp_verif_instance(verif_instance_mlp), key,
         num_steps=num_steps, verbose=False, use_exact_eig_train=True)
     assert abs(dual_ub_cnn - dual_ub_mlp) < 1e-2, (
         'Dual upper bound for MLP and CNN (simple CNN) should match.'
         f'Seed is {seed}. Vals are CNN: {dual_ub_cnn} MLP: {dual_ub_mlp}')
Ejemplo n.º 3
0
 def test_sdp_dual_simple_no_crash(self, model_type):
     verif_instance = test_utils.make_toy_verif_instance(seed=0,
                                                         target_label=1,
                                                         label=2,
                                                         nn=model_type)
     kwargs = {
         'key': jax.random.PRNGKey(0),
         'opt': optax.adam(1e-3),
         'num_steps': 10,
         'eval_every': 5,
         'verbose': False,
         'use_exact_eig_eval': False,
         'use_exact_eig_train': False,
         'n_iter_lanczos': 5,
         'kappa_reg_weight': 1e-5,
         'kappa_zero_after': 8,
         'device_type': None,
     }
     verif_instance = utils.make_sdp_verif_instance(verif_instance)
     # Check all kwargs work.
     dual_val, _ = sdp_verify.solve_sdp_dual_simple(verif_instance,
                                                    **kwargs)
     assert isinstance(dual_val, float)
     # Check code runs without kwargs.
     dual_val, _ = sdp_verify.solve_sdp_dual_simple(verif_instance,
                                                    num_steps=5)
     assert isinstance(dual_val, float)
Ejemplo n.º 4
0
 def test_dual_sdp_no_crash(self):
   for nn in ['cnn', 'mlp']:
     verif_instance = test_utils.make_toy_verif_instance(
         seed=0, target_label=1, label=2, nn=nn)
     key = jax.random.PRNGKey(0)
     dual_val, _ = sdp_verify.solve_sdp_dual(
         utils.make_sdp_verif_instance(verif_instance), key, num_steps=10,
         n_iter_lanczos=5)
   assert isinstance(dual_val, float)
Ejemplo n.º 5
0
 def test_crossing_bounds(self):
     loss_margin = 1e-3
     seed = random.randint(1, 10000)
     verif_instance = test_utils.make_toy_verif_instance(seed)
     key = jax.random.PRNGKey(0)
     primal_opt, _ = cvxpy_verify.solve_mip_mlp_elided(verif_instance)
     dual_ub, _ = sdp_verify.solve_sdp_dual(
         utils.make_sdp_verif_instance(verif_instance), key, num_steps=1000)
     assert dual_ub > primal_opt - loss_margin, (
         'Dual upper bound should be greater than optimal primal objective.'
         f'Seed is {seed}. Vals are Dual: {dual_ub} Primal: {primal_opt}')
Ejemplo n.º 6
0
 def test_correct_dual_var_types(self):
   for nn in ['cnn', 'mlp']:
     verif_instance = test_utils.make_toy_verif_instance(
         seed=0, target_label=1, label=2, nn=nn)
     key = jax.random.PRNGKey(0)
     dual_vars = sdp_verify.init_duals(
         utils.make_sdp_verif_instance(verif_instance), key)
     assert len(dual_vars) == 3, 'Input, one hidden layer, kappa'
     assert isinstance(dual_vars[0], sdp_verify.DualVar)
     assert isinstance(dual_vars[1], sdp_verify.DualVarFin)
     assert isinstance(dual_vars[2], jax.interpreters.xla.DeviceArray)
Ejemplo n.º 7
0
 def _test_tight_duality_gap(self, seed, loss_margin=0.003, num_steps=3000):
   verif_instance = test_utils.make_toy_verif_instance(
       seed, label=1, target_label=2)
   key = jax.random.PRNGKey(0)
   primal_opt, _ = cvxpy_verify.solve_sdp_mlp_elided(verif_instance)
   dual_ub, _ = sdp_verify.solve_sdp_dual(
       utils.make_sdp_verif_instance(verif_instance), key, num_steps=num_steps,
       verbose=False)
   assert dual_ub - primal_opt < loss_margin, (
       'Primal and dual vals should be close. '
       f'Seed: {seed}. Primal: {primal_opt}, Dual: {dual_ub}')
   assert dual_ub > primal_opt - 1e-3, 'crossing bounds'
Ejemplo n.º 8
0
 def test_ibp_init_matches_ibp_bound(self):
   for nn in ['cnn', 'mlp']:
     for seed in range(20):
       orig_verif_instance = test_utils.make_toy_verif_instance(seed, nn=nn)
       key = jax.random.PRNGKey(0)
       verif_instance = utils.make_sdp_verif_instance(orig_verif_instance)
       dual_vars = jax.tree_map(lambda s: None if s is None else jnp.zeros(s),
                                verif_instance.dual_shapes)
       dual_vars = sdp_verify.init_duals_ibp(verif_instance, dual_vars)
       dual_loss = sdp_verify.dual_fun(
           verif_instance, dual_vars, key, exact=True)
       ibp_bound = utils.ibp_bound_elided(orig_verif_instance)
       assert abs(dual_loss - ibp_bound) < 1e-4, (
           f'Loss at initialization should match IBP: {dual_loss} {ibp_bound}')
Ejemplo n.º 9
0
def verify_cnn_single_dual(verif_instance):
    """Run verification for a CNN on a single MNIST/CIFAR problem."""
    verif_instance = utils.make_sdp_verif_instance(verif_instance)
    solver_params = dict(
        use_exact_eig_train=FLAGS.use_exact_eig_train,
        use_exact_eig_eval=FLAGS.use_exact_eig_eval,
        n_iter_lanczos=FLAGS.n_iter_lanczos,
        eval_every=FLAGS.eval_every,
        opt_name=FLAGS.opt_name,
        anneal_factor=FLAGS.anneal_factor,
        lr_init=FLAGS.lr_init,
        kappa_zero_after=FLAGS.kappa_zero_after,
        kappa_reg_weight=FLAGS.kappa_reg_weight,
    )
    # Set schedule
    steps_per_anneal = [int(x) for x in FLAGS.anneal_lengths.split(',')]
    num_steps = sum(steps_per_anneal)
    solver_params['steps_per_anneal'] = steps_per_anneal[:-1] + [int(1e9)]

    # Set learning rate multipliers
    kappa_shape = verif_instance.dual_shapes[-1]
    kappa_index = len(verif_instance.dual_shapes) - 1
    assert len(kappa_shape) == 2 and kappa_shape[0] == 1
    opt_multiplier_fn = functools.partial(_opt_multiplier_fn,
                                          kappa_index=kappa_index,
                                          kappa_dim=kappa_shape[1])

    # Call solver
    obj_value, info = sdp_verify.solve_sdp_dual(
        verif_instance,
        num_steps=num_steps,
        verbose=True,
        opt_multiplier_fn=opt_multiplier_fn,
        **solver_params)
    info['final_dual_vars'] = jax.tree_map(np.array, info['final_dual_vars'])
    return float(obj_value), info