Esempio n. 1
0
 def _act(self, timestep) -> parts.Action:
     """Selects action given timestep, according to epsilon-greedy policy."""
     s_t = timestep.observation
     self._rng_key, a_t = self._select_action(self._rng_key,
                                              self._online_params, s_t,
                                              self.exploration_epsilon)
     return parts.Action(jax.device_get(a_t))
Esempio n. 2
0
 def _act(self, timestep) -> parts.Action:
     """Selects action given timestep, according to greedy policy."""
     s_t = timestep.observation
     self._rng_key, a_t, v_t = self._select_action(self._rng_key,
                                                   self._online_params, s_t)
     a_t, v_t = jax.device_get((a_t, v_t))
     self._statistics['state_value'] = v_t
     return parts.Action(a_t)
Esempio n. 3
0
  def step(self, timestep: dm_env.TimeStep) -> parts.Action:
    """Selects action given a timestep."""
    timestep = self._preprocessor(timestep)

    if timestep is None:  # Repeat action.
      return self._action

    s_t = timestep.observation
    self._rng_key, a_t = self._select_action(self._rng_key, self.network_params,
                                             s_t)
    self._action = parts.Action(jax.device_get(a_t))
    return self._action
Esempio n. 4
0
 def step(self, timestep):
   return parts.Action(0)