Esempio n. 1
0
    def _warmup(self, rollouts):
        model_state = {k: v.clone() for k, v in self.agent.state_dict().items()}
        optim_state = self.agent.optimizer.state.copy()
        self.agent.eval()

        for _ in range(20):
            self._inference(rollouts, 0)
            # Do a cache empty as sometimes cudnn searching
            # doesn't do that
            torch.cuda.empty_cache()

        t_inference_start = time.time()
        n_infers = 200
        for _ in range(n_infers):
            self._inference(rollouts, 0)

        if self.world_rank == 0:
            logger.info(
                "Inference time: {:.3f} ms".format(
                    (time.time() - t_inference_start) / n_infers * 1000
                )
            )
            logger.info(
                "PyTorch CUDA Memory Cache Size: {:.3f} GB".format(
                    torch.cuda.memory_reserved(self.device) / 1e9
                )
            )

        self.agent.train()
        for _ in range(10):
            self._update_agent(rollouts, warmup=True)
            # Do a cache empty as sometimes cudnn searching
            # doesn't do that
            torch.cuda.empty_cache()

        t_learning_start = time.time()
        n_learns = 15
        for _ in range(n_learns):
            self._update_agent(rollouts, warmup=True)

        if self.world_rank == 0:
            logger.info(
                "Learning time: {:.3f} ms".format(
                    (time.time() - t_learning_start) / n_learns * 1000
                )
            )
            logger.info(self.timing)
            logger.info(
                "PyTorch CUDA Memory Cache Size: {:.3f} GB".format(
                    torch.cuda.memory_reserved(self.device) / 1e9
                )
            )

        self.agent.load_state_dict(model_state)
        self.agent.optimizer.state = optim_state
        self.agent.ada_scale.zero_grad()

        self.timing = Timing()
