示例#1
0
class TrainerSpec:
    def __init__(
        self,
        algorithm: AlgorithmSpec,
        env_spec: EnvironmentSpec,
        db_server: DBSpec,
        logdir: str,
        num_workers: int = 1,
        batch_size: int = 64,
        min_num_transitions: int = int(1e4),
        online_update_period: int = 1,
        weights_sync_period: int = 1,
        save_period: int = 10,
        gc_period: int = 10,
        seed: int = 42,
        epoch_limit: int = None,
        monitoring_params: Dict = None,
        **kwargs,
    ):
        # algorithm & environment
        self.algorithm = algorithm
        self.env_spec = env_spec

        # logging
        self.logdir = logdir
        self._prepare_logger(logdir)
        self._seeder = Seeder(init_seed=seed)

        # updates & counters
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.epoch = 0
        self.update_step = 0
        self.num_updates = 0
        self._num_trajectories = 0
        self._num_transitions = 0

        # updates configuration
        # (actor_period, critic_period)
        self.actor_grad_period, self.critic_grad_period = \
            utils.make_tuple(online_update_period)

        # synchronization configuration
        self.db_server = db_server
        self.min_num_transitions = min_num_transitions
        self.save_period = save_period
        self.weights_sync_period = weights_sync_period

        self._gc_period = gc_period

        self.replay_buffer = None
        self.replay_sampler = None
        self.loader = None
        self._epoch_limit = epoch_limit

        #  special
        self.monitoring_params = monitoring_params
        self._prepare_seed()
        self._init(**kwargs)

    def _init(self, **kwargs):
        global WANDB_ENABLED
        assert len(kwargs) == 0
        if WANDB_ENABLED:
            if self.monitoring_params is not None:
                self.checkpoints_glob: List[str] = \
                    self.monitoring_params.pop("checkpoints_glob", [])

                wandb.init(**self.monitoring_params)

                logdir_src = Path(self.logdir)
                logdir_dst = Path(wandb.run.dir)

                configs_src = logdir_src.joinpath("configs")
                os.makedirs(f"{logdir_dst}/{configs_src.name}", exist_ok=True)
                shutil.rmtree(f"{logdir_dst}/{configs_src.name}")
                shutil.copytree(f"{str(configs_src.absolute())}",
                                f"{logdir_dst}/{configs_src.name}")

                code_src = logdir_src.joinpath("code")
                if code_src.exists():
                    os.makedirs(f"{logdir_dst}/{code_src.name}", exist_ok=True)
                    shutil.rmtree(f"{logdir_dst}/{code_src.name}")
                    shutil.copytree(f"{str(code_src.absolute())}",
                                    f"{logdir_dst}/{code_src.name}")
            else:
                WANDB_ENABLED = False
        self.wandb_mode = "trainer"

    def _prepare_logger(self, logdir):
        timestamp = utils.get_utcnow_time()
        logpath = f"{logdir}/trainer.{timestamp}"
        os.makedirs(logpath, exist_ok=True)
        self.logger = SummaryWriter(logpath)

    def _prepare_seed(self):
        seed = self._seeder()[0]
        set_global_seed(seed)

    def _log_to_console(self, fps: float, updates_per_sample: float,
                        num_trajectories: int, num_transitions: int,
                        buffer_size: int, **kwargs):
        prefix = f"--- Epoch {self.epoch:09d}/{self._epoch_limit:09d}" \
            if self._epoch_limit is not None \
            else f"--- Epoch {self.epoch:09d}"
        metrics = [
            prefix,
            f"fps: {fps:7.1f}",
            f"updates per sample: {updates_per_sample:7.1f}",
            f"trajectories: {num_trajectories:09d}",
            f"transitions: {num_transitions:09d}",
            f"buffer size: {buffer_size:09d}",
        ]
        metrics = " | ".join(metrics)
        print(metrics)

    def _log_to_tensorboard(self, fps: float, updates_per_sample: float,
                            num_trajectories: int, num_transitions: int,
                            buffer_size: int, **kwargs):
        self.logger.add_scalar("fps", fps, self.epoch)
        self.logger.add_scalar("updates_per_sample", updates_per_sample,
                               self.epoch)
        self.logger.add_scalar("num_trajectories", num_trajectories,
                               self.epoch)
        self.logger.add_scalar("num_transitions", num_transitions, self.epoch)
        self.logger.add_scalar("buffer_size", buffer_size, self.epoch)
        self.logger.flush()

    @staticmethod
    def _log_wandb_metrics(metrics: Dict,
                           step: int,
                           mode: str,
                           suffix: str = ""):
        metrics = {
            f"{mode}/{key}{suffix}": value
            for key, value in metrics.items()
        }
        step = None  # @TODO: fix, wandb issue
        wandb.log(metrics, step=step)

    def _log_to_wandb(self, *, step, suffix="", **metrics):
        if WANDB_ENABLED:
            self._log_wandb_metrics(metrics,
                                    step=step,
                                    mode=self.wandb_mode,
                                    suffix=suffix)

    def _save_wandb(self):
        if WANDB_ENABLED:
            logdir_src = Path(self.logdir)
            logdir_dst = Path(wandb.run.dir)

            events_src = list(logdir_src.glob("events.out.tfevents*"))
            if len(events_src) > 0:
                events_src = events_src[0]
                os.makedirs(f"{logdir_dst}/{logdir_src.name}", exist_ok=True)
                shutil.copy2(
                    f"{str(events_src.absolute())}",
                    f"{logdir_dst}/{logdir_src.name}/{events_src.name}")

    def _save_checkpoint(self):
        if self.epoch % self.save_period == 0:
            checkpoint = self.algorithm.pack_checkpoint()
            checkpoint["epoch"] = self.epoch
            filename = utils.save_checkpoint(logdir=self.logdir,
                                             checkpoint=checkpoint,
                                             suffix=str(self.epoch))
            print(f"Checkpoint saved to: {filename}")

    def _update_sampler_weights(self):
        if self.epoch % self.weights_sync_period == 0:
            checkpoint = self.algorithm.pack_checkpoint(with_optimizer=False)
            for key in checkpoint:
                checkpoint[key] = {
                    k: v.detach().cpu().numpy()
                    for k, v in checkpoint[key].items()
                }

            self.db_server.put_checkpoint(checkpoint=checkpoint,
                                          epoch=self.epoch)

    def _update_target_weights(self, update_step) -> Dict:
        pass

    def _run_loader(self, loader: DataLoader) -> Dict:
        start_time = time.time()

        # @TODO: add average meters
        for batch in loader:
            metrics: Dict = self.algorithm.train(
                batch,
                actor_update=(self.update_step % self.actor_grad_period == 0),
                critic_update=(self.update_step % self.critic_grad_period
                               == 0)) or {}
            self.update_step += 1

            metrics_ = self._update_target_weights(self.update_step) or {}
            metrics.update(**metrics_)

            metrics = dict((key, value) for key, value in metrics.items()
                           if isinstance(value, (float, int)))

            for key, value in metrics.items():
                self.logger.add_scalar(key, value, self.update_step)
            self._log_to_wandb(step=self.update_step,
                               suffix="_batch",
                               **metrics)

        elapsed_time = time.time() - start_time
        elapsed_num_updates = len(loader) * loader.batch_size
        self.num_updates += elapsed_num_updates
        fps = elapsed_num_updates / elapsed_time

        output = {"elapsed_time": elapsed_time, "fps": fps}

        return output

    def _run_epoch(self) -> Dict:
        raise NotImplementedError()

    def _run_epoch_loop(self):
        self._prepare_seed()
        metrics: Dict = self._run_epoch()
        self.epoch += 1
        self._log_to_console(**metrics)
        self._log_to_tensorboard(**metrics)
        self._log_to_wandb(step=self.epoch, suffix="_epoch", **metrics)
        self._save_checkpoint()
        self._save_wandb()
        self._update_sampler_weights()
        if self.epoch % self._gc_period == 0:
            gc.collect()

    def _run_train_stage(self):
        self.db_server.push_message(self.db_server.Message.ENABLE_TRAINING)
        epoch_limit = self._epoch_limit or np.iinfo(np.int32).max
        while self.epoch < epoch_limit:
            try:
                self._run_epoch_loop()
            except Exception as ex:
                self.db_server.push_message(
                    self.db_server.Message.DISABLE_TRAINING)
                raise ex
        self.db_server.push_message(self.db_server.Message.DISABLE_TRAINING)

    def _start_train_loop(self):
        self._run_train_stage()

    def run(self):
        self._update_sampler_weights()
        self._start_train_loop()
        self.logger.close()
