def _test_boundprop(self, boundprop_method, num_idxs_to_test=10):
        """Test `boundprop_method` on Wong-Small MNIST CNN."""
        with jax_verify.open_file('mnist/x_test_first100.npy', 'rb') as f:
            xs = np.load(f)
        model_name = 'models/mnist_wongsmall_eps_10_adv.pkl'
        with jax_verify.open_file(model_name, 'rb') as f:
            params = pickle.load(f)
        x = xs[0]
        eps = 0.1

        bounds = boundprop_method(params,
                                  np.expand_dims(x, axis=0),
                                  eps,
                                  input_bounds=(0., 1.))
        crown_lbs = utils.flatten([b.lb_pre for b in bounds[1:]])
        crown_ubs = utils.flatten([b.ub_pre for b in bounds[1:]])

        max_idx = crown_lbs.shape[0]
        np.random.seed(0)
        test_idxs = np.random.randint(max_idx, size=num_idxs_to_test)

        @jax.jit
        def fwd(x):
            _, acts = utils.predict_cnn(params,
                                        jnp.expand_dims(x, 0),
                                        include_preactivations=True)
            return acts

        get_act = lambda x, idx: utils.flatten(fwd(x), backend=jnp)[idx]

        print('Number of activations:', crown_lbs.shape[0])
        print('Bound shape', [b.lb.shape for b in bounds])
        print('Activation shape', [a.shape for a in fwd(x)])
        assert utils.flatten(fwd(x)).shape == crown_lbs.shape, (
            f'bad shape {crown_lbs.shape}, {utils.flatten(fwd(x)).shape}')

        for idx in test_idxs:
            nom = get_act(x, idx)
            crown_lb = crown_lbs[idx]
            crown_ub = crown_ubs[idx]

            adv_loss = lambda x: get_act(x, idx)  # pylint: disable=cell-var-from-loop
            x_lb = utils.pgd(adv_loss, x, eps, 50, 0.01)
            fgsm_lb = get_act(x_lb, idx)

            adv_loss = lambda x: -get_act(x, idx)  # pylint: disable=cell-var-from-loop
            x_ub = utils.pgd(adv_loss, x, eps, 50, 0.01)
            fgsm_ub = get_act(x_ub, idx)

            print(f'Idx {idx}: Boundprop LB {crown_lb}, FGSM LB {fgsm_lb}, '
                  f'Nominal {nom}, FGSM UB {fgsm_ub}, Boundprop UB {crown_ub}')
            margin = 1e-5
            assert crown_lb <= fgsm_lb + margin, f'Bad lower bound. Idx {idx}.'
            assert crown_ub >= fgsm_ub - margin, f'Bad upper bound. Idx {idx}.'

            crown_lb_post, fgsm_lb_post = max(crown_lb, 0), max(fgsm_lb, 0)
            crown_ub_post, fgsm_ub_post = max(crown_ub, 0), max(fgsm_ub, 0)
            assert crown_lb_post <= fgsm_lb_post + margin, f'Idx {idx}.'
            assert crown_ub_post >= fgsm_ub_post - margin, f'Idx {idx}.'
示例#2
0
def _load_dataset(dataset):
    """Loads the 10000 MNIST (CIFAR) test set examples, saved as numpy arrays."""
    assert dataset in ('mnist', 'cifar10'), 'invalid dataset name'
    with jax_verify.open_file(os.path.join(dataset, 'x_test.npy'), 'rb') as f:
        xs = np.load(f)
    with jax_verify.open_file(os.path.join(dataset, 'y_test.npy'), 'rb') as f:
        ys = np.load(f)
    return xs, ys
示例#3
0
 def test_mnist_mlp(self, model_name, num_examples, expected_correct):
     with jax_verify.open_file('mnist/x_test_first100.npy', 'rb') as f:
         mnist_x = np.load(f)
     with jax_verify.open_file('mnist/y_test.npy', 'rb') as f:
         mnist_y = np.load(f)
     with jax_verify.open_file(model_name, 'rb') as f:
         params = pickle.load(f)  # pytype: disable=wrong-arg-types  # due to GFile
     logits = np.array(_predict_mlp(params, mnist_x[:num_examples]))
     pred_labels = np.argmax(logits, axis=1)
     num_correct = np.sum(np.equal(mnist_y[:num_examples], pred_labels))
     print(num_correct)
     assert num_correct == expected_correct, f'Number correct: {num_correct}'
示例#4
0
def load_model(model_name):
    """Load model parameters and prediction function."""
    # Choose appropriate prediction function
    if model_name in ('mlp', 'toy'):
        model_path = MLP_PATH

        def model_fn(params, inputs):
            inputs = np.reshape(inputs, (inputs.shape[0], -1))
            return utils.predict_mlp(params, inputs)
    elif model_name == 'cnn':
        model_path = CNN_PATH
        model_fn = utils.predict_cnn
    else:
        raise ValueError('')

    # Get parameters
    if model_name == 'toy':
        params = [
            (np.random.normal(size=(784, 2)), np.random.normal(size=(2, ))),
            (np.random.normal(size=(2, 10)), np.random.normal(size=(10, ))),
        ]
    else:
        with jax_verify.open_file(model_path, 'rb') as f:
            params = pickle.load(f)
    return model_fn, params
