def step_episode(self): """Take a single time-step in the current episode. Returns: bool: True iff the episode is done, either due to the environment indicating termination of due to reaching `max_episode_length`. """ if self._eps_length < self._max_episode_length: a, agent_info = self.agent.get_action(self._prev_obs) if self._deterministic: a = agent_info['mean'] a, agent_info = self.agent.get_action(self._prev_obs) es = self.env.step(a) self._observations.append(self._prev_obs) self._env_steps.append(es) for k, v in agent_info.items(): self._agent_infos[k].append(v) self._eps_length += 1 if self._accum_context: s = TimeStep.from_env_step(env_step=es, last_observation=self._prev_obs, agent_info=agent_info, episode_info=self._episode_info) self.agent.update_context(s) if not es.last: self._prev_obs = es.observation return False self._lengths.append(self._eps_length) self._last_observations.append(self._prev_obs) return True
def test_from_env_step_time_step(sample_data): agent_info = sample_data['agent_info'] last_observation = sample_data['observation'] observation = sample_data['next_observation'] time_step = TimeStep(**sample_data) del sample_data['agent_info'] del sample_data['next_observation'] sample_data['observation'] = observation env_step = EnvStep(**sample_data) time_step_new = TimeStep.from_env_step(env_step=env_step, last_observation=last_observation, agent_info=agent_info) assert time_step == time_step_new