Esempio n. 2
0
def set_cpus(local_rank, world_size):
    local_size = min(world_size, 8)

    curr_process = psutil.Process()
    total_cpus = curr_process.cpu_affinity()
    total_cpu_count = len(total_cpus)

    # Assuming things where already set
    if total_cpu_count > multiprocessing.cpu_count() / world_size:

        orig_cpus = total_cpus
        total_cpus = []
        for i in range(total_cpu_count // 2):
            total_cpus.append(orig_cpus[i])
            total_cpus.append(orig_cpus[i + total_cpu_count // 2])

        ptr = 0
        local_cpu_count = 0
        local_cpus = []
        CORE_GROUPING = min(
            local_size,
            4 if total_cpu_count / 2 >= 20 else (2 if total_cpu_count / 2 >= 10 else 1),
        )
        CORE_GROUPING = 1
        core_dist_size = max(local_size // CORE_GROUPING, 1)
        core_dist_rank = local_rank // CORE_GROUPING

        for r in range(core_dist_rank + 1):
            ptr += local_cpu_count
            local_cpu_count = total_cpu_count // core_dist_size + (
                1 if r < (total_cpu_count % core_dist_size) else 0
            )

        local_cpus += total_cpus[ptr : ptr + local_cpu_count]
        pop_inds = [
            ((local_rank + offset + 1) % CORE_GROUPING)
            for offset in range(CORE_GROUPING - 1)
        ]
        for ind in sorted(pop_inds, reverse=True):
            local_cpus.pop(ind)

        if BPS_BENCHMARK and world_size == 1:
            local_cpus = total_cpus[0:12]

        curr_process.cpu_affinity(local_cpus)

    logger.info(
        "Rank {} uses cpus {}".format(local_rank, sorted(curr_process.cpu_affinity()))
    )
Esempio n. 3
0
def requeue_job():
    r"""Requeues the job by calling ``scontrol requeue ${SLURM_JOBID}``
    """
    if SLURM_JOBID is None:
        return

    if not REQUEUE.is_set():
        return

    if distrib.is_initialized():
        distrib.barrier()

    if not distrib.is_initialized() or distrib.get_rank() == 0:
        logger.info(f"Requeueing job {SLURM_JOBID}")
        subprocess.check_call(shlex.split(f"scontrol requeue {SLURM_JOBID}"))
Esempio n. 4
0
    def eval(self) -> None:
        r"""Main method of trainer evaluation. Calls _eval_checkpoint() that
        is specified in Trainer class that inherits from BaseRLTrainer

        Returns:
            None
        """
        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))

        if "tensorboard" in self.config.VIDEO_OPTION:
            assert (len(self.config.TENSORBOARD_DIR) > 0
                    ), "Must specify a tensorboard directory for video display"
            os.makedirs(self.config.TENSORBOARD_DIR, exist_ok=True)
        if "disk" in self.config.VIDEO_OPTION:
            assert (len(self.config.VIDEO_DIR) >
                    0), "Must specify a directory for storing videos on disk"

        with TensorboardWriter(
                self.config.TENSORBOARD_DIR,
                flush_secs=self.flush_secs,
                purge_step=int(self.num_frames),
        ) as writer:
            if os.path.isfile(self.config.EVAL_CKPT_PATH_DIR):
                # evaluate singe checkpoint
                self._eval_checkpoint(self.config.EVAL_CKPT_PATH_DIR, writer)
            else:
                # evaluate multiple checkpoints in order
                while True:
                    current_ckpt = None
                    while current_ckpt is None:
                        time.sleep(2)  # sleep for 2 secs before polling again
                        current_ckpt = poll_checkpoint_folder(
                            self.config.EVAL_CKPT_PATH_DIR, self.prev_ckpt_ind)

                        if (current_ckpt is not None
                                and os.path.basename(current_ckpt)
                                == "ckpt.done.pth"):
                            return

                    logger.info(f"=======current_ckpt: {current_ckpt}=======")
                    self._eval_checkpoint(
                        checkpoint_path=current_ckpt,
                        writer=writer,
                        checkpoint_index=self.prev_ckpt_ind,
                    )

                    self.prev_ckpt_ind += 1
Esempio n. 5
0
    def _setup_eval_config(self, checkpoint_config):
        r"""Sets up and returns a merged config for evaluation. Config
            object saved from checkpoint is merged into config file specified
            at evaluation time with the following overwrite priority:
                  eval_opts > ckpt_opts > eval_cfg > ckpt_cfg
            If the saved config is outdated, only the eval config is returned.

        Args:
            checkpoint_config: saved config from checkpoint.

        Returns:
            Config: merged config for eval.
        """

        config = self.config.clone()
        config.defrost()

        ckpt_cmd_opts = checkpoint_config.CMD_TRAILING_OPTS
        eval_cmd_opts = config.CMD_TRAILING_OPTS

        try:
            config.merge_from_other_cfg(checkpoint_config)
            config.merge_from_other_cfg(self.config)
            config.merge_from_list(ckpt_cmd_opts)
            config.merge_from_list(eval_cmd_opts)
        except KeyError:
            logger.info("Saved config is outdated, using solely eval config")
            config = self.config.clone()
            config.merge_from_list(eval_cmd_opts)
        if config.TASK_CONFIG.DATASET.SPLIT == "train":
            config.TASK_CONFIG.defrost()
            config.TASK_CONFIG.DATASET.SPLIT = "val"

        config.TASK_CONFIG.SIMULATOR.AGENT_0.SENSORS = self.config.SENSORS
        config.freeze()

        return config
Esempio n. 6
0
def load_interrupted_state(filename: str = None,
                           resume_from: str = None) -> Optional[Any]:
    r"""Loads the saved interrupted state

    :param filename: The filename of the saved state.
        Defaults to "${HOME}/.interrupted_states/${SLURM_JOBID}.pth"

    :return: The saved state if the file exists, else none
    """
    if SLURM_JOBID is None and filename is None:
        return None

    if filename is None:
        filename = INTERRUPTED_STATE_FILE

    if not osp.exists(filename) and resume_from is not None:
        filename = resume_from

    if not osp.exists(filename):
        return None

    logger.info(f"Loading saved state from {filename}")

    return torch.load(filename, map_location="cpu")
Esempio n. 7
0
    def train(self) -> None:
        r"""Main method for DD-PPO.

        Returns:
            None
        """
        import apex

        self.local_rank, tcp_store = init_distrib_slurm(
            self.config.RL.DDPPO.distrib_backend
        )
        # add_signal_handlers()
        self.timing = Timing()

        # Stores the number of workers that have finished their rollout
        num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store)
        num_rollouts_done_store.set("num_done", "0")

        self.world_rank = distrib.get_rank()
        self.world_size = distrib.get_world_size()

        set_cpus(self.local_rank, self.world_size)

        self.config.defrost()
        self.config.TORCH_GPU_ID = self.local_rank
        self.config.SIMULATOR_GPU_ID = self.local_rank
        # Multiply by the number of simulators to make sure they also get unique seeds
        self.config.TASK_CONFIG.SEED += self.world_rank * self.config.SIM_BATCH_SIZE
        self.config.freeze()

        random.seed(self.config.TASK_CONFIG.SEED)
        np.random.seed(self.config.TASK_CONFIG.SEED)
        torch.manual_seed(self.config.TASK_CONFIG.SEED)

        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")

        double_buffered = False
        self._num_worker_groups = self.config.NUM_PARALLEL_SCENES

        self._depth = self.config.DEPTH
        self._color = self.config.COLOR

        if self.config.TASK.lower() == "pointnav":
            self.observation_space = SpaceDict(
                {
                    "pointgoal_with_gps_compass": spaces.Box(
                        low=0.0, high=1.0, shape=(2,), dtype=np.float32
                    )
                }
            )
        else:
            self.observation_space = SpaceDict({})

        self.action_space = spaces.Discrete(4)

        if self._color:
            self.observation_space = SpaceDict(
                {
                    "rgb": spaces.Box(
                        low=np.finfo(np.float32).min,
                        high=np.finfo(np.float32).max,
                        shape=(3, *self.config.RESOLUTION),
                        dtype=np.uint8,
                    ),
                    **self.observation_space.spaces,
                }
            )

        if self._depth:
            self.observation_space = SpaceDict(
                {
                    "depth": spaces.Box(
                        low=np.finfo(np.float32).min,
                        high=np.finfo(np.float32).max,
                        shape=(1, *self.config.RESOLUTION),
                        dtype=np.float32,
                    ),
                    **self.observation_space.spaces,
                }
            )

        ppo_cfg = self.config.RL.PPO
        if not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0:
            os.makedirs(self.config.CHECKPOINT_FOLDER)

        self._setup_actor_critic_agent(ppo_cfg)
        self.count_steps = 0
        burn_steps = 0
        burn_time = 0
        count_checkpoints = 0
        prev_time = 0
        self.update = 0

        LR_SCALE = (
            max(
                np.sqrt(
                    ppo_cfg.num_steps
                    * self.config.SIM_BATCH_SIZE
                    * ppo_cfg.num_accumulate_steps
                    / ppo_cfg.num_mini_batch
                    * self.world_size
                    / (128 * 2)
                ),
                1.0,
            )
            if (self.config.RL.DDPPO.scale_lr and not self.config.RL.PPO.ada_scale)
            else 1.0
        )

        def cosine_decay(x):
            if x < 1:
                return (np.cos(x * np.pi) + 1.0) / 2.0
            else:
                return 0.0

        def warmup_fn(x):
            return LR_SCALE * (0.5 + 0.5 * x)

        def decay_fn(x):
            return LR_SCALE * (DECAY_TARGET + (1 - DECAY_TARGET) * cosine_decay(x))

        DECAY_TARGET = (
            0.01 / LR_SCALE
            if self.config.RL.PPO.ada_scale or True
            else (0.25 / LR_SCALE if self.config.RL.DDPPO.scale_lr else 1.0)
        )
        DECAY_PERCENT = 1.0 if self.config.RL.PPO.ada_scale or True else 0.5
        WARMUP_PERCENT = (
            0.01
            if (self.config.RL.DDPPO.scale_lr and not self.config.RL.PPO.ada_scale)
            else 0.0
        )

        def lr_fn():
            x = self.percent_done()
            if x < WARMUP_PERCENT:
                return warmup_fn(x / WARMUP_PERCENT)
            else:
                return decay_fn((x - WARMUP_PERCENT) / DECAY_PERCENT)

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer, lr_lambda=lambda x: lr_fn()
        )

        interrupted_state = load_interrupted_state(resume_from=self.resume_from)
        if interrupted_state is not None:
            self.agent.load_state_dict(interrupted_state["state_dict"])

        self.agent.init_amp(self.config.SIM_BATCH_SIZE)
        self.actor_critic.init_trt(self.config.SIM_BATCH_SIZE)
        self.actor_critic.script_net()
        self.agent.init_distributed(find_unused_params=False)

        if self.world_rank == 0:
            logger.info(
                "agent number of trainable parameters: {}".format(
                    sum(
                        param.numel()
                        for param in self.agent.parameters()
                        if param.requires_grad
                    )
                )
            )

        if self._static_encoder:
            self._encoder = self.actor_critic.net.visual_encoder
            self.observation_space = SpaceDict(
                {
                    "visual_features": spaces.Box(
                        low=np.finfo(np.float32).min,
                        high=np.finfo(np.float32).max,
                        shape=self._encoder.output_shape,
                        dtype=np.float32,
                    ),
                    **self.observation_space,
                }
            )
            with torch.no_grad():
                batch["visual_features"] = self._encoder(batch)

        nenvs = self.config.SIM_BATCH_SIZE
        rollouts = DoubleBufferedRolloutStorage(
            ppo_cfg.num_steps,
            nenvs,
            self.observation_space,
            self.action_space,
            ppo_cfg.hidden_size,
            num_recurrent_layers=self.actor_critic.num_recurrent_layers,
            use_data_aug=ppo_cfg.use_data_aug,
            aug_type=ppo_cfg.aug_type,
            double_buffered=double_buffered,
            vtrace=ppo_cfg.vtrace,
        )
        rollouts.to(self.device)
        rollouts.to_fp16()

        self._warmup(rollouts)

        (
            self.envs,
            self._observations,
            self._rewards,
            self._masks,
            self._rollout_infos,
            self._syncs,
        ) = construct_envs(
            self.config,
            num_worker_groups=self.config.NUM_PARALLEL_SCENES,
            double_buffered=double_buffered,
        )

        def _setup_render_and_populate_initial_frame():
            for idx in range(2 if double_buffered else 1):
                self.envs.reset(idx)

                batch = self._observations[idx]
                self._syncs[idx].wait()

                tree_copy_in_place(
                    tree_select(0, rollouts[idx].storage_buffers["observations"]),
                    batch,
                )

        _setup_render_and_populate_initial_frame()

        current_episode_reward = torch.zeros(nenvs, 1)
        running_episode_stats = dict(
            count=torch.zeros(nenvs, 1,), reward=torch.zeros(nenvs, 1,),
        )

        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size)
        )
        time_per_frame_window = deque(maxlen=ppo_cfg.reward_window_size)

        buffer_ranges = []
        for i in range(2 if double_buffered else 1):
            start_ind = buffer_ranges[-1].stop if i > 0 else 0
            buffer_ranges.append(
                slice(
                    start_ind,
                    start_ind
                    + self.config.SIM_BATCH_SIZE // (2 if double_buffered else 1),
                )
            )

        if interrupted_state is not None:
            requeue_stats = interrupted_state["requeue_stats"]

            self.count_steps = requeue_stats["count_steps"]
            self.update = requeue_stats["start_update"]
            count_checkpoints = requeue_stats["count_checkpoints"]
            prev_time = requeue_stats["prev_time"]
            burn_steps = requeue_stats["burn_steps"]
            burn_time = requeue_stats["burn_time"]

            self.agent.ada_scale.load_state_dict(interrupted_state["ada_scale_state"])

            lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"])

            if "amp_state" in interrupted_state:
                apex.amp.load_state_dict(interrupted_state["amp_state"])

            if "grad_scaler_state" in interrupted_state:
                self.agent.grad_scaler.load_state_dict(
                    interrupted_state["grad_scaler_state"]
                )

        with (
            TensorboardWriter(
                self.config.TENSORBOARD_DIR,
                flush_secs=self.flush_secs,
                purge_step=int(self.count_steps),
            )
            if self.world_rank == 0
            else contextlib.suppress()
        ) as writer:
            distrib.barrier()
            t_start = time.time()
            while not self.is_done():
                t_rollout_start = time.time()
                if self.update == BURN_IN_UPDATES:
                    burn_time = t_rollout_start - t_start
                    burn_steps = self.count_steps

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        self.percent_done(), final_decay=ppo_cfg.decay_factor,
                    )

                if (
                    not BPS_BENCHMARK
                    and (REQUEUE.is_set() or ((self.update + 1) % 100) == 0)
                    and self.world_rank == 0
                ):
                    requeue_stats = dict(
                        count_steps=self.count_steps,
                        count_checkpoints=count_checkpoints,
                        start_update=self.update,
                        prev_time=(time.time() - t_start) + prev_time,
                        burn_time=burn_time,
                        burn_steps=burn_steps,
                    )

                    def _cast(param):
                        if "Half" in param.type():
                            param = param.to(dtype=torch.float32)

                        return param

                    save_interrupted_state(
                        dict(
                            state_dict={
                                k: _cast(v) for k, v in self.agent.state_dict().items()
                            },
                            ada_scale_state=self.agent.ada_scale.state_dict(),
                            lr_sched_state=lr_scheduler.state_dict(),
                            config=self.config,
                            requeue_stats=requeue_stats,
                            grad_scaler_state=self.agent.grad_scaler.state_dict(),
                        )
                    )

                if EXIT.is_set():
                    self._observations = None
                    self._rewards = None
                    self._masks = None
                    self._rollout_infos = None
                    self._syncs = None

                    del self.envs
                    self.envs = None

                    requeue_job()
                    return

                self.agent.eval()

                count_steps_delta = self._n_buffered_sampling(
                    rollouts,
                    current_episode_reward,
                    running_episode_stats,
                    buffer_ranges,
                    ppo_cfg.num_steps,
                    num_rollouts_done_store,
                )

                num_rollouts_done_store.add("num_done", 1)

                if not rollouts.vtrace:
                    self._compute_returns(ppo_cfg, rollouts)

                (value_loss, action_loss, dist_entropy) = self._update_agent(rollouts)

                if self.world_rank == 0:
                    num_rollouts_done_store.set("num_done", "0")

                lr_scheduler.step()

                with self.timing.add_time("Logging"):
                    stats_ordering = list(sorted(running_episode_stats.keys()))
                    stats = torch.stack(
                        [running_episode_stats[k] for k in stats_ordering], 0,
                    ).to(device=self.device)
                    distrib.all_reduce(stats)
                    stats = stats.to(device="cpu")

                    for i, k in enumerate(stats_ordering):
                        window_episode_stats[k].append(stats[i])

                    stats = torch.tensor(
                        [
                            value_loss,
                            action_loss,
                            count_steps_delta,
                            *self.envs.swap_stats,
                        ],
                        device=self.device,
                    )
                    distrib.all_reduce(stats)
                    stats = stats.to(device="cpu")
                    count_steps_delta = int(stats[2].item())
                    self.count_steps += count_steps_delta

                    time_per_frame_window.append(
                        (time.time() - t_rollout_start) / count_steps_delta
                    )

                    if self.world_rank == 0:
                        losses = [
                            stats[0].item() / self.world_size,
                            stats[1].item() / self.world_size,
                        ]
                        deltas = {
                            k: (
                                (v[-1] - v[0]).sum().item()
                                if len(v) > 1
                                else v[0].sum().item()
                            )
                            for k, v in window_episode_stats.items()
                        }
                        deltas["count"] = max(deltas["count"], 1.0)

                        writer.add_scalar(
                            "reward",
                            deltas["reward"] / deltas["count"],
                            self.count_steps,
                        )

                        # Check to see if there are any metrics
                        # that haven't been logged yet
                        metrics = {
                            k: v / deltas["count"]
                            for k, v in deltas.items()
                            if k not in {"reward", "count"}
                        }
                        if len(metrics) > 0:
                            writer.add_scalars("metrics", metrics, self.count_steps)

                        writer.add_scalars(
                            "losses",
                            {k: l for l, k in zip(losses, ["value", "policy"])},
                            self.count_steps,
                        )

                        optim = self.agent.optimizer
                        writer.add_scalar(
                            "optimizer/base_lr",
                            optim.param_groups[-1]["lr"],
                            self.count_steps,
                        )
                        if "gain" in optim.param_groups[-1]:
                            for idx, group in enumerate(optim.param_groups):
                                writer.add_scalar(
                                    f"optimizer/lr_{idx}",
                                    group["lr"] * group["gain"],
                                    self.count_steps,
                                )
                                writer.add_scalar(
                                    f"optimizer/gain_{idx}",
                                    group["gain"],
                                    self.count_steps,
                                )

                        # log stats
                        if (
                            self.update > 0
                            and self.update % self.config.LOG_INTERVAL == 0
                        ):
                            logger.info(
                                "update: {}\twindow fps: {:.3f}\ttotal fps: {:.3f}\tframes: {}".format(
                                    self.update,
                                    1.0
                                    / (
                                        sum(time_per_frame_window)
                                        / len(time_per_frame_window)
                                    ),
                                    (self.count_steps - burn_steps)
                                    / ((time.time() - t_start) + prev_time - burn_time),
                                    self.count_steps,
                                )
                            )

                            logger.info(
                                "swap percent: {:.3f}\tscenes in use: {:.3f}\tenvs per scene: {:.3f}".format(
                                    stats[3].item() / self.world_size,
                                    stats[4].item() / self.world_size,
                                    stats[5].item() / self.world_size,
                                )
                            )

                            logger.info(
                                "Average window size: {}  {}".format(
                                    len(window_episode_stats["count"]),
                                    "  ".join(
                                        "{}: {:.3f}".format(k, v / deltas["count"])
                                        for k, v in deltas.items()
                                        if k != "count"
                                    ),
                                )
                            )

                            logger.info(self.timing)
                            # self.envs.print_renderer_stats()

                        # checkpoint model
                        if self.should_checkpoint():
                            self.save_checkpoint(
                                f"ckpt.{count_checkpoints}.pth",
                                dict(
                                    step=self.count_steps,
                                    wall_clock_time=(
                                        (time.time() - t_start) + prev_time
                                    ),
                                ),
                            )
                            count_checkpoints += 1

                self.update += 1

            self.save_checkpoint(
                "ckpt.done.pth",
                dict(
                    step=self.count_steps,
                    wall_clock_time=((time.time() - t_start) + prev_time),
                ),
            )
            self._observations = None
            self._rewards = None
            self._masks = None
            self._rollout_infos = None
            self._syncs = None
            del self.envs
            self.envs = None
