示例#1
0
  def test_collect_mineral_shards(self):
    with sc2_env.SC2Env(
        "CollectMineralShards",
        step_mul=self.step_mul,
        game_steps_per_episode=self.steps * self.step_mul) as env:
      agent = scripted_agent.CollectMineralShards()
      run_loop.run_loop([agent], env, self.steps)

    # Get some points
    self.assertLessEqual(agent.episodes, agent.reward)
    self.assertEqual(agent.steps, self.steps)
示例#2
0
  def test_defeat_roaches(self):
    with sc2_env.SC2Env(
        "DefeatRoaches",
        step_mul=self.step_mul,
        game_steps_per_episode=self.steps * self.step_mul) as env:
      agent = scripted_agent.DefeatRoaches()
      run_loop.run_loop([agent], env, self.steps)

    # Get some points
    self.assertLessEqual(agent.episodes, agent.reward)
    self.assertEqual(agent.steps, self.steps)
示例#3
0
  def test_random_agent(self):
    steps = 100
    step_mul = 50
    with sc2_env.SC2Env(
        "Simple64",
        step_mul=step_mul,
        game_steps_per_episode=steps * step_mul) as env:
      agent = random_agent.RandomAgent()
      run_loop.run_loop([agent], env, steps)

    self.assertEqual(agent.steps, steps)
示例#4
0
  def test_move_to_beacon(self):
    with sc2_env.SC2Env(
        "MoveToBeacon",
        step_mul=self.step_mul,
        game_steps_per_episode=self.steps * self.step_mul) as env:
      agent = scripted_agent.MoveToBeacon()
      run_loop.run_loop([agent], env, self.steps)

    # Get some points
    self.assertLessEqual(agent.episodes, agent.reward)
    self.assertEqual(agent.steps, self.steps)
示例#5
0
def main():
  FLAGS(sys.argv)
  with sc2_env.SC2Env(
      map_name="DefeatZerglingsAndBanelings",
      step_mul=step_mul,
      visualize=True,
      game_steps_per_episode=steps * step_mul) as env:

    demo_replay = []

    agent = noop_agent.NOOPAgent(env=env)
    agent.env = env
    run_loop.run_loop([agent], env, steps)
def main():
  with sc2_env.SC2Env(
      map_name="CollectMineralShards",
      step_mul=step_mul,
      visualize=True,
      save_replay_episodes=10,
      replay_dir='replay',
      game_steps_per_episode=steps * step_mul,
      screen_size_px=(32, 32),
      minimap_size_px=(32, 32)) as env:

    demo_replay = []

    agent = CollectMineralShards(env=env)
    run_loop.run_loop([agent], env, steps)
示例#7
0
文件: agent.py 项目: 2kg-jp/pysc2
def run_thread(agent_cls, map_name, visualize):
  with sc2_env.SC2Env(
      map_name,
      agent_race=FLAGS.agent_race,
      bot_race=FLAGS.bot_race,
      difficulty=FLAGS.difficulty,
      step_mul=FLAGS.step_mul,
      game_steps_per_episode=FLAGS.game_steps_per_episode,
      screen_size_px=(FLAGS.screen_resolution, FLAGS.screen_resolution),
      minimap_size_px=(FLAGS.minimap_resolution, FLAGS.minimap_resolution),
      visualize=visualize) as env:
    env = available_actions_printer.AvailableActionsPrinter(env)
    agent = agent_cls()
    run_loop.run_loop([agent], env, FLAGS.max_agent_steps)
    if FLAGS.save_replay:
      env.save_replay(agent_cls.__name__)
示例#8
0
  def test_defeat_zerglings(self):
    with sc2_env.SC2Env(
        map_name="DefeatZerglingsAndBanelings",
        step_mul=self.step_mul,
        visualize=True,
        game_steps_per_episode=self.steps * self.step_mul) as env:
      obs = env.step(actions=[sc2_actions.FunctionCall(_NO_OP, [])])
      player_relative = obs[0].observation["screen"][_PLAYER_RELATIVE]

      # Break Point!!
      print(player_relative)

      agent = random_agent.RandomAgent()
      run_loop.run_loop([agent], env, self.steps)

    self.assertEqual(agent.steps, self.steps)
