コード例 #1
0
    def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        checkpoint_index: int = 0,
    ) -> None:
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        config = self.config

        config.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = self.config.EVAL.SPLIT
        config.freeze()

        eqa_cnn_pretrain_dataset = EQACNNPretrainDataset(config, mode="val")

        eval_loader = DataLoader(
            eqa_cnn_pretrain_dataset,
            batch_size=config.IL.EQACNNPretrain.batch_size,
            shuffle=False,
        )

        logger.info("[ eval_loader has {} samples ]".format(
            len(eqa_cnn_pretrain_dataset)))

        model = MultitaskCNN()

        state_dict = torch.load(checkpoint_path)
        model.load_state_dict(state_dict)

        model.to(self.device).eval()

        depth_loss = torch.nn.SmoothL1Loss()
        ae_loss = torch.nn.SmoothL1Loss()
        seg_loss = torch.nn.CrossEntropyLoss()

        t = 0
        avg_loss = 0.0
        avg_l1 = 0.0
        avg_l2 = 0.0
        avg_l3 = 0.0

        with torch.no_grad():
            for batch in eval_loader:
                t += 1

                idx, gt_rgb, gt_depth, gt_seg = batch
                gt_rgb = gt_rgb.to(self.device)
                gt_depth = gt_depth.to(self.device)
                gt_seg = gt_seg.to(self.device)

                pred_seg, pred_depth, pred_rgb = model(gt_rgb)
                l1 = seg_loss(pred_seg, gt_seg.long())
                l2 = ae_loss(pred_rgb, gt_rgb)
                l3 = depth_loss(pred_depth, gt_depth)

                loss = l1 + (10 * l2) + (10 * l3)

                avg_loss += loss.item()
                avg_l1 += l1.item()
                avg_l2 += l2.item()
                avg_l3 += l3.item()

                if t % config.LOG_INTERVAL == 0:
                    logger.info(
                        "[ Iter: {}; loss: {:.3f} ]".format(t, loss.item()), )

                if (config.EVAL_SAVE_RESULTS
                        and t % config.EVAL_SAVE_RESULTS_INTERVAL == 0):

                    result_id = "ckpt_{}_{}".format(checkpoint_index,
                                                    idx[0].item())
                    result_path = os.path.join(self.config.RESULTS_DIR,
                                               result_id)

                    self._save_results(
                        gt_rgb,
                        pred_rgb,
                        gt_seg,
                        pred_seg,
                        gt_depth,
                        pred_depth,
                        result_path,
                    )

        avg_loss /= len(eval_loader)
        avg_l1 /= len(eval_loader)
        avg_l2 /= len(eval_loader)
        avg_l3 /= len(eval_loader)

        writer.add_scalar("avg val total loss", avg_loss, checkpoint_index)
        writer.add_scalars(
            "avg val individual_losses",
            {
                "seg_loss": avg_l1,
                "ae_loss": avg_l2,
                "depth_loss": avg_l3
            },
            checkpoint_index,
        )

        logger.info("[ Average loss: {:.3f} ]".format(avg_loss))
        logger.info("[ Average seg loss: {:.3f} ]".format(avg_l1))
        logger.info("[ Average autoencoder loss: {:.4f} ]".format(avg_l2))
        logger.info("[ Average depthloss: {:.4f} ]".format(avg_l3))
