示例#1
0
def main(_):
    if has_hvd:
        hvd.init()

    all_learner_specs = []
    for item in FLAGS.learner_spec:
        all_learner_specs += item.split(',')
    this_learner_ind = 0 if not has_hvd else hvd.rank()
    local_learner_spec = all_learner_specs[this_learner_ind]
    gpu_id_ports = local_learner_spec.split(':')  # gpu_id:port1:port2
    gpu_id, learner_ports = int(gpu_id_ports[0]), gpu_id_ports[1:]

    env_config = read_config_dict(FLAGS.env_config)
    interface_config = read_config_dict(FLAGS.interface_config)
    ob_space, ac_space = env_space(FLAGS.env, env_config, interface_config)
    if FLAGS.post_process_data is not None:
        post_process_data = import_module_or_data(FLAGS.post_process_data)
        ob_space, ac_space = post_process_data(ob_space, ac_space)
    policy = import_module_or_data(FLAGS.policy)
    policy_config = read_config_dict(FLAGS.policy_config)
    learner_config = read_config_dict(FLAGS.learner_config)
    if FLAGS.type == 'PPO':
        Learner = PPOLearner
    elif FLAGS.type == 'PPO2':
        Learner = PPO2Learner
    elif FLAGS.type == 'VTrace':
        Learner = VtraceLearner
    elif FLAGS.type == 'DDPG':
        Learner = DDPGLearner
    else:
        raise KeyError(f'Not recognized learner type {FLAGS.type}!')
    learner = Learner(league_mgr_addr=FLAGS.league_mgr_addr,
                      model_pool_addrs=FLAGS.model_pool_addrs.split(','),
                      gpu_id=gpu_id,
                      learner_ports=learner_ports,
                      unroll_length=FLAGS.unroll_length,
                      rm_size=FLAGS.rm_size,
                      batch_size=FLAGS.batch_size,
                      ob_space=ob_space,
                      ac_space=ac_space,
                      pub_interval=FLAGS.pub_interval,
                      log_interval=FLAGS.log_interval,
                      save_interval=FLAGS.save_interval,
                      total_timesteps=FLAGS.total_timesteps,
                      burn_in_timesteps=FLAGS.burn_in_timesteps,
                      policy=policy,
                      policy_config=policy_config,
                      rwd_shape=FLAGS.rwd_shape,
                      learner_id=FLAGS.learner_id,
                      batch_worker_num=FLAGS.batch_worker_num,
                      pull_worker_num=FLAGS.pull_worker_num,
                      rollout_length=FLAGS.rollout_length,
                      data_server_version=FLAGS.data_server_version,
                      decode=FLAGS.decode,
                      log_infos_interval=FLAGS.log_infos_interval,
                      **learner_config)
    learner.run()
示例#2
0
def main(_):
    if FLAGS.replay_dir:
        os.makedirs(FLAGS.replay_dir, exist_ok=True)

    env_config = read_config_dict(FLAGS.env_config)
    interface_config = read_config_dict(FLAGS.interface_config)
    env = create_env(FLAGS.env,
                     env_config=env_config,
                     inter_config=interface_config)
    policy = import_module_or_data(FLAGS.policy)
    policy_config = read_config_dict(FLAGS.policy_config)
    distill_policy_config = read_config_dict(FLAGS.distill_policy_config)
    post_process_data = None
    if FLAGS.post_process_data is not None:
        post_process_data = import_module_or_data(FLAGS.post_process_data)
    if FLAGS.type == 'PPO':
        Actor = PPOActor
    elif FLAGS.type == 'PPO2':
        Actor = PPO2Actor
    elif FLAGS.type == 'VTrace':
        Actor = VtraceActor
    elif FLAGS.type == 'DDPG':
        Actor = DDPGActor
    else:
        raise KeyError(f'Not recognized learner type {FLAGS.type}!')
    actor = Actor(env,
                  policy,
                  policy_config=policy_config,
                  league_mgr_addr=FLAGS.league_mgr_addr or None,
                  model_pool_addrs=FLAGS.model_pool_addrs.split(','),
                  learner_addr=FLAGS.learner_addr,
                  unroll_length=FLAGS.unroll_length,
                  update_model_freq=FLAGS.update_model_freq,
                  n_v=FLAGS.n_v,
                  verbose=FLAGS.verbose,
                  log_interval_steps=FLAGS.log_interval_steps,
                  rwd_shape=FLAGS.rwd_shape,
                  distillation=FLAGS.distillation,
                  distill_policy_config=distill_policy_config,
                  replay_dir=FLAGS.replay_dir,
                  compress=FLAGS.compress,
                  self_infserver_addr=FLAGS.self_infserver_addr or None,
                  distill_infserver_addr=FLAGS.distill_infserver_addr or None,
                  post_process_data=post_process_data)

    n_failures = 0
    while True:
        try:
            actor.run()
        except Exception as e:
            if not FLAGS.reboot_on_failure:
                raise e
            print("Actor crushed no. {}, the exception:\n{}".format(
                n_failures, e))
            n_failures += 1
            print("Rebooting...")
            kill_sc2_processes_v2()
