def eval_policy( network, use_distribution, atoms, net_params, key, obs: chex.Array, lms: chex.Array): """Sample action from greedy policy. Args: network -- haiku Transformed network. use_distribution -- network has distributional output atoms -- support of distributional output net_params -- parameters (weights) of the network. key -- key for categorical sampling. obs -- observation. lm -- one-hot encoded legal actions """ # compute q values # calculate q value from distributional output # by calculating mean of distribution if use_distribution: logits = network.apply(net_params, None, obs) probs = jax.nn.softmax(logits, axis=-1) q_vals = jnp.mean(probs * atoms, axis=-1) # q values equal network output else: q_vals = network.apply(net_params, None, obs) # mask q values of illegal actions q_vals = jnp.where(lms, q_vals, -jnp.inf) # select best action return rlax.greedy().sample(key, q_vals)
def decide(self, timestep: env_base.TimeStep, greedy: bool = False) -> base.Decision: key = next(self._rng) previous_reward = timestep.reward if timestep.reward is not None else 0. previous_action = timestep.action if timestep.action is not None else -1 inputs = [ timestep.observation[None, ...], jnp.ones((1, 1, 1), dtype=jnp.float32) * previous_reward, jax.nn.one_hot([[previous_action]], self._action_spec.num_values) ] (logits, value, policy_embedding, value_embedding, state_embedding), rnn_state = (self._forward(self._state.params, inputs, self._state.rnn_state)) self._state = self._state._replace(rnn_state=rnn_state) if greedy: action = rlax.greedy().sample(key, logits).squeeze() else: action = jax.random.categorical(key, logits).squeeze() return base.Decision(action=int(action), action_dist=jax.nn.softmax(logits), policy_embedding=policy_embedding, value=value, value_embedding=value_embedding, state_embedding=state_embedding)
def select_action(rng_key, network_params, s_t): """Computes greedy (argmax) action wrt Q-values at given state.""" rng_key, apply_key, policy_key = jax.random.split(rng_key, 3) q_t = network.apply(network_params, apply_key, s_t[None, ...]).q_values[0] a_t = rlax.greedy().sample(policy_key, q_t) v_t = jnp.max(q_t, axis=-1) return rng_key, a_t, v_t
def actor_step(self, params, env_output, actor_state, key, evaluation): norm_q = self._network.apply(params, env_output.observation) # This is equivalent to epsilon-greedy on the (unnormalized) Q-values # because normalization is linear, therefore the argmaxes are the same. train_a = rlax.epsilon_greedy(self._epsilon).sample(key, norm_q) eval_a = rlax.greedy().sample(key, norm_q) a = jax.lax.select(evaluation, eval_a, train_a) return ActorOutput(actions=a), actor_state
def actor_step(self, params, env_output, actor_state, key, evaluation): obs = jnp.expand_dims(env_output.observation, 0) # add dummy batch q = self._network.apply(params.online, obs)[0] # remove dummy batch epsilon = self._epsilon_by_frame(actor_state.count) train_a = rlax.epsilon_greedy(epsilon).sample(key, q) eval_a = rlax.greedy().sample(key, q) a = jax.lax.select(evaluation, eval_a, train_a) return ActorOutput(actions=a, q_values=q), ActorState(actor_state.count + 1)
def _actor_step(self, all_params, all_states, observation, rng_key, evaluation): obs = jnp.expand_dims(observation, 0) # dummy batch q_val = self._q_net.apply(all_params.online, obs)[0] # remove batch epsilon = self._epsilon_schedule(all_states.actor_steps) train_action = rlax.epsilon_greedy(epsilon).sample(rng_key, q_val) eval_action = rlax.greedy().sample(rng_key, q_val) action = jax.lax.select(evaluation, eval_action, train_action) return ( ActorOutput(actions=action, q_values=q_val), AllStates( optimizer=all_states.optimizer, learner_steps=all_states.learner_steps, actor_steps=all_states.actor_steps + 1, ), )
def eval_policy(network, net_params, key, obs: rlax.ArrayLike, lm: rlax.ArrayLike): """Sample action from greedy policy. Args: network -- haiku Transformed network. net_params -- parameters (weights) of the network. key -- key for categorical sampling. obs -- observation. lm -- one-hot encoded legal actions """ # compute q q_vals = network.apply(net_params, obs) # add large negative values to illegal actions q_vals = jnp.where(lm, q_vals, -jnp.inf) # compute actions return rlax.greedy().sample(key, q_vals)
def eval_policy( network, atoms, net_params, key, obs: rlax.ArrayLike, lms: rlax.ArrayLike): """Sample action from greedy policy. Args: network -- haiku Transformed network. net_params -- parameters (weights) of the network. key -- key for categorical sampling. obs -- observation. lm -- one-hot encoded legal actions """ # compute logits and convert those to q_vals logits = network.apply(net_params, None, obs) probs = jax.nn.softmax(logits, axis=-1) q_vals = jnp.mean(probs * atoms, axis=-1) q_vals = jnp.where(lms, q_vals, -jnp.inf) # compute actions return rlax.greedy().sample(key, q_vals)
def eval_policy(net_params, key, obs): """Sample action from greedy policy.""" q = network.apply(net_params, obs) return rlax.greedy().sample(key, q)
def actor_step(self, params, env_output, actor_state, key, evaluation): q = self._network.apply(params, env_output.observation) train_a = rlax.epsilon_greedy(self._epsilon).sample(key, q) eval_a = rlax.greedy().sample(key, q) a = jax.lax.select(evaluation, eval_a, train_a) return ActorOutput(actions=a, q_values=q), actor_state