def test_recurrent(self, has_extras): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) output_size = env_spec.actions.num_values obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) rng = hk.PRNGSequence(1) @_transform_without_rng def network(inputs: jnp.ndarray, state: hk.LSTMState): return hk.DeepRNN( [hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)])(inputs, state) @_transform_without_rng def initial_state(batch_size: Optional[int] = None): network = hk.DeepRNN( [hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)]) return network.initial_state(batch_size) initial_state = initial_state.apply(initial_state.init(next(rng)), 1) params = network.init(next(rng), obs, initial_state) def policy( params: jnp.ndarray, key: jnp.ndarray, observation: jnp.ndarray, core_state: hk.LSTMState) -> Tuple[jnp.ndarray, hk.LSTMState]: del key # Unused for test-case deterministic policy. action_values, core_state = network.apply(params, observation, core_state) actions = jnp.argmax(action_values, axis=-1) if has_extras: return (actions, (action_values, )), core_state else: return actions, core_state variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient( variable_source, 'policy') actor = actors.RecurrentActor(policy, jax.random.PRNGKey(1), initial_state, variable_client, has_extras=has_extras) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def test_recurrent(self): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) output_size = env_spec.actions.num_values obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) rng = hk.PRNGSequence(1) @hk.transform def network(inputs: jnp.ndarray, state: hk.LSTMState): return hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])(inputs, state) @hk.transform def initial_state(batch_size: int): network = hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)]) return network.initial_state(batch_size) initial_state = initial_state.apply(initial_state.init(next(rng), 1), 1) params = network.init(next(rng), obs, initial_state) def policy( params: jnp.ndarray, key: jnp.ndarray, observation: jnp.ndarray, core_state: hk.LSTMState ) -> Tuple[jnp.ndarray, hk.LSTMState]: del key # Unused for test-case deterministic policy. action_values, core_state = network.apply(params, observation, core_state) return jnp.argmax(action_values, axis=-1), core_state variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient(variable_source, 'policy') actor = actors.RecurrentActor( policy, hk.PRNGSequence(1), initial_state, variable_client) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)