def on_episode_end(
     self,
     *,
     worker: RolloutWorker,
     base_env: BaseEnv,
     policies: Dict[str, Policy],
     episode: Episode,
     env_index: int,
     **kwargs
 ):
     # Check if there are multiple episodes in a batch, i.e.
     # "batch_mode": "truncate_episodes".
     if worker.policy_config["batch_mode"] == "truncate_episodes":
         # Make sure this episode is really done.
         assert episode.batch_builder.policy_collectors["default_policy"].batches[
             -1
         ]["dones"][-1], (
             "ERROR: `on_episode_end()` should only be called "
             "after episode is done!"
         )
     pole_angle = np.mean(episode.user_data["pole_angles"])
     print(
         "episode {} (env-idx={}) ended with length {} and pole "
         "angles {}".format(
             episode.episode_id, env_index, episode.length, pole_angle
         )
     )
     episode.custom_metrics["pole_angle"] = pole_angle
     episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]
 def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv,
                     policies: Dict[str, Policy], episode: Episode,
                     env_index: int, **kwargs):
     # Make sure this episode is ongoing.
     assert episode.length > 0, (
         "ERROR: `on_episode_step()` callback should not be called right "
         "after env reset!")
     pole_angle = abs(episode.last_observation_for()[2])
     raw_angle = abs(episode.last_raw_obs_for()[2])
     assert pole_angle == raw_angle
     episode.user_data["pole_angles"].append(pole_angle)
 def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv,
                      policies: Dict[str, Policy], episode: Episode,
                      env_index: int, **kwargs):
     # Make sure this episode has just been started (only initial obs
     # logged so far).
     assert episode.length == 0, (
         "ERROR: `on_episode_start()` callback should be called right "
         "after env reset!")
     print("episode {} (env-idx={}) started.".format(
         episode.episode_id, env_index))
     episode.user_data["pole_angles"] = []
     episode.hist_data["pole_angles"] = []
 def on_postprocess_trajectory(self, *, worker: RolloutWorker,
                               episode: Episode, agent_id: str,
                               policy_id: str, policies: Dict[str, Policy],
                               postprocessed_batch: SampleBatch,
                               original_batches: Dict[str, SampleBatch],
                               **kwargs):
     print("postprocessed {} steps".format(postprocessed_batch.count))
     if "num_batches" not in episode.custom_metrics:
         episode.custom_metrics["num_batches"] = 0
     episode.custom_metrics["num_batches"] += 1