def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        checkpoint_index: int = 0,
    ) -> None:
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        # Map location CPU is almost always better than mapping to a CUDA device.
        ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu")

        if self.config.EVAL.USE_CKPT_CONFIG:
            config = self._setup_eval_config(ckpt_dict["config"])
        else:
            config = self.config.clone()

        ppo_cfg = config.RL.PPO

        config.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
        config.freeze()

        if len(self.config.VIDEO_OPTION) > 0:
            config.defrost()
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
            config.freeze()

        logger.info(f"env config: {config}")
        self.envs = construct_envs(config, get_env_class(config.ENV_NAME))
        self._setup_actor_critic_agent(ppo_cfg)
        self.actor_critic.eval()

        if self._static_encoder:
            self._encoder = self.agent.actor_critic.net.visual_encoder

        self.agent.load_state_dict(ckpt_dict["state_dict"])
        self.actor_critic = self.agent.actor_critic

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        if self._static_encoder:
            batch["visual_features"] = self._encoder(batch)
            batch["prev_visual_features"] = torch.zeros_like(
                batch["visual_features"])

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)

        test_recurrent_hidden_states = torch.zeros(
            self.actor_critic.net.num_recurrent_layers,
            self.config.NUM_PROCESSES,
            ppo_cfg.hidden_size,
            device=self.device,
        )
        prev_actions = torch.zeros(self.config.NUM_PROCESSES,
                                   1,
                                   device=self.device,
                                   dtype=torch.long)
        not_done_masks = torch.zeros(self.config.NUM_PROCESSES,
                                     1,
                                     device=self.device)
        stats_episodes = dict()  # dict of dicts that stores stats per episode

        rgb_frames = [[] for _ in range(self.config.NUM_PROCESSES)
                      ]  # type: List[List[np.ndarray]]
        if len(self.config.VIDEO_OPTION) > 0:
            os.makedirs(self.config.VIDEO_DIR, exist_ok=True)

        number_of_eval_episodes = self.config.TEST_EPISODE_COUNT
        if number_of_eval_episodes == -1:
            number_of_eval_episodes = sum(self.envs.number_of_episodes)
        else:
            total_num_eps = sum(self.envs.number_of_episodes)
            if total_num_eps < number_of_eval_episodes:
                logger.warn(
                    f"Config specified {number_of_eval_episodes} eval episodes"
                    ", dataset only has {total_num_eps}.")
                logger.warn(f"Evaluating with {total_num_eps} instead.")
                number_of_eval_episodes = total_num_eps

        pbar = tqdm.tqdm(total=number_of_eval_episodes)
        self.actor_critic.eval()
        while (len(stats_episodes) < number_of_eval_episodes
               and self.envs.num_envs > 0):
            current_episodes = self.envs.current_episodes()

            with torch.no_grad():
                step_batch = batch
                (
                    _,
                    actions,
                    _,
                    test_recurrent_hidden_states,
                ) = self.actor_critic.act(
                    batch,
                    test_recurrent_hidden_states,
                    prev_actions,
                    not_done_masks,
                    deterministic=False,
                )

                prev_actions.copy_(actions)

            outputs = self.envs.step([a[0].item() for a in actions])

            observations, rewards, dones, infos = [
                list(x) for x in zip(*outputs)
            ]
            batch = batch_obs(observations, device=self.device)

            if self._static_encoder:
                batch["prev_visual_features"] = step_batch["visual_features"]
                batch["visual_features"] = self._encoder(batch)

            not_done_masks = torch.tensor(
                [[0.0] if done else [1.0] for done in dones],
                dtype=torch.float,
                device=self.device,
            )

            rewards = torch.tensor(rewards,
                                   dtype=torch.float,
                                   device=self.device).unsqueeze(1)
            current_episode_reward += rewards
            next_episodes = self.envs.current_episodes()
            envs_to_pause = []
            n_envs = self.envs.num_envs
            for i in range(n_envs):
                if (
                        next_episodes[i].scene_id,
                        next_episodes[i].episode_id,
                ) in stats_episodes:
                    envs_to_pause.append(i)

                # episode ended
                if not_done_masks[i].item() == 0:
                    pbar.update()
                    episode_stats = dict()
                    episode_stats["reward"] = current_episode_reward[i].item()
                    episode_stats.update(
                        self._extract_scalars_from_info(infos[i]))
                    current_episode_reward[i] = 0
                    # use scene_id + episode_id as unique id for storing stats
                    stats_episodes[(
                        current_episodes[i].scene_id,
                        current_episodes[i].episode_id,
                    )] = episode_stats

                    if len(self.config.VIDEO_OPTION) > 0:
                        generate_video(
                            video_option=self.config.VIDEO_OPTION,
                            video_dir=self.config.VIDEO_DIR,
                            images=rgb_frames[i],
                            episode_id=current_episodes[i].episode_id,
                            checkpoint_idx=checkpoint_index,
                            metrics=self._extract_scalars_from_info(infos[i]),
                            tb_writer=writer,
                        )

                        rgb_frames[i] = []

                # episode continues
                elif len(self.config.VIDEO_OPTION) > 0:
                    frame = observations_to_image(observations[i], infos[i])
                    rgb_frames[i].append(frame)

            (
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            ) = self._pause_envs(
                envs_to_pause,
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            )

        num_episodes = len(stats_episodes)
        aggregated_stats = dict()
        for stat_key in next(iter(stats_episodes.values())).keys():
            aggregated_stats[stat_key] = (
                sum([v[stat_key]
                     for v in stats_episodes.values()]) / num_episodes)

        for k, v in aggregated_stats.items():
            logger.info(f"Average episode {k}: {v:.4f}")

        step_id = checkpoint_index
        if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]:
            step_id = ckpt_dict["extra_state"]["step"]

        writer.add_scalars(
            "eval_reward",
            {"average reward": aggregated_stats["reward"]},
            step_id,
        )

        metrics = {k: v for k, v in aggregated_stats.items() if k != "reward"}
        if len(metrics) > 0:
            writer.add_scalars("eval_metrics", metrics, step_id)

        self.envs.close()