Esempio n. 8
0
    def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        checkpoint_index: int = 0,
    ) -> None:
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """

        from habitat_baselines.common.environments import get_env_class

        # Map location CPU is almost always better than mapping to a CUDA device.
        ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu")

        if self.config.EVAL.USE_CKPT_CONFIG:
            config = self._setup_eval_config(ckpt_dict["config"])
        else:
            config = self.config.clone()

        ppo_cfg = config.RL.PPO

        config.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
        config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.SHUFFLE = False
        config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = -1
        config.freeze()

        if len(self.config.VIDEO_OPTION) > 0:
            config.defrost()
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
            config.freeze()

        #  logger.info(f"env config: {config}")
        self.envs = construct_envs_habitat(config,
                                           get_env_class(config.ENV_NAME))
        self.observation_space = SpaceDict({
            "pointgoal_with_gps_compass":
            spaces.Box(low=0.0, high=1.0, shape=(2, ), dtype=np.float32)
        })

        if self.config.COLOR:
            self.observation_space = SpaceDict({
                "rgb":
                spaces.Box(
                    low=np.finfo(np.float32).min,
                    high=np.finfo(np.float32).max,
                    shape=(3, *self.config.RESOLUTION),
                    dtype=np.uint8,
                ),
                **self.observation_space.spaces,
            })

        if self.config.DEPTH:
            self.observation_space = SpaceDict({
                "depth":
                spaces.Box(
                    low=np.finfo(np.float32).min,
                    high=np.finfo(np.float32).max,
                    shape=(1, *self.config.RESOLUTION),
                    dtype=np.float32,
                ),
                **self.observation_space.spaces,
            })

        self.action_space = self.envs.action_spaces[0]
        self._setup_actor_critic_agent(ppo_cfg)

        self.agent.load_state_dict(ckpt_dict["state_dict"])
        self.actor_critic = self.agent.actor_critic
        self.actor_critic.script_net()
        self.actor_critic.eval()

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)

        test_recurrent_hidden_states = torch.zeros(
            self.config.NUM_PROCESSES,
            self.actor_critic.num_recurrent_layers,
            ppo_cfg.hidden_size,
            device=self.device,
        )
        prev_actions = torch.zeros(self.config.NUM_PROCESSES,
                                   1,
                                   device=self.device,
                                   dtype=torch.long)
        not_done_masks = torch.zeros(self.config.NUM_PROCESSES,
                                     1,
                                     device=self.device,
                                     dtype=torch.bool)
        stats_episodes = dict()  # dict of dicts that stores stats per episode

        rgb_frames = [[] for _ in range(self.config.NUM_PROCESSES)
                      ]  # type: List[List[np.ndarray]]
        if len(self.config.VIDEO_OPTION) > 0:
            os.makedirs(self.config.VIDEO_DIR, exist_ok=True)

        number_of_eval_episodes = self.config.TEST_EPISODE_COUNT
        if number_of_eval_episodes == -1:
            number_of_eval_episodes = sum(self.envs.number_of_episodes)
        else:
            total_num_eps = sum(self.envs.number_of_episodes)
            if total_num_eps < number_of_eval_episodes:
                logger.warn(
                    f"Config specified {number_of_eval_episodes} eval episodes"
                    ", dataset only has {total_num_eps}.")
                logger.warn(f"Evaluating with {total_num_eps} instead.")
                number_of_eval_episodes = total_num_eps

        evals_per_ep = 5
        count_per_ep = defaultdict(lambda: 0)

        pbar = tqdm.tqdm(total=number_of_eval_episodes * evals_per_ep)
        self.actor_critic.eval()
        while (len(stats_episodes) < (number_of_eval_episodes * evals_per_ep)
               and self.envs.num_envs > 0):
            current_episodes = self.envs.current_episodes()

            with torch.no_grad():
                (_, dist_result,
                 test_recurrent_hidden_states) = self.actor_critic.act(
                     batch,
                     test_recurrent_hidden_states,
                     prev_actions,
                     not_done_masks,
                     deterministic=False,
                 )
                actions = dist_result["actions"]

                prev_actions.copy_(actions)
                actions = actions.to("cpu")

            outputs = self.envs.step([a[0].item() for a in actions])

            observations, rewards, dones, infos = [
                list(x) for x in zip(*outputs)
            ]
            batch = batch_obs(observations, device=self.device)

            not_done_masks = torch.tensor(
                [[False] if done else [True] for done in dones],
                dtype=torch.bool,
                device=self.device,
            )

            rewards = torch.tensor(rewards,
                                   dtype=torch.float,
                                   device=self.device).unsqueeze(1)
            current_episode_reward += rewards
            next_episodes = self.envs.current_episodes()
            envs_to_pause = []
            n_envs = self.envs.num_envs
            for i in range(n_envs):
                next_count_key = (
                    next_episodes[i].scene_id,
                    next_episodes[i].episode_id,
                )

                if count_per_ep[next_count_key] == evals_per_ep:
                    envs_to_pause.append(i)

                # episode ended
                if not_done_masks[i].item() == 0:
                    pbar.update()
                    episode_stats = dict()
                    episode_stats["reward"] = current_episode_reward[i].item()
                    episode_stats.update(
                        self._extract_scalars_from_info(infos[i]))
                    current_episode_reward[i] = 0
                    # use scene_id + episode_id as unique id for storing stats
                    count_key = (
                        current_episodes[i].scene_id,
                        current_episodes[i].episode_id,
                    )
                    count_per_ep[count_key] = count_per_ep[count_key] + 1

                    ep_key = (count_key, count_per_ep[count_key])
                    stats_episodes[ep_key] = episode_stats

                    if len(self.config.VIDEO_OPTION) > 0:
                        generate_video(
                            video_option=self.config.VIDEO_OPTION,
                            video_dir=self.config.VIDEO_DIR,
                            images=rgb_frames[i],
                            episode_id=current_episodes[i].episode_id,
                            checkpoint_idx=checkpoint_index,
                            metrics=self._extract_scalars_from_info(infos[i]),
                            tb_writer=writer,
                        )

                        rgb_frames[i] = []

                # episode continues
                elif len(self.config.VIDEO_OPTION) > 0:
                    frame = observations_to_image(observations[i], infos[i])
                    rgb_frames[i].append(frame)

            (
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            ) = self._pause_envs(
                envs_to_pause,
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            )

        self.envs.close()
        pbar.close()
        num_episodes = len(stats_episodes)
        aggregated_stats = dict()
        for stat_key in next(iter(stats_episodes.values())).keys():
            values = [
                v[stat_key] for v in stats_episodes.values()
                if v[stat_key] is not None
            ]
            if len(values) > 0:
                aggregated_stats[stat_key] = sum(values) / len(values)
            else:
                aggregated_stats[stat_key] = 0

        for k, v in aggregated_stats.items():
            logger.info(f"Average episode {k}: {v:.4f}")

        step_id = checkpoint_index
        if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]:
            step_id = ckpt_dict["extra_state"]["step"]

        writer.add_scalars(
            "eval_reward",
            {"average reward": aggregated_stats["reward"]},
            step_id,
        )

        metrics = {k: v for k, v in aggregated_stats.items() if k != "reward"}
        if len(metrics) > 0:
            writer.add_scalars("eval_metrics", metrics, step_id)

        self.num_frames = step_id

        time.sleep(30)