class TrainerSpec: def __init__( self, algorithm: AlgorithmSpec, env_spec: EnvironmentSpec, db_server: DBSpec, logdir: str, num_workers: int = 1, batch_size: int = 64, min_num_transitions: int = int(1e4), online_update_period: int = 1, weights_sync_period: int = 1, save_period: int = 10, gc_period: int = 10, seed: int = 42, epoch_limit: int = None, monitoring_params: Dict = None, **kwargs, ): # algorithm & environment self.algorithm = algorithm self.env_spec = env_spec # logging self.logdir = logdir self._prepare_logger(logdir) self._seeder = Seeder(init_seed=seed) # updates & counters self.batch_size = batch_size self.num_workers = num_workers self.epoch = 0 self.update_step = 0 self.num_updates = 0 self._num_trajectories = 0 self._num_transitions = 0 # updates configuration # (actor_period, critic_period) self.actor_grad_period, self.critic_grad_period = \ utils.make_tuple(online_update_period) # synchronization configuration self.db_server = db_server self.min_num_transitions = min_num_transitions self.save_period = save_period self.weights_sync_period = weights_sync_period self._gc_period = gc_period self.replay_buffer = None self.replay_sampler = None self.loader = None self._epoch_limit = epoch_limit # special self.monitoring_params = monitoring_params self._prepare_seed() self._init(**kwargs) def _init(self, **kwargs): global WANDB_ENABLED assert len(kwargs) == 0 if WANDB_ENABLED: if self.monitoring_params is not None: self.checkpoints_glob: List[str] = \ self.monitoring_params.pop("checkpoints_glob", []) wandb.init(**self.monitoring_params) logdir_src = Path(self.logdir) logdir_dst = Path(wandb.run.dir) configs_src = logdir_src.joinpath("configs") os.makedirs(f"{logdir_dst}/{configs_src.name}", exist_ok=True) shutil.rmtree(f"{logdir_dst}/{configs_src.name}") shutil.copytree(f"{str(configs_src.absolute())}", f"{logdir_dst}/{configs_src.name}") code_src = logdir_src.joinpath("code") if code_src.exists(): os.makedirs(f"{logdir_dst}/{code_src.name}", exist_ok=True) shutil.rmtree(f"{logdir_dst}/{code_src.name}") shutil.copytree(f"{str(code_src.absolute())}", f"{logdir_dst}/{code_src.name}") else: WANDB_ENABLED = False self.wandb_mode = "trainer" def _prepare_logger(self, logdir): timestamp = utils.get_utcnow_time() logpath = f"{logdir}/trainer.{timestamp}" os.makedirs(logpath, exist_ok=True) self.logger = SummaryWriter(logpath) def _prepare_seed(self): seed = self._seeder()[0] set_global_seed(seed) def _log_to_console(self, fps: float, updates_per_sample: float, num_trajectories: int, num_transitions: int, buffer_size: int, **kwargs): prefix = f"--- Epoch {self.epoch:09d}/{self._epoch_limit:09d}" \ if self._epoch_limit is not None \ else f"--- Epoch {self.epoch:09d}" metrics = [ prefix, f"fps: {fps:7.1f}", f"updates per sample: {updates_per_sample:7.1f}", f"trajectories: {num_trajectories:09d}", f"transitions: {num_transitions:09d}", f"buffer size: {buffer_size:09d}", ] metrics = " | ".join(metrics) print(metrics) def _log_to_tensorboard(self, fps: float, updates_per_sample: float, num_trajectories: int, num_transitions: int, buffer_size: int, **kwargs): self.logger.add_scalar("fps", fps, self.epoch) self.logger.add_scalar("updates_per_sample", updates_per_sample, self.epoch) self.logger.add_scalar("num_trajectories", num_trajectories, self.epoch) self.logger.add_scalar("num_transitions", num_transitions, self.epoch) self.logger.add_scalar("buffer_size", buffer_size, self.epoch) self.logger.flush() @staticmethod def _log_wandb_metrics(metrics: Dict, step: int, mode: str, suffix: str = ""): metrics = { f"{mode}/{key}{suffix}": value for key, value in metrics.items() } step = None # @TODO: fix, wandb issue wandb.log(metrics, step=step) def _log_to_wandb(self, *, step, suffix="", **metrics): if WANDB_ENABLED: self._log_wandb_metrics(metrics, step=step, mode=self.wandb_mode, suffix=suffix) def _save_wandb(self): if WANDB_ENABLED: logdir_src = Path(self.logdir) logdir_dst = Path(wandb.run.dir) events_src = list(logdir_src.glob("events.out.tfevents*")) if len(events_src) > 0: events_src = events_src[0] os.makedirs(f"{logdir_dst}/{logdir_src.name}", exist_ok=True) shutil.copy2( f"{str(events_src.absolute())}", f"{logdir_dst}/{logdir_src.name}/{events_src.name}") def _save_checkpoint(self): if self.epoch % self.save_period == 0: checkpoint = self.algorithm.pack_checkpoint() checkpoint["epoch"] = self.epoch filename = utils.save_checkpoint(logdir=self.logdir, checkpoint=checkpoint, suffix=str(self.epoch)) print(f"Checkpoint saved to: {filename}") def _update_sampler_weights(self): if self.epoch % self.weights_sync_period == 0: checkpoint = self.algorithm.pack_checkpoint(with_optimizer=False) for key in checkpoint: checkpoint[key] = { k: v.detach().cpu().numpy() for k, v in checkpoint[key].items() } self.db_server.put_checkpoint(checkpoint=checkpoint, epoch=self.epoch) def _update_target_weights(self, update_step) -> Dict: pass def _run_loader(self, loader: DataLoader) -> Dict: start_time = time.time() # @TODO: add average meters for batch in loader: metrics: Dict = self.algorithm.train( batch, actor_update=(self.update_step % self.actor_grad_period == 0), critic_update=(self.update_step % self.critic_grad_period == 0)) or {} self.update_step += 1 metrics_ = self._update_target_weights(self.update_step) or {} metrics.update(**metrics_) metrics = dict((key, value) for key, value in metrics.items() if isinstance(value, (float, int))) for key, value in metrics.items(): self.logger.add_scalar(key, value, self.update_step) self._log_to_wandb(step=self.update_step, suffix="_batch", **metrics) elapsed_time = time.time() - start_time elapsed_num_updates = len(loader) * loader.batch_size self.num_updates += elapsed_num_updates fps = elapsed_num_updates / elapsed_time output = {"elapsed_time": elapsed_time, "fps": fps} return output def _run_epoch(self) -> Dict: raise NotImplementedError() def _run_epoch_loop(self): self._prepare_seed() metrics: Dict = self._run_epoch() self.epoch += 1 self._log_to_console(**metrics) self._log_to_tensorboard(**metrics) self._log_to_wandb(step=self.epoch, suffix="_epoch", **metrics) self._save_checkpoint() self._save_wandb() self._update_sampler_weights() if self.epoch % self._gc_period == 0: gc.collect() def _run_train_stage(self): self.db_server.push_message(self.db_server.Message.ENABLE_TRAINING) epoch_limit = self._epoch_limit or np.iinfo(np.int32).max while self.epoch < epoch_limit: try: self._run_epoch_loop() except Exception as ex: self.db_server.push_message( self.db_server.Message.DISABLE_TRAINING) raise ex self.db_server.push_message(self.db_server.Message.DISABLE_TRAINING) def _start_train_loop(self): self._run_train_stage() def run(self): self._update_sampler_weights() self._start_train_loop() self.logger.close()
class Sampler: def __init__( self, agent: Union[ActorSpec, CriticSpec], env: EnvironmentSpec, db_server: DBSpec = None, exploration_handler: ExplorationHandler = None, logdir: str = None, id: int = 0, mode: str = "infer", # train/valid/infer deterministic: bool = None, weights_sync_period: int = 1, weights_sync_mode: str = None, sampler_seed: int = 42, trajectory_seeds: List = None, trajectory_limit: int = None, force_store: bool = False, gc_period: int = 10, monitoring_params: Dict = None, **kwargs): self._device = utils.get_device() self._sampler_id = id self._deterministic = deterministic \ if deterministic is not None \ else mode in ["valid", "infer"] self.trajectory_seeds = trajectory_seeds self._seeder = Seeder(init_seed=sampler_seed) # logging self._prepare_logger(logdir, mode) self._sampling_flag = mp.Value(c_bool, False) self._training_flag = mp.Value(c_bool, True) # environment, model, exploration & action handlers self.env = env self.agent = agent self.exploration_handler = exploration_handler self.trajectory_index = 0 self.trajectory_sampler = TrajectorySampler( env=self.env, agent=self.agent, device=self._device, deterministic=self._deterministic, sampling_flag=self._sampling_flag) # synchronization configuration self.db_server = db_server self._weights_sync_period = weights_sync_period self._weights_sync_mode = weights_sync_mode self._trajectory_limit = trajectory_limit or np.iinfo(np.int32).max self._force_store = force_store self._gc_period = gc_period self._db_loop_thread = None self.checkpoint = None # special self.monitoring_params = monitoring_params self._init(**kwargs) def _init(self, **kwargs): global WANDB_ENABLED assert len(kwargs) == 0 if WANDB_ENABLED: if self.monitoring_params is not None: self.checkpoints_glob: List[str] = \ self.monitoring_params.pop( "checkpoints_glob", ["best.pth", "last.pth"]) wandb.init(**self.monitoring_params) else: WANDB_ENABLED = False self.wandb_mode = "sampler" def _prepare_logger(self, logdir, mode): if logdir is not None: timestamp = utils.get_utcnow_time() logpath = f"{logdir}/" \ f"sampler.{mode}.{self._sampler_id}.{timestamp}" os.makedirs(logpath, exist_ok=True) self.logdir = logpath self.logger = SummaryWriter(logpath) else: self.logdir = None self.logger = None def _start_db_loop(self): if self.db_server is None: self._training_flag.value = True self._sampling_flag.value = True return self._db_loop_thread = threading.Thread(target=_db2sampler_loop, kwargs={ "sampler": self, }) self._db_loop_thread.start() def load_checkpoint(self, *, filepath: str = None, db_server: DBSpec = None): if filepath is not None: checkpoint = utils.load_checkpoint(filepath) elif db_server is not None: checkpoint = db_server.get_checkpoint() while checkpoint is None: time.sleep(3.0) checkpoint = db_server.get_checkpoint() else: raise NotImplementedError("No checkpoint found") self.checkpoint = checkpoint weights = self.checkpoint[f"{self._weights_sync_mode}_state_dict"] weights = { k: utils.any2device(v, device=self._device) for k, v in weights.items() } self.agent.load_state_dict(weights) self.agent.to(self._device) self.agent.eval() def _store_trajectory(self, trajectory, raw=False): if self.db_server is None: return self.db_server.put_trajectory(trajectory, raw=raw) def _get_seed(self): if self.trajectory_seeds is not None: seed = self.trajectory_seeds[self.trajectory_index % len(self.trajectory_seeds)] else: seed = self._seeder()[0] set_global_seed(seed) return seed def _log_to_console(self, *, reward, raw_reward, num_steps, elapsed_time, seed): metrics = [ f"trajectory {int(self.trajectory_index):05d}", f"steps: {int(num_steps):05d}", f"reward: {reward:9.3f}", f"raw_reward: {raw_reward:9.3f}", f"time: {elapsed_time:9.3f}", f"seed: {seed:010d}", ] metrics = " | ".join(metrics) print(f"--- {metrics}") def _log_to_tensorboard(self, *, reward, raw_reward, num_steps, elapsed_time, **kwargs): if self.logger is not None: self.logger.add_scalar("trajectory/num_steps", num_steps, self.trajectory_index) self.logger.add_scalar("trajectory/reward", reward, self.trajectory_index) self.logger.add_scalar("trajectory/raw_reward", raw_reward, self.trajectory_index) self.logger.add_scalar("time/trajectories_per_minute", 60. / elapsed_time, self.trajectory_index) self.logger.add_scalar("time/steps_per_second", num_steps / elapsed_time, self.trajectory_index) self.logger.add_scalar("time/trajectory_time_sec", elapsed_time, self.trajectory_index) self.logger.add_scalar("time/step_time_sec", elapsed_time / num_steps, self.trajectory_index) self.logger.flush() @staticmethod def _log_wandb_metrics(metrics: Dict, step: int, mode: str, suffix: str = ""): metrics = { f"{mode}/{key}{suffix}": value for key, value in metrics.items() } step = None # @TODO: fix, wandb issue wandb.log(metrics, step=step) def _log_to_wandb(self, *, step, suffix="", **metrics): if WANDB_ENABLED: self._log_wandb_metrics(metrics, step=step, mode=self.wandb_mode, suffix=suffix) def _save_wandb(self): if WANDB_ENABLED: logdir_src = Path(self.logdir) logdir_dst = Path(wandb.run.dir) events_src = list(logdir_src.glob("events.out.tfevents*")) if len(events_src) > 0: events_src = events_src[0] os.makedirs(f"{logdir_dst}/{logdir_src.name}", exist_ok=True) shutil.copy2( f"{str(events_src.absolute())}", f"{logdir_dst}/{logdir_src.name}/{events_src.name}") @torch.no_grad() def _run_trajectory_loop(self): seed = self._get_seed() exploration_strategy = \ self.exploration_handler.get_exploration_strategy() \ if self.exploration_handler is not None \ else None self.trajectory_sampler.reset(exploration_strategy) start_time = time.time() trajectory, trajectory_info = self.trajectory_sampler.sample( exploration_strategy=exploration_strategy) elapsed_time = time.time() - start_time trajectory_info = trajectory_info or {} trajectory_info.update({"elapsed_time": elapsed_time, "seed": seed}) return trajectory, trajectory_info def _run_sample_loop(self): while self._training_flag.value: while not self._sampling_flag.value: if not self._training_flag.value: return time.sleep(5.0) # 1 – load from db, 2 – resume load trick (already have checkpoint) need_checkpoint = \ self.db_server is not None or self.checkpoint is None if self.trajectory_index % self._weights_sync_period == 0 \ and need_checkpoint: self.load_checkpoint(db_server=self.db_server) self._save_wandb() trajectory, trajectory_info = self._run_trajectory_loop() if trajectory is None: continue raw_trajectory = trajectory_info.pop("raw_trajectory", None) # Do it firsthand, so the loggers don't crush if not self._deterministic or self._force_store: self._store_trajectory(trajectory) if raw_trajectory is not None: self._store_trajectory(raw_trajectory, raw=True) self._log_to_console(**trajectory_info) self._log_to_tensorboard(**trajectory_info) self._log_to_wandb(step=self.trajectory_index, **trajectory_info) self.trajectory_index += 1 if self.trajectory_index % self._gc_period == 0: gc.collect() if not self._training_flag.value \ or self.trajectory_index >= self._trajectory_limit: return def _start_sample_loop(self): self._run_sample_loop() def run(self): self._start_db_loop() self._start_sample_loop() if self.logger is not None: self.logger.close()