def main(args, unknown_args): args, config = parse_args_uargs(args, unknown_args, dump_config=True) os.makedirs(args.logdir, exist_ok=True) save_config(config=config, logdir=args.logdir) if args.expdir is not None: modules = prepare_modules( # noqa: F841 expdir=args.expdir, dump_dir=args.logdir) algorithm = Registry.get_fn("algorithm", args.algorithm) algorithm_kwargs = algorithm.prepare_for_trainer(config) redis_server = StrictRedis(port=config.get("redis", {}).get("port", 12000)) redis_prefix = config.get("redis", {}).get("prefix", "") pprint(config["trainer"]) pprint(algorithm_kwargs) trainer = Trainer(**config["trainer"], **algorithm_kwargs, logdir=args.logdir, redis_server=redis_server, redis_prefix=redis_prefix) pprint(trainer) def on_exit(): for p in trainer.get_processes(): p.terminate() atexit.register(on_exit) trainer.run()
def dump_environment( experiment_config: Any, logdir: str, configs_path: List[str] = None, ) -> None: """ Saves config, environment variables and package list in JSON into logdir. Args: experiment_config: experiment config logdir: path to logdir configs_path: path(s) to config """ configs_path = configs_path or [] configs_path = [ Path(path) for path in configs_path if isinstance(path, str) ] config_dir = Path(logdir) / "configs" config_dir.mkdir(exist_ok=True, parents=True) if IS_HYDRA_AVAILABLE and isinstance(experiment_config, DictConfig): with open(config_dir / "config.yaml", "w") as f: f.write(OmegaConf.to_yaml(experiment_config, resolve=True)) experiment_config = OmegaConf.to_container(experiment_config, resolve=True) environment = get_environment_vars() save_config(experiment_config, config_dir / "_config.json") save_config(environment, config_dir / "_environment.json") pip_pkg = list_pip_packages() (config_dir / "pip-packages.txt").write_text(pip_pkg) conda_pkg = list_conda_packages() if conda_pkg: (config_dir / "conda-packages.txt").write_text(conda_pkg) for path in configs_path: name: str = path.name outpath = config_dir / name shutil.copyfile(path, outpath) config_str = json.dumps(experiment_config, indent=2, ensure_ascii=False) config_str = config_str.replace("\n", "\n\n") environment_str = json.dumps(environment, indent=2, ensure_ascii=False) environment_str = environment_str.replace("\n", "\n\n") pip_pkg = pip_pkg.replace("\n", "\n\n") conda_pkg = conda_pkg.replace("\n", "\n\n") with SummaryWriter(config_dir) as writer: writer.add_text("_config", config_str, 0) writer.add_text("_environment", environment_str, 0) writer.add_text("pip-packages", pip_pkg, 0) if conda_pkg: writer.add_text("conda-packages", conda_pkg, 0)
def log_hparams( self, hparams: Dict, scope: str = None, # experiment info run_key: str = None, stage_key: str = None, ) -> None: """@TODO: docs.""" if scope == "experiment": save_config(config=hparams, path=os.path.join(self.logdir, "hparams.yml"))
def on_epoch_end(self, runner: "IRunner") -> None: """ Collects and saves checkpoint after epoch. Args: runner: current runner """ if runner.is_infer_stage: return if runner.engine.is_ddp and not runner.engine.is_master_process: return if self._use_model_selection: # score model based on the specified metric score = runner.epoch_metrics[self.loader_key][self.metric_key] else: # score model based on epoch number score = runner.global_epoch_step is_best = False if self.best_score is None or self.is_better(score, self.best_score): self.best_score = score is_best = True if self.save_n_best > 0: # pack checkpoint checkpoint = self._pack_checkpoint(runner) # save checkpoint checkpoint_path = self._save_checkpoint( runner=runner, checkpoint=checkpoint, is_best=is_best, is_last=True, ) # add metrics to records metrics_record = ( float(score), checkpoint_path, runner.stage_key, runner.stage_epoch_step, dict(runner.epoch_metrics), ) self.top_best_metrics.append(metrics_record) # truncate checkpoints self._truncate_checkpoints() # save checkpoint metrics metrics_log = self._prepare_metrics_log(float(score), dict(runner.epoch_metrics)) save_config(metrics_log, f"{self.logdir}/{self.metrics_filename}")
def main(args, unknown_args): args, config = parse_args_uargs(args, unknown_args, dump_config=True) os.makedirs(args.logdir, exist_ok=True) save_config(config=config, logdir=args.logdir) if args.expdir is not None: modules = prepare_modules( # noqa: F841 expdir=args.expdir, dump_dir=args.logdir) algorithm = Registry.get_fn("algorithm", args.algorithm) if args.environment is not None: # @TODO: remove this hack # come on, just refactor whole rl environment_fn = Registry.get_fn("environment", args.environment) env = environment_fn(**config["env"]) config["shared"]["observation_size"] = env.observation_shape[0] config["shared"]["action_size"] = env.action_shape[0] del env algorithm_kwargs = algorithm.prepare_for_trainer(config) redis_server = StrictRedis(port=config.get("redis", {}).get("port", 12000)) redis_prefix = config.get("redis", {}).get("prefix", "") pprint(config["trainer"]) pprint(algorithm_kwargs) trainer = Trainer(**config["trainer"], **algorithm_kwargs, logdir=args.logdir, redis_server=redis_server, redis_prefix=redis_prefix) pprint(trainer) def on_exit(): for p in trainer.get_processes(): p.terminate() atexit.register(on_exit) trainer.run()
def main(args, unknown_args): args, config = parse_args_uargs(args, unknown_args, dump_config=True) set_global_seeds(args.seed) assert args.baselogdir is not None or args.logdir is not None if args.logdir is None: modules_ = prepare_modules(expdir=args.expdir) logdir = modules_["model"].prepare_logdir(config=config) args.logdir = str(pathlib.Path(args.baselogdir).joinpath(logdir)) os.makedirs(args.logdir, exist_ok=True) save_config(config=config, logdir=args.logdir) modules = prepare_modules(expdir=args.expdir, dump_dir=args.logdir) model = Registry.get_model(**config["model_params"]) datasource = modules["data"].DataSource() runner = modules["model"].ModelRunner(model=model) runner.train_stages(datasource=datasource, args=args, stages_config=config["stages"], verbose=args.verbose)
def on_stage_end(self, runner: "IRunner") -> None: """ Show information about best checkpoints during the stage and load model specified in ``load_on_stage_end``. Args: runner: current runner """ if runner.is_infer_stage: return if runner.engine.is_ddp and not runner.engine.is_master_process: # worker sync dist.barrier() return # let's log Top-N base metrics log_message = "Top best models:\n" # store latest state if self.save_n_best == 0: score = runner.epoch_metrics[self.loader_key][self.metric_key] # pack checkpoint checkpoint = self._pack_checkpoint(runner) # save checkpoint checkpoint_path = self._save_checkpoint( runner=runner, checkpoint=checkpoint, is_best=True, # will duplicate current (last) as best is_last=False, # don't need that because current state is last ) # add metrics to records # save checkpoint metrics metrics_log = self._prepare_metrics_log(float(score), dict(runner.epoch_metrics)) save_config(metrics_log, f"{self.logdir}/{self.metrics_filename}") log_message += f"{checkpoint_path}\t{score:3.4f}" else: log_message += "\n".join([ f"{filepath}\t{score:3.4f}" for score, filepath, _, _, _ in self.top_best_metrics ]) print(log_message) # let's load runner state (model, criterion, optimizer, scheduler) if required not_required_load_states = {"last", "last_full"} if (isinstance(self.load_on_stage_end, str) and self.load_on_stage_end not in not_required_load_states and self.save_n_best > 0): need_load_full = (self.load_on_stage_end.endswith("full") if isinstance(self.load_on_stage_end, str) else False) _load_runner( logdir=self.logdir, runner=runner, mapping=self.load_on_stage_end, load_full=need_load_full, ) elif isinstance(self.load_on_stage_end, dict) and self.save_n_best > 0: to_load = { k: v for k, v in self.load_on_stage_end.items() if v not in not_required_load_states } _load_runner(logdir=self.logdir, runner=runner, mapping=to_load) if runner.engine.is_ddp and runner.engine.is_master_process: # master sync dist.barrier()
def main(args, unknown_args): args, config = parse_args_uargs(args, unknown_args) os.makedirs(args.logdir, exist_ok=True) save_config(config=config, logdir=args.logdir) if args.expdir is not None: modules = prepare_modules( # noqa: F841 expdir=args.expdir, dump_dir=args.logdir) algorithm = Registry.get_fn("algorithm", args.algorithm) environment = Registry.get_fn("environment", args.environment) processes = [] sampler_id = 0 def on_exit(): for p in processes: p.terminate() atexit.register(on_exit) params = dict(logdir=args.logdir, algorithm=algorithm, environment=environment, config=config, resume=args.resume, redis=args.redis) if args.debug: params_ = dict( vis=False, infer=False, action_noise=0.5, param_noise=0.5, action_noise_prob=args.action_noise_prob, param_noise_prob=args.param_noise_prob, id=sampler_id, ) run_sampler(**params, **params_) for i in range(args.vis): params_ = dict( vis=False, infer=False, action_noise_prob=0, param_noise_prob=0, id=sampler_id, ) p = mp.Process(target=run_sampler, kwargs=dict(**params, **params_)) p.start() processes.append(p) sampler_id += 1 for i in range(args.infer): params_ = dict( vis=False, infer=True, action_noise_prob=0, param_noise_prob=0, id=sampler_id, ) p = mp.Process(target=run_sampler, kwargs=dict(**params, **params_)) p.start() processes.append(p) sampler_id += 1 for i in range(1, args.train + 1): action_noise = args.max_action_noise * i / args.train \ if args.max_action_noise is not None \ else None param_noise = args.max_param_noise * i / args.train \ if args.max_param_noise is not None \ else None params_ = dict( vis=False, infer=False, action_noise=action_noise, param_noise=param_noise, action_noise_prob=args.action_noise_prob, param_noise_prob=args.param_noise_prob, id=sampler_id, ) p = mp.Process(target=run_sampler, kwargs=dict(**params, **params_)) p.start() processes.append(p) sampler_id += 1 for p in processes: p.join()
def log_hparams(self, hparams: Dict, runner: "IRunner" = None) -> None: """Logs hyperparameters to the logger.""" save_config(config=hparams, path=os.path.join(self.logdir, "hparams.json"))
def _save_metric(self, logdir: Union[str, Path], metrics: Dict) -> None: save_config(metrics, f"{logdir}/checkpoints/{self.metrics_filename}")