Beispiel #1
0
  def test_feedforward(self):
    environment = _make_fake_env()
    env_spec = specs.make_environment_spec(environment)

    def policy(inputs: jnp.ndarray):
      return hk.Sequential([
          hk.Flatten(),
          hk.Linear(env_spec.actions.num_values),
          lambda x: jnp.argmax(x, axis=-1),
      ])(
          inputs)

    policy = hk.transform(policy, apply_rng=True)

    rng = hk.PRNGSequence(1)
    dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
    params = policy.init(next(rng), dummy_obs)

    variable_source = fakes.VariableSource(params)
    variable_client = variable_utils.VariableClient(variable_source, 'policy')

    actor = actors.FeedForwardActor(
        policy.apply, rng=hk.PRNGSequence(1), variable_client=variable_client)

    loop = environment_loop.EnvironmentLoop(environment, actor)
    loop.run(20)
    def test_update(self):
        # Create two instances of the same model.
        actor_model = snt.nets.MLP([50, 30])
        learner_model = snt.nets.MLP([50, 30])

        # Create variables first.
        input_spec = tf.TensorSpec(shape=(28, ), dtype=tf.float32)
        tf2_utils.create_variables(actor_model, [input_spec])
        tf2_utils.create_variables(learner_model, [input_spec])

        # Register them as client and source variables, respectively.
        actor_variables = actor_model.variables
        np_learner_variables = [
            tf2_utils.to_numpy(v) for v in learner_model.variables
        ]
        variable_source = fakes.VariableSource(np_learner_variables)
        variable_client = tf2_variable_utils.VariableClient(
            variable_source, {'policy': actor_variables})

        # Now, given some random batch of test input:
        x = tf.random.normal(shape=(8, 28))

        # Before copying variables, the models have different outputs.
        actor_output = actor_model(x).numpy()
        learner_output = learner_model(x).numpy()
        self.assertFalse(np.allclose(actor_output, learner_output))

        # Update the variable client.
        variable_client.update_and_wait()

        # After copying variables (by updating the client), the models are the same.
        actor_output = actor_model(x).numpy()
        learner_output = learner_model(x).numpy()
        self.assertTrue(np.allclose(actor_output, learner_output))
Beispiel #3
0
 def test_update(self):
     init_fn, _ = hk.without_apply_rng(hk.transform(dummy_network))
     params = init_fn(jax.random.PRNGKey(1), jnp.zeros(shape=(1, 32)))
     variable_source = fakes.VariableSource(params)
     variable_client = variable_utils.VariableClient(variable_source,
                                                     key='policy')
     variable_client.update_and_wait()
     tree.map_structure(np.testing.assert_array_equal,
                        variable_client.params, params)
Beispiel #4
0
    def test_multiple_keys(self):
        init_fn, _ = hk.without_apply_rng(hk.transform(dummy_network))
        params = init_fn(jax.random.PRNGKey(1), jnp.zeros(shape=(1, 32)))
        steps = jnp.zeros(shape=1)
        variables = {'network': params, 'steps': steps}
        variable_source = fakes.VariableSource(variables,
                                               use_default_key=False)
        variable_client = variable_utils.VariableClient(
            variable_source, key=['network', 'steps'])
        variable_client.update_and_wait()

        tree.map_structure(np.testing.assert_array_equal,
                           variable_client.params[0], params)
        tree.map_structure(np.testing.assert_array_equal,
                           variable_client.params[1], steps)
Beispiel #5
0
    def test_recurrent(self, has_extras):
        environment = _make_fake_env()
        env_spec = specs.make_environment_spec(environment)
        output_size = env_spec.actions.num_values
        obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
        rng = hk.PRNGSequence(1)

        @_transform_without_rng
        def network(inputs: jnp.ndarray, state: hk.LSTMState):
            return hk.DeepRNN(
                [hk.Reshape([-1], preserve_dims=1),
                 hk.LSTM(output_size)])(inputs, state)

        @_transform_without_rng
        def initial_state(batch_size: Optional[int] = None):
            network = hk.DeepRNN(
                [hk.Reshape([-1], preserve_dims=1),
                 hk.LSTM(output_size)])
            return network.initial_state(batch_size)

        initial_state = initial_state.apply(initial_state.init(next(rng)), 1)
        params = network.init(next(rng), obs, initial_state)

        def policy(
                params: jnp.ndarray, key: jnp.ndarray,
                observation: jnp.ndarray,
                core_state: hk.LSTMState) -> Tuple[jnp.ndarray, hk.LSTMState]:
            del key  # Unused for test-case deterministic policy.
            action_values, core_state = network.apply(params, observation,
                                                      core_state)
            actions = jnp.argmax(action_values, axis=-1)
            if has_extras:
                return (actions, (action_values, )), core_state
            else:
                return actions, core_state

        variable_source = fakes.VariableSource(params)
        variable_client = variable_utils.VariableClient(
            variable_source, 'policy')

        actor = actors.RecurrentActor(policy,
                                      jax.random.PRNGKey(1),
                                      initial_state,
                                      variable_client,
                                      has_extras=has_extras)

        loop = environment_loop.EnvironmentLoop(environment, actor)
        loop.run(20)