示例#5
0
def main(unused_args):

    # Load the parameters of an existing model.
    model_pred, params = load_model(FLAGS.model)
    logits_fn = functools.partial(model_pred, params)

    # Load some test samples
    with jax_verify.open_file('mnist/x_test_first100.npy', 'rb') as f:
        inputs = np.load(f)

    # Compute boundprop bounds
    eps = 0.1
    lower_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] - eps, 0.0), 1.0)
    upper_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] + eps, 0.0), 1.0)
    init_bound = jax_verify.IntervalBound(lower_bound, upper_bound)

    if FLAGS.boundprop_method == 'fastlin':
        final_bound = jax_verify.fastlin_bound_propagation(
            logits_fn, init_bound)
        boundprop_transform = jax_verify.fastlin_transform
    elif FLAGS.boundprop_method == 'ibp':
        final_bound = jax_verify.interval_bound_propagation(
            logits_fn, init_bound)
        boundprop_transform = jax_verify.ibp_transform
    else:
        raise NotImplementedError('Only ibp/fastlin boundprop are'
                                  'currently supported')

    dummy_output = model_pred(params, inputs)

    # Run LP solver
    objective = jnp.where(
        jnp.arange(dummy_output[0, ...].size) == 0,
        jnp.ones_like(dummy_output[0, ...]),
        jnp.zeros_like(dummy_output[0, ...]))
    objective_bias = 0.
    value, status = jax_verify.solve_planet_relaxation(logits_fn,
                                                       init_bound,
                                                       boundprop_transform,
                                                       objective,
                                                       objective_bias,
                                                       index=0)
    logging.info('Relaxation LB is : %f, Status is %s', value, status)
    value, status = jax_verify.solve_planet_relaxation(logits_fn,
                                                       init_bound,
                                                       boundprop_transform,
                                                       -objective,
                                                       objective_bias,
                                                       index=0)
    logging.info('Relaxation UB is : %f, Status is %s', -value, status)

    logging.info('Boundprop LB is : %f', final_bound.lower[0, 0])
    logging.info('Boundprop UB is : %f', final_bound.upper[0, 0])
示例#6
0
def load_model(model_name):
    """Load model parameters and prediction function."""
    # Choose appropriate prediction function
    if model_name == 'mlp':
        model_path = MLP_PATH

        def model_fn(params, inputs):
            inputs = np.reshape(inputs, (inputs.shape[0], -1))
            return utils.predict_mlp(params, inputs)
    elif model_name == 'cnn':
        model_path = CNN_PATH
        model_fn = utils.predict_cnn
    else:
        raise ValueError('')

    # Load parameters from file
    with jax_verify.open_file(model_path, 'rb') as f:
        params = pickle.load(f)
    return model_fn, params
示例#7
0
def main(unused_args):
    # Load some test samples
    with jax_verify.open_file('mnist/x_test_first100.npy', 'rb') as f:
        inputs = np.load(f)

    # Load the parameters of an existing model.
    model_pred, params = load_model(FLAGS.model)

    # Evaluation of the model on unperturbed images.
    clean_preds = model_pred(params, inputs)

    # Define initial bound
    eps = 0.1
    initial_bound = jax_verify.IntervalBound(
        jnp.minimum(jnp.maximum(inputs - eps, 0.0), 1.0),
        jnp.minimum(jnp.maximum(inputs + eps, 0.0), 1.0))

    # Because our function `model_pred` takes as inputs both the parameters
    # `params` and the `inputs`, we need to wrap it such that it only takes
    # `inputs` as parameters.
    logits_fn = functools.partial(model_pred, params)

    # Apply bound propagation. All boundprop methods take as an input the model
    # `function`, and the inital bounds, and return final bounds with the same
    # structure as the output of `function`. Internally, these methods work by
    # replacing each operation with its boundprop equivalent - see
    # bound_propagation.py for details.
    boundprop_method = (jax_verify.interval_bound_propagation
                        if not FLAGS.boundprop_method else getattr(
                            jax_verify, FLAGS.boundprop_method))
    assert boundprop_method in ALL_BOUNDPROP_METHODS, 'unsupported method'
    final_bound = boundprop_method(logits_fn, initial_bound)

    logging.info('Lower bound: %s', final_bound.lower)
    logging.info('Upper bound: %s', final_bound.upper)
    logging.info('Clean predictions: %s', clean_preds)

    assert jnp.all(final_bound.lower <= clean_preds), 'Invalid lower bounds'
    assert jnp.all(final_bound.upper >= clean_preds), 'Invalid upper bounds'
示例#8
0
def _load_weights(path):
    with jax_verify.open_file(path, 'rb') as f:
        data = pickle.load(f)
    return data