예제 #1
0
파일: base_trainer.py 프로젝트: yilundu/crl
    def eval(self,
             eval_ckpt=None,
             log_diagnostics=[],
             output_dir=".",
             label="eval") -> None:
        r"""Main method of trainer evaluation. Calls _eval_checkpoint() that
        is specified in Trainer class that inherits from BaseRLTrainer

        Returns:
            None
        """
        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))

        if "tensorboard" in self.config.VIDEO_OPTION:
            assert (len(self.config.TENSORBOARD_DIR) > 0
                    ), "Must specify a tensorboard directory for video display"
        if "disk" in self.config.VIDEO_OPTION:
            assert (len(self.config.VIDEO_DIR) >
                    0), "Must specify a directory for storing videos on disk"

        with TensorboardWriter(self.config.TENSORBOARD_DIR,
                               flush_secs=self.flush_secs) as writer:
            if eval_ckpt is not None:  # evaluate a single checkpoint from path
                ckpt_index = os.path.split(eval_ckpt)[1].split(".")[-2]
                self._eval_checkpoint(eval_ckpt,
                                      writer,
                                      checkpoint_index=ckpt_index,
                                      log_diagnostics=log_diagnostics,
                                      output_dir=output_dir,
                                      label=label)
            else:
                if os.path.isfile(self.config.EVAL_CKPT_PATH_DIR):
                    # evaluate singe checkpoint
                    # parse checkpoint from filename
                    ckpt_index = self.config.EVAL_CKPT_PATH_DIR.split('.')[-2]
                    self._eval_checkpoint(self.config.EVAL_CKPT_PATH_DIR,
                                          writer,
                                          checkpoint_index=ckpt_index)
                else:
                    # evaluate multiple checkpoints in order
                    prev_ckpt_ind = -1
                    while True:
                        current_ckpt = None
                        while current_ckpt is None:
                            current_ckpt = poll_checkpoint_folder(
                                self.config.EVAL_CKPT_PATH_DIR, prev_ckpt_ind)
                            time.sleep(
                                2)  # sleep for 2 secs before polling again
                        logger.info(
                            f"=======current_ckpt: {current_ckpt}=======")
                        prev_ckpt_ind += 1
                        self._eval_checkpoint(
                            checkpoint_path=current_ckpt,
                            writer=writer,
                            checkpoint_index=prev_ckpt_ind,
                        )
예제 #2
0
    def eval(self) -> None:
        r"""
        Main method of evaluating PPO
        Returns:
            None
        """
        ppo_cfg = self.config.TRAINER.RL.PPO
        self.device = torch.device("cuda", ppo_cfg.pth_gpu_id)
        self.video_option = ppo_cfg.video_option.strip().split(",")

        if "tensorboard" in self.video_option:
            assert (ppo_cfg.tensorboard_dir is not None
                    ), "Must specify a tensorboard directory for video display"
        if "disk" in self.video_option:
            assert (ppo_cfg.video_dir is not None
                    ), "Must specify a directory for storing videos on disk"

        with get_tensorboard_writer(ppo_cfg.tensorboard_dir,
                                    purge_step=0,
                                    flush_secs=30) as writer:
            if os.path.isfile(ppo_cfg.eval_ckpt_path_or_dir):
                # evaluate singe checkpoint
                self._eval_checkpoint(ppo_cfg.eval_ckpt_path_or_dir, writer)
            else:
                # evaluate multiple checkpoints in order
                prev_ckpt_ind = -1
                while True:
                    current_ckpt = None
                    while current_ckpt is None:
                        current_ckpt = poll_checkpoint_folder(
                            ppo_cfg.eval_ckpt_path_or_dir, prev_ckpt_ind)
                        time.sleep(2)  # sleep for 2 secs before polling again
                    logger.warning(
                        "=============current_ckpt: {}=============".format(
                            current_ckpt))
                    prev_ckpt_ind += 1
                    self._eval_checkpoint(
                        checkpoint_path=current_ckpt,
                        writer=writer,
                        cur_ckpt_idx=prev_ckpt_ind,
                    )
예제 #3
0
    def eval(self) -> None:
        r"""Main method of trainer evaluation. Calls _eval_checkpoint() that
        is specified in Trainer class that inherits from BaseRLTrainer

        Returns:
            None
        """
        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))

        if "tensorboard" in self.config.VIDEO_OPTION:
            assert (len(self.config.TENSORBOARD_DIR) > 0
                    ), "Must specify a tensorboard directory for video display"
            os.makedirs(self.config.TENSORBOARD_DIR, exist_ok=True)
        if "disk" in self.config.VIDEO_OPTION:
            assert (len(self.config.VIDEO_DIR) >
                    0), "Must specify a directory for storing videos on disk"

        with TensorboardWriter(self.config.TENSORBOARD_DIR,
                               flush_secs=self.flush_secs) as writer:
            if os.path.isfile(self.config.EVAL_CKPT_PATH_DIR):
                # evaluate singe checkpoint
                self._eval_checkpoint(self.config.EVAL_CKPT_PATH_DIR, writer)
            else:
                # evaluate multiple checkpoints in order
                prev_ckpt_ind = -1
                while True:
                    current_ckpt = None
                    while current_ckpt is None:
                        current_ckpt = poll_checkpoint_folder(
                            self.config.EVAL_CKPT_PATH_DIR, prev_ckpt_ind)
                        time.sleep(2)  # sleep for 2 secs before polling again
                    logger.info(f"=======current_ckpt: {current_ckpt}=======")
                    prev_ckpt_ind += 1
                    self._eval_checkpoint(
                        checkpoint_path=current_ckpt,
                        writer=writer,
                        checkpoint_index=prev_ckpt_ind,
                    )
예제 #4
0
    def eval_bruce(self):
        r"""Trainer evaluation method by Bruce. Stripped Tensorflow stuff

        Returns:
            device
            actor_critic
            batch
            not_done_masks
            test_recurrent_hidden_states
        """
        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))

        actor_critic = None
        batch = None
        not_done_masks = None
        test_recurrent_hidden_states = None
        if os.path.isfile(self.config.EVAL_CKPT_PATH_DIR):
            # evaluate singe checkpoint
            actor_critic, batch, not_done_masks, test_recurrent_hidden_states = self._eval_checkpoint_bruce(
                self.config.EVAL_CKPT_PATH_DIR)
        else:
            # evaluate multiple checkpoints in order
            prev_ckpt_ind = -1
            while True:
                current_ckpt = None
                while current_ckpt is None:
                    current_ckpt = poll_checkpoint_folder(
                        self.config.EVAL_CKPT_PATH_DIR, prev_ckpt_ind)
                    time.sleep(2)  # sleep for 2 secs before polling again
                logger.info(f"=======current_ckpt: {current_ckpt}=======")
                prev_ckpt_ind += 1
                actor_critic, batch, not_done_masks, test_recurrent_hidden_states = self._eval_checkpoint_bruce(
                    checkpoint_path=current_ckpt,
                    checkpoint_index=prev_ckpt_ind,
                )

        return actor_critic, batch, self.device, not_done_masks, test_recurrent_hidden_states