コード例 #2
0
    def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        checkpoint_index: int = 0,
    ) -> None:
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        # Map location CPU is almost always better than mapping to a CUDA device.
        ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu")

        if self.config.EVAL.USE_CKPT_CONFIG:
            config = self._setup_eval_config(ckpt_dict["config"])
        else:
            config = self.config.clone()

        ppo_cfg = config.RL.PPO
        ans_cfg = config.RL.ANS

        config.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
        config.freeze()

        self.envs = construct_envs(config, get_env_class(config.ENV_NAME))
        self._setup_actor_critic_agent(ppo_cfg, ans_cfg)

        # Convert the state_dict of mapper_agent to mapper
        mapper_dict = {
            k.replace("mapper.", ""): v
            for k, v in ckpt_dict["mapper_state_dict"].items()
        }
        # Converting the state_dict of local_agent to just the local_policy.
        local_dict = {
            k.replace("actor_critic.", ""): v
            for k, v in ckpt_dict["local_state_dict"].items()
        }
        # Strict = False is set to ignore to handle the case where
        # pose_estimator is not required.
        self.mapper.load_state_dict(mapper_dict, strict=False)
        self.local_actor_critic.load_state_dict(local_dict)

        # Set models to evaluation
        self.mapper.eval()
        self.local_actor_critic.eval()

        number_of_eval_episodes = self.config.TEST_EPISODE_COUNT
        if number_of_eval_episodes == -1:
            number_of_eval_episodes = sum(self.envs.number_of_episodes)
        else:
            total_num_eps = sum(self.envs.number_of_episodes)
            if total_num_eps < number_of_eval_episodes:
                logger.warn(
                    f"Config specified {number_of_eval_episodes} eval episodes"
                    ", dataset only has {total_num_eps}."
                )
                logger.warn(f"Evaluating with {total_num_eps} instead.")
                number_of_eval_episodes = total_num_eps

        M = ans_cfg.overall_map_size
        V = ans_cfg.MAPPER.map_size
        s = ans_cfg.MAPPER.map_scale
        imH, imW = ans_cfg.image_scale_hw

        assert (
            self.envs.num_envs == 1
        ), "Number of environments needs to be 1 for evaluation"

        # Define metric accumulators
        # Navigation metrics
        navigation_metrics = {
            "success_rate": Metric(),
            "spl": Metric(),
            "distance_to_goal": Metric(),
            "time": Metric(),
            "softspl": Metric(),
        }
        per_difficulty_navigation_metrics = {
            "easy": {
                "success_rate": Metric(),
                "spl": Metric(),
                "distance_to_goal": Metric(),
                "time": Metric(),
                "softspl": Metric(),
            },
            "medium": {
                "success_rate": Metric(),
                "spl": Metric(),
                "distance_to_goal": Metric(),
                "time": Metric(),
                "softspl": Metric(),
            },
            "hard": {
                "success_rate": Metric(),
                "spl": Metric(),
                "distance_to_goal": Metric(),
                "time": Metric(),
                "softspl": Metric(),
            },
        }

        times_per_episode = deque()
        times_per_step = deque()
        # Define a simple function to return episode difficulty based on
        # the geodesic distance
        def classify_difficulty(gd):
            if gd < 5.0:
                return "easy"
            elif gd < 10.0:
                return "medium"
            else:
                return "hard"

        eval_start_time = time.time()
        # Reset environments only for the very first batch
        observations = self.envs.reset()
        for ep in range(number_of_eval_episodes):
            # ============================== Reset agent ==============================
            # Reset agent states
            state_estimates = {
                "pose_estimates": torch.zeros(self.envs.num_envs, 3).to(self.device),
                "map_states": torch.zeros(self.envs.num_envs, 2, M, M).to(self.device),
                "recurrent_hidden_states": torch.zeros(
                    1, self.envs.num_envs, ans_cfg.LOCAL_POLICY.hidden_size
                ).to(self.device),
            }
            # Reset ANS states
            self.ans_net.reset()
            self.not_done_masks = torch.zeros(self.envs.num_envs, 1, device=self.device)
            self.prev_actions = torch.zeros(self.envs.num_envs, 1, device=self.device)
            self.prev_batch = None
            self.ep_time = torch.zeros(self.envs.num_envs, 1, device=self.device)
            # =========================== Episode loop ================================
            ep_start_time = time.time()
            current_episodes = self.envs.current_episodes()
            for ep_step in range(self.config.T_MAX):
                step_start_time = time.time()
                # ============================ Action step ============================
                batch = self._prepare_batch(observations)
                if self.prev_batch is None:
                    self.prev_batch = copy.deepcopy(batch)

                prev_pose_estimates = state_estimates["pose_estimates"]
                with torch.no_grad():
                    (
                        _,
                        _,
                        mapper_outputs,
                        local_policy_outputs,
                        state_estimates,
                    ) = self.ans_net.act(
                        batch,
                        self.prev_batch,
                        state_estimates,
                        self.ep_time,
                        self.not_done_masks,
                        deterministic=ans_cfg.LOCAL_POLICY.deterministic_flag,
                    )
                    actions = local_policy_outputs["actions"]
                    # Make masks not done till reset (end of episode)
                    self.not_done_masks = torch.ones(
                        self.envs.num_envs, 1, device=self.device
                    )
                    self.prev_actions.copy_(actions)

                if ep_step == 0:
                    state_estimates["pose_estimates"].copy_(prev_pose_estimates)

                self.ep_time += 1
                # Update prev batch
                for k, v in batch.items():
                    self.prev_batch[k].copy_(v)

                # Remap actions from exploration to navigation agent.
                actions_rmp = self._remap_actions(actions)

                # =========================== Environment step ========================
                outputs = self.envs.step([a[0].item() for a in actions_rmp])

                observations, _, dones, infos = [list(x) for x in zip(*outputs)]

                times_per_step.append(time.time() - step_start_time)
                # ============================ Process metrics ========================
                if dones[0]:
                    times_per_episode.append(time.time() - ep_start_time)
                    mins_per_episode = np.mean(times_per_episode).item() / 60.0
                    eta_completion = mins_per_episode * (
                        number_of_eval_episodes - ep - 1
                    )
                    secs_per_step = np.mean(times_per_step).item()
                    for i in range(self.envs.num_envs):
                        episode_id = int(current_episodes[i].episode_id)
                        curr_metrics = {
                            "spl": infos[i]["spl"],
                            "softspl": infos[i]["softspl"],
                            "success_rate": infos[i]["success"],
                            "time": ep_step + 1,
                            "distance_to_goal": infos[i]["distance_to_goal"],
                        }
                        # Estimate difficulty of episode
                        episode_difficulty = classify_difficulty(
                            current_episodes[i].info["geodesic_distance"]
                        )
                        for k, v in curr_metrics.items():
                            navigation_metrics[k].update(v, 1.0)
                            per_difficulty_navigation_metrics[episode_difficulty][
                                k
                            ].update(v, 1.0)

                        logger.info(f"====> {ep}/{number_of_eval_episodes} done")
                        for k, v in curr_metrics.items():
                            logger.info(f"{k:25s} : {v:10.3f}")
                        logger.info("{:25s} : {:10d}".format("episode_id", episode_id))
                        logger.info(f"Time per episode: {mins_per_episode:.3f} mins")
                        logger.info(f"Time per step: {secs_per_step:.3f} secs")
                        logger.info(f"ETA: {eta_completion:.3f} mins")

                    # For navigation, terminate episode loop when dones is called
                    break
            # done-for

        if checkpoint_index == 0:
            try:
                eval_ckpt_idx = self.config.EVAL_CKPT_PATH_DIR.split("/")[-1].split(
                    "."
                )[1]
                logger.add_filehandler(
                    f"{self.config.TENSORBOARD_DIR}/navigation_results_ckpt_final_{eval_ckpt_idx}.txt"
                )
            except:
                logger.add_filehandler(
                    f"{self.config.TENSORBOARD_DIR}/navigation_results_ckpt_{checkpoint_index}.txt"
                )
        else:
            logger.add_filehandler(
                f"{self.config.TENSORBOARD_DIR}/navigation_results_ckpt_{checkpoint_index}.txt"
            )

        logger.info(
            f"======= Evaluating over {number_of_eval_episodes} episodes ============="
        )

        logger.info(f"=======> Navigation metrics")
        for k, v in navigation_metrics.items():
            logger.info(f"{k}: {v.get_metric():.3f}")
            writer.add_scalar(f"navigation/{k}", v.get_metric(), checkpoint_index)

        for diff, diff_metrics in per_difficulty_navigation_metrics.items():
            logger.info(f"=============== {diff:^10s} metrics ==============")
            for k, v in diff_metrics.items():
                logger.info(f"{k}: {v.get_metric():.3f}")
                writer.add_scalar(
                    f"{diff}_navigation/{k}", v.get_metric(), checkpoint_index
                )

        total_eval_time = (time.time() - eval_start_time) / 60.0
        logger.info(f"Total evaluation time: {total_eval_time:.3f} mins")
        self.envs.close()
