Exemple #1
0
    def eval(self, checkpoint_path):
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))
        # Map location CPU is almost always better than mapping to a CUDA device.
        ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu")

        if self.config.EVAL.USE_CKPT_CONFIG:
            config = self._setup_eval_config(ckpt_dict["config"])
        else:
            config = self.config.clone()

        ppo_cfg = config.RL.PPO

        config.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
        config.freeze()

        if len(self.config.VIDEO_OPTION) > 0:
            config.defrost()
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
            config.freeze()

        self.env = construct_envs(config, get_env_class(config.ENV_NAME))
        self._setup_actor_critic_agent(ppo_cfg)

        self.agent.load_state_dict(ckpt_dict["state_dict"])
        self.actor_critic = self.agent.actor_critic

        # get name of performance metric, e.g. "spl"
        metric_name = self.config.TASK_CONFIG.TASK.MEASUREMENTS[0]
        metric_cfg = getattr(self.config.TASK_CONFIG.TASK, metric_name)
        measure_type = baseline_registry.get_measure(metric_cfg.TYPE)
        assert measure_type is not None, "invalid measurement type {}".format(
            metric_cfg.TYPE)
        self.metric_uuid = measure_type(sim=None, task=None,
                                        config=None)._get_uuid()

        observations = self.env.reset()
        batch = batch_obs(observations, self.device)

        current_episode_reward = torch.zeros(self.env.num_envs,
                                             1,
                                             device=self.device)

        test_recurrent_hidden_states = torch.zeros(
            self.actor_critic.net.num_recurrent_layers,
            self.config.NUM_PROCESSES,
            ppo_cfg.hidden_size,
            device=self.device,
        )
        prev_actions = torch.zeros(self.config.NUM_PROCESSES,
                                   1,
                                   device=self.device,
                                   dtype=torch.long)
        not_done_masks = torch.zeros(self.config.NUM_PROCESSES,
                                     1,
                                     device=self.device)
        stats_episodes = dict()  # dict of dicts that stores stats per episode

        rgb_frames = [[] for _ in range(self.config.NUM_PROCESSES)
                      ]  # type: List[List[np.ndarray]]
        if len(self.config.VIDEO_OPTION) > 0:
            os.makedirs(self.config.VIDEO_DIR, exist_ok=True)

        self.actor_critic.eval()
        while (len(stats_episodes) < self.config.TEST_EPISODE_COUNT
               and self.env.num_envs > 0):
            current_episodes = self.env.current_episodes()

            with torch.no_grad():
                (
                    _,
                    actions,
                    _,
                    test_recurrent_hidden_states,
                ) = self.actor_critic.act(
                    batch,
                    test_recurrent_hidden_states,
                    prev_actions,
                    not_done_masks,
                    deterministic=False,
                )

                prev_actions.copy_(actions)

            outputs = self.env.step([a[0].item() for a in actions])

            observations, rewards, dones, infos = [
                list(x) for x in zip(*outputs)
            ]
            batch = batch_obs(observations, self.device)

            not_done_masks = torch.tensor(
                [[0.0] if done else [1.0] for done in dones],
                dtype=torch.float,
                device=self.device,
            )

            rewards = torch.tensor(rewards,
                                   dtype=torch.float,
                                   device=self.device).unsqueeze(1)
            current_episode_reward += rewards
            next_episodes = self.env.current_episodes()
            envs_to_pause = []
            n_envs = self.env.num_envs
            for i in range(n_envs):
                if (
                        next_episodes[i].scene_id,
                        next_episodes[i].episode_id,
                ) in stats_episodes:
                    envs_to_pause.append(i)

                # episode ended
                if not_done_masks[i].item() == 0:
                    episode_stats = dict()
                    episode_stats[self.metric_uuid] = infos[i][
                        self.metric_uuid]
                    episode_stats["success"] = int(
                        infos[i][self.metric_uuid] > 0)
                    episode_stats["reward"] = current_episode_reward[i].item()
                    current_episode_reward[i] = 0
                    # use scene_id + episode_id as unique id for storing stats
                    stats_episodes[(
                        current_episodes[i].scene_id,
                        current_episodes[i].episode_id,
                    )] = episode_stats

                    if len(self.config.VIDEO_OPTION) > 0:
                        generate_video(
                            video_option=self.config.VIDEO_OPTION,
                            video_dir=self.config.VIDEO_DIR,
                            images=rgb_frames[i],
                            episode_id=current_episodes[i].episode_id,
                            checkpoint_idx=0,
                            metric_name=self.metric_uuid,
                            metric_value=infos[i][self.metric_uuid],
                        )

                        rgb_frames[i] = []

                # episode continues
                elif len(self.config.VIDEO_OPTION) > 0:
                    frame = observations_to_image(observations[i], infos[i])
                    rgb_frames[i].append(frame)

            (
                self.env,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            ) = self._pause_envs(
                envs_to_pause,
                self.env,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            )

        aggregated_stats = dict()
        for stat_key in next(iter(stats_episodes.values())).keys():
            aggregated_stats[stat_key] = sum(
                [v[stat_key] for v in stats_episodes.values()])
        num_episodes = len(stats_episodes)

        episode_reward_mean = aggregated_stats["reward"] / num_episodes
        episode_metric_mean = aggregated_stats[self.metric_uuid] / num_episodes
        episode_success_mean = aggregated_stats["success"] / num_episodes

        print(f"Average episode reward: {episode_reward_mean:.6f}")
        print(f"Average episode success: {episode_success_mean:.6f}")
        print(f"Average episode {self.metric_uuid}: {episode_metric_mean:.6f}")

        if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]:
            step_id = ckpt_dict["extra_state"]["step"]

        print("eval_reward", {"average reward": episode_reward_mean})
        print(
            f"eval_{self.metric_uuid}",
            {f"average {self.metric_uuid}": episode_metric_mean},
        )
        print("eval_success", {"average success": episode_success_mean})

        self.env.close()
