def main(_): if has_hvd: hvd.init() all_learner_specs = [] for item in FLAGS.learner_spec: all_learner_specs += item.split(',') this_learner_ind = 0 if not has_hvd else hvd.rank() local_learner_spec = all_learner_specs[this_learner_ind] gpu_id_ports = local_learner_spec.split(':') # gpu_id:port1:port2 gpu_id, learner_ports = int(gpu_id_ports[0]), gpu_id_ports[1:] env_config = read_config_dict(FLAGS.env_config) interface_config = read_config_dict(FLAGS.interface_config) ob_space, ac_space = env_space(FLAGS.env, env_config, interface_config) if FLAGS.post_process_data is not None: post_process_data = import_module_or_data(FLAGS.post_process_data) ob_space, ac_space = post_process_data(ob_space, ac_space) policy = import_module_or_data(FLAGS.policy) policy_config = read_config_dict(FLAGS.policy_config) learner_config = read_config_dict(FLAGS.learner_config) if FLAGS.type == 'PPO': Learner = PPOLearner elif FLAGS.type == 'PPO2': Learner = PPO2Learner elif FLAGS.type == 'VTrace': Learner = VtraceLearner elif FLAGS.type == 'DDPG': Learner = DDPGLearner else: raise KeyError(f'Not recognized learner type {FLAGS.type}!') learner = Learner(league_mgr_addr=FLAGS.league_mgr_addr, model_pool_addrs=FLAGS.model_pool_addrs.split(','), gpu_id=gpu_id, learner_ports=learner_ports, unroll_length=FLAGS.unroll_length, rm_size=FLAGS.rm_size, batch_size=FLAGS.batch_size, ob_space=ob_space, ac_space=ac_space, pub_interval=FLAGS.pub_interval, log_interval=FLAGS.log_interval, save_interval=FLAGS.save_interval, total_timesteps=FLAGS.total_timesteps, burn_in_timesteps=FLAGS.burn_in_timesteps, policy=policy, policy_config=policy_config, rwd_shape=FLAGS.rwd_shape, learner_id=FLAGS.learner_id, batch_worker_num=FLAGS.batch_worker_num, pull_worker_num=FLAGS.pull_worker_num, rollout_length=FLAGS.rollout_length, data_server_version=FLAGS.data_server_version, decode=FLAGS.decode, log_infos_interval=FLAGS.log_infos_interval, **learner_config) learner.run()
def main(_): if FLAGS.replay_dir: os.makedirs(FLAGS.replay_dir, exist_ok=True) env_config = read_config_dict(FLAGS.env_config) interface_config = read_config_dict(FLAGS.interface_config) env = create_env(FLAGS.env, env_config=env_config, inter_config=interface_config) policy = import_module_or_data(FLAGS.policy) policy_config = read_config_dict(FLAGS.policy_config) distill_policy_config = read_config_dict(FLAGS.distill_policy_config) post_process_data = None if FLAGS.post_process_data is not None: post_process_data = import_module_or_data(FLAGS.post_process_data) if FLAGS.type == 'PPO': Actor = PPOActor elif FLAGS.type == 'PPO2': Actor = PPO2Actor elif FLAGS.type == 'VTrace': Actor = VtraceActor elif FLAGS.type == 'DDPG': Actor = DDPGActor else: raise KeyError(f'Not recognized learner type {FLAGS.type}!') actor = Actor(env, policy, policy_config=policy_config, league_mgr_addr=FLAGS.league_mgr_addr or None, model_pool_addrs=FLAGS.model_pool_addrs.split(','), learner_addr=FLAGS.learner_addr, unroll_length=FLAGS.unroll_length, update_model_freq=FLAGS.update_model_freq, n_v=FLAGS.n_v, verbose=FLAGS.verbose, log_interval_steps=FLAGS.log_interval_steps, rwd_shape=FLAGS.rwd_shape, distillation=FLAGS.distillation, distill_policy_config=distill_policy_config, replay_dir=FLAGS.replay_dir, compress=FLAGS.compress, self_infserver_addr=FLAGS.self_infserver_addr or None, distill_infserver_addr=FLAGS.distill_infserver_addr or None, post_process_data=post_process_data) n_failures = 0 while True: try: actor.run() except Exception as e: if not FLAGS.reboot_on_failure: raise e print("Actor crushed no. {}, the exception:\n{}".format( n_failures, e)) n_failures += 1 print("Rebooting...") kill_sc2_processes_v2()
def main(_): inter_config = read_config_dict(FLAGS.interface_config) env_config = read_config_dict(FLAGS.env_config) env_config['replay_dir'] = FLAGS.replay_dir env = create_env(FLAGS.env_id, env_config=env_config, inter_config=inter_config) # policy_config = { # 'use_xla': False, # 'rollout_len': 1, # 'test': True, # 'rl': False, # 'use_loss_type': 'none', # 'use_value_head': False, # 'use_self_fed_heads': True, # 'use_lstm': True, # 'nlstm': 64, # 'hs_len': 128, # 'lstm_duration': 1, # 'lstm_dropout_rate': 0.0, # 'lstm_cell_type': 'lstm', # 'lstm_layer_norm': True, # 'weight_decay': 0.00000002, # 'n_v': 11, # 'merge_pi': False, # } policy = import_module_or_data(FLAGS.policy) policy_config = read_config_dict(FLAGS.policy_config) n_v = policy_config['n_v'] if 'n_v' in policy_config else 1 model_path = FLAGS.model model = joblib.load(model_path) obs = env.reset() print(env.observation_space) agent = PGAgent(policy, env.observation_space.spaces[0], env.action_space.spaces[0], n_v, policy_config=policy_config, scope_name='model') agent.load_model(model.model) for _ in range(FLAGS.episodes): agent.reset(obs[0]) sum_rwd = 0 while True: if FLAGS.render: env.render() time.sleep(0.1) act = [agent.step(obs[0]), [0, 0]] obs, rwd, done, info = env.step(act) sum_rwd += np.array(rwd) if done: if FLAGS.render: env.render() time.sleep(1) print(f'reward sum: {sum_rwd}, info: {info}') obs = env.reset() break print('--------------------------------') env.close()
def main(_): players = [ sc2_env.Agent(sc2_env.Race.zerg), sc2_env.Bot(sc2_env.Race.zerg, FLAGS.difficulty) ] env = SC2BaseEnv( players=players, agent_interface='feature', map_name='KairosJunction', max_steps_per_episode=48000, screen_resolution=168, screen_ratio=0.905, step_mul=1, version=FLAGS.version, replay_dir=FLAGS.replay_dir, save_replay_episodes=1, use_pysc2_feature=False, minimap_resolution=(152, 168), ) interface_cls = import_module_or_data(FLAGS.interface) interface_config = read_config_dict(FLAGS.interface_config) interface = interface_cls(**interface_config) env = EnvIntWrapper(env, [interface]) obs = env.reset() print(env.observation_space.spaces) policy = import_module_or_data(FLAGS.policy) policy_config = read_config_dict(FLAGS.policy_config) agent = PGAgent2(policy, env.observation_space.spaces[0], env.action_space.spaces[0], policy_config=policy_config) model_path = FLAGS.model model = joblib.load(model_path) agent.load_model(model.model) agent.reset(obs[0]) episodes = FLAGS.episodes iter = 0 sum_rwd = [] while True: while True: if obs[0] is not None: act = [agent.step(obs[0])] else: act = [[]] obs, rwd, done, info = env.step(act) if done: print(rwd) sum_rwd.append(rwd[0]) break iter += 1 if iter >= episodes: print(sum_rwd) break obs = env.reset()
def main(_): if has_hvd: hvd.init() all_learner_specs = [] for item in FLAGS.learner_spec: all_learner_specs += item.split(',') this_learner_ind = 0 if not has_hvd else hvd.rank() local_learner_spec = all_learner_specs[this_learner_ind] gpu_id_ports = local_learner_spec.split(':') # gpu_id:port1:port2 gpu_id, learner_ports = int(gpu_id_ports[0]), gpu_id_ports[1:] replay_converter_type = import_module_or_data(FLAGS.replay_converter) converter_config = read_config_dict(FLAGS.converter_config) policy = import_module_or_data(FLAGS.policy) policy_config = read_config_dict(FLAGS.policy_config) post_process_data = None if FLAGS.post_process_data is not None: post_process_data = import_module_or_data(FLAGS.post_process_data) learner = ImitationLearner3( ports=learner_ports, gpu_id=gpu_id, policy=policy.net_build_fun, policy_config=policy_config, policy_config_type=policy.net_config_cls, replay_filelist=FLAGS.replay_filelist, batch_size=FLAGS.batch_size, min_train_sample_num=FLAGS.min_train_sample_num, min_val_sample_num=FLAGS.min_val_sample_num, rm_size=FLAGS.rm_size, learning_rate=FLAGS.learning_rate, print_interval=FLAGS.print_interval, replay_converter_type=replay_converter_type, converter_config=converter_config, checkpoint_interval=FLAGS.checkpoint_interval, num_val_batches=FLAGS.num_val_batches, checkpoints_dir=FLAGS.checkpoints_dir, restore_checkpoint_path=FLAGS.restore_checkpoint_path, train_generator_worker_num=FLAGS.train_generator_worker_num, val_generator_worker_num=FLAGS.val_generator_worker_num, repeat_training_task=FLAGS.repeat_training_task, unroll_length=FLAGS.unroll_length, rollout_length=FLAGS.rollout_length, model_pool_addrs=FLAGS.model_pool_addrs.split(','), pub_interval=FLAGS.pub_interval, after_loading_init_scope=FLAGS.after_loading_init_scope, max_clip_grad_norm=FLAGS.max_clip_grad_norm, use_mixed_precision=FLAGS.use_mixed_precision, use_sparse_as_dense=FLAGS.use_sparse_as_dense, enable_validation=FLAGS.enable_validation, post_process_data=post_process_data) learner.run()
def main(_): if not has_hvd: print('horovod is unavailable, FLAGS.hvd_run will be ignored.') index = 0 if has_hvd and FLAGS.hvd_run: hvd.init() index = hvd.local_rank() print('Horovod initialized.') port = FLAGS.port - index print('index: {}, using port: {}'.format(index, port), flush=True) policy = import_module_or_data(FLAGS.policy) policy_config = read_config_dict( FLAGS.policy_config[index % len(FLAGS.policy_config)] ) infserver_config = read_config_dict( FLAGS.infserver_config[index % len(FLAGS.infserver_config)] ) post_process_data = FLAGS.post_process_data[ index % len(FLAGS.post_process_data)] or None if FLAGS.is_rl: env_config = read_config_dict(FLAGS.env_config) interface_config = read_config_dict(FLAGS.interface_config) ob_space, ac_space = env_space(FLAGS.env, env_config, interface_config) else: replay_converter_type = import_module_or_data(FLAGS.replay_converter) converter_config = read_config_dict(FLAGS.converter_config) replay_converter = replay_converter_type(**converter_config) ob_space, ac_space = replay_converter.space.spaces if 'model_key' not in infserver_config: infserver_config['model_key'] = 'IL-model' if post_process_data is not None: post_process_data = import_module_or_data(post_process_data) ob_space, ac_space = post_process_data(ob_space, ac_space) nc = policy.net_config_cls(ob_space, ac_space, **policy_config) ds = InfData(ob_space, ac_space, nc.use_self_fed_heads, nc.use_lstm, nc.hs_len) server = InfServer(league_mgr_addr=FLAGS.league_mgr_addr or None, model_pool_addrs=FLAGS.model_pool_addrs.split(','), port=port, ds=ds, batch_size=policy_config['batch_size'], ob_space=ob_space, ac_space=ac_space, policy=policy, policy_config=policy_config, gpu_id=FLAGS.gpu_id, compress=FLAGS.compress, batch_worker_num=FLAGS.batch_worker_num, learner_id=FLAGS.learner_id, **infserver_config) server.run()
def main(_): converter_module, converter_name = FLAGS.replay_converter.rsplit(".", 1) replay_converter_type = getattr(importlib.import_module(converter_module), converter_name) converter_config = read_config_dict(FLAGS.converter_config) policy = None policy_config = None model_pool_addrs = None agent = None post_process_data = None if FLAGS.post_process_data is not None: post_process_data = import_module_or_data(FLAGS.post_process_data) if FLAGS.policy: policy = import_module_or_data(FLAGS.policy) policy_config = read_config_dict(FLAGS.policy_config) assert FLAGS.model_pool_addrs is not None model_pool_addrs = FLAGS.model_pool_addrs.split(',') agent = import_module_or_data(FLAGS.agent) actor = ReplayActor(learner_addr=FLAGS.learner_addr, replay_dir=FLAGS.replay_dir, replay_converter_type=replay_converter_type, log_interval=FLAGS.log_interval, step_mul=FLAGS.step_mul, n_v=FLAGS.n_v, game_version=FLAGS.game_version, unroll_length=FLAGS.unroll_length, policy=policy, policy_config=policy_config, model_pool_addrs=model_pool_addrs, update_model_freq=FLAGS.update_model_freq, converter_config=converter_config, SC2_bin_root=FLAGS.SC2_bin_root, agent_cls=agent, infserver_addr=FLAGS.infserver_addr or None, compress=FLAGS.compress, post_process_data=post_process_data, da_rate=FLAGS.data_augment_rate, unk_mmr_dft_to=FLAGS.unk_mmr_dft_to) n_failures = 0 while True: try: actor.run() except Exception as e: if not FLAGS.reboot_on_failure: raise e print("Actor crushed no. {}, the exception:\n{}".format( n_failures, e)) n_failures += 1 print("Rebooting...") kill_sc2_processes_v2()
def __init__(self, model_pool_apis, mutable_hyperparam_type, hyperparam_config_name): self._model_pool_apis = model_pool_apis self._mutable_hyperparam_cls = getattr(hyperparam_types, mutable_hyperparam_type) if hyperparam_config_name is not None: try: self._hyperparam_config = getattr(configs, hyperparam_config_name) except: self._hyperparam_config = read_config_dict(hyperparam_config_name) else: self._hyperparam_config = {} self.blackboard = HyperparamMgr.Blackboard()
def main(_): game_mgr_config = read_config_dict(FLAGS.game_mgr_config) if FLAGS.init_model_paths is not None: init_model_paths = eval(FLAGS.init_model_paths) else: init_model_paths = [] if FLAGS.pseudo_learner_num > 0: league_mgr = PARLeagueMgr( port=FLAGS.port, model_pool_addrs=FLAGS.model_pool_addrs.split(','), mutable_hyperparam_type=FLAGS.mutable_hyperparam_type, pseudo_learner_num=FLAGS.pseudo_learner_num, hyperparam_config_name=FLAGS.hyperparam_config_name or None, restore_checkpoint_dir=FLAGS.restore_checkpoint_dir or None, save_checkpoint_root=FLAGS.save_checkpoint_root or None, save_interval_secs=FLAGS.save_interval_secs, mute_actor_msg=FLAGS.mute_actor_msg, game_mgr_type=FLAGS.game_mgr_type, game_mgr_config=game_mgr_config, verbose=FLAGS.verbose, init_model_paths=init_model_paths, save_learner_meta=FLAGS.save_learner_meta, ) else: league_mgr = LeagueMgr( port=FLAGS.port, model_pool_addrs=FLAGS.model_pool_addrs.split(','), mutable_hyperparam_type=FLAGS.mutable_hyperparam_type, hyperparam_config_name=FLAGS.hyperparam_config_name or None, restore_checkpoint_dir=FLAGS.restore_checkpoint_dir or None, save_checkpoint_root=FLAGS.save_checkpoint_root or None, save_interval_secs=FLAGS.save_interval_secs, mute_actor_msg=FLAGS.mute_actor_msg, game_mgr_type=FLAGS.game_mgr_type, game_mgr_config=game_mgr_config, verbose=FLAGS.verbose, init_model_paths=init_model_paths, save_learner_meta=FLAGS.save_learner_meta, ) league_mgr.run()