示例#3
0
def main(_):
  inter_config = read_config_dict(FLAGS.interface_config)
  env_config = read_config_dict(FLAGS.env_config)
  env_config['replay_dir'] = FLAGS.replay_dir
  env = create_env(FLAGS.env_id, env_config=env_config, inter_config=inter_config)

  # policy_config = {
  #   'use_xla': False,
  #   'rollout_len': 1,
  #   'test': True,
  #   'rl': False,
  #   'use_loss_type': 'none',
  #   'use_value_head': False,
  #   'use_self_fed_heads': True,
  #   'use_lstm': True,
  #   'nlstm': 64,
  #   'hs_len': 128,
  #   'lstm_duration': 1,
  #   'lstm_dropout_rate': 0.0,
  #   'lstm_cell_type': 'lstm',
  #   'lstm_layer_norm': True,
  #   'weight_decay': 0.00000002,
  #   'n_v': 11,
  #   'merge_pi': False,
  # }
  policy = import_module_or_data(FLAGS.policy)
  policy_config = read_config_dict(FLAGS.policy_config)
  n_v = policy_config['n_v'] if 'n_v' in policy_config else 1
  model_path = FLAGS.model
  model = joblib.load(model_path)
  obs = env.reset()
  print(env.observation_space)
  agent = PGAgent(policy, env.observation_space.spaces[0],
                  env.action_space.spaces[0], n_v,
                  policy_config=policy_config, scope_name='model')
  agent.load_model(model.model)

  for _ in range(FLAGS.episodes):
    agent.reset(obs[0])
    sum_rwd = 0
    while True:
      if FLAGS.render:
        env.render()
        time.sleep(0.1)
      act = [agent.step(obs[0]), [0, 0]]
      obs, rwd, done, info = env.step(act)
      sum_rwd += np.array(rwd)
      if done:
        if FLAGS.render:
          env.render()
          time.sleep(1)
        print(f'reward sum: {sum_rwd}, info: {info}')
        obs = env.reset()
        break
  print('--------------------------------')
  env.close()
示例#4
0
def main(_):
    players = [
        sc2_env.Agent(sc2_env.Race.zerg),
        sc2_env.Bot(sc2_env.Race.zerg, FLAGS.difficulty)
    ]
    env = SC2BaseEnv(
        players=players,
        agent_interface='feature',
        map_name='KairosJunction',
        max_steps_per_episode=48000,
        screen_resolution=168,
        screen_ratio=0.905,
        step_mul=1,
        version=FLAGS.version,
        replay_dir=FLAGS.replay_dir,
        save_replay_episodes=1,
        use_pysc2_feature=False,
        minimap_resolution=(152, 168),
    )
    interface_cls = import_module_or_data(FLAGS.interface)
    interface_config = read_config_dict(FLAGS.interface_config)
    interface = interface_cls(**interface_config)
    env = EnvIntWrapper(env, [interface])
    obs = env.reset()
    print(env.observation_space.spaces)
    policy = import_module_or_data(FLAGS.policy)
    policy_config = read_config_dict(FLAGS.policy_config)
    agent = PGAgent2(policy,
                     env.observation_space.spaces[0],
                     env.action_space.spaces[0],
                     policy_config=policy_config)
    model_path = FLAGS.model
    model = joblib.load(model_path)
    agent.load_model(model.model)
    agent.reset(obs[0])

    episodes = FLAGS.episodes
    iter = 0
    sum_rwd = []
    while True:
        while True:
            if obs[0] is not None:
                act = [agent.step(obs[0])]
            else:
                act = [[]]
            obs, rwd, done, info = env.step(act)
            if done:
                print(rwd)
                sum_rwd.append(rwd[0])
                break
        iter += 1
        if iter >= episodes:
            print(sum_rwd)
            break
        obs = env.reset()
