def main(args, unknown_args): args, config = parse_args_uargs(args, unknown_args) set_global_seed(args.seed) Experiment, Runner = import_experiment_and_runner(Path(args.expdir)) experiment = Experiment(config) runner = Runner() if experiment.logdir is not None: dump_config(config, experiment.logdir, args.configs) dump_code(args.expdir, experiment.logdir) runner.run_experiment(experiment, check=args.check)
def main(args, unknown_args): args, config = parse_args_uargs(args, unknown_args) set_global_seed(args.seed) if args.logdir is not None: os.makedirs(args.logdir, exist_ok=True) dump_config(config, args.logdir, args.configs) if args.expdir is not None: module = import_module(expdir=args.expdir) # noqa: F841 env = ENVIRONMENTS.get_from_params(**config["environment"]) algorithm_name = config["algorithm"].pop("algorithm") if algorithm_name in OFFPOLICY_ALGORITHMS_NAMES: ALGORITHMS = OFFPOLICY_ALGORITHMS trainer_fn = OffpolicyTrainer sync_epoch = False weights_sync_mode = "critic" if env.discrete_actions else "actor" elif algorithm_name in ONPOLICY_ALGORITHMS_NAMES: ALGORITHMS = ONPOLICY_ALGORITHMS trainer_fn = OnpolicyTrainer sync_epoch = True weights_sync_mode = "actor" else: # @TODO: add registry for algorithms, trainers, samplers raise NotImplementedError() db_server = DATABASES.get_from_params( **config.get("db", {}), sync_epoch=sync_epoch ) algorithm_fn = ALGORITHMS.get(algorithm_name) algorithm = algorithm_fn.prepare_for_trainer(env_spec=env, config=config) if args.resume is not None: algorithm.load_checkpoint(filepath=args.resume) trainer = trainer_fn( algorithm=algorithm, env_spec=env, db_server=db_server, logdir=args.logdir, weights_sync_mode=weights_sync_mode, **config["trainer"], ) trainer.run()
def main(args, unknown_args): args, config = parse_args_uargs(args, unknown_args) set_global_seed(args.seed) if args.logdir is not None: os.makedirs(args.logdir, exist_ok=True) dump_config(args.configs, args.logdir) if args.expdir is not None: module = import_module(expdir=args.expdir) # noqa: F841 algorithm_name = config["algorithm"].pop("algorithm") if algorithm_name in OFFPOLICY_ALGORITHMS_NAMES: ALGORITHMS = OFFPOLICY_ALGORITHMS trainer_fn = OffpolicyTrainer sync_epoch = False else: ALGORITHMS = ONPOLICY_ALGORITHMS trainer_fn = OnpolicyTrainer sync_epoch = True db_server = DATABASES.get_from_params(**config.get("db", {}), sync_epoch=sync_epoch) env = ENVIRONMENTS.get_from_params(**config["environment"]) algorithm_fn = ALGORITHMS.get(algorithm_name) algorithm = algorithm_fn.prepare_for_trainer(env_spec=env, config=config) if args.resume is not None: algorithm.load_checkpoint(filepath=args.resume) trainer = trainer_fn( algorithm=algorithm, env_spec=env, db_server=db_server, **config["trainer"], logdir=args.logdir, ) def on_exit(): for p in trainer.get_processes(): p.terminate() atexit.register(on_exit) trainer.run()
def work(self): args, config = self.parse_args_uargs() set_global_seed(args.seed) Experiment, R = import_experiment_and_runner(Path(args.expdir)) experiment = Experiment(config) runner: Runner = R() register() self.experiment = experiment self.runner = runner stages = experiment.stages[:] if self.master: task = self.task if not self.task.parent \ else self.task_provider.by_id(self.task.parent) task.steps = len(stages) self.task_provider.commit() self._checkpoint_fix_config(experiment) _get_callbacks = experiment.get_callbacks def get_callbacks(stage): res = self.callbacks() for k, v in _get_callbacks(stage).items(): res[k] = v self._checkpoint_fix_callback(res) return res experiment.get_callbacks = get_callbacks if experiment.logdir is not None: dump_config(config, experiment.logdir, args.configs) if self.distr_info: info = yaml_load(self.task.additional_info) info['resume'] = { 'master_computer': self.distr_info['master_computer'], 'master_task_id': self.task.id - self.distr_info['rank'], 'load_best': True } self.task.additional_info = yaml_dump(info) self.task_provider.commit() experiment.stages_config = { k: v for k, v in experiment.stages_config.items() if k == experiment.stages[0] } runner.run_experiment(experiment, check=args.check) if self.master and self.trace: traced = trace_model_from_checkpoint(self.experiment.logdir, self) torch.jit.save(traced, self.trace) return {'stage': experiment.stages[-1], 'stages': stages}