def _test_boundprop(self, boundprop_method, num_idxs_to_test=10): """Test `boundprop_method` on Wong-Small MNIST CNN.""" with jax_verify.open_file('mnist/x_test_first100.npy', 'rb') as f: xs = np.load(f) model_name = 'models/mnist_wongsmall_eps_10_adv.pkl' with jax_verify.open_file(model_name, 'rb') as f: params = pickle.load(f) x = xs[0] eps = 0.1 bounds = boundprop_method(params, np.expand_dims(x, axis=0), eps, input_bounds=(0., 1.)) crown_lbs = utils.flatten([b.lb_pre for b in bounds[1:]]) crown_ubs = utils.flatten([b.ub_pre for b in bounds[1:]]) max_idx = crown_lbs.shape[0] np.random.seed(0) test_idxs = np.random.randint(max_idx, size=num_idxs_to_test) @jax.jit def fwd(x): _, acts = utils.predict_cnn(params, jnp.expand_dims(x, 0), include_preactivations=True) return acts get_act = lambda x, idx: utils.flatten(fwd(x), backend=jnp)[idx] print('Number of activations:', crown_lbs.shape[0]) print('Bound shape', [b.lb.shape for b in bounds]) print('Activation shape', [a.shape for a in fwd(x)]) assert utils.flatten(fwd(x)).shape == crown_lbs.shape, ( f'bad shape {crown_lbs.shape}, {utils.flatten(fwd(x)).shape}') for idx in test_idxs: nom = get_act(x, idx) crown_lb = crown_lbs[idx] crown_ub = crown_ubs[idx] adv_loss = lambda x: get_act(x, idx) # pylint: disable=cell-var-from-loop x_lb = utils.pgd(adv_loss, x, eps, 50, 0.01) fgsm_lb = get_act(x_lb, idx) adv_loss = lambda x: -get_act(x, idx) # pylint: disable=cell-var-from-loop x_ub = utils.pgd(adv_loss, x, eps, 50, 0.01) fgsm_ub = get_act(x_ub, idx) print(f'Idx {idx}: Boundprop LB {crown_lb}, FGSM LB {fgsm_lb}, ' f'Nominal {nom}, FGSM UB {fgsm_ub}, Boundprop UB {crown_ub}') margin = 1e-5 assert crown_lb <= fgsm_lb + margin, f'Bad lower bound. Idx {idx}.' assert crown_ub >= fgsm_ub - margin, f'Bad upper bound. Idx {idx}.' crown_lb_post, fgsm_lb_post = max(crown_lb, 0), max(fgsm_lb, 0) crown_ub_post, fgsm_ub_post = max(crown_ub, 0), max(fgsm_ub, 0) assert crown_lb_post <= fgsm_lb_post + margin, f'Idx {idx}.' assert crown_ub_post >= fgsm_ub_post - margin, f'Idx {idx}.'
def _load_dataset(dataset): """Loads the 10000 MNIST (CIFAR) test set examples, saved as numpy arrays.""" assert dataset in ('mnist', 'cifar10'), 'invalid dataset name' with jax_verify.open_file(os.path.join(dataset, 'x_test.npy'), 'rb') as f: xs = np.load(f) with jax_verify.open_file(os.path.join(dataset, 'y_test.npy'), 'rb') as f: ys = np.load(f) return xs, ys
def test_mnist_mlp(self, model_name, num_examples, expected_correct): with jax_verify.open_file('mnist/x_test_first100.npy', 'rb') as f: mnist_x = np.load(f) with jax_verify.open_file('mnist/y_test.npy', 'rb') as f: mnist_y = np.load(f) with jax_verify.open_file(model_name, 'rb') as f: params = pickle.load(f) # pytype: disable=wrong-arg-types # due to GFile logits = np.array(_predict_mlp(params, mnist_x[:num_examples])) pred_labels = np.argmax(logits, axis=1) num_correct = np.sum(np.equal(mnist_y[:num_examples], pred_labels)) print(num_correct) assert num_correct == expected_correct, f'Number correct: {num_correct}'
def load_model(model_name): """Load model parameters and prediction function.""" # Choose appropriate prediction function if model_name in ('mlp', 'toy'): model_path = MLP_PATH def model_fn(params, inputs): inputs = np.reshape(inputs, (inputs.shape[0], -1)) return utils.predict_mlp(params, inputs) elif model_name == 'cnn': model_path = CNN_PATH model_fn = utils.predict_cnn else: raise ValueError('') # Get parameters if model_name == 'toy': params = [ (np.random.normal(size=(784, 2)), np.random.normal(size=(2, ))), (np.random.normal(size=(2, 10)), np.random.normal(size=(10, ))), ] else: with jax_verify.open_file(model_path, 'rb') as f: params = pickle.load(f) return model_fn, params
def main(unused_args): # Load the parameters of an existing model. model_pred, params = load_model(FLAGS.model) logits_fn = functools.partial(model_pred, params) # Load some test samples with jax_verify.open_file('mnist/x_test_first100.npy', 'rb') as f: inputs = np.load(f) # Compute boundprop bounds eps = 0.1 lower_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] - eps, 0.0), 1.0) upper_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] + eps, 0.0), 1.0) init_bound = jax_verify.IntervalBound(lower_bound, upper_bound) if FLAGS.boundprop_method == 'fastlin': final_bound = jax_verify.fastlin_bound_propagation( logits_fn, init_bound) boundprop_transform = jax_verify.fastlin_transform elif FLAGS.boundprop_method == 'ibp': final_bound = jax_verify.interval_bound_propagation( logits_fn, init_bound) boundprop_transform = jax_verify.ibp_transform else: raise NotImplementedError('Only ibp/fastlin boundprop are' 'currently supported') dummy_output = model_pred(params, inputs) # Run LP solver objective = jnp.where( jnp.arange(dummy_output[0, ...].size) == 0, jnp.ones_like(dummy_output[0, ...]), jnp.zeros_like(dummy_output[0, ...])) objective_bias = 0. value, status = jax_verify.solve_planet_relaxation(logits_fn, init_bound, boundprop_transform, objective, objective_bias, index=0) logging.info('Relaxation LB is : %f, Status is %s', value, status) value, status = jax_verify.solve_planet_relaxation(logits_fn, init_bound, boundprop_transform, -objective, objective_bias, index=0) logging.info('Relaxation UB is : %f, Status is %s', -value, status) logging.info('Boundprop LB is : %f', final_bound.lower[0, 0]) logging.info('Boundprop UB is : %f', final_bound.upper[0, 0])
def load_model(model_name): """Load model parameters and prediction function.""" # Choose appropriate prediction function if model_name == 'mlp': model_path = MLP_PATH def model_fn(params, inputs): inputs = np.reshape(inputs, (inputs.shape[0], -1)) return utils.predict_mlp(params, inputs) elif model_name == 'cnn': model_path = CNN_PATH model_fn = utils.predict_cnn else: raise ValueError('') # Load parameters from file with jax_verify.open_file(model_path, 'rb') as f: params = pickle.load(f) return model_fn, params
def main(unused_args): # Load some test samples with jax_verify.open_file('mnist/x_test_first100.npy', 'rb') as f: inputs = np.load(f) # Load the parameters of an existing model. model_pred, params = load_model(FLAGS.model) # Evaluation of the model on unperturbed images. clean_preds = model_pred(params, inputs) # Define initial bound eps = 0.1 initial_bound = jax_verify.IntervalBound( jnp.minimum(jnp.maximum(inputs - eps, 0.0), 1.0), jnp.minimum(jnp.maximum(inputs + eps, 0.0), 1.0)) # Because our function `model_pred` takes as inputs both the parameters # `params` and the `inputs`, we need to wrap it such that it only takes # `inputs` as parameters. logits_fn = functools.partial(model_pred, params) # Apply bound propagation. All boundprop methods take as an input the model # `function`, and the inital bounds, and return final bounds with the same # structure as the output of `function`. Internally, these methods work by # replacing each operation with its boundprop equivalent - see # bound_propagation.py for details. boundprop_method = (jax_verify.interval_bound_propagation if not FLAGS.boundprop_method else getattr( jax_verify, FLAGS.boundprop_method)) assert boundprop_method in ALL_BOUNDPROP_METHODS, 'unsupported method' final_bound = boundprop_method(logits_fn, initial_bound) logging.info('Lower bound: %s', final_bound.lower) logging.info('Upper bound: %s', final_bound.upper) logging.info('Clean predictions: %s', clean_preds) assert jnp.all(final_bound.lower <= clean_preds), 'Invalid lower bounds' assert jnp.all(final_bound.upper >= clean_preds), 'Invalid upper bounds'
def _load_weights(path): with jax_verify.open_file(path, 'rb') as f: data = pickle.load(f) return data