示例#5
0
def main(_):
    if has_hvd:
        hvd.init()

    all_learner_specs = []
    for item in FLAGS.learner_spec:
        all_learner_specs += item.split(',')
    this_learner_ind = 0 if not has_hvd else hvd.rank()
    local_learner_spec = all_learner_specs[this_learner_ind]
    gpu_id_ports = local_learner_spec.split(':')  # gpu_id:port1:port2
    gpu_id, learner_ports = int(gpu_id_ports[0]), gpu_id_ports[1:]

    replay_converter_type = import_module_or_data(FLAGS.replay_converter)
    converter_config = read_config_dict(FLAGS.converter_config)
    policy = import_module_or_data(FLAGS.policy)
    policy_config = read_config_dict(FLAGS.policy_config)

    post_process_data = None
    if FLAGS.post_process_data is not None:
        post_process_data = import_module_or_data(FLAGS.post_process_data)

    learner = ImitationLearner3(
        ports=learner_ports,
        gpu_id=gpu_id,
        policy=policy.net_build_fun,
        policy_config=policy_config,
        policy_config_type=policy.net_config_cls,
        replay_filelist=FLAGS.replay_filelist,
        batch_size=FLAGS.batch_size,
        min_train_sample_num=FLAGS.min_train_sample_num,
        min_val_sample_num=FLAGS.min_val_sample_num,
        rm_size=FLAGS.rm_size,
        learning_rate=FLAGS.learning_rate,
        print_interval=FLAGS.print_interval,
        replay_converter_type=replay_converter_type,
        converter_config=converter_config,
        checkpoint_interval=FLAGS.checkpoint_interval,
        num_val_batches=FLAGS.num_val_batches,
        checkpoints_dir=FLAGS.checkpoints_dir,
        restore_checkpoint_path=FLAGS.restore_checkpoint_path,
        train_generator_worker_num=FLAGS.train_generator_worker_num,
        val_generator_worker_num=FLAGS.val_generator_worker_num,
        repeat_training_task=FLAGS.repeat_training_task,
        unroll_length=FLAGS.unroll_length,
        rollout_length=FLAGS.rollout_length,
        model_pool_addrs=FLAGS.model_pool_addrs.split(','),
        pub_interval=FLAGS.pub_interval,
        after_loading_init_scope=FLAGS.after_loading_init_scope,
        max_clip_grad_norm=FLAGS.max_clip_grad_norm,
        use_mixed_precision=FLAGS.use_mixed_precision,
        use_sparse_as_dense=FLAGS.use_sparse_as_dense,
        enable_validation=FLAGS.enable_validation,
        post_process_data=post_process_data)
    learner.run()
def main(_):
  if not has_hvd:
    print('horovod is unavailable, FLAGS.hvd_run will be ignored.')

  index = 0
  if has_hvd and FLAGS.hvd_run:
    hvd.init()
    index = hvd.local_rank()
    print('Horovod initialized.')
  port = FLAGS.port - index
  print('index: {}, using port: {}'.format(index, port), flush=True)

  policy = import_module_or_data(FLAGS.policy)
  policy_config = read_config_dict(
    FLAGS.policy_config[index % len(FLAGS.policy_config)]
  )
  infserver_config = read_config_dict(
    FLAGS.infserver_config[index % len(FLAGS.infserver_config)]
  )
  post_process_data = FLAGS.post_process_data[
                        index % len(FLAGS.post_process_data)] or None
  if FLAGS.is_rl:
    env_config = read_config_dict(FLAGS.env_config)
    interface_config = read_config_dict(FLAGS.interface_config)
    ob_space, ac_space = env_space(FLAGS.env, env_config, interface_config)
  else:
    replay_converter_type = import_module_or_data(FLAGS.replay_converter)
    converter_config = read_config_dict(FLAGS.converter_config)
    replay_converter = replay_converter_type(**converter_config)
    ob_space, ac_space = replay_converter.space.spaces
    if 'model_key' not in infserver_config:
      infserver_config['model_key'] = 'IL-model'
  if post_process_data is not None:
    post_process_data = import_module_or_data(post_process_data)
    ob_space, ac_space = post_process_data(ob_space, ac_space)
  nc = policy.net_config_cls(ob_space, ac_space, **policy_config)
  ds = InfData(ob_space, ac_space, nc.use_self_fed_heads, nc.use_lstm,
               nc.hs_len)

  server = InfServer(league_mgr_addr=FLAGS.league_mgr_addr or None,
                     model_pool_addrs=FLAGS.model_pool_addrs.split(','),
                     port=port,
                     ds=ds,
                     batch_size=policy_config['batch_size'],
                     ob_space=ob_space, ac_space=ac_space,
                     policy=policy,
                     policy_config=policy_config,
                     gpu_id=FLAGS.gpu_id,
                     compress=FLAGS.compress,
                     batch_worker_num=FLAGS.batch_worker_num,
                     learner_id=FLAGS.learner_id,
                     **infserver_config)
  server.run()
