def test_cnn_mlp_match_sliding_window(self): num_steps = 1000 for seed in range(1): verif_instance = test_utils.make_toy_verif_instance(seed, nn='cnn_slide') key = jax.random.PRNGKey(0) lb, ub, params_mlp = test_utils.make_mlp_verif_instance_from_cnn( verif_instance) bounds_mlp = sdp_verify.boundprop( params_mlp, sdp_verify.IntBound(lb=lb, ub=ub, lb_pre=None, ub_pre=None)) verif_instance_mlp = utils.make_nn_verif_instance( params_mlp, bounds_mlp) dual_ub_cnn, _ = sdp_verify.solve_sdp_dual( utils.make_sdp_verif_instance(verif_instance), key, num_steps=num_steps, verbose=False, use_exact_eig_train=True) dual_ub_mlp, _ = sdp_verify.solve_sdp_dual( utils.make_sdp_verif_instance(verif_instance_mlp), key, num_steps=num_steps, verbose=False, use_exact_eig_train=True) assert abs(dual_ub_cnn - dual_ub_mlp) < 5e-3, ( 'Dual upper bound for MLP and CNN (sliding filter) should match.' f'Seed is {seed}. Vals are CNN: {dual_ub_cnn} MLP: {dual_ub_mlp}' )
def test_cnn_mlp_match_fixed_window(self): num_steps = 1000 for seed in range(1): verif_instance = test_utils.make_toy_verif_instance(seed, nn='cnn_simple') key = jax.random.PRNGKey(0) params_cnn = verif_instance.params_full in_shape = int(np.prod(np.array(params_cnn[0]['W'].shape[:-1]))) # Input and filter size match -> filter is applied at just one location. # Number of layer 1 neurons = no. channels out of conv filter (last dim). out_shape = params_cnn[0]['W'].shape[-1] params_mlp = [(jnp.reshape(params_cnn[0]['W'], (in_shape, out_shape)), params_cnn[0]['b']), (params_cnn[1][0], params_cnn[1][1])] bounds_mlp = sdp_verify.boundprop( params_mlp, sdp_verify.IntBound( lb=np.zeros((1, in_shape)), ub=1 * np.ones((1, in_shape)), lb_pre=None, ub_pre=None)) verif_instance_mlp = utils.make_nn_verif_instance(params_mlp, bounds_mlp) dual_ub_cnn, _ = sdp_verify.solve_sdp_dual( utils.make_sdp_verif_instance(verif_instance), key, num_steps=num_steps, verbose=False, use_exact_eig_train=True) dual_ub_mlp, _ = sdp_verify.solve_sdp_dual( utils.make_sdp_verif_instance(verif_instance_mlp), key, num_steps=num_steps, verbose=False, use_exact_eig_train=True) assert abs(dual_ub_cnn - dual_ub_mlp) < 1e-2, ( 'Dual upper bound for MLP and CNN (simple CNN) should match.' f'Seed is {seed}. Vals are CNN: {dual_ub_cnn} MLP: {dual_ub_mlp}')
def test_sdp_dual_simple_no_crash(self, model_type): verif_instance = test_utils.make_toy_verif_instance(seed=0, target_label=1, label=2, nn=model_type) kwargs = { 'key': jax.random.PRNGKey(0), 'opt': optax.adam(1e-3), 'num_steps': 10, 'eval_every': 5, 'verbose': False, 'use_exact_eig_eval': False, 'use_exact_eig_train': False, 'n_iter_lanczos': 5, 'kappa_reg_weight': 1e-5, 'kappa_zero_after': 8, 'device_type': None, } verif_instance = utils.make_sdp_verif_instance(verif_instance) # Check all kwargs work. dual_val, _ = sdp_verify.solve_sdp_dual_simple(verif_instance, **kwargs) assert isinstance(dual_val, float) # Check code runs without kwargs. dual_val, _ = sdp_verify.solve_sdp_dual_simple(verif_instance, num_steps=5) assert isinstance(dual_val, float)
def test_dual_sdp_no_crash(self): for nn in ['cnn', 'mlp']: verif_instance = test_utils.make_toy_verif_instance( seed=0, target_label=1, label=2, nn=nn) key = jax.random.PRNGKey(0) dual_val, _ = sdp_verify.solve_sdp_dual( utils.make_sdp_verif_instance(verif_instance), key, num_steps=10, n_iter_lanczos=5) assert isinstance(dual_val, float)
def test_crossing_bounds(self): loss_margin = 1e-3 seed = random.randint(1, 10000) verif_instance = test_utils.make_toy_verif_instance(seed) key = jax.random.PRNGKey(0) primal_opt, _ = cvxpy_verify.solve_mip_mlp_elided(verif_instance) dual_ub, _ = sdp_verify.solve_sdp_dual( utils.make_sdp_verif_instance(verif_instance), key, num_steps=1000) assert dual_ub > primal_opt - loss_margin, ( 'Dual upper bound should be greater than optimal primal objective.' f'Seed is {seed}. Vals are Dual: {dual_ub} Primal: {primal_opt}')
def test_correct_dual_var_types(self): for nn in ['cnn', 'mlp']: verif_instance = test_utils.make_toy_verif_instance( seed=0, target_label=1, label=2, nn=nn) key = jax.random.PRNGKey(0) dual_vars = sdp_verify.init_duals( utils.make_sdp_verif_instance(verif_instance), key) assert len(dual_vars) == 3, 'Input, one hidden layer, kappa' assert isinstance(dual_vars[0], sdp_verify.DualVar) assert isinstance(dual_vars[1], sdp_verify.DualVarFin) assert isinstance(dual_vars[2], jax.interpreters.xla.DeviceArray)
def _test_tight_duality_gap(self, seed, loss_margin=0.003, num_steps=3000): verif_instance = test_utils.make_toy_verif_instance( seed, label=1, target_label=2) key = jax.random.PRNGKey(0) primal_opt, _ = cvxpy_verify.solve_sdp_mlp_elided(verif_instance) dual_ub, _ = sdp_verify.solve_sdp_dual( utils.make_sdp_verif_instance(verif_instance), key, num_steps=num_steps, verbose=False) assert dual_ub - primal_opt < loss_margin, ( 'Primal and dual vals should be close. ' f'Seed: {seed}. Primal: {primal_opt}, Dual: {dual_ub}') assert dual_ub > primal_opt - 1e-3, 'crossing bounds'
def test_ibp_init_matches_ibp_bound(self): for nn in ['cnn', 'mlp']: for seed in range(20): orig_verif_instance = test_utils.make_toy_verif_instance(seed, nn=nn) key = jax.random.PRNGKey(0) verif_instance = utils.make_sdp_verif_instance(orig_verif_instance) dual_vars = jax.tree_map(lambda s: None if s is None else jnp.zeros(s), verif_instance.dual_shapes) dual_vars = sdp_verify.init_duals_ibp(verif_instance, dual_vars) dual_loss = sdp_verify.dual_fun( verif_instance, dual_vars, key, exact=True) ibp_bound = utils.ibp_bound_elided(orig_verif_instance) assert abs(dual_loss - ibp_bound) < 1e-4, ( f'Loss at initialization should match IBP: {dual_loss} {ibp_bound}')
def verify_cnn_single_dual(verif_instance): """Run verification for a CNN on a single MNIST/CIFAR problem.""" verif_instance = utils.make_sdp_verif_instance(verif_instance) solver_params = dict( use_exact_eig_train=FLAGS.use_exact_eig_train, use_exact_eig_eval=FLAGS.use_exact_eig_eval, n_iter_lanczos=FLAGS.n_iter_lanczos, eval_every=FLAGS.eval_every, opt_name=FLAGS.opt_name, anneal_factor=FLAGS.anneal_factor, lr_init=FLAGS.lr_init, kappa_zero_after=FLAGS.kappa_zero_after, kappa_reg_weight=FLAGS.kappa_reg_weight, ) # Set schedule steps_per_anneal = [int(x) for x in FLAGS.anneal_lengths.split(',')] num_steps = sum(steps_per_anneal) solver_params['steps_per_anneal'] = steps_per_anneal[:-1] + [int(1e9)] # Set learning rate multipliers kappa_shape = verif_instance.dual_shapes[-1] kappa_index = len(verif_instance.dual_shapes) - 1 assert len(kappa_shape) == 2 and kappa_shape[0] == 1 opt_multiplier_fn = functools.partial(_opt_multiplier_fn, kappa_index=kappa_index, kappa_dim=kappa_shape[1]) # Call solver obj_value, info = sdp_verify.solve_sdp_dual( verif_instance, num_steps=num_steps, verbose=True, opt_multiplier_fn=opt_multiplier_fn, **solver_params) info['final_dual_vars'] = jax.tree_map(np.array, info['final_dual_vars']) return float(obj_value), info