Esempio n. 1
0
def main(_):
    if FLAGS.replay_dir:
        os.makedirs(FLAGS.replay_dir, exist_ok=True)

    env_config = read_config_dict(FLAGS.env_config)
    interface_config = read_config_dict(FLAGS.interface_config)
    env = create_env(FLAGS.env,
                     env_config=env_config,
                     inter_config=interface_config)
    policy = import_module_or_data(FLAGS.policy)
    policy_config = read_config_dict(FLAGS.policy_config)
    distill_policy_config = read_config_dict(FLAGS.distill_policy_config)
    post_process_data = None
    if FLAGS.post_process_data is not None:
        post_process_data = import_module_or_data(FLAGS.post_process_data)
    if FLAGS.type == 'PPO':
        Actor = PPOActor
    elif FLAGS.type == 'PPO2':
        Actor = PPO2Actor
    elif FLAGS.type == 'VTrace':
        Actor = VtraceActor
    elif FLAGS.type == 'DDPG':
        Actor = DDPGActor
    else:
        raise KeyError(f'Not recognized learner type {FLAGS.type}!')
    actor = Actor(env,
                  policy,
                  policy_config=policy_config,
                  league_mgr_addr=FLAGS.league_mgr_addr or None,
                  model_pool_addrs=FLAGS.model_pool_addrs.split(','),
                  learner_addr=FLAGS.learner_addr,
                  unroll_length=FLAGS.unroll_length,
                  update_model_freq=FLAGS.update_model_freq,
                  n_v=FLAGS.n_v,
                  verbose=FLAGS.verbose,
                  log_interval_steps=FLAGS.log_interval_steps,
                  rwd_shape=FLAGS.rwd_shape,
                  distillation=FLAGS.distillation,
                  distill_policy_config=distill_policy_config,
                  replay_dir=FLAGS.replay_dir,
                  compress=FLAGS.compress,
                  self_infserver_addr=FLAGS.self_infserver_addr or None,
                  distill_infserver_addr=FLAGS.distill_infserver_addr or None,
                  post_process_data=post_process_data)

    n_failures = 0
    while True:
        try:
            actor.run()
        except Exception as e:
            if not FLAGS.reboot_on_failure:
                raise e
            print("Actor crushed no. {}, the exception:\n{}".format(
                n_failures, e))
            n_failures += 1
            print("Rebooting...")
            kill_sc2_processes_v2()
Esempio n. 2
0
def main(_):
  inter_config = read_config_dict(FLAGS.interface_config)
  env_config = read_config_dict(FLAGS.env_config)
  env_config['replay_dir'] = FLAGS.replay_dir
  env = create_env(FLAGS.env_id, env_config=env_config, inter_config=inter_config)

  # policy_config = {
  #   'use_xla': False,
  #   'rollout_len': 1,
  #   'test': True,
  #   'rl': False,
  #   'use_loss_type': 'none',
  #   'use_value_head': False,
  #   'use_self_fed_heads': True,
  #   'use_lstm': True,
  #   'nlstm': 64,
  #   'hs_len': 128,
  #   'lstm_duration': 1,
  #   'lstm_dropout_rate': 0.0,
  #   'lstm_cell_type': 'lstm',
  #   'lstm_layer_norm': True,
  #   'weight_decay': 0.00000002,
  #   'n_v': 11,
  #   'merge_pi': False,
  # }
  policy = import_module_or_data(FLAGS.policy)
  policy_config = read_config_dict(FLAGS.policy_config)
  n_v = policy_config['n_v'] if 'n_v' in policy_config else 1
  model_path = FLAGS.model
  model = joblib.load(model_path)
  obs = env.reset()
  print(env.observation_space)
  agent = PGAgent(policy, env.observation_space.spaces[0],
                  env.action_space.spaces[0], n_v,
                  policy_config=policy_config, scope_name='model')
  agent.load_model(model.model)

  for _ in range(FLAGS.episodes):
    agent.reset(obs[0])
    sum_rwd = 0
    while True:
      if FLAGS.render:
        env.render()
        time.sleep(0.1)
      act = [agent.step(obs[0]), [0, 0]]
      obs, rwd, done, info = env.step(act)
      sum_rwd += np.array(rwd)
      if done:
        if FLAGS.render:
          env.render()
          time.sleep(1)
        print(f'reward sum: {sum_rwd}, info: {info}')
        obs = env.reset()
        break
  print('--------------------------------')
  env.close()