Beispiel #6
0
  def test_update_and_wait(self):
    # Create a variable source (emulating the learner).
    np_learner_variables = tf2_utils.to_numpy(self._learner_model.variables)
    variable_source = fakes.VariableSource(np_learner_variables)

    # Create a variable client (emulating the actor).
    variable_client = tf2_variable_utils.VariableClient(
        variable_source, {'policy': self._actor_model.variables})

    # Create some random batch of test input:
    x = tf.random.normal(shape=(_BATCH_SIZE, _INPUT_SIZE))

    # Before copying variables, the models have different outputs.
    self.assertNotAllClose(self._actor_model(x), self._learner_model(x))

    # Update the variable client.
    variable_client.update_and_wait()

    # After copying variables (by updating the client), the models are the same.
    self.assertAllClose(self._actor_model(x), self._learner_model(x))
Beispiel #7
0
    def test_feedforward(self, has_extras):
        environment = _make_fake_env()
        env_spec = specs.make_environment_spec(environment)

        def policy(inputs: jnp.ndarray):
            action_values = hk.Sequential([
                hk.Flatten(),
                hk.Linear(env_spec.actions.num_values),
            ])(inputs)
            action = jnp.argmax(action_values, axis=-1)
            if has_extras:
                return action, (action_values, )
            else:
                return action

        policy = hk.transform(policy)

        rng = hk.PRNGSequence(1)
        dummy_obs = utils.add_batch_dim(utils.zeros_like(
            env_spec.observations))
        params = policy.init(next(rng), dummy_obs)

        variable_source = fakes.VariableSource(params)
        variable_client = variable_utils.VariableClient(
            variable_source, 'policy')

        if has_extras:
            actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core(
                policy.apply)
        else:
            actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
                policy.apply)
        actor = actors.GenericActor(actor_core,
                                    random_key=jax.random.PRNGKey(1),
                                    variable_client=variable_client)

        loop = environment_loop.EnvironmentLoop(environment, actor)
        loop.run(20)
Beispiel #8
0
  def test_recurrent(self):
    environment = _make_fake_env()
    env_spec = specs.make_environment_spec(environment)
    output_size = env_spec.actions.num_values
    obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
    rng = hk.PRNGSequence(1)

    @hk.transform
    def network(inputs: jnp.ndarray, state: hk.LSTMState):
      return hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])(inputs, state)

    @hk.transform
    def initial_state(batch_size: int):
      network = hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])
      return network.initial_state(batch_size)

    initial_state = initial_state.apply(initial_state.init(next(rng), 1), 1)
    params = network.init(next(rng), obs, initial_state)

    def policy(
        params: jnp.ndarray,
        key: jnp.ndarray,
        observation: jnp.ndarray,
        core_state: hk.LSTMState
    ) -> Tuple[jnp.ndarray, hk.LSTMState]:
      del key  # Unused for test-case deterministic policy.
      action_values, core_state = network.apply(params, observation, core_state)
      return jnp.argmax(action_values, axis=-1), core_state

    variable_source = fakes.VariableSource(params)
    variable_client = variable_utils.VariableClient(variable_source, 'policy')

    actor = actors.RecurrentActor(
        policy, hk.PRNGSequence(1), initial_state, variable_client)

    loop = environment_loop.EnvironmentLoop(environment, actor)
    loop.run(20)
Beispiel #9
0
    def test_update(self):
        # Create a barrier to be shared between the test body and the variable
        # source. The barrier will block until, in this case, two threads call
        # wait(). Note that the (fake) variable source will call it within its
        # get_variables() call.
        barrier = threading.Barrier(2)

        # Create a variable source (emulating the learner).
        np_learner_variables = tf2_utils.to_numpy(
            self._learner_model.variables)
        variable_source = fakes.VariableSource(np_learner_variables, barrier)

        # Create a variable client (emulating the actor).
        variable_client = tf2_variable_utils.VariableClient(
            variable_source, {'policy': self._actor_model.variables},
            update_period=_UPDATE_PERIOD)

        # Create some random batch of test input:
        x = tf.random.normal(shape=(_BATCH_SIZE, _INPUT_SIZE))

        # Create variables by doing the computation once.
        learner_output = self._learner_model(x)
        actor_output = self._actor_model(x)
        del learner_output, actor_output

        for _ in range(_UPDATE_PERIOD):
            # Before the update period is reached, the models have different outputs.
            self.assertNotAllClose(self._actor_model.variables,
                                   self._learner_model.variables)

            # Before the update period is reached, the variable client should not make
            # any requests for variables.
            self.assertIsNone(variable_client._future)

            variable_client.update()

        # Make sure the last call created a request for variables and reset the
        # internal call counter.
        self.assertIsNotNone(variable_client._future)
        self.assertEqual(variable_client._call_counter, 0)
        future = variable_client._future

        for _ in range(_UPDATE_PERIOD):
            # Before the barrier allows the variables to be released, the models have
            # different outputs.
            self.assertNotAllClose(self._actor_model.variables,
                                   self._learner_model.variables)

            variable_client.update()

            # Make sure no new requests are made.
            self.assertEqual(variable_client._future, future)

        # Calling wait() on the barrier will now allow the variables to be copied
        # over from source to client.
        barrier.wait()

        # Update once more to ensure the variables are copied over.
        while variable_client._future is not None:
            variable_client.update()

        # After a number of update calls, the variables should be the same.
        self.assertAllClose(self._actor_model.variables,
                            self._learner_model.variables)