예제 #1
0
 def step(
     self,
     params: hk.Params,
     rng: jnp.ndarray,
     timestep: dm_env.TimeStep,
 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
   """Steps on a single observation."""
   timestep = jax.tree_map(lambda t: jnp.expand_dims(t, 0), timestep)
   logits, _ = self._net(params, timestep)
   logits = jnp.squeeze(logits, axis=0)
   action = hk.multinomial(rng, logits, num_samples=1)
   action = jnp.squeeze(action, axis=-1)
   return action, logits
예제 #2
0
    def step(
        self,
        rng_key,
        params: hk.Params,
        timestep: dm_env.TimeStep,
        state: Nest,
    ) -> Tuple[AgentOutput, Nest]:
        """For a given single-step, unbatched timestep, output the chosen action."""
        # Pad timestep, state to be [T, B, ...] and [B, ...] respectively.
        timestep = jax.tree_map(lambda t: t[None, None, ...], timestep)
        state = jax.tree_map(lambda t: t[None, ...], state)

        net_out, next_state = self._apply_fn(params, timestep, state)
        # Remove the padding from above.
        net_out = jax.tree_map(lambda t: jnp.squeeze(t, axis=(0, 1)), net_out)
        next_state = jax.tree_map(lambda t: jnp.squeeze(t, axis=0), next_state)
        # Sample an action and return.
        action = hk.multinomial(rng_key, net_out.policy_logits, num_samples=1)
        action = jnp.squeeze(action, axis=-1)
        return AgentOutput(net_out.policy_logits, net_out.value,
                           action), next_state