Esempio n. 1
0
    def test_apply_round_robin(self):
        x = jnp.ones(10)  # Base input
        bx = jnp.ones((7, 10))  # Batched input

        wrapped_ffn = params_adding_ffn(x)

        rr_ensemble = ensemble.make_ensemble(wrapped_ffn,
                                             ensemble.apply_round_robin,
                                             num_networks=3)
        key = jax.random.PRNGKey(0)
        params = rr_ensemble.init(key)
        self.assertTupleEqual(params.shape, (3, ) + x.shape)

        y = rr_ensemble.apply(params, jnp.broadcast_to(x, (3, ) + x.shape))
        self.assertTupleEqual(y.shape, (3, ) + x.shape)
        np.testing.assert_allclose(params, y - x)

        # Note: the ensemble dimension must lead, the batch dimension is no longer
        # the leading dimension.
        by = rr_ensemble.apply(
            params,
            jnp.broadcast_to(jnp.expand_dims(bx, axis=0), (3, ) + bx.shape))
        self.assertTupleEqual(by.shape, (3, ) + bx.shape)

        # If num_networks=3, then `round_robin(params, input)[4]` should be equal
        # to `apply(params[1], input[4])`, etc.
        yy = rr_ensemble.apply(params, jnp.broadcast_to(x, (6, ) + x.shape))
        self.assertTupleEqual(yy.shape, (6, ) + x.shape)
        np.testing.assert_allclose(jnp.concatenate([params, params], axis=0),
                                   yy - jnp.expand_dims(x, axis=0))
Esempio n. 2
0
    def test_round_robin_random(self):
        x = jnp.ones(10)  # Base input
        bx = jnp.ones((9, 10))  # Batched input
        ffn = RandomFFN()
        wrapped_ffn = networks.FeedForwardNetwork(init=functools.partial(
            ffn.init, x=x),
                                                  apply=ffn.apply)
        rr_ensemble = ensemble.make_ensemble(wrapped_ffn,
                                             ensemble.apply_round_robin,
                                             num_networks=3)

        key = jax.random.PRNGKey(0)
        params = rr_ensemble.init(key)
        out = rr_ensemble.apply(params, bx)
        # The output should be the same every 3 rows:
        blocks = jnp.split(out, 3, axis=0)
        np.testing.assert_array_equal(blocks[0], blocks[1])
        np.testing.assert_array_equal(blocks[0], blocks[2])
        self.assertTrue((out[0] != out[1]).any())

        for i in range(9):
            np.testing.assert_allclose(
                out[i],
                ffn.apply(jax.tree_map(lambda p, i=i: p[i % 3], params),
                          bx[i]),
                atol=1E-5,
                rtol=1E-5)
Esempio n. 3
0
    def test_mean_random(self):
        x = jnp.ones(10)
        bx = jnp.ones((9, 10))
        ffn = RandomFFN()
        wrapped_ffn = networks.FeedForwardNetwork(init=functools.partial(
            ffn.init, x=x),
                                                  apply=ffn.apply)
        mean_ensemble = ensemble.make_ensemble(wrapped_ffn,
                                               ensemble.apply_mean,
                                               num_networks=3)
        key = jax.random.PRNGKey(0)
        params = mean_ensemble.init(key)
        single_output = mean_ensemble.apply(params, x)
        self.assertEqual(single_output.shape, (15, ))
        batch_output = mean_ensemble.apply(params, bx)
        # Make sure all rows are equal:
        np.testing.assert_allclose(jnp.broadcast_to(batch_output[0],
                                                    batch_output.shape),
                                   batch_output,
                                   atol=1E-5,
                                   rtol=1E-5)

        # Check results explicitly:
        all_members = jnp.concatenate([
            jnp.expand_dims(ffn.apply(
                jax.tree_map(lambda p, i=i: p[i], params), bx),
                            axis=0) for i in range(3)
        ])
        batch_means = jnp.mean(all_members, axis=0)
        np.testing.assert_allclose(batch_output,
                                   batch_means,
                                   atol=1E-5,
                                   rtol=1E-5)