示例#9
0
def run_thread(agent_classes, players, map_name, visualize):
    """Run one thread worth of the environment with agents."""
    with sc2_env.SC2Env(
            map_name=map_name,
            players=players,
            agent_interface_format=sc2_env.parse_agent_interface_format(
                feature_screen=FLAGS.feature_screen_size,
                feature_minimap=FLAGS.feature_minimap_size,
                rgb_screen=FLAGS.rgb_screen_size,
                rgb_minimap=FLAGS.rgb_minimap_size,
                action_space=FLAGS.action_space,
                use_feature_units=FLAGS.use_feature_units,
                use_raw_units=FLAGS.use_raw_units,
                camera_width_world_units=FLAGS.camera_width),
            step_mul=FLAGS.step_mul,
            game_steps_per_episode=FLAGS.game_steps_per_episode,
            disable_fog=FLAGS.disable_fog,
            visualize=visualize) as env:
        agents = [agent_cls(env) for agent_cls in agent_classes]
        env = available_actions_printer.AvailableActionsPrinter(env)
        run_loop.run_loop(agents, env, FLAGS.max_agent_steps,
                          FLAGS.max_episodes)
        if FLAGS.save_replay:
            env.save_replay(agent_classes[0].__name__)
示例#10
0
def main(unused_argv):
    agent1 = ZergAgent()
    #agent2 = Zerg_Gas_Agent.ZergGasAgent() # sc2_env.Bot(sc2_env.Race.random, sc2_env.Difficulty.very_easy)
    try:
        while True:  # sc2_env.Agent(sc2_env.Race.zerg)
            with sc2_env.SC2Env(
                    map_name="Catalyst",
                    players=[
                        sc2_env.Agent(sc2_env.Race.zerg),
                        sc2_env.Bot(sc2_env.Race.random,
                                    sc2_env.Difficulty.medium)
                    ],
                    agent_interface_format=features.AgentInterfaceFormat(
                        feature_dimensions=features.Dimensions(screen=84,
                                                               minimap=64),
                        use_feature_units=True),
                    step_mul=16,
                    game_steps_per_episode=0,
                    visualize=False) as env:

                run_loop([agent1], env)

    except KeyboardInterrupt:
        pass
    def agent():
        """Run the agent, connecting to a (remote) host started independently."""
        agent_module, agent_name = FLAGS.agent.rsplit(".", 1)
        agent_cls = getattr(importlib.import_module(agent_module), agent_name)

        with lan_sc2_env.LanSC2Env(
                host=FLAGS.host,
                config_port=FLAGS.config_port,
                race=sc2_env.Race[FLAGS.agent_race],
                step_mul=FLAGS.step_mul,
                realtime=FLAGS.realtime,
                agent_interface_format=sc2_env.parse_agent_interface_format(
                    feature_screen=FLAGS.feature_screen_size,
                    feature_minimap=FLAGS.feature_minimap_size,
                    rgb_screen=FLAGS.rgb_screen_size,
                    rgb_minimap=FLAGS.rgb_minimap_size,
                    action_space=FLAGS.action_space,
                    use_unit_counts=True,
                    use_camera_position=True,
                    show_cloaked=True,
                    show_burrowed_shadows=True,
                    show_placeholders=True,
                    send_observation_proto=True,
                    crop_to_playable_area=True,
                    raw_crop_to_playable_area=True,
                    allow_cheating_layers=True,
                    add_cargo_to_units=True,
                    use_feature_units=FLAGS.use_feature_units),
                visualize=FLAGS.render) as env:
            agents = [agent_cls()]
            logging.info("Connected, starting run_loop.")
            try:
                run_loop.run_loop(agents, env)
            except lan_sc2_env.RestartError:
                pass
        logging.info("Done.")
