Exemplo n.º 1
0
    def eval_policy(
            network,
            use_distribution,
            atoms,
            net_params,
            key,
            obs: chex.Array,
            lms: chex.Array):
        """Sample action from greedy policy.
        Args:
            network            -- haiku Transformed network.
            use_distribution   -- network has distributional output
            atoms              -- support of distributional output
            net_params         -- parameters (weights) of the network.
            key                -- key for categorical sampling.
            obs                -- observation.
            lm                 -- one-hot encoded legal actions
        """
        # compute q values
        # calculate q value from distributional output
        # by calculating mean of distribution
        if use_distribution:
            logits = network.apply(net_params, None, obs)
            probs = jax.nn.softmax(logits, axis=-1)
            q_vals = jnp.mean(probs * atoms, axis=-1)
            
        # q values equal network output
        else:
            q_vals = network.apply(net_params, None, obs)
        
        # mask q values of illegal actions
        q_vals = jnp.where(lms, q_vals, -jnp.inf)

        # select best action
        return rlax.greedy().sample(key, q_vals)
Exemplo n.º 2
0
    def decide(self,
               timestep: env_base.TimeStep,
               greedy: bool = False) -> base.Decision:
        key = next(self._rng)
        previous_reward = timestep.reward if timestep.reward is not None else 0.
        previous_action = timestep.action if timestep.action is not None else -1
        inputs = [
            timestep.observation[None, ...],
            jnp.ones((1, 1, 1), dtype=jnp.float32) * previous_reward,
            jax.nn.one_hot([[previous_action]], self._action_spec.num_values)
        ]
        (logits, value, policy_embedding, value_embedding,
         state_embedding), rnn_state = (self._forward(self._state.params,
                                                      inputs,
                                                      self._state.rnn_state))
        self._state = self._state._replace(rnn_state=rnn_state)

        if greedy:
            action = rlax.greedy().sample(key, logits).squeeze()
        else:
            action = jax.random.categorical(key, logits).squeeze()

        return base.Decision(action=int(action),
                             action_dist=jax.nn.softmax(logits),
                             policy_embedding=policy_embedding,
                             value=value,
                             value_embedding=value_embedding,
                             state_embedding=state_embedding)
Exemplo n.º 3
0
 def select_action(rng_key, network_params, s_t):
     """Computes greedy (argmax) action wrt Q-values at given state."""
     rng_key, apply_key, policy_key = jax.random.split(rng_key, 3)
     q_t = network.apply(network_params, apply_key,
                         s_t[None, ...]).q_values[0]
     a_t = rlax.greedy().sample(policy_key, q_t)
     v_t = jnp.max(q_t, axis=-1)
     return rng_key, a_t, v_t
Exemplo n.º 4
0
 def actor_step(self, params, env_output, actor_state, key, evaluation):
     norm_q = self._network.apply(params, env_output.observation)
     # This is equivalent to epsilon-greedy on the (unnormalized) Q-values
     # because normalization is linear, therefore the argmaxes are the same.
     train_a = rlax.epsilon_greedy(self._epsilon).sample(key, norm_q)
     eval_a = rlax.greedy().sample(key, norm_q)
     a = jax.lax.select(evaluation, eval_a, train_a)
     return ActorOutput(actions=a), actor_state
Exemplo n.º 5
0
 def actor_step(self, params, env_output, actor_state, key, evaluation):
     obs = jnp.expand_dims(env_output.observation, 0)  # add dummy batch
     q = self._network.apply(params.online, obs)[0]  # remove dummy batch
     epsilon = self._epsilon_by_frame(actor_state.count)
     train_a = rlax.epsilon_greedy(epsilon).sample(key, q)
     eval_a = rlax.greedy().sample(key, q)
     a = jax.lax.select(evaluation, eval_a, train_a)
     return ActorOutput(actions=a,
                        q_values=q), ActorState(actor_state.count + 1)
Exemplo n.º 6
0
 def _actor_step(self, all_params, all_states, observation, rng_key, evaluation):
     obs = jnp.expand_dims(observation, 0)  # dummy batch
     q_val = self._q_net.apply(all_params.online, obs)[0]  # remove batch
     epsilon = self._epsilon_schedule(all_states.actor_steps)
     train_action = rlax.epsilon_greedy(epsilon).sample(rng_key, q_val)
     eval_action = rlax.greedy().sample(rng_key, q_val)
     action = jax.lax.select(evaluation, eval_action, train_action)
     return (
         ActorOutput(actions=action, q_values=q_val),
         AllStates(
             optimizer=all_states.optimizer,
             learner_steps=all_states.learner_steps,
             actor_steps=all_states.actor_steps + 1,
         ),
     )
Exemplo n.º 7
0
 def eval_policy(network, net_params, key, obs: rlax.ArrayLike, lm: rlax.ArrayLike):
     """Sample action from greedy policy.
     Args:
         network    -- haiku Transformed network.
         net_params -- parameters (weights) of the network.
         key        -- key for categorical sampling.
         obs        -- observation.
         lm         -- one-hot encoded legal actions
     """
     # compute q
     q_vals = network.apply(net_params, obs)
     # add large negative values to illegal actions
     q_vals = jnp.where(lm, q_vals, -jnp.inf)
     # compute actions
     return rlax.greedy().sample(key, q_vals)
Exemplo n.º 8
0
    def eval_policy(
            network,
            atoms,
            net_params,
            key,
            obs: rlax.ArrayLike,
            lms: rlax.ArrayLike):
        """Sample action from greedy policy.
        Args:
            network    -- haiku Transformed network.
            net_params -- parameters (weights) of the network.
            key        -- key for categorical sampling.
            obs        -- observation.
            lm         -- one-hot encoded legal actions
        """
        # compute logits and convert those to q_vals
        logits = network.apply(net_params, None, obs)
        probs = jax.nn.softmax(logits, axis=-1)
        q_vals = jnp.mean(probs * atoms, axis=-1)

        q_vals = jnp.where(lms, q_vals, -jnp.inf)

        # compute actions
        return rlax.greedy().sample(key, q_vals)
Exemplo n.º 9
0
 def eval_policy(net_params, key, obs):
     """Sample action from greedy policy."""
     q = network.apply(net_params, obs)
     return rlax.greedy().sample(key, q)
Exemplo n.º 10
0
 def actor_step(self, params, env_output, actor_state, key, evaluation):
     q = self._network.apply(params, env_output.observation)
     train_a = rlax.epsilon_greedy(self._epsilon).sample(key, q)
     eval_a = rlax.greedy().sample(key, q)
     a = jax.lax.select(evaluation, eval_a, train_a)
     return ActorOutput(actions=a, q_values=q), actor_state