class Sampler:
    def __init__(
            self,
            agent: Union[ActorSpec, CriticSpec],
            env: EnvironmentSpec,
            db_server: DBSpec = None,
            exploration_handler: ExplorationHandler = None,
            logdir: str = None,
            id: int = 0,
            mode: str = "infer",  # train/valid/infer
            deterministic: bool = None,
            weights_sync_period: int = 1,
            weights_sync_mode: str = None,
            sampler_seed: int = 42,
            trajectory_seeds: List = None,
            trajectory_limit: int = None,
            force_store: bool = False,
            gc_period: int = 10,
            monitoring_params: Dict = None,
            **kwargs):
        self._device = utils.get_device()
        self._sampler_id = id

        self._deterministic = deterministic \
            if deterministic is not None \
            else mode in ["valid", "infer"]
        self.trajectory_seeds = trajectory_seeds
        self._seeder = Seeder(init_seed=sampler_seed)

        # logging
        self._prepare_logger(logdir, mode)
        self._sampling_flag = mp.Value(c_bool, False)
        self._training_flag = mp.Value(c_bool, True)

        # environment, model, exploration & action handlers
        self.env = env
        self.agent = agent
        self.exploration_handler = exploration_handler
        self.trajectory_index = 0
        self.trajectory_sampler = TrajectorySampler(
            env=self.env,
            agent=self.agent,
            device=self._device,
            deterministic=self._deterministic,
            sampling_flag=self._sampling_flag)

        # synchronization configuration
        self.db_server = db_server
        self._weights_sync_period = weights_sync_period
        self._weights_sync_mode = weights_sync_mode
        self._trajectory_limit = trajectory_limit or np.iinfo(np.int32).max
        self._force_store = force_store
        self._gc_period = gc_period
        self._db_loop_thread = None
        self.checkpoint = None

        #  special
        self.monitoring_params = monitoring_params
        self._init(**kwargs)

    def _init(self, **kwargs):
        global WANDB_ENABLED
        assert len(kwargs) == 0
        if WANDB_ENABLED:
            if self.monitoring_params is not None:
                self.checkpoints_glob: List[str] = \
                    self.monitoring_params.pop(
                        "checkpoints_glob", ["best.pth", "last.pth"])

                wandb.init(**self.monitoring_params)
            else:
                WANDB_ENABLED = False
        self.wandb_mode = "sampler"

    def _prepare_logger(self, logdir, mode):
        if logdir is not None:
            timestamp = utils.get_utcnow_time()
            logpath = f"{logdir}/" \
                f"sampler.{mode}.{self._sampler_id}.{timestamp}"
            os.makedirs(logpath, exist_ok=True)
            self.logdir = logpath
            self.logger = SummaryWriter(logpath)
        else:
            self.logdir = None
            self.logger = None

    def _start_db_loop(self):
        if self.db_server is None:
            self._training_flag.value = True
            self._sampling_flag.value = True
            return
        self._db_loop_thread = threading.Thread(target=_db2sampler_loop,
                                                kwargs={
                                                    "sampler": self,
                                                })
        self._db_loop_thread.start()

    def load_checkpoint(self,
                        *,
                        filepath: str = None,
                        db_server: DBSpec = None):
        if filepath is not None:
            checkpoint = utils.load_checkpoint(filepath)
        elif db_server is not None:
            checkpoint = db_server.get_checkpoint()
            while checkpoint is None:
                time.sleep(3.0)
                checkpoint = db_server.get_checkpoint()
        else:
            raise NotImplementedError("No checkpoint found")

        self.checkpoint = checkpoint
        weights = self.checkpoint[f"{self._weights_sync_mode}_state_dict"]
        weights = {
            k: utils.any2device(v, device=self._device)
            for k, v in weights.items()
        }
        self.agent.load_state_dict(weights)
        self.agent.to(self._device)
        self.agent.eval()

    def _store_trajectory(self, trajectory, raw=False):
        if self.db_server is None:
            return
        self.db_server.put_trajectory(trajectory, raw=raw)

    def _get_seed(self):
        if self.trajectory_seeds is not None:
            seed = self.trajectory_seeds[self.trajectory_index %
                                         len(self.trajectory_seeds)]
        else:
            seed = self._seeder()[0]
        set_global_seed(seed)
        return seed

    def _log_to_console(self, *, reward, raw_reward, num_steps, elapsed_time,
                        seed):
        metrics = [
            f"trajectory {int(self.trajectory_index):05d}",
            f"steps: {int(num_steps):05d}",
            f"reward: {reward:9.3f}",
            f"raw_reward: {raw_reward:9.3f}",
            f"time: {elapsed_time:9.3f}",
            f"seed: {seed:010d}",
        ]
        metrics = " | ".join(metrics)
        print(f"--- {metrics}")

    def _log_to_tensorboard(self, *, reward, raw_reward, num_steps,
                            elapsed_time, **kwargs):
        if self.logger is not None:
            self.logger.add_scalar("trajectory/num_steps", num_steps,
                                   self.trajectory_index)
            self.logger.add_scalar("trajectory/reward", reward,
                                   self.trajectory_index)
            self.logger.add_scalar("trajectory/raw_reward", raw_reward,
                                   self.trajectory_index)
            self.logger.add_scalar("time/trajectories_per_minute",
                                   60. / elapsed_time, self.trajectory_index)
            self.logger.add_scalar("time/steps_per_second",
                                   num_steps / elapsed_time,
                                   self.trajectory_index)
            self.logger.add_scalar("time/trajectory_time_sec", elapsed_time,
                                   self.trajectory_index)
            self.logger.add_scalar("time/step_time_sec",
                                   elapsed_time / num_steps,
                                   self.trajectory_index)

            self.logger.flush()

    @staticmethod
    def _log_wandb_metrics(metrics: Dict,
                           step: int,
                           mode: str,
                           suffix: str = ""):
        metrics = {
            f"{mode}/{key}{suffix}": value
            for key, value in metrics.items()
        }
        step = None  # @TODO: fix, wandb issue
        wandb.log(metrics, step=step)

    def _log_to_wandb(self, *, step, suffix="", **metrics):
        if WANDB_ENABLED:
            self._log_wandb_metrics(metrics,
                                    step=step,
                                    mode=self.wandb_mode,
                                    suffix=suffix)

    def _save_wandb(self):
        if WANDB_ENABLED:
            logdir_src = Path(self.logdir)
            logdir_dst = Path(wandb.run.dir)

            events_src = list(logdir_src.glob("events.out.tfevents*"))
            if len(events_src) > 0:
                events_src = events_src[0]
                os.makedirs(f"{logdir_dst}/{logdir_src.name}", exist_ok=True)
                shutil.copy2(
                    f"{str(events_src.absolute())}",
                    f"{logdir_dst}/{logdir_src.name}/{events_src.name}")

    @torch.no_grad()
    def _run_trajectory_loop(self):
        seed = self._get_seed()
        exploration_strategy = \
            self.exploration_handler.get_exploration_strategy() \
            if self.exploration_handler is not None \
            else None
        self.trajectory_sampler.reset(exploration_strategy)

        start_time = time.time()
        trajectory, trajectory_info = self.trajectory_sampler.sample(
            exploration_strategy=exploration_strategy)
        elapsed_time = time.time() - start_time

        trajectory_info = trajectory_info or {}
        trajectory_info.update({"elapsed_time": elapsed_time, "seed": seed})
        return trajectory, trajectory_info

    def _run_sample_loop(self):
        while self._training_flag.value:
            while not self._sampling_flag.value:
                if not self._training_flag.value:
                    return
                time.sleep(5.0)

            # 1 – load from db, 2 – resume load trick (already have checkpoint)
            need_checkpoint = \
                self.db_server is not None or self.checkpoint is None
            if self.trajectory_index % self._weights_sync_period == 0 \
                    and need_checkpoint:
                self.load_checkpoint(db_server=self.db_server)
                self._save_wandb()

            trajectory, trajectory_info = self._run_trajectory_loop()
            if trajectory is None:
                continue
            raw_trajectory = trajectory_info.pop("raw_trajectory", None)
            # Do it firsthand, so the loggers don't crush
            if not self._deterministic or self._force_store:
                self._store_trajectory(trajectory)
                if raw_trajectory is not None:
                    self._store_trajectory(raw_trajectory, raw=True)
            self._log_to_console(**trajectory_info)
            self._log_to_tensorboard(**trajectory_info)
            self._log_to_wandb(step=self.trajectory_index, **trajectory_info)
            self.trajectory_index += 1

            if self.trajectory_index % self._gc_period == 0:
                gc.collect()

            if not self._training_flag.value \
                    or self.trajectory_index >= self._trajectory_limit:
                return

    def _start_sample_loop(self):
        self._run_sample_loop()

    def run(self):
        self._start_db_loop()
        self._start_sample_loop()

        if self.logger is not None:
            self.logger.close()