def behavior_policy(params: networks_lib.Params, key: networks_lib.PRNGKey, observation: types.NestedArray, core_state: types.NestedArray, epsilon: float): q_values, core_state = networks.forward.apply(params, key, observation, core_state) epsilon = config.evaluation_epsilon if evaluation else epsilon return rlax.epsilon_greedy(epsilon).sample(key, q_values), core_state
def select_action(rng_key, network_params, s_t, exploration_epsilon): """Samples action from eps-greedy policy wrt Q-values at given state.""" rng_key, apply_key, policy_key = jax.random.split(rng_key, 3) q_t = network.apply(network_params, apply_key, s_t[None, ...]).q_values[0] a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon) v_t = jnp.max(q_t, axis=-1) return rng_key, a_t, v_t
def select_action(rng_key, network_params, s_t): """Samples action from eps-greedy policy wrt Q-values at given state.""" rng_key, tau_key, apply_key, policy_key = jax.random.split(rng_key, 4) tau_t = _sample_tau(tau_key, (1, tau_samples)) q_t = network.apply(network_params, apply_key, IqnInputs(s_t[None, ...], tau_t)).q_values[0] a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon) return rng_key, a_t
def apply_and_sample(params: networks_lib.Params, key: networks_lib.PRNGKey, observation: networks_lib.Observation, epsilon: Epsilon ) -> networks_lib.Action: # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. observation = utils.add_batch_dim(observation) action_values = network.apply(params, observation) action_values = utils.squeeze_batch_dim(action_values) return rlax.epsilon_greedy(epsilon).sample(key, action_values)
def default_behavior_policy(network: networks_lib.FeedForwardNetwork, epsilon: float, params: networks_lib.Params, key: networks_lib.PRNGKey, observation: networks_lib.Observation): """Returns an action for the given observation.""" action_values = network.apply(params, observation) actions = rlax.epsilon_greedy(epsilon).sample(key, action_values) return actions.astype(jnp.int32)
def actor_step(self, params, env_output, actor_state, key, evaluation): norm_q = self._network.apply(params, env_output.observation) # This is equivalent to epsilon-greedy on the (unnormalized) Q-values # because normalization is linear, therefore the argmaxes are the same. train_a = rlax.epsilon_greedy(self._epsilon).sample(key, norm_q) eval_a = rlax.greedy().sample(key, norm_q) a = jax.lax.select(evaluation, eval_a, train_a) return ActorOutput(actions=a), actor_state
def actor_step(self, params, env_output, actor_state, key, evaluation): obs = jnp.expand_dims(env_output.observation, 0) # add dummy batch q = self._network.apply(params.online, obs)[0] # remove dummy batch epsilon = self._epsilon_by_frame(actor_state.count) train_a = rlax.epsilon_greedy(epsilon).sample(key, q) eval_a = rlax.greedy().sample(key, q) a = jax.lax.select(evaluation, eval_a, train_a) return ActorOutput(actions=a, q_values=q), ActorState(actor_state.count + 1)
def select_action(params: networks_lib.Params, observation: networks_lib.Observation, state: R2D2ActorState[actor_core_lib.RecurrentState]): rng, policy_rng = jax.random.split(state.rng) q_values, recurrent_state = networks.forward.apply(params, policy_rng, observation, state.recurrent_state) action = rlax.epsilon_greedy(state.epsilon).sample(policy_rng, q_values) return action, R2D2ActorState(rng, state.epsilon, recurrent_state)
def _actor_step(self, all_params, all_states, observation, rng_key, evaluation): obs = jnp.expand_dims(observation, 0) # dummy batch q_val = self._q_net.apply(all_params.online, obs)[0] # remove batch epsilon = self._epsilon_schedule(all_states.actor_steps) train_action = rlax.epsilon_greedy(epsilon).sample(rng_key, q_val) eval_action = rlax.greedy().sample(rng_key, q_val) action = jax.lax.select(evaluation, eval_action, train_action) return ( ActorOutput(actions=action, q_values=q_val), AllStates( optimizer=all_states.optimizer, learner_steps=all_states.learner_steps, actor_steps=all_states.actor_steps + 1, ), )
def policy(net_params, key, obs): """Sample action from epsilon-greedy policy.""" q = network.apply(net_params, obs) a = rlax.epsilon_greedy(epsilon=FLAGS.epsilon).sample(key, q) return q, a
def policy(params: hk.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = network.apply(params, observation) return rlax.epsilon_greedy(epsilon).sample(key, action_values)
def actor_step(self, params, env_output, actor_state, key, evaluation): q = self._network.apply(params, env_output.observation) train_a = rlax.epsilon_greedy(self._epsilon).sample(key, q) eval_a = rlax.greedy().sample(key, q) a = jax.lax.select(evaluation, eval_a, train_a) return ActorOutput(actions=a, q_values=q), actor_state
def policy(params: hk.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = hk.without_apply_rng( hk.transform(network, apply_rng=True)).apply(params, observation) return rlax.epsilon_greedy(epsilon).sample(key, action_values)
def apply_and_sample(params, key, obs, epsilon): # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. obs = utils.add_batch_dim(obs) action_values = network.apply(params, obs)[0] action_values = utils.squeeze_batch_dim(action_values) return rlax.epsilon_greedy(epsilon).sample(key, action_values)
def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: dist_params = network.apply(params, observation) return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample( key, dist_params)
def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: action_values = policy_network.apply(params, observation) return rlax.epsilon_greedy(FLAGS.epsilon).sample(key, action_values)