def main():
    FLAGS(sys.argv)

    steps = 0  #Test steps

    print("algorithm : %s" % FLAGS.algorithm)
    print("timesteps : %s" % FLAGS.timesteps)
    print("exploration_fraction : %s" % FLAGS.exploration_fraction)
    print("prioritized : %s" % FLAGS.prioritized)
    print("dueling : %s" % FLAGS.dueling)
    print("num_agents : %s" % FLAGS.num_agents)
    print("lr : %s" % FLAGS.lr)

    if FLAGS.lr == 0:
        FLAGS.lr = random.uniform(0.00001, 0.001)

    print("random lr : %s" % FLAGS.lr)

    lr_round = round(FLAGS.lr, 8)

    logdir = "tensorboard"

    if FLAGS.algorithm == "deepq-4way":
        logdir = "tensorboard/mineral/%s/%s_%s_prio%s_duel%s_lr%s/%s" % (
            FLAGS.algorithm, FLAGS.timesteps, FLAGS.exploration_fraction,
            FLAGS.prioritized, FLAGS.dueling, lr_round, start_time)
    elif FLAGS.algorithm == "deepq":
        logdir = "tensorboard/mineral/%s/%s_%s_prio%s_duel%s_lr%s/%s" % (
            FLAGS.algorithm, FLAGS.timesteps, FLAGS.exploration_fraction,
            FLAGS.prioritized, FLAGS.dueling, lr_round, start_time)
    elif FLAGS.algorithm == "a2c":
        logdir = "tensorboard/mineral/%s/%s_n%s_s%s_nsteps%s/lr%s/%s" % (
            FLAGS.algorithm, FLAGS.timesteps,
            FLAGS.num_agents + FLAGS.num_scripts, FLAGS.num_scripts,
            FLAGS.nsteps, lr_round, start_time)

    if FLAGS.log == "tensorboard":
        Logger.DEFAULT \
          = Logger.CURRENT \
          = Logger(dir=None,
                   output_formats=[TensorBoardOutputFormat(logdir)])

    elif FLAGS.log == "stdout":
        Logger.DEFAULT \
          = Logger.CURRENT \
          = Logger(dir=None,
                   output_formats=[HumanOutputFormat(sys.stdout)])

    if FLAGS.algorithm == "deepq":

        AGENT_INTERFACE_FORMAT = sc2_env.AgentInterfaceFormat(
            feature_dimensions=sc2_env.Dimensions(screen=16, minimap=16))
        # temp solution - sc2_env.Agent(sc2_env.Race.terran) might be too restricting
        # We need this change because sc2 now requires specifying players.
        with sc2_env.SC2Env(
                map_name="Simple64",
                players=[
                    sc2_env.Agent(race=sc2_env.Race.terran),
                    sc2_env.Agent(race=sc2_env.Race.terran)
                ],
                #players=[sc2_env.Agent(sc2_env.Race.terran),sc2_env.Agent(sc2_env.Race.terran)],
                step_mul=step_mul,
                visualize=True,
                agent_interface_format=AGENT_INTERFACE_FORMAT) as env:

            model = cnn_to_mlp(convs=[(16, 8, 4), (32, 4, 2)],
                               hiddens=[256],
                               dueling=True)

            acts = deepq_nexus_wars.learn(
                env,
                q_func=model,
                num_actions=16,
                lr=FLAGS.lr,
                max_timesteps=FLAGS.timesteps,
                buffer_size=10000,
                exploration_fraction=FLAGS.exploration_fraction,
                exploration_final_eps=0.01,
                train_freq=4,
                learning_starts=10000,
                target_network_update_freq=1000,
                gamma=0.99,
                prioritized_replay=True,
                callback=deepq_callback)

            agent = random_agent.RandomAgent()
            run_loop.run_loop([agent], env, steps)

            acts[0].save("mineral_shards_x.pkl")
            acts[1].save("mineral_shards_y.pkl")

    elif FLAGS.algorithm == "deepq-4way":

        AGENT_INTERFACE_FORMAT = sc2_env.AgentInterfaceFormat(
            feature_dimensions=sc2_env.Dimensions(screen=32, minimap=32))
        with sc2_env.SC2Env(map_name="Simple64",
                            players=[
                                sc2_env.Agent(race=sc2_env.Race.terran),
                                sc2_env.Agent(race=sc2_env.Race.terran)
                            ],
                            step_mul=step_mul,
                            agent_interface_format=AGENT_INTERFACE_FORMAT,
                            visualize=True) as env:

            model = cnn_to_mlp(convs=[(16, 8, 4), (32, 4, 2)],
                               hiddens=[256],
                               dueling=True)

            act = deepq_mineral_4way.learn(
                env,
                q_func=model,
                num_actions=4,
                lr=FLAGS.lr,
                max_timesteps=FLAGS.timesteps,
                buffer_size=10000,
                exploration_fraction=FLAGS.exploration_fraction,
                exploration_final_eps=0.01,
                train_freq=4,
                learning_starts=10000,
                target_network_update_freq=1000,
                gamma=0.99,
                prioritized_replay=True,
                callback=deepq_4way_callback)

            act.save("mineral_shards.pkl")

    elif FLAGS.algorithm == "a2c":

        num_timesteps = int(40e6)

        num_timesteps //= 4

        seed = 0

        env = SubprocVecEnv(FLAGS.num_agents + FLAGS.num_scripts,
                            FLAGS.num_scripts, FLAGS.map)

        policy_fn = CnnPolicy
        a2c.learn(policy_fn,
                  env,
                  seed,
                  total_timesteps=num_timesteps,
                  nprocs=FLAGS.num_agents + FLAGS.num_scripts,
                  nscripts=FLAGS.num_scripts,
                  ent_coef=0.5,
                  nsteps=FLAGS.nsteps,
                  max_grad_norm=0.01,
                  callback=a2c_callback)
