def on_episode_end( self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: Episode, env_index: int, **kwargs ): # Check if there are multiple episodes in a batch, i.e. # "batch_mode": "truncate_episodes". if worker.policy_config["batch_mode"] == "truncate_episodes": # Make sure this episode is really done. assert episode.batch_builder.policy_collectors["default_policy"].batches[ -1 ]["dones"][-1], ( "ERROR: `on_episode_end()` should only be called " "after episode is done!" ) pole_angle = np.mean(episode.user_data["pole_angles"]) print( "episode {} (env-idx={}) ended with length {} and pole " "angles {}".format( episode.episode_id, env_index, episode.length, pole_angle ) ) episode.custom_metrics["pole_angle"] = pole_angle episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: Episode, env_index: int, **kwargs): # Make sure this episode is ongoing. assert episode.length > 0, ( "ERROR: `on_episode_step()` callback should not be called right " "after env reset!") pole_angle = abs(episode.last_observation_for()[2]) raw_angle = abs(episode.last_raw_obs_for()[2]) assert pole_angle == raw_angle episode.user_data["pole_angles"].append(pole_angle)
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: Episode, env_index: int, **kwargs): # Make sure this episode has just been started (only initial obs # logged so far). assert episode.length == 0, ( "ERROR: `on_episode_start()` callback should be called right " "after env reset!") print("episode {} (env-idx={}) started.".format( episode.episode_id, env_index)) episode.user_data["pole_angles"] = [] episode.hist_data["pole_angles"] = []
def on_postprocess_trajectory(self, *, worker: RolloutWorker, episode: Episode, agent_id: str, policy_id: str, policies: Dict[str, Policy], postprocessed_batch: SampleBatch, original_batches: Dict[str, SampleBatch], **kwargs): print("postprocessed {} steps".format(postprocessed_batch.count)) if "num_batches" not in episode.custom_metrics: episode.custom_metrics["num_batches"] = 0 episode.custom_metrics["num_batches"] += 1