コード例 #3
0
ファイル: dagger_trainer.py プロジェクト: roy860328/VLN-CE
    def _eval_checkpoint(self,
                         checkpoint_path: str,
                         writer: TensorboardWriter,
                         checkpoint_index: int = 0) -> None:
        r"""Evaluates a single checkpoint. Assumes episode IDs are unique.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        logger.info(f"checkpoint_path: {checkpoint_path}")

        if self.config.EVAL.USE_CKPT_CONFIG:
            config = self._setup_eval_config(
                self.load_checkpoint(checkpoint_path,
                                     map_location="cpu")["config"])
        else:
            config = self.config.clone()

        config.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
        config.TASK_CONFIG.TASK.NDTW.SPLIT = config.EVAL.SPLIT
        config.TASK_CONFIG.TASK.SDTW.SPLIT = config.EVAL.SPLIT
        config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.SHUFFLE = False
        config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = -1
        if len(config.VIDEO_OPTION) > 0:
            config.defrost()
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")

        config.freeze()

        # setup agent
        self.envs = construct_envs_auto_reset_false(
            config, get_env_class(config.ENV_NAME))
        self.device = (torch.device("cuda", config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))

        self._setup_actor_critic_agent(config.MODEL, True, checkpoint_path)

        observations = self.envs.reset()
        observations = transform_obs(
            observations, config.TASK_CONFIG.TASK.INSTRUCTION_SENSOR_UUID)
        batch = batch_obs(observations, self.device)

        eval_recurrent_hidden_states = torch.zeros(
            self.actor_critic.net.num_recurrent_layers,
            config.NUM_PROCESSES,
            config.MODEL.STATE_ENCODER.hidden_size,
            device=self.device,
        )
        prev_actions = torch.zeros(config.NUM_PROCESSES,
                                   1,
                                   device=self.device,
                                   dtype=torch.long)
        not_done_masks = torch.zeros(config.NUM_PROCESSES,
                                     1,
                                     device=self.device)

        stats_episodes = {}  # dict of dicts that stores stats per episode

        if len(config.VIDEO_OPTION) > 0:
            os.makedirs(config.VIDEO_DIR, exist_ok=True)
            rgb_frames = [[] for _ in range(config.NUM_PROCESSES)]

        self.actor_critic.eval()
        while (self.envs.num_envs > 0
               and len(stats_episodes) < config.EVAL.EPISODE_COUNT):
            current_episodes = self.envs.current_episodes()

            with torch.no_grad():
                (_, actions, _,
                 eval_recurrent_hidden_states) = self.actor_critic.act(
                     batch,
                     eval_recurrent_hidden_states,
                     prev_actions,
                     not_done_masks,
                     deterministic=True,
                 )
                # actions = batch["vln_oracle_action_sensor"].long()
                prev_actions.copy_(actions)

            outputs = self.envs.step([a[0].item() for a in actions])
            observations, _, dones, infos = [list(x) for x in zip(*outputs)]

            not_done_masks = torch.tensor(
                [[0.0] if done else [1.0] for done in dones],
                dtype=torch.float,
                device=self.device,
            )

            # reset envs and observations if necessary
            for i in range(self.envs.num_envs):
                if len(config.VIDEO_OPTION) > 0:
                    frame = observations_to_image(observations[i], infos[i])
                    frame = append_text_to_image(
                        frame,
                        current_episodes[i].instruction.instruction_text)
                    rgb_frames[i].append(frame)

                if not dones[i]:
                    continue

                stats_episodes[current_episodes[i].episode_id] = infos[i]
                observations[i] = self.envs.reset_at(i)[0]
                prev_actions[i] = torch.zeros(1, dtype=torch.long)

                if len(config.VIDEO_OPTION) > 0:
                    generate_video(
                        video_option=config.VIDEO_OPTION,
                        video_dir=config.VIDEO_DIR,
                        images=rgb_frames[i],
                        episode_id=current_episodes[i].episode_id,
                        checkpoint_idx=checkpoint_index,
                        metrics={
                            "SPL":
                            round(
                                stats_episodes[current_episodes[i].episode_id]
                                ["spl"], 6)
                        },
                        tb_writer=writer,
                    )

                    del stats_episodes[
                        current_episodes[i].episode_id]["top_down_map"]
                    del stats_episodes[
                        current_episodes[i].episode_id]["collisions"]
                    rgb_frames[i] = []

            observations = transform_obs(
                observations, config.TASK_CONFIG.TASK.INSTRUCTION_SENSOR_UUID)
            batch = batch_obs(observations, self.device)

            envs_to_pause = []
            next_episodes = self.envs.current_episodes()

            for i in range(self.envs.num_envs):
                if next_episodes[i].episode_id in stats_episodes:
                    envs_to_pause.append(i)

            (
                self.envs,
                eval_recurrent_hidden_states,
                not_done_masks,
                prev_actions,
                batch,
            ) = self._pause_envs(
                envs_to_pause,
                self.envs,
                eval_recurrent_hidden_states,
                not_done_masks,
                prev_actions,
                batch,
            )

        self.envs.close()

        aggregated_stats = {}
        num_episodes = len(stats_episodes)
        for stat_key in next(iter(stats_episodes.values())).keys():
            aggregated_stats[stat_key] = (
                sum([v[stat_key]
                     for v in stats_episodes.values()]) / num_episodes)

        split = config.TASK_CONFIG.DATASET.SPLIT
        with open(f"stats_ckpt_{checkpoint_index}_{split}.json", "w") as f:
            json.dump(aggregated_stats, f, indent=4)

        logger.info(f"Episodes evaluated: {num_episodes}")
        checkpoint_num = checkpoint_index + 1
        for k, v in aggregated_stats.items():
            logger.info(f"Average episode {k}: {v:.6f}")
            writer.add_scalar(f"eval_{split}_{k}", v, checkpoint_num)
コード例 #4
0
ファイル: vqa_trainer.py プロジェクト: jsdd25/habitat-lab
    def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        checkpoint_index: int = 0,
    ) -> None:
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        config = self.config

        config.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = self.config.EVAL.SPLIT
        config.freeze()

        vqa_dataset = (
            EQADataset(
                config,
                input_type="vqa",
                num_frames=config.IL.VQA.num_frames,
            )
            .shuffle(1000)
            .to_tuple(
                "episode_id",
                "question",
                "answer",
                *["{0:0=3d}.jpg".format(x) for x in range(0, 5)],
            )
            .map(img_bytes_2_np_array)
        )

        eval_loader = DataLoader(
            vqa_dataset, batch_size=config.IL.VQA.batch_size
        )

        logger.info("eval_loader has {} samples".format(len(vqa_dataset)))

        q_vocab_dict, ans_vocab_dict = vqa_dataset.get_vocab_dicts()

        model_kwargs = {
            "q_vocab": q_vocab_dict.word2idx_dict,
            "ans_vocab": ans_vocab_dict.word2idx_dict,
            "eqa_cnn_pretrain_ckpt_path": config.EQA_CNN_PRETRAIN_CKPT_PATH,
        }
        model = VqaLstmCnnAttentionModel(**model_kwargs)

        state_dict = torch.load(
            checkpoint_path, map_location={"cuda:0": "cpu"}
        )
        model.load_state_dict(state_dict)

        lossFn = torch.nn.CrossEntropyLoss()

        t = 0

        avg_loss = 0.0
        avg_accuracy = 0.0
        avg_mean_rank = 0.0
        avg_mean_reciprocal_rank = 0.0

        model.eval()
        model.cnn.eval()
        model.to(self.device)

        metrics = VqaMetric(
            info={"split": "val"},
            metric_names=[
                "loss",
                "accuracy",
                "mean_rank",
                "mean_reciprocal_rank",
            ],
            log_json=os.path.join(config.OUTPUT_LOG_DIR, "eval.json"),
        )
        with torch.no_grad():
            for batch in eval_loader:
                t += 1
                episode_ids, questions, answers, frame_queue = batch
                questions = questions.to(self.device)
                answers = answers.to(self.device)
                frame_queue = frame_queue.to(self.device)

                scores, _ = model(frame_queue, questions)

                loss = lossFn(scores, answers)

                accuracy, ranks = metrics.compute_ranks(
                    scores.data.cpu(), answers
                )
                metrics.update([loss.item(), accuracy, ranks, 1.0 / ranks])

                (
                    metrics_loss,
                    accuracy,
                    mean_rank,
                    mean_reciprocal_rank,
                ) = metrics.get_stats(mode=0)

                avg_loss += metrics_loss
                avg_accuracy += accuracy
                avg_mean_rank += mean_rank
                avg_mean_reciprocal_rank += mean_reciprocal_rank

                if t % config.LOG_INTERVAL == 0:
                    logger.info(metrics.get_stat_string(mode=0))
                    metrics.dump_log()

                if (
                    config.EVAL_SAVE_RESULTS
                    and t % config.EVAL_SAVE_RESULTS_INTERVAL == 0
                ):

                    self._save_vqa_results(
                        checkpoint_index,
                        episode_ids,
                        questions,
                        frame_queue,
                        scores,
                        answers,
                        q_vocab_dict,
                        ans_vocab_dict,
                    )

        num_batches = math.ceil(len(vqa_dataset) / config.IL.VQA.batch_size)

        avg_loss /= num_batches
        avg_accuracy /= num_batches
        avg_mean_rank /= num_batches
        avg_mean_reciprocal_rank /= num_batches

        writer.add_scalar("avg val loss", avg_loss, checkpoint_index)
        writer.add_scalar("avg val accuracy", avg_accuracy, checkpoint_index)
        writer.add_scalar("avg val mean rank", avg_mean_rank, checkpoint_index)
        writer.add_scalar(
            "avg val mean reciprocal rank",
            avg_mean_reciprocal_rank,
            checkpoint_index,
        )

        logger.info("Average loss: {:.2f}".format(avg_loss))
        logger.info("Average accuracy: {:.2f}".format(avg_accuracy))
        logger.info("Average mean rank: {:.2f}".format(avg_mean_rank))
        logger.info(
            "Average mean reciprocal rank: {:.2f}".format(
                avg_mean_reciprocal_rank
            )
        )