def run(config, env, max_steps):
    r"""Main method for training PPO.

    Returns:
        None
    """

    observations = env.reset()
    batch = batch_obs(observations)

    batch = None
    observations = None

    episode_rewards = torch.zeros(env.num_envs, 1)
    episode_counts = torch.zeros(env.num_envs, 1)
    episode_dist = torch.zeros(env.num_envs, 1)
    current_episode_reward = torch.zeros(env.num_envs, 1)

    window_episode_reward = deque(maxlen=max_steps)
    window_episode_counts = deque(maxlen=max_steps)
    dist_val = deque(maxlen=max_steps)

    t_start = time.time()
    env_time = 0
    pth_time = 0
    count_steps = 0
    count_checkpoints = 0

    for update in range(max_steps):
        print(update)
        reward_sum = 0
        dist_sum = 0
        iter = 0
        rgb_frames = []
        if len(config.VIDEO_OPTION) > 0:
            os.makedirs(config.VIDEO_DIR, exist_ok=True)

        # get name of performance metric, e.g. "spl"
        metric_name = config.TASK_CONFIG.TASK.MEASUREMENTS[0]
        metric_cfg = getattr(config.TASK_CONFIG.TASK, metric_name)
        measure_type = baseline_registry.get_measure(metric_cfg.TYPE)

        for step in range(500):
            dones = [False]
            while dones[0] == False:
                outputs = env.step([env.action_spaces[0].sample()])
                observations, rewards, dones, infos = [
                    list(x) for x in zip(*outputs)
                ]
                reward_sum += rewards[0]
                dist_sum += observations[0]['pointgoal_with_gps_compass'][0]
                iter += 1

                frame = observations_to_image(observations[0], [])
                rgb_frames.append(frame)

        observations = env.reset()
        window_episode_reward.append(reward_sum / iter)
        window_episode_counts.append(iter)
        dist_val.append(dist_sum / iter)

        generate_video(
            video_option=config.VIDEO_OPTION,
            video_dir=config.VIDEO_DIR,
            images=np.array(rgb_frames),
            episode_id=update,
            checkpoint_idx=0,
            metric_name="spl",
            metric_value=1.0,
        )

        rgb_frames = []

    np.savetxt("window_episode_reward_ppo.csv",
               window_episode_reward,
               delimiter=",")
    np.savetxt("window_episode_counts_ppo.csv",
               window_episode_counts,
               delimiter=",")
    np.savetxt("episode_dist_ppo.csv", episode_dist, delimiter=",")

    env.close()
