예제 #1
0
  def test_fc_fastlin(self):

    @hk.without_apply_rng
    @hk.transform
    def linear_model(inp):
      return hk.Linear(1)(inp)

    z = jnp.array([[1., 2., 3.]])
    params = {'linear':
              {'w': jnp.ones((3, 1), dtype=jnp.float32),
               'b': jnp.array([2.])}}
    input_bounds = jax_verify.IntervalBound(z-1., z+1.)
    fun = functools.partial(linear_model.apply, params)
    output_bounds = jax_verify.fastlin_bound_propagation(fun, input_bounds)

    self.assertTrue(jnp.all(output_bounds.lower_lin.lin_coeffs == 1.))
    self.assertTrue(jnp.all(output_bounds.lower_lin.offset == 2.))
    self.assertTrue(jnp.all(output_bounds.upper_lin.lin_coeffs == 1.))
    self.assertTrue(jnp.all(output_bounds.upper_lin.offset == 2.))
    self.assertArrayAlmostEqual(jnp.array([[0., 1., 2.]]),
                                output_bounds.reference.lower)
    self.assertArrayAlmostEqual(jnp.array([[2., 3., 4.]]),
                                output_bounds.reference.upper)

    self.assertAlmostEqual(5., output_bounds.lower)
    self.assertAlmostEqual(11., output_bounds.upper)
예제 #2
0
  def test_relu_fastlin(self):
    def relu_model(inp):
      return jax.nn.relu(inp)
    z = jnp.array([[-2., 3.]])

    input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
    output_bounds = jax_verify.fastlin_bound_propagation(relu_model,
                                                         input_bounds)

    self.assertArrayAlmostEqual(jnp.array([[0., 2.]]), output_bounds.lower)
    self.assertArrayAlmostEqual(jnp.array([[0., 4.]]), output_bounds.upper)
예제 #3
0
def main(unused_args):

    # Load the parameters of an existing model.
    model_pred, params = load_model(FLAGS.model)
    logits_fn = functools.partial(model_pred, params)

    # Load some test samples
    with jax_verify.open_file('mnist/x_test_first100.npy', 'rb') as f:
        inputs = np.load(f)

    # Compute boundprop bounds
    eps = 0.1
    lower_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] - eps, 0.0), 1.0)
    upper_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] + eps, 0.0), 1.0)
    init_bound = jax_verify.IntervalBound(lower_bound, upper_bound)

    if FLAGS.boundprop_method == 'fastlin':
        final_bound = jax_verify.fastlin_bound_propagation(
            logits_fn, init_bound)
        boundprop_transform = jax_verify.fastlin_transform
    elif FLAGS.boundprop_method == 'ibp':
        final_bound = jax_verify.interval_bound_propagation(
            logits_fn, init_bound)
        boundprop_transform = jax_verify.ibp_transform
    else:
        raise NotImplementedError('Only ibp/fastlin boundprop are'
                                  'currently supported')

    dummy_output = model_pred(params, inputs)

    # Run LP solver
    objective = jnp.where(
        jnp.arange(dummy_output[0, ...].size) == 0,
        jnp.ones_like(dummy_output[0, ...]),
        jnp.zeros_like(dummy_output[0, ...]))
    objective_bias = 0.
    value, status = jax_verify.solve_planet_relaxation(logits_fn,
                                                       init_bound,
                                                       boundprop_transform,
                                                       objective,
                                                       objective_bias,
                                                       index=0)
    logging.info('Relaxation LB is : %f, Status is %s', value, status)
    value, status = jax_verify.solve_planet_relaxation(logits_fn,
                                                       init_bound,
                                                       boundprop_transform,
                                                       -objective,
                                                       objective_bias,
                                                       index=0)
    logging.info('Relaxation UB is : %f, Status is %s', -value, status)

    logging.info('Boundprop LB is : %f', final_bound.lower[0, 0])
    logging.info('Boundprop UB is : %f', final_bound.upper[0, 0])
예제 #4
0
    def test_fastlin(self, model_cls):
        @hk.transform_with_state
        def model_pred(inputs, is_training, test_local_stats=False):
            model = model_cls()
            return model(inputs, is_training, test_local_stats)

        inps = jnp.zeros((4, 28, 28, 1), dtype=jnp.float32)
        params, state = model_pred.init(jax.random.PRNGKey(42),
                                        inps,
                                        is_training=True)

        def logits_fun(inputs):
            return model_pred.apply(params,
                                    state,
                                    None,
                                    inputs,
                                    False,
                                    test_local_stats=False)[0]

        input_bounds = jax_verify.IntervalBound(inps - 1.0, inps + 1.0)
        jax_verify.fastlin_bound_propagation(logits_fun, input_bounds)
예제 #5
0
  def test_conv1d_fastlin(self):

    @hk.without_apply_rng
    @hk.transform
    def conv1d_model(inp):
      return hk.Conv1D(output_channels=1, kernel_shape=2,
                       padding='VALID', stride=1, with_bias=True)(inp)

    z = jnp.array([3., 4.])
    z = jnp.reshape(z, [1, 2, 1])

    params = {'conv1_d':
              {'w': jnp.ones((2, 1, 1), dtype=jnp.float32),
               'b': jnp.array([2.])}}

    fun = functools.partial(conv1d_model.apply, params)
    input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
    output_bounds = jax_verify.fastlin_bound_propagation(fun, input_bounds)

    self.assertAlmostEqual(7., output_bounds.lower)
    self.assertAlmostEqual(11., output_bounds.upper)