def _warmup(self, rollouts): model_state = {k: v.clone() for k, v in self.agent.state_dict().items()} optim_state = self.agent.optimizer.state.copy() self.agent.eval() for _ in range(20): self._inference(rollouts, 0) # Do a cache empty as sometimes cudnn searching # doesn't do that torch.cuda.empty_cache() t_inference_start = time.time() n_infers = 200 for _ in range(n_infers): self._inference(rollouts, 0) if self.world_rank == 0: logger.info( "Inference time: {:.3f} ms".format( (time.time() - t_inference_start) / n_infers * 1000 ) ) logger.info( "PyTorch CUDA Memory Cache Size: {:.3f} GB".format( torch.cuda.memory_reserved(self.device) / 1e9 ) ) self.agent.train() for _ in range(10): self._update_agent(rollouts, warmup=True) # Do a cache empty as sometimes cudnn searching # doesn't do that torch.cuda.empty_cache() t_learning_start = time.time() n_learns = 15 for _ in range(n_learns): self._update_agent(rollouts, warmup=True) if self.world_rank == 0: logger.info( "Learning time: {:.3f} ms".format( (time.time() - t_learning_start) / n_learns * 1000 ) ) logger.info(self.timing) logger.info( "PyTorch CUDA Memory Cache Size: {:.3f} GB".format( torch.cuda.memory_reserved(self.device) / 1e9 ) ) self.agent.load_state_dict(model_state) self.agent.optimizer.state = optim_state self.agent.ada_scale.zero_grad() self.timing = Timing()
def set_cpus(local_rank, world_size): local_size = min(world_size, 8) curr_process = psutil.Process() total_cpus = curr_process.cpu_affinity() total_cpu_count = len(total_cpus) # Assuming things where already set if total_cpu_count > multiprocessing.cpu_count() / world_size: orig_cpus = total_cpus total_cpus = [] for i in range(total_cpu_count // 2): total_cpus.append(orig_cpus[i]) total_cpus.append(orig_cpus[i + total_cpu_count // 2]) ptr = 0 local_cpu_count = 0 local_cpus = [] CORE_GROUPING = min( local_size, 4 if total_cpu_count / 2 >= 20 else (2 if total_cpu_count / 2 >= 10 else 1), ) CORE_GROUPING = 1 core_dist_size = max(local_size // CORE_GROUPING, 1) core_dist_rank = local_rank // CORE_GROUPING for r in range(core_dist_rank + 1): ptr += local_cpu_count local_cpu_count = total_cpu_count // core_dist_size + ( 1 if r < (total_cpu_count % core_dist_size) else 0 ) local_cpus += total_cpus[ptr : ptr + local_cpu_count] pop_inds = [ ((local_rank + offset + 1) % CORE_GROUPING) for offset in range(CORE_GROUPING - 1) ] for ind in sorted(pop_inds, reverse=True): local_cpus.pop(ind) if BPS_BENCHMARK and world_size == 1: local_cpus = total_cpus[0:12] curr_process.cpu_affinity(local_cpus) logger.info( "Rank {} uses cpus {}".format(local_rank, sorted(curr_process.cpu_affinity())) )
def requeue_job(): r"""Requeues the job by calling ``scontrol requeue ${SLURM_JOBID}`` """ if SLURM_JOBID is None: return if not REQUEUE.is_set(): return if distrib.is_initialized(): distrib.barrier() if not distrib.is_initialized() or distrib.get_rank() == 0: logger.info(f"Requeueing job {SLURM_JOBID}") subprocess.check_call(shlex.split(f"scontrol requeue {SLURM_JOBID}"))
def eval(self) -> None: r"""Main method of trainer evaluation. Calls _eval_checkpoint() that is specified in Trainer class that inherits from BaseRLTrainer Returns: None """ self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if "tensorboard" in self.config.VIDEO_OPTION: assert (len(self.config.TENSORBOARD_DIR) > 0 ), "Must specify a tensorboard directory for video display" os.makedirs(self.config.TENSORBOARD_DIR, exist_ok=True) if "disk" in self.config.VIDEO_OPTION: assert (len(self.config.VIDEO_DIR) > 0), "Must specify a directory for storing videos on disk" with TensorboardWriter( self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs, purge_step=int(self.num_frames), ) as writer: if os.path.isfile(self.config.EVAL_CKPT_PATH_DIR): # evaluate singe checkpoint self._eval_checkpoint(self.config.EVAL_CKPT_PATH_DIR, writer) else: # evaluate multiple checkpoints in order while True: current_ckpt = None while current_ckpt is None: time.sleep(2) # sleep for 2 secs before polling again current_ckpt = poll_checkpoint_folder( self.config.EVAL_CKPT_PATH_DIR, self.prev_ckpt_ind) if (current_ckpt is not None and os.path.basename(current_ckpt) == "ckpt.done.pth"): return logger.info(f"=======current_ckpt: {current_ckpt}=======") self._eval_checkpoint( checkpoint_path=current_ckpt, writer=writer, checkpoint_index=self.prev_ckpt_ind, ) self.prev_ckpt_ind += 1
def _setup_eval_config(self, checkpoint_config): r"""Sets up and returns a merged config for evaluation. Config object saved from checkpoint is merged into config file specified at evaluation time with the following overwrite priority: eval_opts > ckpt_opts > eval_cfg > ckpt_cfg If the saved config is outdated, only the eval config is returned. Args: checkpoint_config: saved config from checkpoint. Returns: Config: merged config for eval. """ config = self.config.clone() config.defrost() ckpt_cmd_opts = checkpoint_config.CMD_TRAILING_OPTS eval_cmd_opts = config.CMD_TRAILING_OPTS try: config.merge_from_other_cfg(checkpoint_config) config.merge_from_other_cfg(self.config) config.merge_from_list(ckpt_cmd_opts) config.merge_from_list(eval_cmd_opts) except KeyError: logger.info("Saved config is outdated, using solely eval config") config = self.config.clone() config.merge_from_list(eval_cmd_opts) if config.TASK_CONFIG.DATASET.SPLIT == "train": config.TASK_CONFIG.defrost() config.TASK_CONFIG.DATASET.SPLIT = "val" config.TASK_CONFIG.SIMULATOR.AGENT_0.SENSORS = self.config.SENSORS config.freeze() return config
def load_interrupted_state(filename: str = None, resume_from: str = None) -> Optional[Any]: r"""Loads the saved interrupted state :param filename: The filename of the saved state. Defaults to "${HOME}/.interrupted_states/${SLURM_JOBID}.pth" :return: The saved state if the file exists, else none """ if SLURM_JOBID is None and filename is None: return None if filename is None: filename = INTERRUPTED_STATE_FILE if not osp.exists(filename) and resume_from is not None: filename = resume_from if not osp.exists(filename): return None logger.info(f"Loading saved state from {filename}") return torch.load(filename, map_location="cpu")
def train(self) -> None: r"""Main method for DD-PPO. Returns: None """ import apex self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend ) # add_signal_handlers() self.timing = Timing() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() self.world_size = distrib.get_world_size() set_cpus(self.local_rank, self.world_size) self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank self.config.SIMULATOR_GPU_ID = self.local_rank # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += self.world_rank * self.config.SIM_BATCH_SIZE self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") double_buffered = False self._num_worker_groups = self.config.NUM_PARALLEL_SCENES self._depth = self.config.DEPTH self._color = self.config.COLOR if self.config.TASK.lower() == "pointnav": self.observation_space = SpaceDict( { "pointgoal_with_gps_compass": spaces.Box( low=0.0, high=1.0, shape=(2,), dtype=np.float32 ) } ) else: self.observation_space = SpaceDict({}) self.action_space = spaces.Discrete(4) if self._color: self.observation_space = SpaceDict( { "rgb": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=(3, *self.config.RESOLUTION), dtype=np.uint8, ), **self.observation_space.spaces, } ) if self._depth: self.observation_space = SpaceDict( { "depth": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=(1, *self.config.RESOLUTION), dtype=np.float32, ), **self.observation_space.spaces, } ) ppo_cfg = self.config.RL.PPO if not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0: os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) self.count_steps = 0 burn_steps = 0 burn_time = 0 count_checkpoints = 0 prev_time = 0 self.update = 0 LR_SCALE = ( max( np.sqrt( ppo_cfg.num_steps * self.config.SIM_BATCH_SIZE * ppo_cfg.num_accumulate_steps / ppo_cfg.num_mini_batch * self.world_size / (128 * 2) ), 1.0, ) if (self.config.RL.DDPPO.scale_lr and not self.config.RL.PPO.ada_scale) else 1.0 ) def cosine_decay(x): if x < 1: return (np.cos(x * np.pi) + 1.0) / 2.0 else: return 0.0 def warmup_fn(x): return LR_SCALE * (0.5 + 0.5 * x) def decay_fn(x): return LR_SCALE * (DECAY_TARGET + (1 - DECAY_TARGET) * cosine_decay(x)) DECAY_TARGET = ( 0.01 / LR_SCALE if self.config.RL.PPO.ada_scale or True else (0.25 / LR_SCALE if self.config.RL.DDPPO.scale_lr else 1.0) ) DECAY_PERCENT = 1.0 if self.config.RL.PPO.ada_scale or True else 0.5 WARMUP_PERCENT = ( 0.01 if (self.config.RL.DDPPO.scale_lr and not self.config.RL.PPO.ada_scale) else 0.0 ) def lr_fn(): x = self.percent_done() if x < WARMUP_PERCENT: return warmup_fn(x / WARMUP_PERCENT) else: return decay_fn((x - WARMUP_PERCENT) / DECAY_PERCENT) lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: lr_fn() ) interrupted_state = load_interrupted_state(resume_from=self.resume_from) if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.init_amp(self.config.SIM_BATCH_SIZE) self.actor_critic.init_trt(self.config.SIM_BATCH_SIZE) self.actor_critic.script_net() self.agent.init_distributed(find_unused_params=False) if self.world_rank == 0: logger.info( "agent number of trainable parameters: {}".format( sum( param.numel() for param in self.agent.parameters() if param.requires_grad ) ) ) if self._static_encoder: self._encoder = self.actor_critic.net.visual_encoder self.observation_space = SpaceDict( { "visual_features": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=self._encoder.output_shape, dtype=np.float32, ), **self.observation_space, } ) with torch.no_grad(): batch["visual_features"] = self._encoder(batch) nenvs = self.config.SIM_BATCH_SIZE rollouts = DoubleBufferedRolloutStorage( ppo_cfg.num_steps, nenvs, self.observation_space, self.action_space, ppo_cfg.hidden_size, num_recurrent_layers=self.actor_critic.num_recurrent_layers, use_data_aug=ppo_cfg.use_data_aug, aug_type=ppo_cfg.aug_type, double_buffered=double_buffered, vtrace=ppo_cfg.vtrace, ) rollouts.to(self.device) rollouts.to_fp16() self._warmup(rollouts) ( self.envs, self._observations, self._rewards, self._masks, self._rollout_infos, self._syncs, ) = construct_envs( self.config, num_worker_groups=self.config.NUM_PARALLEL_SCENES, double_buffered=double_buffered, ) def _setup_render_and_populate_initial_frame(): for idx in range(2 if double_buffered else 1): self.envs.reset(idx) batch = self._observations[idx] self._syncs[idx].wait() tree_copy_in_place( tree_select(0, rollouts[idx].storage_buffers["observations"]), batch, ) _setup_render_and_populate_initial_frame() current_episode_reward = torch.zeros(nenvs, 1) running_episode_stats = dict( count=torch.zeros(nenvs, 1,), reward=torch.zeros(nenvs, 1,), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size) ) time_per_frame_window = deque(maxlen=ppo_cfg.reward_window_size) buffer_ranges = [] for i in range(2 if double_buffered else 1): start_ind = buffer_ranges[-1].stop if i > 0 else 0 buffer_ranges.append( slice( start_ind, start_ind + self.config.SIM_BATCH_SIZE // (2 if double_buffered else 1), ) ) if interrupted_state is not None: requeue_stats = interrupted_state["requeue_stats"] self.count_steps = requeue_stats["count_steps"] self.update = requeue_stats["start_update"] count_checkpoints = requeue_stats["count_checkpoints"] prev_time = requeue_stats["prev_time"] burn_steps = requeue_stats["burn_steps"] burn_time = requeue_stats["burn_time"] self.agent.ada_scale.load_state_dict(interrupted_state["ada_scale_state"]) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) if "amp_state" in interrupted_state: apex.amp.load_state_dict(interrupted_state["amp_state"]) if "grad_scaler_state" in interrupted_state: self.agent.grad_scaler.load_state_dict( interrupted_state["grad_scaler_state"] ) with ( TensorboardWriter( self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs, purge_step=int(self.count_steps), ) if self.world_rank == 0 else contextlib.suppress() ) as writer: distrib.barrier() t_start = time.time() while not self.is_done(): t_rollout_start = time.time() if self.update == BURN_IN_UPDATES: burn_time = t_rollout_start - t_start burn_steps = self.count_steps if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( self.percent_done(), final_decay=ppo_cfg.decay_factor, ) if ( not BPS_BENCHMARK and (REQUEUE.is_set() or ((self.update + 1) % 100) == 0) and self.world_rank == 0 ): requeue_stats = dict( count_steps=self.count_steps, count_checkpoints=count_checkpoints, start_update=self.update, prev_time=(time.time() - t_start) + prev_time, burn_time=burn_time, burn_steps=burn_steps, ) def _cast(param): if "Half" in param.type(): param = param.to(dtype=torch.float32) return param save_interrupted_state( dict( state_dict={ k: _cast(v) for k, v in self.agent.state_dict().items() }, ada_scale_state=self.agent.ada_scale.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, grad_scaler_state=self.agent.grad_scaler.state_dict(), ) ) if EXIT.is_set(): self._observations = None self._rewards = None self._masks = None self._rollout_infos = None self._syncs = None del self.envs self.envs = None requeue_job() return self.agent.eval() count_steps_delta = self._n_buffered_sampling( rollouts, current_episode_reward, running_episode_stats, buffer_ranges, ppo_cfg.num_steps, num_rollouts_done_store, ) num_rollouts_done_store.add("num_done", 1) if not rollouts.vtrace: self._compute_returns(ppo_cfg, rollouts) (value_loss, action_loss, dist_entropy) = self._update_agent(rollouts) if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") lr_scheduler.step() with self.timing.add_time("Logging"): stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0, ).to(device=self.device) distrib.all_reduce(stats) stats = stats.to(device="cpu") for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i]) stats = torch.tensor( [ value_loss, action_loss, count_steps_delta, *self.envs.swap_stats, ], device=self.device, ) distrib.all_reduce(stats) stats = stats.to(device="cpu") count_steps_delta = int(stats[2].item()) self.count_steps += count_steps_delta time_per_frame_window.append( (time.time() - t_rollout_start) / count_steps_delta ) if self.world_rank == 0: losses = [ stats[0].item() / self.world_size, stats[1].item() / self.world_size, ] deltas = { k: ( (v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item() ) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], self.count_steps, ) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, self.count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, self.count_steps, ) optim = self.agent.optimizer writer.add_scalar( "optimizer/base_lr", optim.param_groups[-1]["lr"], self.count_steps, ) if "gain" in optim.param_groups[-1]: for idx, group in enumerate(optim.param_groups): writer.add_scalar( f"optimizer/lr_{idx}", group["lr"] * group["gain"], self.count_steps, ) writer.add_scalar( f"optimizer/gain_{idx}", group["gain"], self.count_steps, ) # log stats if ( self.update > 0 and self.update % self.config.LOG_INTERVAL == 0 ): logger.info( "update: {}\twindow fps: {:.3f}\ttotal fps: {:.3f}\tframes: {}".format( self.update, 1.0 / ( sum(time_per_frame_window) / len(time_per_frame_window) ), (self.count_steps - burn_steps) / ((time.time() - t_start) + prev_time - burn_time), self.count_steps, ) ) logger.info( "swap percent: {:.3f}\tscenes in use: {:.3f}\tenvs per scene: {:.3f}".format( stats[3].item() / self.world_size, stats[4].item() / self.world_size, stats[5].item() / self.world_size, ) ) logger.info( "Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count" ), ) ) logger.info(self.timing) # self.envs.print_renderer_stats() # checkpoint model if self.should_checkpoint(): self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict( step=self.count_steps, wall_clock_time=( (time.time() - t_start) + prev_time ), ), ) count_checkpoints += 1 self.update += 1 self.save_checkpoint( "ckpt.done.pth", dict( step=self.count_steps, wall_clock_time=((time.time() - t_start) + prev_time), ), ) self._observations = None self._rewards = None self._masks = None self._rollout_infos = None self._syncs = None del self.envs self.envs = None
def _eval_checkpoint( self, checkpoint_path: str, writer: TensorboardWriter, checkpoint_index: int = 0, ) -> None: r"""Evaluates a single checkpoint. Args: checkpoint_path: path of checkpoint writer: tensorboard writer object for logging to tensorboard checkpoint_index: index of cur checkpoint for logging Returns: None """ from habitat_baselines.common.environments import get_env_class # Map location CPU is almost always better than mapping to a CUDA device. ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu") if self.config.EVAL.USE_CKPT_CONFIG: config = self._setup_eval_config(ckpt_dict["config"]) else: config = self.config.clone() ppo_cfg = config.RL.PPO config.defrost() config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.SHUFFLE = False config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = -1 config.freeze() if len(self.config.VIDEO_OPTION) > 0: config.defrost() config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP") config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS") config.freeze() # logger.info(f"env config: {config}") self.envs = construct_envs_habitat(config, get_env_class(config.ENV_NAME)) self.observation_space = SpaceDict({ "pointgoal_with_gps_compass": spaces.Box(low=0.0, high=1.0, shape=(2, ), dtype=np.float32) }) if self.config.COLOR: self.observation_space = SpaceDict({ "rgb": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=(3, *self.config.RESOLUTION), dtype=np.uint8, ), **self.observation_space.spaces, }) if self.config.DEPTH: self.observation_space = SpaceDict({ "depth": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=(1, *self.config.RESOLUTION), dtype=np.float32, ), **self.observation_space.spaces, }) self.action_space = self.envs.action_spaces[0] self._setup_actor_critic_agent(ppo_cfg) self.agent.load_state_dict(ckpt_dict["state_dict"]) self.actor_critic = self.agent.actor_critic self.actor_critic.script_net() self.actor_critic.eval() observations = self.envs.reset() batch = batch_obs(observations, device=self.device) current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) test_recurrent_hidden_states = torch.zeros( self.config.NUM_PROCESSES, self.actor_critic.num_recurrent_layers, ppo_cfg.hidden_size, device=self.device, ) prev_actions = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device, dtype=torch.long) not_done_masks = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device, dtype=torch.bool) stats_episodes = dict() # dict of dicts that stores stats per episode rgb_frames = [[] for _ in range(self.config.NUM_PROCESSES) ] # type: List[List[np.ndarray]] if len(self.config.VIDEO_OPTION) > 0: os.makedirs(self.config.VIDEO_DIR, exist_ok=True) number_of_eval_episodes = self.config.TEST_EPISODE_COUNT if number_of_eval_episodes == -1: number_of_eval_episodes = sum(self.envs.number_of_episodes) else: total_num_eps = sum(self.envs.number_of_episodes) if total_num_eps < number_of_eval_episodes: logger.warn( f"Config specified {number_of_eval_episodes} eval episodes" ", dataset only has {total_num_eps}.") logger.warn(f"Evaluating with {total_num_eps} instead.") number_of_eval_episodes = total_num_eps evals_per_ep = 5 count_per_ep = defaultdict(lambda: 0) pbar = tqdm.tqdm(total=number_of_eval_episodes * evals_per_ep) self.actor_critic.eval() while (len(stats_episodes) < (number_of_eval_episodes * evals_per_ep) and self.envs.num_envs > 0): current_episodes = self.envs.current_episodes() with torch.no_grad(): (_, dist_result, test_recurrent_hidden_states) = self.actor_critic.act( batch, test_recurrent_hidden_states, prev_actions, not_done_masks, deterministic=False, ) actions = dist_result["actions"] prev_actions.copy_(actions) actions = actions.to("cpu") outputs = self.envs.step([a[0].item() for a in actions]) observations, rewards, dones, infos = [ list(x) for x in zip(*outputs) ] batch = batch_obs(observations, device=self.device) not_done_masks = torch.tensor( [[False] if done else [True] for done in dones], dtype=torch.bool, device=self.device, ) rewards = torch.tensor(rewards, dtype=torch.float, device=self.device).unsqueeze(1) current_episode_reward += rewards next_episodes = self.envs.current_episodes() envs_to_pause = [] n_envs = self.envs.num_envs for i in range(n_envs): next_count_key = ( next_episodes[i].scene_id, next_episodes[i].episode_id, ) if count_per_ep[next_count_key] == evals_per_ep: envs_to_pause.append(i) # episode ended if not_done_masks[i].item() == 0: pbar.update() episode_stats = dict() episode_stats["reward"] = current_episode_reward[i].item() episode_stats.update( self._extract_scalars_from_info(infos[i])) current_episode_reward[i] = 0 # use scene_id + episode_id as unique id for storing stats count_key = ( current_episodes[i].scene_id, current_episodes[i].episode_id, ) count_per_ep[count_key] = count_per_ep[count_key] + 1 ep_key = (count_key, count_per_ep[count_key]) stats_episodes[ep_key] = episode_stats if len(self.config.VIDEO_OPTION) > 0: generate_video( video_option=self.config.VIDEO_OPTION, video_dir=self.config.VIDEO_DIR, images=rgb_frames[i], episode_id=current_episodes[i].episode_id, checkpoint_idx=checkpoint_index, metrics=self._extract_scalars_from_info(infos[i]), tb_writer=writer, ) rgb_frames[i] = [] # episode continues elif len(self.config.VIDEO_OPTION) > 0: frame = observations_to_image(observations[i], infos[i]) rgb_frames[i].append(frame) ( self.envs, test_recurrent_hidden_states, not_done_masks, current_episode_reward, prev_actions, batch, rgb_frames, ) = self._pause_envs( envs_to_pause, self.envs, test_recurrent_hidden_states, not_done_masks, current_episode_reward, prev_actions, batch, rgb_frames, ) self.envs.close() pbar.close() num_episodes = len(stats_episodes) aggregated_stats = dict() for stat_key in next(iter(stats_episodes.values())).keys(): values = [ v[stat_key] for v in stats_episodes.values() if v[stat_key] is not None ] if len(values) > 0: aggregated_stats[stat_key] = sum(values) / len(values) else: aggregated_stats[stat_key] = 0 for k, v in aggregated_stats.items(): logger.info(f"Average episode {k}: {v:.4f}") step_id = checkpoint_index if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]: step_id = ckpt_dict["extra_state"]["step"] writer.add_scalars( "eval_reward", {"average reward": aggregated_stats["reward"]}, step_id, ) metrics = {k: v for k, v in aggregated_stats.items() if k != "reward"} if len(metrics) > 0: writer.add_scalars("eval_metrics", metrics, step_id) self.num_frames = step_id time.sleep(30)