Exemple #3
0
    def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        checkpoint_index: int = 0,
    ) -> None:
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        ckpt_dict = self.load_checkpoint(checkpoint_path,
                                         map_location=self.device)

        config = self._setup_eval_config(ckpt_dict["config"])
        ppo_cfg = config.RL.PPO

        if len(self.config.VIDEO_OPTION) > 0:
            config.defrost()
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
            config.freeze()

        logger.info(f"env config: {config}")
        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))
        self._setup_actor_critic_agent(ppo_cfg)

        self.agent.load_state_dict(ckpt_dict["state_dict"])
        self.actor_critic = self.agent.actor_critic

        # get name of performance metric, e.g. "spl"
        metric_name = self.config.TASK_CONFIG.TASK.MEASUREMENTS[0]
        metric_cfg = getattr(self.config.TASK_CONFIG.TASK, metric_name)
        measure_type = baseline_registry.get_measure(metric_cfg.TYPE)
        assert measure_type is not None, "invalid measurement type {}".format(
            metric_cfg.TYPE)
        self.metric_uuid = measure_type(None, None)._get_uuid()

        observations = self.envs.reset()
        batch = batch_obs(observations)
        for sensor in batch:
            batch[sensor] = batch[sensor].to(self.device)

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)

        test_recurrent_hidden_states = torch.zeros(
            self.actor_critic.net.num_recurrent_layers,
            self.config.NUM_PROCESSES,
            ppo_cfg.hidden_size,
            device=self.device,
        )
        prev_actions = torch.zeros(self.config.NUM_PROCESSES,
                                   1,
                                   device=self.device,
                                   dtype=torch.long)
        not_done_masks = torch.zeros(self.config.NUM_PROCESSES,
                                     1,
                                     device=self.device)
        stats_episodes = dict()  # dict of dicts that stores stats per episode

        rgb_frames = [
            []
        ] * self.config.NUM_PROCESSES  # type: List[List[np.ndarray]]
        if len(self.config.VIDEO_OPTION) > 0:
            os.makedirs(self.config.VIDEO_DIR, exist_ok=True)

        while (len(stats_episodes) < self.config.TEST_EPISODE_COUNT
               and self.envs.num_envs > 0):
            current_episodes = self.envs.current_episodes()

            with torch.no_grad():
                _, actions, _, test_recurrent_hidden_states = self.actor_critic.act(
                    batch,
                    test_recurrent_hidden_states,
                    prev_actions,
                    not_done_masks,
                    deterministic=False,
                )

                prev_actions.copy_(actions)

            outputs = self.envs.step([a[0].item() for a in actions])

            observations, rewards, dones, infos = [
                list(x) for x in zip(*outputs)
            ]
            batch = batch_obs(observations)
            for sensor in batch:
                batch[sensor] = batch[sensor].to(self.device)

            not_done_masks = torch.tensor(
                [[0.0] if done else [1.0] for done in dones],
                dtype=torch.float,
                device=self.device,
            )

            rewards = torch.tensor(rewards,
                                   dtype=torch.float,
                                   device=self.device).unsqueeze(1)
            current_episode_reward += rewards
            next_episodes = self.envs.current_episodes()
            envs_to_pause = []
            n_envs = self.envs.num_envs
            for i in range(n_envs):
                if (
                        next_episodes[i].scene_id,
                        next_episodes[i].episode_id,
                ) in stats_episodes:
                    envs_to_pause.append(i)

                # episode ended
                if not_done_masks[i].item() == 0:
                    episode_stats = dict()
                    episode_stats[self.metric_uuid] = infos[i][
                        self.metric_uuid]
                    episode_stats["success"] = int(
                        infos[i][self.metric_uuid] > 0)
                    episode_stats["reward"] = current_episode_reward[i].item()
                    current_episode_reward[i] = 0
                    # use scene_id + episode_id as unique id for storing stats
                    stats_episodes[(
                        current_episodes[i].scene_id,
                        current_episodes[i].episode_id,
                    )] = episode_stats

                    if len(self.config.VIDEO_OPTION) > 0:
                        generate_video(
                            video_option=self.config.VIDEO_OPTION,
                            video_dir=self.config.VIDEO_DIR,
                            images=rgb_frames[i],
                            episode_id=current_episodes[i].episode_id,
                            checkpoint_idx=checkpoint_index,
                            metric_name=self.metric_uuid,
                            metric_value=infos[i][self.metric_uuid],
                            tb_writer=writer,
                        )

                        rgb_frames[i] = []

                # episode continues
                elif len(self.config.VIDEO_OPTION) > 0:
                    frame = observations_to_image(observations[i], infos[i])
                    rgb_frames[i].append(frame)

            # pausing self.envs with no new episode
            if len(envs_to_pause) > 0:
                state_index = list(range(self.envs.num_envs))
                for idx in reversed(envs_to_pause):
                    state_index.pop(idx)
                    self.envs.pause_at(idx)

                # indexing along the batch dimensions
                test_recurrent_hidden_states = test_recurrent_hidden_states[
                    state_index]
                not_done_masks = not_done_masks[state_index]
                current_episode_reward = current_episode_reward[state_index]
                prev_actions = prev_actions[state_index]

                for k, v in batch.items():
                    batch[k] = v[state_index]

                if len(self.config.VIDEO_OPTION) > 0:
                    rgb_frames = [rgb_frames[i] for i in state_index]

        aggregated_stats = dict()
        for stat_key in next(iter(stats_episodes.values())).keys():
            aggregated_stats[stat_key] = sum(
                [v[stat_key] for v in stats_episodes.values()])
        num_episodes = len(stats_episodes)

        episode_reward_mean = aggregated_stats["reward"] / num_episodes
        episode_metric_mean = aggregated_stats[self.metric_uuid] / num_episodes
        episode_success_mean = aggregated_stats["success"] / num_episodes

        logger.info(f"Average episode reward: {episode_reward_mean:.6f}")
        logger.info(f"Average episode success: {episode_success_mean:.6f}")
        logger.info(
            f"Average episode {self.metric_uuid}: {episode_metric_mean:.6f}")

        writer.add_scalars(
            "eval_reward",
            {"average reward": episode_reward_mean},
            checkpoint_index,
        )
        writer.add_scalars(
            f"eval_{self.metric_uuid}",
            {f"average {self.metric_uuid}": episode_metric_mean},
            checkpoint_index,
        )
        writer.add_scalars(
            "eval_success",
            {"average success": episode_success_mean},
            checkpoint_index,
        )

        self.envs.close()
