def batched_policy( params: network_types.Params, key: RNGKey, observation: Observation ) -> Union[Action, Tuple[Action, types.NestedArray]]: # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. observation = utils.add_batch_dim(observation) output = policy(params, key, observation) return utils.squeeze_batch_dim(output)
def select_action(params: networks_lib.Params, observation: networks_lib.Observation, state: PRNGKey): rng = state rng1, rng2 = jax.random.split(rng) observation = utils.add_batch_dim(observation) action = utils.squeeze_batch_dim(policy(params, rng1, observation)) return action, rng2
def batched_recurrent_to_actor_core( recurrent_policy: RecurrentPolicy, initial_core_state: RecurrentState ) -> ActorCore[SimpleActorCoreRecurrentState[RecurrentState], Mapping[ str, jnp.ndarray]]: """Returns ActorCore for a recurrent policy.""" def select_action(params: networks_lib.Params, observation: networks_lib.Observation, state: SimpleActorCoreRecurrentState[RecurrentState]): # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. rng = state.rng rng, policy_rng = jax.random.split(rng) observation = utils.add_batch_dim(observation) recurrent_state = utils.add_batch_dim(state.recurrent_state) action, new_recurrent_state = utils.squeeze_batch_dim( recurrent_policy(params, policy_rng, observation, recurrent_state)) return action, SimpleActorCoreRecurrentState(rng, new_recurrent_state) initial_core_state = utils.squeeze_batch_dim(initial_core_state) def init(rng: PRNGKey) -> SimpleActorCoreRecurrentState[RecurrentState]: return SimpleActorCoreRecurrentState(rng, initial_core_state) def get_extras( state: SimpleActorCoreRecurrentState[RecurrentState] ) -> Mapping[str, jnp.ndarray]: return {'core_state': state.recurrent_state} return ActorCore(init=init, select_action=select_action, get_extras=get_extras)
def __init__( self, networks: r2d2_networks.R2D2Networks, config: r2d2_config.R2D2Config, logger_fn: Callable[[], loggers.Logger] = lambda: None, ): """Creates a R2D2 learner, a behavior policy and an eval actor. Args: networks: R2D2 networks, used to build core state spec. config: a config with R2D2 hps logger_fn: a logger factory for the learner """ self._networks = networks self._config = config self._logger_fn = logger_fn # Sequence length for dataset iterator. self._sequence_length = (self._config.burn_in_length + self._config.trace_length + 1) # Construct the core state spec. dummy_key = jax.random.PRNGKey(0) initial_state_params = networks.initial_state.init(dummy_key, 1) initial_state = networks.initial_state.apply(initial_state_params, dummy_key, 1) core_state_spec = utils.squeeze_batch_dim(initial_state) self._extra_spec = {'core_state': core_state_spec}
def unvectorized_select_action( params: networks_lib.Params, observations: networks_lib.Observation, state: State, ) -> Tuple[networks_lib.Action, State]: observations, state = utils.add_batch_dim((observations, state)) actions, state = actor_core.select_action(params, observations, state) return utils.squeeze_batch_dim((actions, state))
def apply_and_sample(params: networks_lib.Params, key: networks_lib.PRNGKey, observation: networks_lib.Observation, epsilon: Epsilon ) -> networks_lib.Action: # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. observation = utils.add_batch_dim(observation) action_values = network.apply(params, observation) action_values = utils.squeeze_batch_dim(action_values) return rlax.epsilon_greedy(epsilon).sample(key, action_values)
def select_action(params: networks_lib.Params, observation: networks_lib.Observation, state: SimpleActorCoreStateWithExtras): rng = state.rng rng1, rng2 = jax.random.split(rng) observation = utils.add_batch_dim(observation) action, extras = utils.squeeze_batch_dim( policy(params, rng1, observation)) return action, SimpleActorCoreStateWithExtras(rng2, extras)
def batched_policy( params: network_lib.Params, key: network_lib.PRNGKey, observation: network_lib.Observation ) -> Tuple[Union[network_lib.Action, Tuple[ network_lib.Action, types.NestedArray]], network_lib.PRNGKey]: # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. key, key2 = jax.random.split(key) observation = utils.add_batch_dim(observation) output = policy(params, key2, observation) return utils.squeeze_batch_dim(output), key
def batched_policy( params, observation, discrete_action, ): observation = utils.add_batch_dim(observation) action = utils.squeeze_batch_dim( policy(params, observation, discrete_action)) return action
def select_action(params: networks_lib.Params, observation: networks_lib.Observation, state: SimpleActorCoreRecurrentState[RecurrentState]): # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. rng = state.rng rng, policy_rng = jax.random.split(rng) observation = utils.add_batch_dim(observation) recurrent_state = utils.add_batch_dim(state.recurrent_state) action, new_recurrent_state = utils.squeeze_batch_dim( recurrent_policy(params, policy_rng, observation, recurrent_state)) return action, SimpleActorCoreRecurrentState(rng, new_recurrent_state)
def critic_mean( critic_params: networks_lib.Params, observation: types.NestedArray, action: types.NestedArray, ) -> jnp.ndarray: # We add batch dimension to make sure batch concat in critic_network # works correctly. observation = utils.add_batch_dim(observation) action = utils.add_batch_dim(action) # Computes the mean action-value estimate. logits, atoms = critic_network.apply(critic_params, observation, action) logits = utils.squeeze_batch_dim(logits) probabilities = jax.nn.softmax(logits) return jnp.sum(probabilities * atoms, axis=-1)
def apply_and_sample(params, key, obs, epsilon): # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. obs = utils.add_batch_dim(obs) action_values = network.apply(params, obs)[0] action_values = utils.squeeze_batch_dim(action_values) return rlax.epsilon_greedy(epsilon).sample(key, action_values)