Example #1
0
 def test_cnn_mlp_match_sliding_window(self):
     num_steps = 1000
     for seed in range(1):
         verif_instance = test_utils.make_toy_verif_instance(seed,
                                                             nn='cnn_slide')
         key = jax.random.PRNGKey(0)
         lb, ub, params_mlp = test_utils.make_mlp_verif_instance_from_cnn(
             verif_instance)
         bounds_mlp = sdp_verify.boundprop(
             params_mlp,
             sdp_verify.IntBound(lb=lb, ub=ub, lb_pre=None, ub_pre=None))
         verif_instance_mlp = utils.make_nn_verif_instance(
             params_mlp, bounds_mlp)
         dual_ub_cnn, _ = sdp_verify.solve_sdp_dual(
             utils.make_sdp_verif_instance(verif_instance),
             key,
             num_steps=num_steps,
             verbose=False,
             use_exact_eig_train=True)
         dual_ub_mlp, _ = sdp_verify.solve_sdp_dual(
             utils.make_sdp_verif_instance(verif_instance_mlp),
             key,
             num_steps=num_steps,
             verbose=False,
             use_exact_eig_train=True)
         assert abs(dual_ub_cnn - dual_ub_mlp) < 5e-3, (
             'Dual upper bound for MLP and CNN (sliding filter) should match.'
             f'Seed is {seed}. Vals are CNN: {dual_ub_cnn} MLP: {dual_ub_mlp}'
         )
Example #2
0
 def test_cnn_mlp_match_fixed_window(self):
   num_steps = 1000
   for seed in range(1):
     verif_instance = test_utils.make_toy_verif_instance(seed, nn='cnn_simple')
     key = jax.random.PRNGKey(0)
     params_cnn = verif_instance.params_full
     in_shape = int(np.prod(np.array(params_cnn[0]['W'].shape[:-1])))
     # Input and filter size match -> filter is applied at just one location.
     # Number of layer 1 neurons = no. channels out of conv filter (last dim).
     out_shape = params_cnn[0]['W'].shape[-1]
     params_mlp = [(jnp.reshape(params_cnn[0]['W'],
                                (in_shape, out_shape)), params_cnn[0]['b']),
                   (params_cnn[1][0], params_cnn[1][1])]
     bounds_mlp = sdp_verify.boundprop(
         params_mlp,
         sdp_verify.IntBound(
             lb=np.zeros((1, in_shape)),
             ub=1 * np.ones((1, in_shape)),
             lb_pre=None,
             ub_pre=None))
     verif_instance_mlp = utils.make_nn_verif_instance(params_mlp, bounds_mlp)
     dual_ub_cnn, _ = sdp_verify.solve_sdp_dual(
         utils.make_sdp_verif_instance(verif_instance), key,
         num_steps=num_steps, verbose=False, use_exact_eig_train=True)
     dual_ub_mlp, _ = sdp_verify.solve_sdp_dual(
         utils.make_sdp_verif_instance(verif_instance_mlp), key,
         num_steps=num_steps, verbose=False, use_exact_eig_train=True)
     assert abs(dual_ub_cnn - dual_ub_mlp) < 1e-2, (
         'Dual upper bound for MLP and CNN (simple CNN) should match.'
         f'Seed is {seed}. Vals are CNN: {dual_ub_cnn} MLP: {dual_ub_mlp}')
Example #3
0
 def test_dual_sdp_no_crash(self):
   for nn in ['cnn', 'mlp']:
     verif_instance = test_utils.make_toy_verif_instance(
         seed=0, target_label=1, label=2, nn=nn)
     key = jax.random.PRNGKey(0)
     dual_val, _ = sdp_verify.solve_sdp_dual(
         utils.make_sdp_verif_instance(verif_instance), key, num_steps=10,
         n_iter_lanczos=5)
   assert isinstance(dual_val, float)
Example #4
0
 def test_crossing_bounds(self):
     loss_margin = 1e-3
     seed = random.randint(1, 10000)
     verif_instance = test_utils.make_toy_verif_instance(seed)
     key = jax.random.PRNGKey(0)
     primal_opt, _ = cvxpy_verify.solve_mip_mlp_elided(verif_instance)
     dual_ub, _ = sdp_verify.solve_sdp_dual(
         utils.make_sdp_verif_instance(verif_instance), key, num_steps=1000)
     assert dual_ub > primal_opt - loss_margin, (
         'Dual upper bound should be greater than optimal primal objective.'
         f'Seed is {seed}. Vals are Dual: {dual_ub} Primal: {primal_opt}')
Example #5
0
 def _test_tight_duality_gap(self, seed, loss_margin=0.003, num_steps=3000):
   verif_instance = test_utils.make_toy_verif_instance(
       seed, label=1, target_label=2)
   key = jax.random.PRNGKey(0)
   primal_opt, _ = cvxpy_verify.solve_sdp_mlp_elided(verif_instance)
   dual_ub, _ = sdp_verify.solve_sdp_dual(
       utils.make_sdp_verif_instance(verif_instance), key, num_steps=num_steps,
       verbose=False)
   assert dual_ub - primal_opt < loss_margin, (
       'Primal and dual vals should be close. '
       f'Seed: {seed}. Primal: {primal_opt}, Dual: {dual_ub}')
   assert dual_ub > primal_opt - 1e-3, 'crossing bounds'