Esempio n. 3
0
def main(_):
    env = create_env(FLAGS.env)
    policy_module, policy_name = FLAGS.policy.rsplit(".", 1)
    policy = getattr(importlib.import_module(policy_module), policy_name)
    actor = PPOActor(env,
                     policy,
                     league_mgr_addr=FLAGS.league_mgr_addr,
                     model_pool_addrs=FLAGS.model_pool_addrs.split(','),
                     learner_addr=FLAGS.learner_addr,
                     unroll_length=FLAGS.unroll_length,
                     update_model_freq=FLAGS.update_model_freq,
                     n_v=FLAGS.n_v,
                     verbose=FLAGS.verbose,
                     log_interval_steps=FLAGS.log_interval_steps,
                     rwd_shape=FLAGS.rwd_shape)
    actor.run()
Esempio n. 4
0
def main(_):
    env = create_env(FLAGS.env)
    obs = env.reset()
    print(env.observation_space.spaces)
    policy_module, policy_name = FLAGS.policy1.rsplit(".", 1)
    policy1 = getattr(importlib.import_module(policy_module), policy_name)
    if not FLAGS.policy2:
        policy2 = policy1
    else:
        policy_module, policy_name = FLAGS.policy2.rsplit(".", 1)
        policy2 = getattr(importlib.import_module(policy_module), policy_name)
    policies = [policy1, policy2]
    if FLAGS.policy_config:
        config_module, config_name = FLAGS.policy_config.rsplit(".", 1)
        policy_config = getattr(importlib.import_module(config_module),
                                config_name)
    else:
        policy_config = {}
    agents = [
        PGAgent(policy,
                ob_sp,
                ac_sp,
                n_v=FLAGS.n_v,
                scope_name=name,
                policy_config=policy_config) for policy, ob_sp, ac_sp, name in
        zip(policies, env.observation_space.spaces, env.action_space.spaces,
            ['p1', 'p2'])
    ]
    model_file1 = FLAGS.model1
    model_file2 = FLAGS.model2 or model_file1
    model_0 = joblib.load(model_file1)
    model_1 = joblib.load(model_file2)
    agents[0].load_model(model_0.model)
    agents[1].load_model(model_1.model)
    agents[0].reset(obs[0])
    agents[1].reset(obs[1])
    while True:
        if hasattr(env, 'render') and FLAGS.env not in [
                'sc2', 'sc2full_formal', 'sc2_unit_rwd_no_micro'
        ]:
            env.render()
        act = [agent.step(ob) for agent, ob in zip(agents, obs)]
        obs, rwd, done, info = env.step(act)
        if done:
            print(rwd)
            break
def main(_):
  env = create_env(FLAGS.env, difficulty = FLAGS.difficulty)
  obs = env.reset()
  print(env.observation_space.spaces)
  policy_module, policy_name = FLAGS.policy1.rsplit(".", 1)
  policy1 = getattr(importlib.import_module(policy_module), policy_name)
  policies = [policy1]
  if FLAGS.policy_config:
    config_module, config_name = FLAGS.policy_config.rsplit(".", 1)
    policy_config = getattr(importlib.import_module(config_module), config_name)
  else:
    policy_config = {}
  agents = [PGAgent(policy, ob_sp, ac_sp, n_v=FLAGS.n_v, scope_name=name, policy_config=policy_config)
            for policy, ob_sp, ac_sp, name in
            zip(policies,
                env.observation_space.spaces,
                env.action_space.spaces,
                ['p1'])]
  model_file1 = FLAGS.model1
  model_0 = joblib.load(model_file1)
  agents[0].load_model(model_0.model)
  agents[0].reset(obs[0])

  episodes = FLAGS.episodes
  iter = 0
  sum_rwd = []
  while iter < episodes:
    while True:
      if hasattr(env, 'render') and FLAGS.env not in ['sc2', 'sc2full_formal', 'sc2vsbot_unit_rwd_no_micro']:
        env.render()
      act = [agent.step(ob) for agent, ob in zip(agents, obs)]
      obs, rwd, done, info = env.step(act)
      if done:
        print(rwd)
        sum_rwd.append(rwd[0,0])
        obs = env.reset()
        break
    iter += 1
  print(sum_rwd)