示例#7
0
def main(_):
    converter_module, converter_name = FLAGS.replay_converter.rsplit(".", 1)
    replay_converter_type = getattr(importlib.import_module(converter_module),
                                    converter_name)
    converter_config = read_config_dict(FLAGS.converter_config)
    policy = None
    policy_config = None
    model_pool_addrs = None
    agent = None
    post_process_data = None
    if FLAGS.post_process_data is not None:
        post_process_data = import_module_or_data(FLAGS.post_process_data)
    if FLAGS.policy:
        policy = import_module_or_data(FLAGS.policy)
        policy_config = read_config_dict(FLAGS.policy_config)
        assert FLAGS.model_pool_addrs is not None
        model_pool_addrs = FLAGS.model_pool_addrs.split(',')
        agent = import_module_or_data(FLAGS.agent)
    actor = ReplayActor(learner_addr=FLAGS.learner_addr,
                        replay_dir=FLAGS.replay_dir,
                        replay_converter_type=replay_converter_type,
                        log_interval=FLAGS.log_interval,
                        step_mul=FLAGS.step_mul,
                        n_v=FLAGS.n_v,
                        game_version=FLAGS.game_version,
                        unroll_length=FLAGS.unroll_length,
                        policy=policy,
                        policy_config=policy_config,
                        model_pool_addrs=model_pool_addrs,
                        update_model_freq=FLAGS.update_model_freq,
                        converter_config=converter_config,
                        SC2_bin_root=FLAGS.SC2_bin_root,
                        agent_cls=agent,
                        infserver_addr=FLAGS.infserver_addr or None,
                        compress=FLAGS.compress,
                        post_process_data=post_process_data,
                        da_rate=FLAGS.data_augment_rate,
                        unk_mmr_dft_to=FLAGS.unk_mmr_dft_to)

    n_failures = 0
    while True:
        try:
            actor.run()
        except Exception as e:
            if not FLAGS.reboot_on_failure:
                raise e
            print("Actor crushed no. {}, the exception:\n{}".format(
                n_failures, e))
            n_failures += 1
            print("Rebooting...")
            kill_sc2_processes_v2()
示例#8
0
  def __init__(self, model_pool_apis, mutable_hyperparam_type,
               hyperparam_config_name):
    self._model_pool_apis = model_pool_apis
    self._mutable_hyperparam_cls = getattr(hyperparam_types,
                                           mutable_hyperparam_type)
    if hyperparam_config_name is not None:
      try:
        self._hyperparam_config = getattr(configs, hyperparam_config_name)
      except:
        self._hyperparam_config = read_config_dict(hyperparam_config_name)
    else:
      self._hyperparam_config = {}

    self.blackboard = HyperparamMgr.Blackboard()
示例#9
0
def main(_):
  game_mgr_config = read_config_dict(FLAGS.game_mgr_config)
  if FLAGS.init_model_paths is not None:
    init_model_paths = eval(FLAGS.init_model_paths)
  else:
    init_model_paths = []
  if FLAGS.pseudo_learner_num > 0:
    league_mgr = PARLeagueMgr(
      port=FLAGS.port,
      model_pool_addrs=FLAGS.model_pool_addrs.split(','),
      mutable_hyperparam_type=FLAGS.mutable_hyperparam_type,
      pseudo_learner_num=FLAGS.pseudo_learner_num,
      hyperparam_config_name=FLAGS.hyperparam_config_name or None,
      restore_checkpoint_dir=FLAGS.restore_checkpoint_dir or None,
      save_checkpoint_root=FLAGS.save_checkpoint_root or None,
      save_interval_secs=FLAGS.save_interval_secs,
      mute_actor_msg=FLAGS.mute_actor_msg,
      game_mgr_type=FLAGS.game_mgr_type,
      game_mgr_config=game_mgr_config,
      verbose=FLAGS.verbose,
      init_model_paths=init_model_paths,
      save_learner_meta=FLAGS.save_learner_meta,
    )
  else:
    league_mgr = LeagueMgr(
      port=FLAGS.port,
      model_pool_addrs=FLAGS.model_pool_addrs.split(','),
      mutable_hyperparam_type=FLAGS.mutable_hyperparam_type,
      hyperparam_config_name=FLAGS.hyperparam_config_name or None,
      restore_checkpoint_dir=FLAGS.restore_checkpoint_dir or None,
      save_checkpoint_root=FLAGS.save_checkpoint_root or None,
      save_interval_secs=FLAGS.save_interval_secs,
      mute_actor_msg=FLAGS.mute_actor_msg,
      game_mgr_type=FLAGS.game_mgr_type,
      game_mgr_config=game_mgr_config,
      verbose=FLAGS.verbose,
      init_model_paths=init_model_paths,
      save_learner_meta=FLAGS.save_learner_meta,
    )
  league_mgr.run()