示例#13
0
if __name__ == "__main__":
    FLAGS = flags.FLAGS
    FLAGS(sys.argv)

    agent = Agent()

    try:
        with sc2_env.SC2Env(
                map_name="MoveToBeacon",
                players=[sc2_env.Agent(sc2_env.Race.terran)],
                agent_interface_format=sc2_env.parse_agent_interface_format(
                    feature_screen=screen_size,
                    feature_minimap=screen_size,
                    action_space=None,
                    use_feature_units=False,
                    use_raw_units=False),
                step_mul=8,
                game_steps_per_episode=None,
                disable_fog=False,
                visualize=True) as env:
            run_loop.run_loop([agent], env, max_episodes=1)
    except KeyboardInterrupt:
        pass
# if __name__ == "__main__":
#     FLAGS = flags.FLAGS
#     FLAGS(sys.argv)
#
#     agent = Agent(1)
#     for i in range(10):
#         print(agent.test())
示例#14
0
from pysc2.env import sc2_env, run_loop
from DQN import DQNModel

import sys
from absl import flags

import Agent

FLAGS = flags.FLAGS
FLAGS(sys.argv)

env = sc2_env.SC2Env(map_name='Simple64',
                     players=[
                         sc2_env.Agent(race=sc2_env.Race.terran),
                         sc2_env.Bot(difficulty=sc2_env.Difficulty.very_easy,
                                     race=sc2_env.Race.random)
                     ],
                     step_mul=16,
                     agent_interface_format=sc2_env.AgentInterfaceFormat(
                         use_feature_units=True,
                         hide_specific_actions=False,
                         feature_dimensions=sc2_env.Dimensions(screen=64,
                                                               minimap=64)),
                     game_steps_per_episode=0,
                     visualize=True)
agent = Agent.Agent()
observation_spec = env.observation_spec()
action_spec = env.action_spec()
agent.setup(observation_spec[0], action_spec[0])
run_loop.run_loop([agent], env)