Exemple #4
0
    def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        checkpoint_index: int = 0,
    ) -> None:
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        self.add_new_based_on_cfg()

        # Map location CPU is almost always better than mapping to a CUDA device.
        ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu")

        # ==========================================================================================
        # -- Update config for eval
        if self.config.EVAL.USE_CKPT_CONFIG:
            config = self._setup_eval_config(ckpt_dict["config"])
        else:
            config = self.config.clone()

        ppo_cfg = config.RL.PPO

        # # Mostly for visualization
        # config.defrost()
        # config.TASK_CONFIG.SIMULATOR.HABITAT_SIM_V0.GPU_GPU = False
        # config.freeze()

        split = config.TASK_CONFIG.DATASET.SPLIT

        config.defrost()
        config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
        config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
        config.freeze()
        # ==========================================================================================

        num_procs = self.config.NUM_PROCESSES
        device = self.device
        cfg = self.config

        logger.info(f"env config: {config}")
        self.envs = construct_envs(config, get_env_class(self.config.ENV_NAME))
        num_envs = self.envs.num_envs

        self._setup_actor_critic_agent(ppo_cfg, train=False)

        self.agent.load_state_dict(ckpt_dict["state_dict"])
        self.actor_critic = self.agent.actor_critic
        self.r_policy = self.agent.actor_critic.reachability_policy

        aux_models = self.actor_critic.net.aux_models

        other_losses = dict({
            k: torch.zeros(num_envs, 1, device=device)
            for k in aux_models.keys()
        })
        other_losses_action = dict({
            k: torch.zeros(num_envs,
                           self.envs.action_spaces[0].n,
                           device=device)
            for k in aux_models.keys()
        })

        num_steps = torch.zeros(num_envs, 1, device=device)

        # Config aux models for eval per item in batch
        for k, maux in aux_models.items():
            maux.set_per_element_loss()

        total_loss = 0

        if config.EVAL_MODE:
            self.agent.eval()
            self.r_policy.eval()

        # get name of performance metric, e.g. "spl"
        metric_name = cfg.TASK_CONFIG.TASK.MEASUREMENTS[0]
        metric_cfg = getattr(cfg.TASK_CONFIG.TASK, metric_name)
        measure_type = baseline_registry.get_measure(metric_cfg.TYPE)
        assert measure_type is not None, "invalid measurement type {}".format(
            metric_cfg.TYPE)

        self.metric_uuid = measure_type(sim=None, task=None,
                                        config=None)._get_uuid()

        observations = self.envs.reset()
        batch = batch_obs_augment_aux(observations, self.envs.get_shared_mem())

        info_data_keys = ["discovered", "collisions_wall", "collisions_prox"]
        log_data_keys = [
            "current_episode_reward", "current_episode_go_reward"
        ] + info_data_keys
        log_data = dict({
            k: torch.zeros(num_envs, 1, device=device)
            for k in log_data_keys
        })
        info_data = dict({k: log_data[k] for k in info_data_keys})

        test_recurrent_hidden_states = torch.zeros(
            self.actor_critic.net.num_recurrent_layers,
            num_procs,
            ppo_cfg.hidden_size,
            device=device,
        )
        prev_actions = torch.zeros(num_procs,
                                   1,
                                   device=device,
                                   dtype=torch.long)
        not_done_masks = torch.zeros(num_procs, 1, device=device)

        stats_episodes = dict()  # dict of dicts that stores stats per episode
        stats_episodes_scenes = dict(
        )  # dict of number of collected stats from

        # each scene
        max_test_ep_count = cfg.TEST_EPISODE_COUNT

        # TODO this should depend on number of scenes :(
        # TODO But than envs shouldn't be paused but fast-fwd to next scene
        # TODO We consider num envs == num scenes
        max_ep_per_env = max_test_ep_count / float(num_envs)

        rgb_frames = [[] for _ in range(num_procs)
                      ]  # type: List[List[np.ndarray]]

        if len(cfg.VIDEO_OPTION) > 0:
            os.makedirs(cfg.VIDEO_DIR, exist_ok=True)

        video_log_int = cfg.VIDEO_OPTION_INTERVAL
        num_frames = 0

        plot_pos = -1
        prev_true_pos = []
        prev_pred_pos = []

        while (len(stats_episodes) <= cfg.TEST_EPISODE_COUNT and num_envs > 0):
            current_episodes = self.envs.current_episodes()

            with torch.no_grad():
                prev_hidden = test_recurrent_hidden_states
                _, actions, _, test_recurrent_hidden_states, aux_out \
                    = self.actor_critic.act(
                        batch,
                        test_recurrent_hidden_states,
                        prev_actions,
                        not_done_masks,
                        deterministic=False
                    )

                prev_actions.copy_(actions)

                if 'action' in batch:
                    prev_actions = batch['action'].unsqueeze(1).to(
                        actions.device).long()

                for k, v in aux_out.items():
                    loss = aux_models[k].calc_loss(v, batch, prev_hidden,
                                                   prev_actions,
                                                   not_done_masks, actions)
                    total_loss += loss

                    if other_losses[k] is None:
                        other_losses[k] = loss
                    else:
                        other_losses[k] += loss.unsqueeze(1)
                    if len(prev_actions) == 1:
                        other_losses_action[k][0, prev_actions.item()] += \
                            loss.item()

                # ==================================================================================
                # - Hacky logs

                if plot_pos >= 0:
                    prev_true_pos.append(batch["gps_compass_start"]
                                         [plot_pos].data[:2].cpu().numpy())
                    prev_pred_pos.append(aux_out["rel_start_pos_reg"]
                                         [plot_pos].data.cpu().numpy() * 15)
                    if num_frames % 10 == 0:
                        xx, yy = [], []
                        for x, y in prev_true_pos:
                            xx.append(x)
                            yy.append(y)
                        plt.scatter(xx, yy, label="true_pos")
                        xx, yy = [], []
                        for x, y in prev_pred_pos:
                            xx.append(x)
                            yy.append(y)
                        plt.scatter(xx, yy, label="pred_pos")
                        plt.legend()
                        plt.show()
                        plt.waitforbuttonpress()
                        plt.close()
                # ==================================================================================

            num_steps += 1
            outputs = self.envs.step([a[0].item() for a in actions])

            observations, rewards, dones, infos = [
                list(x) for x in zip(*outputs)
            ]

            not_done_masks = torch.tensor(
                [[0.0] if done else [1.0] for done in dones],
                dtype=torch.float,
                device=device,
            )

            map_values = self._get_mapping(observations, aux_out)
            batch = batch_obs_augment_aux(observations,
                                          self.envs.get_shared_mem(),
                                          device=device,
                                          map_values=map_values,
                                          masks=not_done_masks)

            valid_map_size = [
                float(ifs["top_down_map"]["valid_map"].sum()) for ifs in infos
            ]
            discovered_factor = [
                infos[ix]["top_down_map"]["explored_map"].sum() /
                valid_map_size[ix] for ix in range(len(infos))
            ]

            seen_factor = [
                infos[ix]["top_down_map"]["ful_fog_of_war_mask"].sum() /
                valid_map_size[ix] for ix in range(len(infos))
            ]

            rewards = torch.tensor(rewards, dtype=torch.float,
                                   device=device).unsqueeze(1)

            log_data["current_episode_reward"] += rewards

            # -- Add intrinsic Reward
            if self.only_intrinsic_reward:
                rewards.zero_()

            if self.r_enabled:
                ir_rewards = self._add_intrinsic_reward(
                    batch, actions, rewards, not_done_masks)
                log_data["current_episode_go_reward"] += ir_rewards

                rewards += ir_rewards

            # Log other info from infos dict
            for iii, info in enumerate(infos):
                for k_info, v_info in info_data.items():
                    v_info[iii] += info[k_info]

            next_episodes = self.envs.current_episodes()

            envs_to_pause = []
            n_envs = num_envs

            for i in range(n_envs):
                scene = next_episodes[i].scene_id

                if scene not in stats_episodes_scenes:
                    stats_episodes_scenes[scene] = 0

                if stats_episodes_scenes[scene] >= max_ep_per_env:
                    envs_to_pause.append(i)

                # episode ended
                if not_done_masks[i].item() == 0:
                    episode_stats = dict()
                    episode_stats[self.metric_uuid] = infos[i][
                        self.metric_uuid]
                    episode_stats["success"] = int(
                        infos[i][self.metric_uuid] > 0)

                    for kk, vv in log_data.items():
                        episode_stats[kk] = vv[i].item()
                        vv[i] = 0

                    episode_stats["map_discovered"] = discovered_factor[i]
                    episode_stats["map_seen"] = seen_factor[i]

                    for k, v in other_losses.items():
                        episode_stats[k] = v[i].item() / num_steps[i].item()
                        other_losses_action[k][i].fill_(0)
                        other_losses[k][i] = 0

                    num_steps[i] = 0

                    # use scene_id + episode_id as unique id for storing stats
                    stats_episodes[(current_episodes[i].scene_id, current_episodes[i].episode_id)] \
                        = episode_stats

                    print(f"Episode {len(stats_episodes)} stats:",
                          episode_stats)

                    stats_episodes_scenes[current_episodes[i].scene_id] += 1

                    if len(cfg.VIDEO_OPTION
                           ) > 0 and checkpoint_index % video_log_int == 0:
                        generate_video(
                            video_option=cfg.VIDEO_OPTION,
                            video_dir=cfg.VIDEO_DIR,
                            images=rgb_frames[i],
                            episode_id=current_episodes[i].episode_id,
                            checkpoint_idx=checkpoint_index,
                            metric_name=self.metric_uuid,
                            metric_value=infos[i][self.metric_uuid],
                            tb_writer=writer,
                        )

                        rgb_frames[i] = []

                # episode continues
                elif len(cfg.VIDEO_OPTION) > 0:
                    for k, v in observations[i].items():
                        if isinstance(v, torch.Tensor):
                            observations[i][k] = v.cpu().numpy()
                    frame = observations_to_image(observations[i], infos[i])
                    rgb_frames[i].append(frame)

            # Pop done envs:
            if len(envs_to_pause) > 0:
                s_index = list(range(num_envs))
                for idx in reversed(envs_to_pause):
                    s_index.pop(idx)

                for k, v in other_losses.items():
                    other_losses[k] = other_losses[k][s_index]

                for k, v in log_data.items():
                    log_data[k] = log_data[k][s_index]

            (
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            ) = self._pause_envs(
                envs_to_pause,
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                None,
                prev_actions,
                batch,
                rgb_frames,
            )

        aggregated_stats = dict()
        for stat_key in next(iter(stats_episodes.values())).keys():
            aggregated_stats[stat_key] = sum(
                [v[stat_key] for v in stats_episodes.values()])
        num_episodes = len(stats_episodes)

        episodes_agg_stats = dict()
        for k, v in aggregated_stats.items():
            episodes_agg_stats[k] = v / num_episodes
            logger.info(f"Average episode {k}: {episodes_agg_stats[k]:.6f}")

        for k, v in episodes_agg_stats.items():
            writer.add_scalars(f"eval_{k}", {f"{split}_average {k}": v},
                               checkpoint_index)
            print(f"[{checkpoint_index}] average {k}", v)

        self.envs.close()