def step( self, params: hk.Params, rng: jnp.ndarray, timestep: dm_env.TimeStep, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Steps on a single observation.""" timestep = jax.tree_map(lambda t: jnp.expand_dims(t, 0), timestep) logits, _ = self._net(params, timestep) logits = jnp.squeeze(logits, axis=0) action = hk.multinomial(rng, logits, num_samples=1) action = jnp.squeeze(action, axis=-1) return action, logits
def step( self, rng_key, params: hk.Params, timestep: dm_env.TimeStep, state: Nest, ) -> Tuple[AgentOutput, Nest]: """For a given single-step, unbatched timestep, output the chosen action.""" # Pad timestep, state to be [T, B, ...] and [B, ...] respectively. timestep = jax.tree_map(lambda t: t[None, None, ...], timestep) state = jax.tree_map(lambda t: t[None, ...], state) net_out, next_state = self._apply_fn(params, timestep, state) # Remove the padding from above. net_out = jax.tree_map(lambda t: jnp.squeeze(t, axis=(0, 1)), net_out) next_state = jax.tree_map(lambda t: jnp.squeeze(t, axis=0), next_state) # Sample an action and return. action = hk.multinomial(rng_key, net_out.policy_logits, num_samples=1) action = jnp.squeeze(action, axis=-1) return AgentOutput(net_out.policy_logits, net_out.value, action), next_state