def main(args, unknown_args): args, config = parse_args_uargs(args, unknown_args) set_global_seed(args.seed) prepare_cudnn(args.deterministic, args.benchmark) if args.logdir is not None: os.makedirs(args.logdir, exist_ok=True) dump_environment(config, args.logdir, args.configs) if args.expdir is not None: module = import_module(expdir=args.expdir) # noqa: F841 if args.logdir is not None: dump_code(args.expdir, args.logdir) env = ENVIRONMENTS.get_from_params(**config["environment"]) algorithm_name = config["algorithm"].pop("algorithm") if algorithm_name in OFFPOLICY_ALGORITHMS_NAMES: ALGORITHMS = OFFPOLICY_ALGORITHMS trainer_fn = OffpolicyTrainer sync_epoch = False elif algorithm_name in ONPOLICY_ALGORITHMS_NAMES: ALGORITHMS = ONPOLICY_ALGORITHMS trainer_fn = OnpolicyTrainer sync_epoch = True else: # @TODO: add registry for algorithms, trainers, samplers raise NotImplementedError() db_server = DATABASES.get_from_params( **config.get("db", {}), sync_epoch=sync_epoch ) algorithm_fn = ALGORITHMS.get(algorithm_name) algorithm = algorithm_fn.prepare_for_trainer(env_spec=env, config=config) if args.resume is not None: checkpoint = utils.load_checkpoint(filepath=args.resume) checkpoint = utils.any2device(checkpoint, utils.get_device()) algorithm.unpack_checkpoint( checkpoint=checkpoint, with_optimizer=False ) monitoring_params = config.get("monitoring_params", None) trainer = trainer_fn( algorithm=algorithm, env_spec=env, db_server=db_server, logdir=args.logdir, monitoring_params=monitoring_params, **config["trainer"], ) trainer.run()
def run_sampler( *, config, logdir, algorithm_fn, environment_fn, sampler_fn, vis, infer, seed=42, id=None, resume=None, db=True, exploration_power=1.0, sync_epoch=False ): config_ = copy.deepcopy(config) id = 0 if id is None else id set_global_seed(seed + id) db_server = DATABASES.get_from_params( **config.get("db", {}), sync_epoch=sync_epoch ) if db else None env = environment_fn(**config_["environment"], visualize=vis) agent = algorithm_fn.prepare_for_sampler(env_spec=env, config=config_) exploration_params = config_["sampler"].pop("exploration_params", None) exploration_handler = ExplorationHandler(env=env, *exploration_params) \ if exploration_params is not None \ else None if exploration_handler is not None: exploration_handler.set_power(exploration_power) mode = "infer" if infer else "train" valid_seeds = config_["sampler"].pop("valid_seeds") seeds = valid_seeds if infer else None sampler = sampler_fn( agent=agent, env=env, db_server=db_server, exploration_handler=exploration_handler, **config_["sampler"], logdir=logdir, id=id, mode=mode, seeds=seeds ) if resume is not None: sampler.load_checkpoint(filepath=resume) sampler.run()
def main(args, unknown_args): args, config = parse_args_uargs(args, unknown_args) set_global_seed(args.seed) if args.logdir is not None: os.makedirs(args.logdir, exist_ok=True) dump_config(config, args.logdir, args.configs) if args.expdir is not None: module = import_module(expdir=args.expdir) # noqa: F841 env = ENVIRONMENTS.get_from_params(**config["environment"]) algorithm_name = config["algorithm"].pop("algorithm") if algorithm_name in OFFPOLICY_ALGORITHMS_NAMES: ALGORITHMS = OFFPOLICY_ALGORITHMS trainer_fn = OffpolicyTrainer sync_epoch = False weights_sync_mode = "critic" if env.discrete_actions else "actor" elif algorithm_name in ONPOLICY_ALGORITHMS_NAMES: ALGORITHMS = ONPOLICY_ALGORITHMS trainer_fn = OnpolicyTrainer sync_epoch = True weights_sync_mode = "actor" else: # @TODO: add registry for algorithms, trainers, samplers raise NotImplementedError() db_server = DATABASES.get_from_params( **config.get("db", {}), sync_epoch=sync_epoch ) algorithm_fn = ALGORITHMS.get(algorithm_name) algorithm = algorithm_fn.prepare_for_trainer(env_spec=env, config=config) if args.resume is not None: algorithm.load_checkpoint(filepath=args.resume) trainer = trainer_fn( algorithm=algorithm, env_spec=env, db_server=db_server, logdir=args.logdir, weights_sync_mode=weights_sync_mode, **config["trainer"], ) trainer.run()
def main(args, unknown_args): args, config = parse_args_uargs(args, unknown_args) set_global_seed(args.seed) if args.logdir is not None: os.makedirs(args.logdir, exist_ok=True) dump_config(args.configs, args.logdir) if args.expdir is not None: module = import_module(expdir=args.expdir) # noqa: F841 algorithm_name = config["algorithm"].pop("algorithm") if algorithm_name in OFFPOLICY_ALGORITHMS_NAMES: ALGORITHMS = OFFPOLICY_ALGORITHMS trainer_fn = OffpolicyTrainer sync_epoch = False else: ALGORITHMS = ONPOLICY_ALGORITHMS trainer_fn = OnpolicyTrainer sync_epoch = True db_server = DATABASES.get_from_params(**config.get("db", {}), sync_epoch=sync_epoch) env = ENVIRONMENTS.get_from_params(**config["environment"]) algorithm_fn = ALGORITHMS.get(algorithm_name) algorithm = algorithm_fn.prepare_for_trainer(env_spec=env, config=config) if args.resume is not None: algorithm.load_checkpoint(filepath=args.resume) trainer = trainer_fn( algorithm=algorithm, env_spec=env, db_server=db_server, **config["trainer"], logdir=args.logdir, ) def on_exit(): for p in trainer.get_processes(): p.terminate() atexit.register(on_exit) trainer.run()
def run_sampler(*, config, logdir, algorithm_fn, environment_fn, visualize, mode, seed=42, id=None, resume=None, db=True, exploration_power=1.0, sync_epoch=False): config_ = copy.deepcopy(config) id = 0 if id is None else id seed = seed + id set_global_seed(seed) db_server = DATABASES.get_from_params( **config.get("db", {}), sync_epoch=sync_epoch) if db else None env = environment_fn( **config_["environment"], visualize=visualize, mode=mode, sampler_id=id, ) agent = algorithm_fn.prepare_for_sampler(env_spec=env, config=config_) exploration_params = config_["sampler"].pop("exploration_params", None) exploration_handler = ExplorationHandler(env=env, *exploration_params) \ if exploration_params is not None \ else None if exploration_handler is not None: exploration_handler.set_power(exploration_power) seeds = dict((k, config_["sampler"].pop(f"{k}_seeds", None)) for k in ["train", "valid", "infer"]) seeds = seeds[mode] if algorithm_fn in OFFPOLICY_ALGORITHMS.values(): weights_sync_mode = "critic" if env.discrete_actions else "actor" elif algorithm_fn in ONPOLICY_ALGORITHMS.values(): weights_sync_mode = "actor" else: # @TODO: add registry for algorithms, trainers, samplers raise NotImplementedError() if mode in ["valid"]: sampler_fn = ValidSampler else: sampler_fn = Sampler monitoring_params = config.get("monitoring_params", None) sampler = sampler_fn( agent=agent, env=env, db_server=db_server, exploration_handler=exploration_handler, logdir=logdir, id=id, mode=mode, weights_sync_mode=weights_sync_mode, seeds=seeds, monitoring_params=monitoring_params, **config_["sampler"], ) if resume is not None: sampler.load_checkpoint(filepath=resume) sampler.run()