Exemple #2
0
    def _eval_checkpoint(self,
                         checkpoint_path: str,
                         writer: TensorboardWriter,
                         checkpoint_index: int = 0) -> None:
        r"""Evaluates a single checkpoint. Assumes episode IDs are unique.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        logger.info(f"checkpoint_path: {checkpoint_path}")

        if self.config.EVAL.USE_CKPT_CONFIG:
            config = self._setup_eval_config(
                self.load_checkpoint(checkpoint_path,
                                     map_location="cpu")["config"])
        else:
            config = self.config.clone()

        config.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
        config.TASK_CONFIG.TASK.NDTW.SPLIT = config.EVAL.SPLIT
        config.TASK_CONFIG.TASK.SDTW.SPLIT = config.EVAL.SPLIT
        config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.SHUFFLE = False
        config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = -1
        if len(config.VIDEO_OPTION) > 0:
            config.defrost()
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")

        config.freeze()

        # setup agent
        self.envs = construct_envs_auto_reset_false(
            config, get_env_class(config.ENV_NAME))
        self.device = (torch.device("cuda", config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))

        self._setup_actor_critic_agent(config.MODEL, True, checkpoint_path)

        observations = self.envs.reset()
        observations = transform_obs(
            observations, config.TASK_CONFIG.TASK.INSTRUCTION_SENSOR_UUID)
        batch = batch_obs(observations, self.device)

        eval_recurrent_hidden_states = torch.zeros(
            self.actor_critic.net.num_recurrent_layers,
            config.NUM_PROCESSES,
            config.MODEL.STATE_ENCODER.hidden_size,
            device=self.device,
        )
        prev_actions = torch.zeros(config.NUM_PROCESSES,
                                   1,
                                   device=self.device,
                                   dtype=torch.long)
        not_done_masks = torch.zeros(config.NUM_PROCESSES,
                                     1,
                                     device=self.device)

        stats_episodes = {}  # dict of dicts that stores stats per episode

        if len(config.VIDEO_OPTION) > 0:
            os.makedirs(config.VIDEO_DIR, exist_ok=True)
            rgb_frames = [[] for _ in range(config.NUM_PROCESSES)]

        self.actor_critic.eval()
        while (self.envs.num_envs > 0
               and len(stats_episodes) < config.EVAL.EPISODE_COUNT):
            current_episodes = self.envs.current_episodes()

            with torch.no_grad():
                (_, actions, _,
                 eval_recurrent_hidden_states) = self.actor_critic.act(
                     batch,
                     eval_recurrent_hidden_states,
                     prev_actions,
                     not_done_masks,
                     deterministic=True,
                 )
                # actions = batch["vln_oracle_action_sensor"].long()
                prev_actions.copy_(actions)

            outputs = self.envs.step([a[0].item() for a in actions])
            observations, _, dones, infos = [list(x) for x in zip(*outputs)]

            not_done_masks = torch.tensor(
                [[0.0] if done else [1.0] for done in dones],
                dtype=torch.float,
                device=self.device,
            )

            # reset envs and observations if necessary
            for i in range(self.envs.num_envs):
                if len(config.VIDEO_OPTION) > 0:
                    frame = observations_to_image(observations[i], infos[i])
                    frame = append_text_to_image(
                        frame,
                        current_episodes[i].instruction.instruction_text)
                    rgb_frames[i].append(frame)

                if not dones[i]:
                    continue

                stats_episodes[current_episodes[i].episode_id] = infos[i]
                observations[i] = self.envs.reset_at(i)[0]
                prev_actions[i] = torch.zeros(1, dtype=torch.long)

                if len(config.VIDEO_OPTION) > 0:
                    generate_video(
                        video_option=config.VIDEO_OPTION,
                        video_dir=config.VIDEO_DIR,
                        images=rgb_frames[i],
                        episode_id=current_episodes[i].episode_id,
                        checkpoint_idx=checkpoint_index,
                        metrics={
                            "SPL":
                            round(
                                stats_episodes[current_episodes[i].episode_id]
                                ["spl"], 6)
                        },
                        tb_writer=writer,
                    )

                    del stats_episodes[
                        current_episodes[i].episode_id]["top_down_map"]
                    del stats_episodes[
                        current_episodes[i].episode_id]["collisions"]
                    rgb_frames[i] = []

            observations = transform_obs(
                observations, config.TASK_CONFIG.TASK.INSTRUCTION_SENSOR_UUID)
            batch = batch_obs(observations, self.device)

            envs_to_pause = []
            next_episodes = self.envs.current_episodes()

            for i in range(self.envs.num_envs):
                if next_episodes[i].episode_id in stats_episodes:
                    envs_to_pause.append(i)

            (
                self.envs,
                eval_recurrent_hidden_states,
                not_done_masks,
                prev_actions,
                batch,
            ) = self._pause_envs(
                envs_to_pause,
                self.envs,
                eval_recurrent_hidden_states,
                not_done_masks,
                prev_actions,
                batch,
            )

        self.envs.close()

        aggregated_stats = {}
        num_episodes = len(stats_episodes)
        for stat_key in next(iter(stats_episodes.values())).keys():
            aggregated_stats[stat_key] = (
                sum([v[stat_key]
                     for v in stats_episodes.values()]) / num_episodes)

        split = config.TASK_CONFIG.DATASET.SPLIT
        with open(f"stats_ckpt_{checkpoint_index}_{split}.json", "w") as f:
            json.dump(aggregated_stats, f, indent=4)

        logger.info(f"Episodes evaluated: {num_episodes}")
        checkpoint_num = checkpoint_index + 1
        for k, v in aggregated_stats.items():
            logger.info(f"Average episode {k}: {v:.6f}")
            writer.add_scalar(f"eval_{split}_{k}", v, checkpoint_num)
Exemple #3
0
    def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        checkpoint_index: int = 0,
    ) -> None:
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        ckpt_dict = self.load_checkpoint(checkpoint_path,
                                         map_location=self.device)

        config = self._setup_eval_config(ckpt_dict["config"])
        ppo_cfg = config.RL.PPO

        if len(self.config.VIDEO_OPTION) > 0:
            config.defrost()
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
            config.freeze()

        logger.info(f"env config: {config}")
        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))
        self._setup_actor_critic_agent(ppo_cfg)

        self.agent.load_state_dict(ckpt_dict["state_dict"])
        self.actor_critic = self.agent.actor_critic

        # get name of performance metric, e.g. "spl"
        metric_name = self.config.TASK_CONFIG.TASK.MEASUREMENTS[0]
        metric_cfg = getattr(self.config.TASK_CONFIG.TASK, metric_name)
        measure_type = baseline_registry.get_measure(metric_cfg.TYPE)
        assert measure_type is not None, "invalid measurement type {}".format(
            metric_cfg.TYPE)
        self.metric_uuid = measure_type(None, None)._get_uuid()

        observations = self.envs.reset()
        batch = batch_obs(observations)
        for sensor in batch:
            batch[sensor] = batch[sensor].to(self.device)

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)

        test_recurrent_hidden_states = torch.zeros(
            self.actor_critic.net.num_recurrent_layers,
            self.config.NUM_PROCESSES,
            ppo_cfg.hidden_size,
            device=self.device,
        )
        prev_actions = torch.zeros(self.config.NUM_PROCESSES,
                                   1,
                                   device=self.device,
                                   dtype=torch.long)
        not_done_masks = torch.zeros(self.config.NUM_PROCESSES,
                                     1,
                                     device=self.device)
        stats_episodes = dict()  # dict of dicts that stores stats per episode

        rgb_frames = [
            []
        ] * self.config.NUM_PROCESSES  # type: List[List[np.ndarray]]
        if len(self.config.VIDEO_OPTION) > 0:
            os.makedirs(self.config.VIDEO_DIR, exist_ok=True)

        while (len(stats_episodes) < self.config.TEST_EPISODE_COUNT
               and self.envs.num_envs > 0):
            current_episodes = self.envs.current_episodes()

            with torch.no_grad():
                _, actions, _, test_recurrent_hidden_states = self.actor_critic.act(
                    batch,
                    test_recurrent_hidden_states,
                    prev_actions,
                    not_done_masks,
                    deterministic=False,
                )

                prev_actions.copy_(actions)

            outputs = self.envs.step([a[0].item() for a in actions])

            observations, rewards, dones, infos = [
                list(x) for x in zip(*outputs)
            ]
            batch = batch_obs(observations)
            for sensor in batch:
                batch[sensor] = batch[sensor].to(self.device)

            not_done_masks = torch.tensor(
                [[0.0] if done else [1.0] for done in dones],
                dtype=torch.float,
                device=self.device,
            )

            rewards = torch.tensor(rewards,
                                   dtype=torch.float,
                                   device=self.device).unsqueeze(1)
            current_episode_reward += rewards
            next_episodes = self.envs.current_episodes()
            envs_to_pause = []
            n_envs = self.envs.num_envs
            for i in range(n_envs):
                if (
                        next_episodes[i].scene_id,
                        next_episodes[i].episode_id,
                ) in stats_episodes:
                    envs_to_pause.append(i)

                # episode ended
                if not_done_masks[i].item() == 0:
                    episode_stats = dict()
                    episode_stats[self.metric_uuid] = infos[i][
                        self.metric_uuid]
                    episode_stats["success"] = int(
                        infos[i][self.metric_uuid] > 0)
                    episode_stats["reward"] = current_episode_reward[i].item()
                    current_episode_reward[i] = 0
                    # use scene_id + episode_id as unique id for storing stats
                    stats_episodes[(
                        current_episodes[i].scene_id,
                        current_episodes[i].episode_id,
                    )] = episode_stats

                    if len(self.config.VIDEO_OPTION) > 0:
                        generate_video(
                            video_option=self.config.VIDEO_OPTION,
                            video_dir=self.config.VIDEO_DIR,
                            images=rgb_frames[i],
                            episode_id=current_episodes[i].episode_id,
                            checkpoint_idx=checkpoint_index,
                            metric_name=self.metric_uuid,
                            metric_value=infos[i][self.metric_uuid],
                            tb_writer=writer,
                        )

                        rgb_frames[i] = []

                # episode continues
                elif len(self.config.VIDEO_OPTION) > 0:
                    frame = observations_to_image(observations[i], infos[i])
                    rgb_frames[i].append(frame)

            # pausing self.envs with no new episode
            if len(envs_to_pause) > 0:
                state_index = list(range(self.envs.num_envs))
                for idx in reversed(envs_to_pause):
                    state_index.pop(idx)
                    self.envs.pause_at(idx)

                # indexing along the batch dimensions
                test_recurrent_hidden_states = test_recurrent_hidden_states[
                    state_index]
                not_done_masks = not_done_masks[state_index]
                current_episode_reward = current_episode_reward[state_index]
                prev_actions = prev_actions[state_index]

                for k, v in batch.items():
                    batch[k] = v[state_index]

                if len(self.config.VIDEO_OPTION) > 0:
                    rgb_frames = [rgb_frames[i] for i in state_index]

        aggregated_stats = dict()
        for stat_key in next(iter(stats_episodes.values())).keys():
            aggregated_stats[stat_key] = sum(
                [v[stat_key] for v in stats_episodes.values()])
        num_episodes = len(stats_episodes)

        episode_reward_mean = aggregated_stats["reward"] / num_episodes
        episode_metric_mean = aggregated_stats[self.metric_uuid] / num_episodes
        episode_success_mean = aggregated_stats["success"] / num_episodes

        logger.info(f"Average episode reward: {episode_reward_mean:.6f}")
        logger.info(f"Average episode success: {episode_success_mean:.6f}")
        logger.info(
            f"Average episode {self.metric_uuid}: {episode_metric_mean:.6f}")

        writer.add_scalars(
            "eval_reward",
            {"average reward": episode_reward_mean},
            checkpoint_index,
        )
        writer.add_scalars(
            f"eval_{self.metric_uuid}",
            {f"average {self.metric_uuid}": episode_metric_mean},
            checkpoint_index,
        )
        writer.add_scalars(
            "eval_success",
            {"average success": episode_success_mean},
            checkpoint_index,
        )

        self.envs.close()
    def _eval_checkpoint(self,
                         checkpoint_path: str,
                         writer: TensorboardWriter,
                         checkpoint_index: int = 0,
                         log_diagnostics=[],
                         output_dir='.',
                         label='.',
                         num_eval_runs=1) -> None:
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        if checkpoint_index == -1:
            ckpt_file = checkpoint_path.split('/')[-1]
            split_info = ckpt_file.split('.')
            checkpoint_index = split_info[1]
        # Map location CPU is almost always better than mapping to a CUDA device.
        ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu")

        if self.config.EVAL.USE_CKPT_CONFIG:
            config = self._setup_eval_config(ckpt_dict["config"])
        else:
            config = self.config.clone()

        ppo_cfg = config.RL.PPO
        task_cfg = config.TASK_CONFIG.TASK

        config.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
        config.freeze()

        if len(self.config.VIDEO_OPTION) > 0:
            config.defrost()
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
            config.freeze()

        logger.info(f"env config: {config}")
        self.envs = construct_envs(config, get_env_class(config.ENV_NAME))
        # pass in aux config if we're doing attention
        aux_cfg = self.config.RL.AUX_TASKS
        self._setup_actor_critic_agent(ppo_cfg, task_cfg, aux_cfg)

        # Check if we accidentally recorded `visual_resnet` in our checkpoint and drop it (it's redundant with `visual_encoder`)
        ckpt_dict['state_dict'] = {
            k: v
            for k, v in ckpt_dict['state_dict'].items()
            if 'visual_resnet' not in k
        }
        self.agent.load_state_dict(ckpt_dict["state_dict"])

        logger.info("agent number of trainable parameters: {}".format(
            sum(param.numel() for param in self.agent.parameters()
                if param.requires_grad)))

        self.actor_critic = self.agent.actor_critic

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)

        test_recurrent_hidden_states = torch.zeros(
            self.actor_critic.net.num_recurrent_layers,
            self.config.NUM_PROCESSES,
            ppo_cfg.hidden_size,
            device=self.device,
        )
        _, num_recurrent_memories, _ = self._setup_auxiliary_tasks(
            aux_cfg, ppo_cfg, task_cfg, is_eval=True)
        if self.config.RL.PPO.policy in MULTIPLE_BELIEF_CLASSES:
            aux_tasks = self.config.RL.AUX_TASKS.tasks
            num_recurrent_memories = len(self.config.RL.AUX_TASKS.tasks)
            test_recurrent_hidden_states = test_recurrent_hidden_states.unsqueeze(
                2).repeat(1, 1, num_recurrent_memories, 1)

        prev_actions = torch.zeros(self.config.NUM_PROCESSES,
                                   1,
                                   device=self.device,
                                   dtype=torch.long)

        not_done_masks = torch.zeros(self.config.NUM_PROCESSES,
                                     1,
                                     device=self.device)
        stats_episodes = dict()  # dict of dicts that stores stats per episode

        rgb_frames = [[] for _ in range(self.config.NUM_PROCESSES)
                      ]  # type: List[List[np.ndarray]]

        if len(self.config.VIDEO_OPTION) > 0:
            os.makedirs(self.config.VIDEO_DIR, exist_ok=True)

        number_of_eval_episodes = self.config.TEST_EPISODE_COUNT
        if number_of_eval_episodes == -1:
            number_of_eval_episodes = sum(self.envs.number_of_episodes)
        else:
            total_num_eps = sum(self.envs.number_of_episodes)
            if total_num_eps < number_of_eval_episodes:
                logger.warn(
                    f"Config specified {number_of_eval_episodes} eval episodes"
                    ", dataset only has {total_num_eps}.")
                logger.warn(f"Evaluating with {total_num_eps} instead.")
                number_of_eval_episodes = total_num_eps

        videos_cap = 2  # number of videos to generate per checkpoint
        if len(log_diagnostics) > 0:
            videos_cap = 10
        # video_indices = random.sample(range(self.config.TEST_EPISODE_COUNT),
        # min(videos_cap, self.config.TEST_EPISODE_COUNT))
        video_indices = range(10)
        print(f"Videos: {video_indices}")

        total_stats = []
        dones_per_ep = dict()

        # Logging more extensive evaluation stats for analysis
        if len(log_diagnostics) > 0:
            d_stats = {}
            for d in log_diagnostics:
                d_stats[d] = [
                    [] for _ in range(self.config.NUM_PROCESSES)
                ]  # stored as nested list envs x timesteps x k (# tasks)

        pbar = tqdm.tqdm(total=number_of_eval_episodes * num_eval_runs)
        self.agent.eval()
        while (len(stats_episodes) < number_of_eval_episodes * num_eval_runs
               and self.envs.num_envs > 0):
            current_episodes = self.envs.current_episodes()
            with torch.no_grad():
                weights_output = None
                if self.config.RL.PPO.policy in MULTIPLE_BELIEF_CLASSES:
                    weights_output = torch.empty(self.envs.num_envs,
                                                 len(aux_tasks))
                (
                    _,
                    actions,
                    _,
                    test_recurrent_hidden_states,
                ) = self.actor_critic.act(batch,
                                          test_recurrent_hidden_states,
                                          prev_actions,
                                          not_done_masks,
                                          deterministic=False,
                                          weights_output=weights_output)
                prev_actions.copy_(actions)

                for i in range(self.envs.num_envs):
                    if Diagnostics.actions in log_diagnostics:
                        d_stats[Diagnostics.actions][i].append(
                            prev_actions[i].item())
                    if Diagnostics.weights in log_diagnostics:
                        aux_weights = None if weights_output is None else weights_output[
                            i]
                        if aux_weights is not None:
                            d_stats[Diagnostics.weights][i].append(
                                aux_weights.half().tolist())

            outputs = self.envs.step([a[0].item() for a in actions])

            observations, rewards, dones, infos = [
                list(x) for x in zip(*outputs)
            ]
            batch = batch_obs(observations, device=self.device)

            not_done_masks = torch.tensor(
                [[0.0] if done else [1.0] for done in dones],
                dtype=torch.float,
                device=self.device,
            )

            rewards = torch.tensor(rewards,
                                   dtype=torch.float,
                                   device=self.device).unsqueeze(1)
            current_episode_reward += rewards
            next_episodes = self.envs.current_episodes()
            envs_to_pause = []
            n_envs = self.envs.num_envs
            for i in range(n_envs):
                next_k = (
                    next_episodes[i].scene_id,
                    next_episodes[i].episode_id,
                )
                if dones_per_ep.get(next_k, 0) == num_eval_runs:
                    envs_to_pause.append(i)  # wait for the rest

                if not_done_masks[i].item() == 0:
                    episode_stats = dict()

                    episode_stats["reward"] = current_episode_reward[i].item()
                    episode_stats.update(
                        self._extract_scalars_from_info(infos[i]))

                    current_episode_reward[i] = 0
                    # use scene_id + episode_id as unique id for storing stats

                    k = (
                        current_episodes[i].scene_id,
                        current_episodes[i].episode_id,
                    )
                    dones_per_ep[k] = dones_per_ep.get(k, 0) + 1

                    if dones_per_ep.get(k, 0) == 1 and len(
                            self.config.VIDEO_OPTION) > 0 and len(
                                stats_episodes) in video_indices:
                        logger.info(f"Generating video {len(stats_episodes)}")
                        category = getattr(current_episodes[i],
                                           "object_category", "")
                        if category != "":
                            category += "_"
                        try:
                            generate_video(
                                video_option=self.config.VIDEO_OPTION,
                                video_dir=self.config.VIDEO_DIR,
                                images=rgb_frames[i],
                                episode_id=current_episodes[i].episode_id,
                                checkpoint_idx=checkpoint_index,
                                metrics=self._extract_scalars_from_info(
                                    infos[i]),
                                tag=f"{category}{label}",
                                tb_writer=writer,
                            )
                        except Exception as e:
                            logger.warning(str(e))
                    rgb_frames[i] = []

                    stats_episodes[(
                        current_episodes[i].scene_id,
                        current_episodes[i].episode_id,
                        dones_per_ep[k],
                    )] = episode_stats

                    if len(log_diagnostics) > 0:
                        diagnostic_info = dict()
                        for metric in log_diagnostics:
                            diagnostic_info[metric] = d_stats[metric][i]
                            d_stats[metric][i] = []
                        if Diagnostics.top_down_map in log_diagnostics:
                            top_down_map = torch.tensor([])
                            if len(self.config.VIDEO_OPTION) > 0:
                                top_down_map = infos[i]["top_down_map"]["map"]
                                top_down_map = maps.colorize_topdown_map(
                                    top_down_map, fog_of_war_mask=None)
                            diagnostic_info.update(
                                dict(top_down_map=top_down_map))
                        total_stats.append(
                            dict(
                                stats=episode_stats,
                                did_stop=bool(prev_actions[i] == 0),
                                episode_info=attr.asdict(current_episodes[i]),
                                info=diagnostic_info,
                            ))
                    pbar.update()

                # episode continues
                else:
                    if len(self.config.VIDEO_OPTION) > 0:
                        aux_weights = None if weights_output is None else weights_output[
                            i]
                        frame = observations_to_image(
                            observations[i], infos[i],
                            current_episode_reward[i].item(), aux_weights,
                            aux_tasks)
                        rgb_frames[i].append(frame)
                    if Diagnostics.gps in log_diagnostics:
                        d_stats[Diagnostics.gps][i].append(
                            observations[i]["gps"].tolist())
                    if Diagnostics.heading in log_diagnostics:
                        d_stats[Diagnostics.heading][i].append(
                            observations[i]["heading"].tolist())

            (
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            ) = self._pause_envs(
                envs_to_pause,
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            )

        num_episodes = len(stats_episodes)
        aggregated_stats = dict()
        for stat_key in next(iter(stats_episodes.values())).keys():
            aggregated_stats[stat_key] = (
                sum([v[stat_key]
                     for v in stats_episodes.values()]) / num_episodes)

        for k, v in aggregated_stats.items():
            logger.info(f"Average episode {k}: {v:.4f}")

        step_id = checkpoint_index
        if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]:
            step_id = ckpt_dict["extra_state"]["step"]

        writer.add_scalars(
            "eval_reward",
            {"average reward": aggregated_stats["reward"]},
            step_id,
        )

        metrics = {k: v for k, v in aggregated_stats.items() if k != "reward"}
        if len(metrics) > 0:
            writer.add_scalars("eval_metrics", metrics, step_id)
            logger.info("eval_metrics")
            logger.info(metrics)
        if len(log_diagnostics) > 0:
            os.makedirs(output_dir, exist_ok=True)
            eval_fn = f"{label}.json"
            with open(os.path.join(output_dir, eval_fn), 'w',
                      encoding='utf-8') as f:
                json.dump(total_stats, f, ensure_ascii=False, indent=4)
        self.envs.close()
    def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        cur_ckpt_idx: int = 0,
    ) -> None:
        r"""
        Evaluates a single checkpoint
        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            cur_ckpt_idx: index of cur checkpoint for logging

        Returns:
            None
        """
        ckpt_dict = self.load_checkpoint(checkpoint_path,
                                         map_location=self.device)

        ckpt_config = ckpt_dict["config"]
        config = self.config.clone()
        ckpt_cmd_opts = ckpt_config.CMD_TRAILING_OPTS
        eval_cmd_opts = config.CMD_TRAILING_OPTS

        # config merge priority: eval_opts > ckpt_opts > eval_cfg > ckpt_cfg
        # first line for old checkpoint compatibility
        config.merge_from_other_cfg(ckpt_config)
        config.merge_from_other_cfg(self.config)
        config.merge_from_list(ckpt_cmd_opts)
        config.merge_from_list(eval_cmd_opts)

        ppo_cfg = config.TRAINER.RL.PPO
        config.TASK_CONFIG.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = "val"
        agent_sensors = ppo_cfg.sensors.strip().split(",")
        config.TASK_CONFIG.SIMULATOR.AGENT_0.SENSORS = agent_sensors
        if self.video_option:
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
        config.freeze()

        logger.info(f"env config: {config}")
        self.envs = construct_envs(config, NavRLEnv)
        self._setup_actor_critic_agent(ppo_cfg)

        self.agent.load_state_dict(ckpt_dict["state_dict"])
        self.actor_critic = self.agent.actor_critic

        observations = self.envs.reset()
        batch = batch_obs(observations)
        for sensor in batch:
            batch[sensor] = batch[sensor].to(self.device)

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)

        test_recurrent_hidden_states = torch.zeros(ppo_cfg.num_processes,
                                                   ppo_cfg.hidden_size,
                                                   device=self.device)
        not_done_masks = torch.zeros(ppo_cfg.num_processes,
                                     1,
                                     device=self.device)
        stats_episodes = dict()  # dict of dicts that stores stats per episode

        rgb_frames = [[]
                      ] * ppo_cfg.num_processes  # type: List[List[np.ndarray]]
        if self.video_option:
            os.makedirs(ppo_cfg.video_dir, exist_ok=True)

        while (len(stats_episodes) < ppo_cfg.count_test_episodes
               and self.envs.num_envs > 0):
            current_episodes = self.envs.current_episodes()

            with torch.no_grad():
                _, actions, _, test_recurrent_hidden_states = self.actor_critic.act(
                    batch,
                    test_recurrent_hidden_states,
                    not_done_masks,
                    deterministic=False,
                )

            outputs = self.envs.step([a[0].item() for a in actions])

            observations, rewards, dones, infos = [
                list(x) for x in zip(*outputs)
            ]
            batch = batch_obs(observations)
            for sensor in batch:
                batch[sensor] = batch[sensor].to(self.device)

            not_done_masks = torch.tensor(
                [[0.0] if done else [1.0] for done in dones],
                dtype=torch.float,
                device=self.device,
            )

            rewards = torch.tensor(rewards,
                                   dtype=torch.float,
                                   device=self.device).unsqueeze(1)
            current_episode_reward += rewards
            next_episodes = self.envs.current_episodes()
            envs_to_pause = []
            n_envs = self.envs.num_envs
            for i in range(n_envs):
                if (
                        next_episodes[i].scene_id,
                        next_episodes[i].episode_id,
                ) in stats_episodes:
                    envs_to_pause.append(i)

                # episode ended
                if not_done_masks[i].item() == 0:
                    episode_stats = dict()
                    episode_stats["spl"] = infos[i]["spl"]
                    episode_stats["success"] = int(infos[i]["spl"] > 0)
                    episode_stats["reward"] = current_episode_reward[i].item()
                    current_episode_reward[i] = 0
                    # use scene_id + episode_id as unique id for storing stats
                    stats_episodes[(
                        current_episodes[i].scene_id,
                        current_episodes[i].episode_id,
                    )] = episode_stats
                    if self.video_option:
                        generate_video(
                            ppo_cfg,
                            rgb_frames[i],
                            current_episodes[i].episode_id,
                            cur_ckpt_idx,
                            infos[i]["spl"],
                            writer,
                        )
                        rgb_frames[i] = []

                # episode continues
                elif self.video_option:
                    frame = observations_to_image(observations[i], infos[i])
                    rgb_frames[i].append(frame)

            # pausing self.envs with no new episode
            if len(envs_to_pause) > 0:
                state_index = list(range(self.envs.num_envs))
                for idx in reversed(envs_to_pause):
                    state_index.pop(idx)
                    self.envs.pause_at(idx)

                # indexing along the batch dimensions
                test_recurrent_hidden_states = test_recurrent_hidden_states[
                    state_index]
                not_done_masks = not_done_masks[state_index]
                current_episode_reward = current_episode_reward[state_index]

                for k, v in batch.items():
                    batch[k] = v[state_index]

                if self.video_option:
                    rgb_frames = [rgb_frames[i] for i in state_index]

        aggregated_stats = dict()
        for stat_key in next(iter(stats_episodes.values())).keys():
            aggregated_stats[stat_key] = sum(
                [v[stat_key] for v in stats_episodes.values()])
        num_episodes = len(stats_episodes)

        episode_reward_mean = aggregated_stats["reward"] / num_episodes
        episode_spl_mean = aggregated_stats["spl"] / num_episodes
        episode_success_mean = aggregated_stats["success"] / num_episodes

        logger.info(
            "Average episode reward: {:.6f}".format(episode_reward_mean))
        logger.info(
            "Average episode success: {:.6f}".format(episode_success_mean))
        logger.info("Average episode SPL: {:.6f}".format(episode_spl_mean))

        writer.add_scalars(
            "eval_reward",
            {"average reward": episode_reward_mean},
            cur_ckpt_idx,
        )
        writer.add_scalars("eval_SPL", {"average SPL": episode_spl_mean},
                           cur_ckpt_idx)
        writer.add_scalars(
            "eval_success",
            {"average success": episode_success_mean},
            cur_ckpt_idx,
        )
Exemple #6
0
    def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        checkpoint_index: int = 0,
    ) -> None:
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        self.add_new_based_on_cfg()

        # Map location CPU is almost always better than mapping to a CUDA device.
        ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu")

        # ==========================================================================================
        # -- Update config for eval
        if self.config.EVAL.USE_CKPT_CONFIG:
            config = self._setup_eval_config(ckpt_dict["config"])
        else:
            config = self.config.clone()

        ppo_cfg = config.RL.PPO

        # # Mostly for visualization
        # config.defrost()
        # config.TASK_CONFIG.SIMULATOR.HABITAT_SIM_V0.GPU_GPU = False
        # config.freeze()

        split = config.TASK_CONFIG.DATASET.SPLIT

        config.defrost()
        config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
        config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
        config.freeze()
        # ==========================================================================================

        num_procs = self.config.NUM_PROCESSES
        device = self.device
        cfg = self.config

        logger.info(f"env config: {config}")
        self.envs = construct_envs(config, get_env_class(self.config.ENV_NAME))
        num_envs = self.envs.num_envs

        self._setup_actor_critic_agent(ppo_cfg, train=False)

        self.agent.load_state_dict(ckpt_dict["state_dict"])
        self.actor_critic = self.agent.actor_critic
        self.r_policy = self.agent.actor_critic.reachability_policy

        aux_models = self.actor_critic.net.aux_models

        other_losses = dict({
            k: torch.zeros(num_envs, 1, device=device)
            for k in aux_models.keys()
        })
        other_losses_action = dict({
            k: torch.zeros(num_envs,
                           self.envs.action_spaces[0].n,
                           device=device)
            for k in aux_models.keys()
        })

        num_steps = torch.zeros(num_envs, 1, device=device)

        # Config aux models for eval per item in batch
        for k, maux in aux_models.items():
            maux.set_per_element_loss()

        total_loss = 0

        if config.EVAL_MODE:
            self.agent.eval()
            self.r_policy.eval()

        # get name of performance metric, e.g. "spl"
        metric_name = cfg.TASK_CONFIG.TASK.MEASUREMENTS[0]
        metric_cfg = getattr(cfg.TASK_CONFIG.TASK, metric_name)
        measure_type = baseline_registry.get_measure(metric_cfg.TYPE)
        assert measure_type is not None, "invalid measurement type {}".format(
            metric_cfg.TYPE)

        self.metric_uuid = measure_type(sim=None, task=None,
                                        config=None)._get_uuid()

        observations = self.envs.reset()
        batch = batch_obs_augment_aux(observations, self.envs.get_shared_mem())

        info_data_keys = ["discovered", "collisions_wall", "collisions_prox"]
        log_data_keys = [
            "current_episode_reward", "current_episode_go_reward"
        ] + info_data_keys
        log_data = dict({
            k: torch.zeros(num_envs, 1, device=device)
            for k in log_data_keys
        })
        info_data = dict({k: log_data[k] for k in info_data_keys})

        test_recurrent_hidden_states = torch.zeros(
            self.actor_critic.net.num_recurrent_layers,
            num_procs,
            ppo_cfg.hidden_size,
            device=device,
        )
        prev_actions = torch.zeros(num_procs,
                                   1,
                                   device=device,
                                   dtype=torch.long)
        not_done_masks = torch.zeros(num_procs, 1, device=device)

        stats_episodes = dict()  # dict of dicts that stores stats per episode
        stats_episodes_scenes = dict(
        )  # dict of number of collected stats from

        # each scene
        max_test_ep_count = cfg.TEST_EPISODE_COUNT

        # TODO this should depend on number of scenes :(
        # TODO But than envs shouldn't be paused but fast-fwd to next scene
        # TODO We consider num envs == num scenes
        max_ep_per_env = max_test_ep_count / float(num_envs)

        rgb_frames = [[] for _ in range(num_procs)
                      ]  # type: List[List[np.ndarray]]

        if len(cfg.VIDEO_OPTION) > 0:
            os.makedirs(cfg.VIDEO_DIR, exist_ok=True)

        video_log_int = cfg.VIDEO_OPTION_INTERVAL
        num_frames = 0

        plot_pos = -1
        prev_true_pos = []
        prev_pred_pos = []

        while (len(stats_episodes) <= cfg.TEST_EPISODE_COUNT and num_envs > 0):
            current_episodes = self.envs.current_episodes()

            with torch.no_grad():
                prev_hidden = test_recurrent_hidden_states
                _, actions, _, test_recurrent_hidden_states, aux_out \
                    = self.actor_critic.act(
                        batch,
                        test_recurrent_hidden_states,
                        prev_actions,
                        not_done_masks,
                        deterministic=False
                    )

                prev_actions.copy_(actions)

                if 'action' in batch:
                    prev_actions = batch['action'].unsqueeze(1).to(
                        actions.device).long()

                for k, v in aux_out.items():
                    loss = aux_models[k].calc_loss(v, batch, prev_hidden,
                                                   prev_actions,
                                                   not_done_masks, actions)
                    total_loss += loss

                    if other_losses[k] is None:
                        other_losses[k] = loss
                    else:
                        other_losses[k] += loss.unsqueeze(1)
                    if len(prev_actions) == 1:
                        other_losses_action[k][0, prev_actions.item()] += \
                            loss.item()

                # ==================================================================================
                # - Hacky logs

                if plot_pos >= 0:
                    prev_true_pos.append(batch["gps_compass_start"]
                                         [plot_pos].data[:2].cpu().numpy())
                    prev_pred_pos.append(aux_out["rel_start_pos_reg"]
                                         [plot_pos].data.cpu().numpy() * 15)
                    if num_frames % 10 == 0:
                        xx, yy = [], []
                        for x, y in prev_true_pos:
                            xx.append(x)
                            yy.append(y)
                        plt.scatter(xx, yy, label="true_pos")
                        xx, yy = [], []
                        for x, y in prev_pred_pos:
                            xx.append(x)
                            yy.append(y)
                        plt.scatter(xx, yy, label="pred_pos")
                        plt.legend()
                        plt.show()
                        plt.waitforbuttonpress()
                        plt.close()
                # ==================================================================================

            num_steps += 1
            outputs = self.envs.step([a[0].item() for a in actions])

            observations, rewards, dones, infos = [
                list(x) for x in zip(*outputs)
            ]

            not_done_masks = torch.tensor(
                [[0.0] if done else [1.0] for done in dones],
                dtype=torch.float,
                device=device,
            )

            map_values = self._get_mapping(observations, aux_out)
            batch = batch_obs_augment_aux(observations,
                                          self.envs.get_shared_mem(),
                                          device=device,
                                          map_values=map_values,
                                          masks=not_done_masks)

            valid_map_size = [
                float(ifs["top_down_map"]["valid_map"].sum()) for ifs in infos
            ]
            discovered_factor = [
                infos[ix]["top_down_map"]["explored_map"].sum() /
                valid_map_size[ix] for ix in range(len(infos))
            ]

            seen_factor = [
                infos[ix]["top_down_map"]["ful_fog_of_war_mask"].sum() /
                valid_map_size[ix] for ix in range(len(infos))
            ]

            rewards = torch.tensor(rewards, dtype=torch.float,
                                   device=device).unsqueeze(1)

            log_data["current_episode_reward"] += rewards

            # -- Add intrinsic Reward
            if self.only_intrinsic_reward:
                rewards.zero_()

            if self.r_enabled:
                ir_rewards = self._add_intrinsic_reward(
                    batch, actions, rewards, not_done_masks)
                log_data["current_episode_go_reward"] += ir_rewards

                rewards += ir_rewards

            # Log other info from infos dict
            for iii, info in enumerate(infos):
                for k_info, v_info in info_data.items():
                    v_info[iii] += info[k_info]

            next_episodes = self.envs.current_episodes()

            envs_to_pause = []
            n_envs = num_envs

            for i in range(n_envs):
                scene = next_episodes[i].scene_id

                if scene not in stats_episodes_scenes:
                    stats_episodes_scenes[scene] = 0

                if stats_episodes_scenes[scene] >= max_ep_per_env:
                    envs_to_pause.append(i)

                # episode ended
                if not_done_masks[i].item() == 0:
                    episode_stats = dict()
                    episode_stats[self.metric_uuid] = infos[i][
                        self.metric_uuid]
                    episode_stats["success"] = int(
                        infos[i][self.metric_uuid] > 0)

                    for kk, vv in log_data.items():
                        episode_stats[kk] = vv[i].item()
                        vv[i] = 0

                    episode_stats["map_discovered"] = discovered_factor[i]
                    episode_stats["map_seen"] = seen_factor[i]

                    for k, v in other_losses.items():
                        episode_stats[k] = v[i].item() / num_steps[i].item()
                        other_losses_action[k][i].fill_(0)
                        other_losses[k][i] = 0

                    num_steps[i] = 0

                    # use scene_id + episode_id as unique id for storing stats
                    stats_episodes[(current_episodes[i].scene_id, current_episodes[i].episode_id)] \
                        = episode_stats

                    print(f"Episode {len(stats_episodes)} stats:",
                          episode_stats)

                    stats_episodes_scenes[current_episodes[i].scene_id] += 1

                    if len(cfg.VIDEO_OPTION
                           ) > 0 and checkpoint_index % video_log_int == 0:
                        generate_video(
                            video_option=cfg.VIDEO_OPTION,
                            video_dir=cfg.VIDEO_DIR,
                            images=rgb_frames[i],
                            episode_id=current_episodes[i].episode_id,
                            checkpoint_idx=checkpoint_index,
                            metric_name=self.metric_uuid,
                            metric_value=infos[i][self.metric_uuid],
                            tb_writer=writer,
                        )

                        rgb_frames[i] = []

                # episode continues
                elif len(cfg.VIDEO_OPTION) > 0:
                    for k, v in observations[i].items():
                        if isinstance(v, torch.Tensor):
                            observations[i][k] = v.cpu().numpy()
                    frame = observations_to_image(observations[i], infos[i])
                    rgb_frames[i].append(frame)

            # Pop done envs:
            if len(envs_to_pause) > 0:
                s_index = list(range(num_envs))
                for idx in reversed(envs_to_pause):
                    s_index.pop(idx)

                for k, v in other_losses.items():
                    other_losses[k] = other_losses[k][s_index]

                for k, v in log_data.items():
                    log_data[k] = log_data[k][s_index]

            (
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            ) = self._pause_envs(
                envs_to_pause,
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                None,
                prev_actions,
                batch,
                rgb_frames,
            )

        aggregated_stats = dict()
        for stat_key in next(iter(stats_episodes.values())).keys():
            aggregated_stats[stat_key] = sum(
                [v[stat_key] for v in stats_episodes.values()])
        num_episodes = len(stats_episodes)

        episodes_agg_stats = dict()
        for k, v in aggregated_stats.items():
            episodes_agg_stats[k] = v / num_episodes
            logger.info(f"Average episode {k}: {episodes_agg_stats[k]:.6f}")

        for k, v in episodes_agg_stats.items():
            writer.add_scalars(f"eval_{k}", {f"{split}_average {k}": v},
                               checkpoint_index)
            print(f"[{checkpoint_index}] average {k}", v)

        self.envs.close()
Exemple #7
0
    def _eval(self):

        start_time = time.time()

        if self.config.MANUAL_COMMANDS:
            init_time = None
            manual_step_start_time = None
            total_manual_time = 0.0

        checkpoint_index = int(
            (re.findall('\d+', self.config.EVAL_CKPT_PATH_DIR))[-1])
        ckpt_dict = torch.load(self.config.EVAL_CKPT_PATH_DIR,
                               map_location="cpu")

        print(
            f'Number of steps of the ckpt: {ckpt_dict["extra_state"]["step"]}')

        config = self._setup_config(ckpt_dict)
        ppo_cfg = config.RL.PPO
        ans_cfg = config.RL.ANS

        self.mapper_rollouts = None
        self._setup_actor_critic_agent(ppo_cfg, ans_cfg)

        self.mapper_agent.load_state_dict(ckpt_dict["mapper_state_dict"])
        if self.local_agent is not None:
            self.local_agent.load_state_dict(ckpt_dict["local_state_dict"])
            self.local_actor_critic = self.local_agent.actor_critic
        else:
            self.local_actor_critic = self.ans_net.local_policy
        self.global_agent.load_state_dict(ckpt_dict["global_state_dict"])
        self.mapper = self.mapper_agent.mapper
        self.global_actor_critic = self.global_agent.actor_critic

        # Set models to evaluation
        self.mapper.eval()
        self.local_actor_critic.eval()
        self.global_actor_critic.eval()

        M = ans_cfg.overall_map_size
        V = ans_cfg.MAPPER.map_size
        s = ans_cfg.MAPPER.map_scale
        imH, imW = ans_cfg.image_scale_hw

        num_steps = self.config.T_EXP

        prev_action = torch.zeros(1, 1, device=self.device, dtype=torch.long)
        masks = torch.zeros(1, 1, device=self.device)

        try:
            self.sim = make_sim('PyRobot-v1',
                                config=self.config.TASK_CONFIG.PYROBOT)
        except (KeyboardInterrupt, SystemExit):
            sys.exit()

        pose = defaultdict()
        self.sim._robot.camera.set_tilt(math.radians(self.config.CAMERA_TILT),
                                        wait=True)
        print(
            f"\nStarting Camera State: {self.sim.get_agent_state()['camera']}")
        print(f"Starting Agent State: {self.sim.get_agent_state()['base']}")
        obs = [self.sim.reset()]

        if self.config.SAVE_OBS_IMGS:
            cv2.imwrite(f'obs/depth_dirty_s.jpg', obs[0]['depth'] * 255.0)

        obs[0]['depth'][..., 0] = self._correct_depth(obs, -1)

        if self.config.SAVE_OBS_IMGS:
            cv2.imwrite(f'obs/rgb_s.jpg', obs[0]['rgb'][:, :, ::-1])
            cv2.imwrite(f'depth_s.jpg', obs[0]['depth'] * 255.0)

        starting_agent_state = self.sim.get_agent_state()
        locobot2relative = CoordProjection(starting_agent_state['base'])
        pose['base'] = locobot2relative(starting_agent_state['base'])

        print(f"Starting Agent Pose: {pose['base']}\n")
        batch = self._prepare_batch(obs, -1, device=self.device)
        if ans_cfg.MAPPER.use_sensor_positioning:
            batch['pose'] = pose['base'].to(self.device)
            batch['pose'][0][1:] = -batch['pose'][0][1:]
        prev_batch = batch

        num_envs = self.config.NUM_PROCESSES
        agent_poses_over_time = []
        for i in range(num_envs):
            agent_poses_over_time.append(
                torch.tensor([(M - 1) / 2, (M - 1) / 2, 0]))
        state_estimates = {
            "pose_estimates":
            torch.zeros(num_envs, 3).to(self.device),
            "map_states":
            torch.zeros(num_envs, 2, M, M).to(self.device),
            "recurrent_hidden_states":
            torch.zeros(1, num_envs,
                        ans_cfg.LOCAL_POLICY.hidden_size).to(self.device),
            "visited_states":
            torch.zeros(num_envs, 1, M, M).to(self.device),
        }
        ground_truth_states = {
            "visible_occupancy": torch.zeros(num_envs, 2, M,
                                             M).to(self.device),
            "pose": torch.zeros(num_envs, 3).to(self.device),
            "environment_layout": torch.zeros(num_envs, 2, M,
                                              M).to(self.device)
        }

        # Reset ANS states
        self.ans_net.reset()

        # Frames for video creation
        rgb_frames = []
        if len(self.config.VIDEO_OPTION) > 0:
            os.makedirs(self.config.VIDEO_DIR, exist_ok=True)

        step_start_time = time.time()

        for i in range(num_steps):
            print(
                f"\n\n---------------------------------------------------<<< STEP {i} >>>---------------------------------------------------"
            )
            ep_time = torch.zeros(num_envs, 1, device=self.device).fill_(i)

            (
                mapper_inputs,
                local_policy_inputs,
                global_policy_inputs,
                mapper_outputs,
                local_policy_outputs,
                global_policy_outputs,
                state_estimates,
                intrinsic_rewards,
            ) = self.ans_net.act(
                batch,
                prev_batch,
                state_estimates,
                ep_time,
                masks,
                deterministic=True,
            )
            if self.config.SAVE_MAP_IMGS:
                cv2.imwrite(
                    f'maps/test_map_{i - 1}.jpg',
                    self._round_map(state_estimates['map_states']) * 255)

            action = local_policy_outputs["actions"][0][0]

            distance2ggoal = torch.norm(
                mapper_outputs['curr_map_position'] -
                self.ans_net.states["curr_global_goals"],
                dim=1) * s

            print(f"Distance to Global Goal: {distance2ggoal}")

            reached_flag = distance2ggoal < ans_cfg.goal_success_radius

            if self.config.MANUAL_COMMANDS:
                if init_time is None:
                    init_time = time.time() - start_time
                    total_manual_time = total_manual_time + init_time
                if manual_step_start_time is not None:
                    manual_step_time = time.time() - manual_step_start_time
                    total_manual_time = total_manual_time + manual_step_time
                action = torch.tensor(
                    int(input('Waiting input to start new action: ')))
                manual_step_start_time = time.time()

                if action.item() == 3:
                    reached_flag = True

            prev_action.copy_(action)

            if not reached_flag and action.item() != 3:
                print(f'Doing Env Step [{self.ACT_2_NAME[action.item()]}]...')
                action_command = self.ACT_2_COMMAND[action.item()]

                obs = self._do_action(action_command)

                if self.config.SAVE_OBS_IMGS:
                    cv2.imwrite(f'obs/depth_dirty_{i}.jpg',
                                obs[0]['depth'] * 255.0)

                # Correcting invalid depth pixels
                obs[0]['depth'][..., 0] = self._correct_depth(obs, i)

                if self.config.SAVE_OBS_IMGS:
                    cv2.imwrite(f'obs/rgb_{i}.jpg', obs[0]['rgb'][:, :, ::-1])
                    cv2.imwrite(f'obs/depth_{i}.jpg', obs[0]['depth'] * 255.0)

                agent_state = self.sim.get_agent_state()
                prev_batch = batch
                batch = self._prepare_batch(obs, i, device=self.device)

                pose = defaultdict()
                pose['base'] = locobot2relative(agent_state['base'])

                if ans_cfg.MAPPER.use_sensor_positioning:
                    batch['pose'] = pose['base'].to(self.device)
                    batch['pose'][0][1:] = -batch['pose'][0][1:]

                map_coords = convert_world2map(
                    batch['pose'], (M, M), ans_cfg.OCCUPANCY_ANTICIPATOR.
                    EGO_PROJECTION.map_scale).squeeze()
                map_coords = torch.cat(
                    (map_coords, batch['pose'][0][-1].reshape(1)))
                if self.config.COORD_DEBUG:
                    print('COORDINATES CHECK')
                    print(
                        f'Starting Agent State: {starting_agent_state["base"]}'
                    )
                    print(f'Current Agent State: {agent_state["base"]}')
                    print(
                        f'Current Sim Agent State: {self.sim.get_agent_state()["base"]}'
                    )
                    print(f'Current Global Coords: {batch["pose"]}')
                    print(f'Current Map Coords: {map_coords}')
                agent_poses_over_time.append(map_coords)

                step_time = time.time() - step_start_time
                print(f"\nStep Time: {step_time}")
                step_start_time = time.time()

            # Create new frame of the video
            if (len(self.config.VIDEO_OPTION) > 0):
                frame = observations_to_image(
                    obs[0],
                    observation_size=300,
                    collision_flag=self.config.DRAW_COLLISIONS)
                # Add ego_map_gt to frame
                ego_map_gt_i = asnumpy(batch["ego_map_gt"][0])  # (2, H, W)
                ego_map_gt_i = convert_gt2channel_to_gtrgb(ego_map_gt_i)
                ego_map_gt_i = cv2.resize(ego_map_gt_i, (300, 300))
                # frame = np.concatenate([frame], axis=1)
                # Generate ANS specific visualizations
                environment_layout = asnumpy(
                    ground_truth_states["environment_layout"][0])  # (2, H, W)
                visible_occupancy = mapper_outputs["gt_mt"][0].cpu().numpy(
                )  # (2, H, W)
                anticipated_occupancy = mapper_outputs["hat_mt"][0].cpu(
                ).numpy()  # (2, H, W)

                H = frame.shape[0]
                visible_occupancy_vis = generate_topdown_allocentric_map(
                    environment_layout,
                    visible_occupancy,
                    agent_poses_over_time,
                    thresh_explored=ans_cfg.thresh_explored,
                    thresh_obstacle=ans_cfg.thresh_obstacle,
                    zoom=False)
                visible_occupancy_vis = cv2.resize(visible_occupancy_vis,
                                                   (H, H))
                anticipated_occupancy_vis = generate_topdown_allocentric_map(
                    environment_layout,
                    anticipated_occupancy,
                    agent_poses_over_time,
                    thresh_explored=ans_cfg.thresh_explored,
                    thresh_obstacle=ans_cfg.thresh_obstacle,
                    zoom=False)
                anticipated_occupancy_vis = cv2.resize(
                    anticipated_occupancy_vis, (H, H))
                anticipated_action_map = generate_topdown_allocentric_map(
                    environment_layout,
                    anticipated_occupancy,
                    agent_poses_over_time,
                    zoom=False,
                    thresh_explored=ans_cfg.thresh_explored,
                    thresh_obstacle=ans_cfg.thresh_obstacle,
                )
                global_goals = self.ans_net.states["curr_global_goals"]
                local_goals = self.ans_net.states["curr_local_goals"]
                if global_goals is not None:
                    cX = int(global_goals[0, 0].item())
                    cY = int(global_goals[0, 1].item())
                    anticipated_action_map = cv2.circle(
                        anticipated_action_map,
                        (cX, cY),
                        10,
                        (255, 0, 0),
                        -1,
                    )
                if local_goals is not None:
                    cX = int(local_goals[0, 0].item())
                    cY = int(local_goals[0, 1].item())
                    anticipated_action_map = cv2.circle(
                        anticipated_action_map,
                        (cX, cY),
                        10,
                        (0, 255, 255),
                        -1,
                    )
                anticipated_action_map = cv2.resize(anticipated_action_map,
                                                    (H, H))

                maps_vis = np.concatenate(
                    [
                        visible_occupancy_vis, anticipated_occupancy_vis,
                        anticipated_action_map, ego_map_gt_i
                    ],
                    axis=1,
                )

                if self.config.RL.ANS.overall_map_size == 2001 or self.config.RL.ANS.overall_map_size == 961:
                    if frame.shape[1] < maps_vis.shape[1]:
                        diff = maps_vis.shape[1] - frame.shape[1]
                        npad = ((0, 0), (diff // 2, diff // 2), (0, 0))
                        frame = np.pad(frame,
                                       pad_width=npad,
                                       mode='constant',
                                       constant_values=0)
                    elif frame.shape[1] > maps_vis.shape[1]:
                        diff = frame.shape[1] - maps_vis.shape[1]
                        npad = ((0, 0), (diff // 2, diff // 2), (0, 0))
                        maps_vis = np.pad(maps_vis,
                                          pad_width=npad,
                                          mode='constant',
                                          constant_values=0)
                frame = np.concatenate([frame, maps_vis], axis=0)
                rgb_frames.append(frame)
                if self.config.SAVE_VIDEO_IMGS:
                    try:
                        os.mkdir("fig1")
                    except:
                        pass
                    print("Saved imgs for Fig. 1!")
                    cv2.imwrite(f'fig1/rgb_{step_start_time}.jpg',
                                obs[0]['rgb'][:, :, ::-1])
                    cv2.imwrite(f'fig1/depth_{step_start_time}.jpg',
                                obs[0]['depth'] * 255.0)
                    cv2.imwrite(f'fig1/aap_{step_start_time}.jpg',
                                anticipated_action_map[..., ::-1])
                    cv2.imwrite(f'fig1/em_{step_start_time}.jpg',
                                ego_map_gt_i[..., ::-1])
                if self.config.DEBUG_VIDEO_FRAME:
                    cv2.imwrite('last_frame.jpg', frame)

                if reached_flag:
                    for f in range(20):
                        rgb_frames.append(frame)

                # Video creation
                video_dict = {"t": start_time}
                if (i + 1) % 10 == 0 or reached_flag:
                    generate_video(
                        video_option=self.config.VIDEO_OPTION,
                        video_dir=self.config.VIDEO_DIR,
                        images=rgb_frames,
                        episode_id=0,
                        checkpoint_idx=checkpoint_index,
                        metrics=video_dict,
                        tb_writer=TensorboardWriter('tb/locobot'),
                    )

            if reached_flag:
                if self.config.MANUAL_COMMANDS:
                    manual_step_time = time.time() - manual_step_start_time
                    total_manual_time = total_manual_time + manual_step_time
                    print(f"Manual elapsed time: {total_manual_time}")

                print(f"Number of steps: {i + 1}")
                print(f"Elapsed time: {time.time() - start_time}")
                print(f"Final Distance to Goal: {distance2ggoal}")
                if "bump" in obs[0]:
                    print(f"Collision: {obs[0]['bump']}")
                print("Exiting...")
                break
        return