Exemplo n.º 1
0
 def behavior_policy(params: networks_lib.Params, key: networks_lib.PRNGKey,
                     observation: types.NestedArray,
                     core_state: types.NestedArray, epsilon: float):
   q_values, core_state = networks.forward.apply(params, key, observation,
                                                 core_state)
   epsilon = config.evaluation_epsilon if evaluation else epsilon
   return rlax.epsilon_greedy(epsilon).sample(key, q_values), core_state
Exemplo n.º 2
0
 def select_action(rng_key, network_params, s_t, exploration_epsilon):
   """Samples action from eps-greedy policy wrt Q-values at given state."""
   rng_key, apply_key, policy_key = jax.random.split(rng_key, 3)
   q_t = network.apply(network_params, apply_key, s_t[None, ...]).q_values[0]
   a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon)
   v_t = jnp.max(q_t, axis=-1)
   return rng_key, a_t, v_t
Exemplo n.º 3
0
 def select_action(rng_key, network_params, s_t):
   """Samples action from eps-greedy policy wrt Q-values at given state."""
   rng_key, tau_key, apply_key, policy_key = jax.random.split(rng_key, 4)
   tau_t = _sample_tau(tau_key, (1, tau_samples))
   q_t = network.apply(network_params, apply_key,
                       IqnInputs(s_t[None, ...], tau_t)).q_values[0]
   a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon)
   return rng_key, a_t
Exemplo n.º 4
0
 def apply_and_sample(params: networks_lib.Params, key: networks_lib.PRNGKey,
                      observation: networks_lib.Observation, epsilon: Epsilon
                      ) -> networks_lib.Action:
   # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs.
   observation = utils.add_batch_dim(observation)
   action_values = network.apply(params, observation)
   action_values = utils.squeeze_batch_dim(action_values)
   return rlax.epsilon_greedy(epsilon).sample(key, action_values)
Exemplo n.º 5
0
def default_behavior_policy(network: networks_lib.FeedForwardNetwork,
                            epsilon: float, params: networks_lib.Params,
                            key: networks_lib.PRNGKey,
                            observation: networks_lib.Observation):
    """Returns an action for the given observation."""
    action_values = network.apply(params, observation)
    actions = rlax.epsilon_greedy(epsilon).sample(key, action_values)
    return actions.astype(jnp.int32)
Exemplo n.º 6
0
 def actor_step(self, params, env_output, actor_state, key, evaluation):
     norm_q = self._network.apply(params, env_output.observation)
     # This is equivalent to epsilon-greedy on the (unnormalized) Q-values
     # because normalization is linear, therefore the argmaxes are the same.
     train_a = rlax.epsilon_greedy(self._epsilon).sample(key, norm_q)
     eval_a = rlax.greedy().sample(key, norm_q)
     a = jax.lax.select(evaluation, eval_a, train_a)
     return ActorOutput(actions=a), actor_state
Exemplo n.º 7
0
 def actor_step(self, params, env_output, actor_state, key, evaluation):
     obs = jnp.expand_dims(env_output.observation, 0)  # add dummy batch
     q = self._network.apply(params.online, obs)[0]  # remove dummy batch
     epsilon = self._epsilon_by_frame(actor_state.count)
     train_a = rlax.epsilon_greedy(epsilon).sample(key, q)
     eval_a = rlax.greedy().sample(key, q)
     a = jax.lax.select(evaluation, eval_a, train_a)
     return ActorOutput(actions=a,
                        q_values=q), ActorState(actor_state.count + 1)
Exemplo n.º 8
0
  def select_action(params: networks_lib.Params,
                    observation: networks_lib.Observation,
                    state: R2D2ActorState[actor_core_lib.RecurrentState]):
    rng, policy_rng = jax.random.split(state.rng)

    q_values, recurrent_state = networks.forward.apply(params, policy_rng,
                                                       observation,
                                                       state.recurrent_state)
    action = rlax.epsilon_greedy(state.epsilon).sample(policy_rng, q_values)

    return action, R2D2ActorState(rng, state.epsilon, recurrent_state)
Exemplo n.º 9
0
 def _actor_step(self, all_params, all_states, observation, rng_key, evaluation):
     obs = jnp.expand_dims(observation, 0)  # dummy batch
     q_val = self._q_net.apply(all_params.online, obs)[0]  # remove batch
     epsilon = self._epsilon_schedule(all_states.actor_steps)
     train_action = rlax.epsilon_greedy(epsilon).sample(rng_key, q_val)
     eval_action = rlax.greedy().sample(rng_key, q_val)
     action = jax.lax.select(evaluation, eval_action, train_action)
     return (
         ActorOutput(actions=action, q_values=q_val),
         AllStates(
             optimizer=all_states.optimizer,
             learner_steps=all_states.learner_steps,
             actor_steps=all_states.actor_steps + 1,
         ),
     )
Exemplo n.º 10
0
 def policy(net_params, key, obs):
     """Sample action from epsilon-greedy policy."""
     q = network.apply(net_params, obs)
     a = rlax.epsilon_greedy(epsilon=FLAGS.epsilon).sample(key, q)
     return q, a
Exemplo n.º 11
0
 def policy(params: hk.Params, key: jnp.ndarray,
            observation: jnp.ndarray) -> jnp.ndarray:
     action_values = network.apply(params, observation)
     return rlax.epsilon_greedy(epsilon).sample(key, action_values)
Exemplo n.º 12
0
 def actor_step(self, params, env_output, actor_state, key, evaluation):
     q = self._network.apply(params, env_output.observation)
     train_a = rlax.epsilon_greedy(self._epsilon).sample(key, q)
     eval_a = rlax.greedy().sample(key, q)
     a = jax.lax.select(evaluation, eval_a, train_a)
     return ActorOutput(actions=a, q_values=q), actor_state
Exemplo n.º 13
0
 def policy(params: hk.Params, key: jnp.ndarray,
            observation: jnp.ndarray) -> jnp.ndarray:
     action_values = hk.without_apply_rng(
         hk.transform(network,
                      apply_rng=True)).apply(params, observation)
     return rlax.epsilon_greedy(epsilon).sample(key, action_values)
Exemplo n.º 14
0
 def apply_and_sample(params, key, obs, epsilon):
     # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs.
     obs = utils.add_batch_dim(obs)
     action_values = network.apply(params, obs)[0]
     action_values = utils.squeeze_batch_dim(action_values)
     return rlax.epsilon_greedy(epsilon).sample(key, action_values)
Exemplo n.º 15
0
 def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                       observation: jnp.DeviceArray) -> jnp.DeviceArray:
     dist_params = network.apply(params, observation)
     return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample(
         key, dist_params)
Exemplo n.º 16
0
 def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                       observation: jnp.DeviceArray) -> jnp.DeviceArray:
     action_values = policy_network.apply(params, observation)
     return rlax.epsilon_greedy(FLAGS.epsilon).sample(key, action_values)