Esempio n. 4
0
    def test_apply_mean_multiargs(self):
        x = jnp.ones(10)  # Base input

        wrapped_ffn = funny_args_ffn(x)

        rr_ensemble = ensemble.make_ensemble(wrapped_ffn,
                                             ensemble.apply_mean,
                                             num_networks=3)
        key = jax.random.PRNGKey(0)
        params = rr_ensemble.init(key)
        self.assertTupleEqual(params.shape, (3, ) + x.shape)

        y = rr_ensemble.apply(params, x, 2 * x, x)
        self.assertTupleEqual(y.shape, x.shape)
        np.testing.assert_allclose(jnp.mean(params, axis=0),
                                   y - 2 * x,
                                   atol=1E-5,
                                   rtol=1E-5)

        y = rr_ensemble.apply(params, x, bar=x, foo=2 * x)
        self.assertTupleEqual(y.shape, x.shape)
        np.testing.assert_allclose(jnp.mean(params, axis=0),
                                   y - 2 * x,
                                   atol=1E-5,
                                   rtol=1E-5)
Esempio n. 5
0
    def test_apply_round_robin_multiargs(self):
        x = jnp.ones(10)  # Base input

        wrapped_ffn = funny_args_ffn(x)

        rr_ensemble = ensemble.make_ensemble(wrapped_ffn,
                                             ensemble.apply_round_robin,
                                             num_networks=3)
        key = jax.random.PRNGKey(0)
        params = rr_ensemble.init(key)
        self.assertTupleEqual(params.shape, (3, ) + x.shape)

        ex = jnp.broadcast_to(x, (3, ) + x.shape)
        y = rr_ensemble.apply(params, ex, 2 * ex, ex)
        self.assertTupleEqual(y.shape, (3, ) + x.shape)
        np.testing.assert_allclose(params,
                                   y -
                                   jnp.broadcast_to(2 * x, (3, ) + x.shape),
                                   atol=1E-5,
                                   rtol=1E-5)

        y = rr_ensemble.apply(params, ex, bar=ex, foo=2 * ex)
        self.assertTupleEqual(y.shape, (3, ) + x.shape)
        np.testing.assert_allclose(params,
                                   y -
                                   jnp.broadcast_to(2 * x, (3, ) + x.shape),
                                   atol=1E-5,
                                   rtol=1E-5)
Esempio n. 6
0
def make_ensemble_regressor_learner(
    name: str,
    num_networks: int,
    logger_fn: loggers.LoggerFactory,
    counter: counting.Counter,
    rng_key: jnp.ndarray,
    iterator: Iterator[types.Transition],
    base_network: networks_lib.FeedForwardNetwork,
    loss: mbop_losses.TransitionLoss,
    optimizer: optax.GradientTransformation,
    num_sgd_steps_per_step: int,
):
  """Creates an ensemble regressor learner from the base network.

  Args:
    name: Name of the learner used for logging and counters.
    num_networks: Number of networks in the ensemble.
    logger_fn: Constructs a logger for a label.
    counter: Parent counter object.
    rng_key: Random key.
    iterator: An iterator of time-batched transitions used to train the
      networks.
    base_network: Base network for the ensemble.
    loss: Training loss to use.
    optimizer: Optax optimizer.
    num_sgd_steps_per_step: Number of gradient updates per step.

  Returns:
    An ensemble regressor learner.
  """
  mbop_ensemble = ensemble.make_ensemble(base_network, ensemble.apply_all,
                                         num_networks)
  local_counter = counting.Counter(parent=counter, prefix=name)
  local_logger = logger_fn(name,
                           local_counter.get_steps_key()) if logger_fn else None

  def loss_fn(apply_fn: Callable[..., networks_lib.NetworkOutput],
              params: networks_lib.Params, key: jnp.ndarray,
              transitions: types.Transition) -> jnp.ndarray:
    del key
    return loss(functools.partial(apply_fn, params), transitions)

  # This is effectively a regressor learner.
  return bc.BCLearner(
      mbop_ensemble,
      rng_key,
      loss_fn,
      optimizer,
      iterator,
      num_sgd_steps_per_step,
      logger=local_logger,
      counter=local_counter)
