def _act(self, timestep) -> parts.Action: """Selects action given timestep, according to epsilon-greedy policy.""" s_t = timestep.observation self._rng_key, a_t = self._select_action(self._rng_key, self._online_params, s_t, self.exploration_epsilon) return parts.Action(jax.device_get(a_t))
def _act(self, timestep) -> parts.Action: """Selects action given timestep, according to greedy policy.""" s_t = timestep.observation self._rng_key, a_t, v_t = self._select_action(self._rng_key, self._online_params, s_t) a_t, v_t = jax.device_get((a_t, v_t)) self._statistics['state_value'] = v_t return parts.Action(a_t)
def step(self, timestep: dm_env.TimeStep) -> parts.Action: """Selects action given a timestep.""" timestep = self._preprocessor(timestep) if timestep is None: # Repeat action. return self._action s_t = timestep.observation self._rng_key, a_t = self._select_action(self._rng_key, self.network_params, s_t) self._action = parts.Action(jax.device_get(a_t)) return self._action
def step(self, timestep): return parts.Action(0)