Example #1
0
def set_up_toy_problem(rng_key, batch_size, architecture):
  key_1, key_2 = jax.random.split(rng_key)
  params = sdp_test_utils.make_mlp_params(architecture, key_2)

  inputs = jax.random.uniform(key_1, (batch_size, architecture[0]))
  eps = 0.1
  lb = jnp.maximum(jnp.minimum(inputs - eps, 1.), 0.)
  ub = jnp.maximum(jnp.minimum(inputs + eps, 1.), 0.)
  fun = functools.partial(utils.predict_cnn, params)
  return fun, (lb, ub)
Example #2
0
    def test_mlp_extract(self):
        """Test that weights from a MLP can be extracted."""
        key = random.PRNGKey(0)
        k1, k2 = random.split(key)

        input_sizes = (5, )
        layer_sizes = (5, 8, 5)
        mlp_params = test_utils.make_mlp_params(layer_sizes, k1)

        fun_to_extract = functools.partial(utils.predict_mlp, mlp_params)
        example_inputs = random.normal(k2, input_sizes)
        self.check_fun_extract(fun_to_extract, example_inputs)
Example #3
0
    def test_mlp_withpreproc(self):
        """Test extraction of weights from a MLP with input preprocessing."""
        key = random.PRNGKey(0)
        k1, k2, k3, k4 = random.split(key, num=4)

        input_sizes = (5, )
        layer_sizes = (5, 8, 5)
        mlp_params = test_utils.make_mlp_params(layer_sizes, k1)
        example_inputs = random.normal(k2, input_sizes)
        input_mean = random.normal(k3, input_sizes)
        input_std = random.normal(k4, input_sizes)

        def fun_to_extract(inputs):
            inp = (inputs - input_mean) / input_std
            return utils.predict_mlp(mlp_params, inp)

        self.check_fun_extract(fun_to_extract, example_inputs)
Example #4
0
    def setUp(self):
        super(LinearTest, self).setUp()

        self.target_label = 1
        self.label = 0
        self.input_bounds = (0.0, 1.0)
        self.layer_sizes = LAYER_SIZES
        self.eps = 0.1

        prng_key = jax.random.PRNGKey(13579)

        self.keys = jax.random.split(prng_key, 5)
        self.network_params = sdp_test_utils.make_mlp_params(
            self.layer_sizes, self.keys[0])

        self.inputs = create_inputs(self.keys[1])

        objective = jnp.zeros(self.layer_sizes[-1])
        objective = objective.at[self.target_label].add(1)
        objective = objective.at[self.label].add(-1)
        self.objective = objective
        self.objective_bias = jax.random.normal(self.keys[2], [])