Beispiel #1
0
 def batched_policy(
     params: network_types.Params, key: RNGKey, observation: Observation
 ) -> Union[Action, Tuple[Action, types.NestedArray]]:
     # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs.
     observation = utils.add_batch_dim(observation)
     output = policy(params, key, observation)
     return utils.squeeze_batch_dim(output)
Beispiel #2
0
 def select_action(params: networks_lib.Params,
                   observation: networks_lib.Observation, state: PRNGKey):
     rng = state
     rng1, rng2 = jax.random.split(rng)
     observation = utils.add_batch_dim(observation)
     action = utils.squeeze_batch_dim(policy(params, rng1, observation))
     return action, rng2
Beispiel #3
0
def batched_recurrent_to_actor_core(
    recurrent_policy: RecurrentPolicy, initial_core_state: RecurrentState
) -> ActorCore[SimpleActorCoreRecurrentState[RecurrentState], Mapping[
        str, jnp.ndarray]]:
    """Returns ActorCore for a recurrent policy."""
    def select_action(params: networks_lib.Params,
                      observation: networks_lib.Observation,
                      state: SimpleActorCoreRecurrentState[RecurrentState]):
        # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs.
        rng = state.rng
        rng, policy_rng = jax.random.split(rng)
        observation = utils.add_batch_dim(observation)
        recurrent_state = utils.add_batch_dim(state.recurrent_state)
        action, new_recurrent_state = utils.squeeze_batch_dim(
            recurrent_policy(params, policy_rng, observation, recurrent_state))
        return action, SimpleActorCoreRecurrentState(rng, new_recurrent_state)

    initial_core_state = utils.squeeze_batch_dim(initial_core_state)

    def init(rng: PRNGKey) -> SimpleActorCoreRecurrentState[RecurrentState]:
        return SimpleActorCoreRecurrentState(rng, initial_core_state)

    def get_extras(
        state: SimpleActorCoreRecurrentState[RecurrentState]
    ) -> Mapping[str, jnp.ndarray]:
        return {'core_state': state.recurrent_state}

    return ActorCore(init=init,
                     select_action=select_action,
                     get_extras=get_extras)
Beispiel #4
0
    def __init__(
        self,
        networks: r2d2_networks.R2D2Networks,
        config: r2d2_config.R2D2Config,
        logger_fn: Callable[[], loggers.Logger] = lambda: None,
    ):
        """Creates a R2D2 learner, a behavior policy and an eval actor.

    Args:
      networks: R2D2 networks, used to build core state spec.
      config: a config with R2D2 hps
      logger_fn: a logger factory for the learner
    """
        self._networks = networks
        self._config = config
        self._logger_fn = logger_fn

        # Sequence length for dataset iterator.
        self._sequence_length = (self._config.burn_in_length +
                                 self._config.trace_length + 1)

        # Construct the core state spec.
        dummy_key = jax.random.PRNGKey(0)
        initial_state_params = networks.initial_state.init(dummy_key, 1)
        initial_state = networks.initial_state.apply(initial_state_params,
                                                     dummy_key, 1)
        core_state_spec = utils.squeeze_batch_dim(initial_state)
        self._extra_spec = {'core_state': core_state_spec}
Beispiel #5
0
 def unvectorized_select_action(
     params: networks_lib.Params,
     observations: networks_lib.Observation,
     state: State,
 ) -> Tuple[networks_lib.Action, State]:
     observations, state = utils.add_batch_dim((observations, state))
     actions, state = actor_core.select_action(params, observations, state)
     return utils.squeeze_batch_dim((actions, state))
Beispiel #6
0
 def apply_and_sample(params: networks_lib.Params, key: networks_lib.PRNGKey,
                      observation: networks_lib.Observation, epsilon: Epsilon
                      ) -> networks_lib.Action:
   # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs.
   observation = utils.add_batch_dim(observation)
   action_values = network.apply(params, observation)
   action_values = utils.squeeze_batch_dim(action_values)
   return rlax.epsilon_greedy(epsilon).sample(key, action_values)
Beispiel #7
0
 def select_action(params: networks_lib.Params,
                   observation: networks_lib.Observation,
                   state: SimpleActorCoreStateWithExtras):
     rng = state.rng
     rng1, rng2 = jax.random.split(rng)
     observation = utils.add_batch_dim(observation)
     action, extras = utils.squeeze_batch_dim(
         policy(params, rng1, observation))
     return action, SimpleActorCoreStateWithExtras(rng2, extras)
Beispiel #8
0
 def batched_policy(
     params: network_lib.Params, key: network_lib.PRNGKey,
     observation: network_lib.Observation
 ) -> Tuple[Union[network_lib.Action, Tuple[
         network_lib.Action, types.NestedArray]], network_lib.PRNGKey]:
     # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs.
     key, key2 = jax.random.split(key)
     observation = utils.add_batch_dim(observation)
     output = policy(params, key2, observation)
     return utils.squeeze_batch_dim(output), key
Beispiel #9
0
        def batched_policy(
            params,
            observation,
            discrete_action,
        ):
            observation = utils.add_batch_dim(observation)
            action = utils.squeeze_batch_dim(
                policy(params, observation, discrete_action))

            return action
Beispiel #10
0
 def select_action(params: networks_lib.Params,
                   observation: networks_lib.Observation,
                   state: SimpleActorCoreRecurrentState[RecurrentState]):
     # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs.
     rng = state.rng
     rng, policy_rng = jax.random.split(rng)
     observation = utils.add_batch_dim(observation)
     recurrent_state = utils.add_batch_dim(state.recurrent_state)
     action, new_recurrent_state = utils.squeeze_batch_dim(
         recurrent_policy(params, policy_rng, observation, recurrent_state))
     return action, SimpleActorCoreRecurrentState(rng, new_recurrent_state)
Beispiel #11
0
 def critic_mean(
     critic_params: networks_lib.Params,
     observation: types.NestedArray,
     action: types.NestedArray,
 ) -> jnp.ndarray:
   # We add batch dimension to make sure batch concat in critic_network
   # works correctly.
   observation = utils.add_batch_dim(observation)
   action = utils.add_batch_dim(action)
   # Computes the mean action-value estimate.
   logits, atoms = critic_network.apply(critic_params, observation, action)
   logits = utils.squeeze_batch_dim(logits)
   probabilities = jax.nn.softmax(logits)
   return jnp.sum(probabilities * atoms, axis=-1)
Beispiel #12
0
 def apply_and_sample(params, key, obs, epsilon):
     # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs.
     obs = utils.add_batch_dim(obs)
     action_values = network.apply(params, obs)[0]
     action_values = utils.squeeze_batch_dim(action_values)
     return rlax.epsilon_greedy(epsilon).sample(key, action_values)