Exemple #1
0
    def test_recurrent(self, has_extras):
        environment = _make_fake_env()
        env_spec = specs.make_environment_spec(environment)
        output_size = env_spec.actions.num_values
        obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
        rng = hk.PRNGSequence(1)

        @_transform_without_rng
        def network(inputs: jnp.ndarray, state: hk.LSTMState):
            return hk.DeepRNN(
                [hk.Reshape([-1], preserve_dims=1),
                 hk.LSTM(output_size)])(inputs, state)

        @_transform_without_rng
        def initial_state(batch_size: Optional[int] = None):
            network = hk.DeepRNN(
                [hk.Reshape([-1], preserve_dims=1),
                 hk.LSTM(output_size)])
            return network.initial_state(batch_size)

        initial_state = initial_state.apply(initial_state.init(next(rng)), 1)
        params = network.init(next(rng), obs, initial_state)

        def policy(
                params: jnp.ndarray, key: jnp.ndarray,
                observation: jnp.ndarray,
                core_state: hk.LSTMState) -> Tuple[jnp.ndarray, hk.LSTMState]:
            del key  # Unused for test-case deterministic policy.
            action_values, core_state = network.apply(params, observation,
                                                      core_state)
            actions = jnp.argmax(action_values, axis=-1)
            if has_extras:
                return (actions, (action_values, )), core_state
            else:
                return actions, core_state

        variable_source = fakes.VariableSource(params)
        variable_client = variable_utils.VariableClient(
            variable_source, 'policy')

        actor = actors.RecurrentActor(policy,
                                      jax.random.PRNGKey(1),
                                      initial_state,
                                      variable_client,
                                      has_extras=has_extras)

        loop = environment_loop.EnvironmentLoop(environment, actor)
        loop.run(20)
Exemple #2
0
  def test_recurrent(self):
    environment = _make_fake_env()
    env_spec = specs.make_environment_spec(environment)
    output_size = env_spec.actions.num_values
    obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
    rng = hk.PRNGSequence(1)

    @hk.transform
    def network(inputs: jnp.ndarray, state: hk.LSTMState):
      return hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])(inputs, state)

    @hk.transform
    def initial_state(batch_size: int):
      network = hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])
      return network.initial_state(batch_size)

    initial_state = initial_state.apply(initial_state.init(next(rng), 1), 1)
    params = network.init(next(rng), obs, initial_state)

    def policy(
        params: jnp.ndarray,
        key: jnp.ndarray,
        observation: jnp.ndarray,
        core_state: hk.LSTMState
    ) -> Tuple[jnp.ndarray, hk.LSTMState]:
      del key  # Unused for test-case deterministic policy.
      action_values, core_state = network.apply(params, observation, core_state)
      return jnp.argmax(action_values, axis=-1), core_state

    variable_source = fakes.VariableSource(params)
    variable_client = variable_utils.VariableClient(variable_source, 'policy')

    actor = actors.RecurrentActor(
        policy, hk.PRNGSequence(1), initial_state, variable_client)

    loop = environment_loop.EnvironmentLoop(environment, actor)
    loop.run(20)