Esempio n. 7
0
    def test_ensemble_init(self):
        x = jnp.ones(10)  # Base input

        wrapped_ffn = params_adding_ffn(x)

        rr_ensemble = ensemble.make_ensemble(wrapped_ffn,
                                             ensemble.apply_round_robin,
                                             num_networks=3)
        key = jax.random.PRNGKey(0)
        params = rr_ensemble.init(key)

        self.assertTupleEqual(params.shape, (3, ) + x.shape)

        # The ensemble dimension is the lead dimension.
        self.assertFalse((params[0, ...] == params[1, ...]).all())
Esempio n. 8
0
    def test_apply_all_structured(self):
        x = jnp.ones(10)
        sx = [(3 * x, 2 * x), 5 * x]  # Base input

        wrapped_ffn = struct_params_adding_ffn(sx)

        rr_ensemble = ensemble.make_ensemble(wrapped_ffn,
                                             ensemble.apply_all,
                                             num_networks=3)
        key = jax.random.PRNGKey(0)
        params = rr_ensemble.init(key)

        y = rr_ensemble.apply(params, sx)
        ex = jnp.broadcast_to(x, (3, ) + x.shape)
        np.testing.assert_allclose(y[0][0], params[0][0] + 3 * ex)
Esempio n. 9
0
    def test_apply_mean_structured(self):
        x = jnp.ones(10)
        sx = [(3 * x, 2 * x), 5 * x]  # Base input

        wrapped_ffn = struct_params_adding_ffn(sx)

        rr_ensemble = ensemble.make_ensemble(wrapped_ffn,
                                             ensemble.apply_mean,
                                             num_networks=3)
        key = jax.random.PRNGKey(0)
        params = rr_ensemble.init(key)

        y = rr_ensemble.apply(params, sx)
        np.testing.assert_allclose(y[0][0],
                                   jnp.mean(params[0][0], axis=0) + 3 * x,
                                   atol=1E-5,
                                   rtol=1E-5)
Esempio n. 10
0
    def test_apply_all(self):
        x = jnp.ones(10)  # Base input
        bx = jnp.ones((7, 10))  # Batched input

        wrapped_ffn = params_adding_ffn(x)

        rr_ensemble = ensemble.make_ensemble(wrapped_ffn,
                                             ensemble.apply_all,
                                             num_networks=3)
        key = jax.random.PRNGKey(0)
        params = rr_ensemble.init(key)
        self.assertTupleEqual(params.shape, (3, ) + x.shape)

        y = rr_ensemble.apply(params, x)
        self.assertTupleEqual(y.shape, (3, ) + x.shape)
        np.testing.assert_allclose(params,
                                   y - jnp.broadcast_to(x, (3, ) + x.shape))

        by = rr_ensemble.apply(params, bx)
        # Note: the batch dimension is no longer the leading dimension.
        self.assertTupleEqual(by.shape, (3, ) + bx.shape)
Esempio n. 11
0
    def test_apply_mean(self):
        x = jnp.ones(10)  # Base input
        bx = jnp.ones((7, 10))  # Batched input

        wrapped_ffn = params_adding_ffn(x)

        rr_ensemble = ensemble.make_ensemble(wrapped_ffn,
                                             ensemble.apply_mean,
                                             num_networks=3)
        key = jax.random.PRNGKey(0)
        params = rr_ensemble.init(key)
        self.assertTupleEqual(params.shape, (3, ) + x.shape)
        self.assertFalse((params[0, ...] == params[1, ...]).all())

        y = rr_ensemble.apply(params, x)
        self.assertTupleEqual(y.shape, x.shape)
        np.testing.assert_allclose(jnp.mean(params, axis=0),
                                   y - x,
                                   atol=1E-5,
                                   rtol=1E-5)

        by = rr_ensemble.apply(params, bx)
        self.assertTupleEqual(by.shape, bx.shape)