Beispiel #1
0
    def test_defeat_roaches(self):
        with sc2_env.SC2Env(
                map_name="DefeatRoaches",
                players=[sc2_env.Agent(sc2_env.Race.terran)],
                agent_interface_format=sc2_env.AgentInterfaceFormat(
                    feature_dimensions=sc2_env.Dimensions(screen=84,
                                                          minimap=64)),
                step_mul=self.step_mul,
                game_steps_per_episode=self.steps * self.step_mul) as env:
            agent = scripted_agent.DefeatRoaches()
            run_loop.run_loop([agent], env, self.steps)

        # Get some points
        self.assertLessEqual(agent.episodes, agent.reward)
        self.assertEqual(agent.steps, self.steps)
def make_sc2env(**kwargs):

    agent_format = sc2_env.AgentInterfaceFormat(
        feature_dimensions=sc2_env.Dimensions(
            screen=(64, 64),
            minimap=(64, 64),
        ))

    kwargs['agent_interface_format'] = [agent_format]
    kwargs.pop('screen_size_px')
    kwargs.pop('minimap_size_px')

    env = sc2_env.SC2Env(**kwargs)
    # env = available_actions_printer.AvailableActionsPrinter(env)
    return env
    def test_collect_mineral_shards_raw(self):
        with sc2_env.SC2Env(
                map_name="CollectMineralShards",
                players=[sc2_env.Agent(sc2_env.Race.terran)],
                agent_interface_format=sc2_env.AgentInterfaceFormat(
                    action_space=sc2_env.ActionSpace.RAW,  # or: use_raw_actions=True,
                    use_raw_units=True),
                step_mul=self.step_mul,
                game_steps_per_episode=self.steps * self.step_mul) as env:
            agent = scripted_agent.CollectMineralShardsRaw()
            run_loop.run_loop([agent], env, self.steps)

        # Get some points
        self.assertLessEqual(agent.episodes, agent.reward)
        self.assertEqual(agent.steps, self.steps)
def main():
    FLAGS(sys.argv)
    with sc2_env.SC2Env(
            map_name="DefeatZerglingsAndBanelings",
            step_mul=step_mul,
            visualize=True,
            players=[sc2_env.Agent(sc2_env.Race.terran)],
            agent_interface_format=sc2_env.AgentInterfaceFormat(
                feature_dimensions=sc2_env.Dimensions(screen=64, minimap=64)),
            game_steps_per_episode=steps * step_mul) as env:

        demo_replay = []

        agent = demo_agent.MarineAgent(env=env)
        agent.env = env
        run_loop.run_loop([agent], env, steps)
  def test_collect_mineral_shards_feature_units(self):
    with sc2_env.SC2Env(
        map_name="CollectMineralShards",
        agent_interface_format=sc2_env.AgentInterfaceFormat(
            feature_dimensions=sc2_env.Dimensions(
                screen=84,
                minimap=64),
            use_feature_units=True),
        step_mul=self.step_mul,
        game_steps_per_episode=self.steps * self.step_mul) as env:
      agent = scripted_agent.CollectMineralShardsFeatureUnits()
      run_loop.run_loop([agent], env, self.steps)

    # Get some points
    self.assertLessEqual(agent.episodes, agent.reward)
    self.assertEqual(agent.steps, self.steps)
Beispiel #6
0
def make_sc2env():
    env_args = {
        'agent_interface_format':
        sc2_env.AgentInterfaceFormat(
            feature_dimensions=sc2_env.Dimensions(screen=(MAP_SIZE, MAP_SIZE),
                                                  minimap=(MAP_SIZE,
                                                           MAP_SIZE)),
            rgb_dimensions=sc2_env.Dimensions(
                screen=(RGB_SCREEN_SIZE, RGB_SCREEN_SIZE),
                minimap=(RGB_SCREEN_SIZE, RGB_SCREEN_SIZE),
            ),
            action_space=actions.ActionSpace.FEATURES,
        ),
        'map_name':
        MAP_NAME,
        'step_mul':
        170,  # 17 is ~1 action per second
    }
    maps_dir = os.path.join(os.path.dirname(__file__), '..', 'maps')
    register_map(maps_dir, env_args['map_name'])
    return sc2_env.SC2Env(**env_args)
Beispiel #7
0
 def run(self):
     super(Worker, self).run()
     env = sc2_env.SC2Env(
         map_name='Simple64',
         players=[
             sc2_env.Agent(race=sc2_env.Race.terran),
             sc2_env.Bot(difficulty=sc2_env.Difficulty.very_easy,
                         race=sc2_env.Race.random)
         ],
         step_mul=16,
         agent_interface_format=sc2_env.AgentInterfaceFormat(
             use_feature_units=True,
             hide_specific_actions=False,
             feature_dimensions=sc2_env.Dimensions(screen=64, minimap=64)),
         game_steps_per_episode=0,
         visualize=True)
     agent = Agent.Agent(model=self.model)
     observation_spec = env.observation_spec()
     action_spec = env.action_spec()
     agent.setup(observation_spec[0], action_spec[0])
     run_loop.run_loop([agent], env)
Beispiel #8
0
    def test_defeat_zerglings(self):
        FLAGS(sys.argv)

        with sc2_env.SC2Env(
                map_name="DefeatZerglingsAndBanelings",
                step_mul=self.step_mul,
                visualize=True,
                players=[sc2_env.Agent(sc2_env.Race.terran)],
                agent_interface_format=sc2_env.AgentInterfaceFormat(
                    feature_dimensions=sc2_env.Dimensions(screen=64,
                                                          minimap=64)),
                game_steps_per_episode=self.steps * self.step_mul) as env:
            obs = env.step(actions=[sc2_actions.FunctionCall(_NO_OP, [])])
            player_relative = obs[0].observation["screen"][_PLAYER_RELATIVE]

            # Break Point!!
            print(player_relative)

            agent = random_agent.RandomAgent()
            run_loop.run_loop([agent], env, self.steps)

        self.assertEqual(agent.steps, self.steps)
Beispiel #9
0
def generate_env(map_name, viz=False):
    """This function returns a new sc2 environment, with 64x64 dimensions.
    Parameters
    ----------
    map_name (String): the name of the sc2 minigame to generate.
    viz (Boolean): activate game visualization.

    Returns
    -------
    env (pysc2.env.sc2_env.SC2Env): the new environment."""
    FLAGS = flags.FLAGS
    FLAGS(['run_sc2'])

    env = sc2_env.SC2Env(
        agent_race=None,
        bot_race=None,
        difficulty=None,
        map_name=map_name,
        visualize=viz,
        agent_interface_format=sc2_env.AgentInterfaceFormat(
            feature_dimensions=sc2_env.Dimensions(screen=64, minimap=64)))
    return env
Beispiel #10
0
def run_thread(agent, map_name, visualize):
  with sc2_env.SC2Env(
    map_name=map_name,
    agent_race=FLAGS.agent_race,
    bot_race=FLAGS.bot_race,
    difficulty=FLAGS.difficulty,
    step_mul=FLAGS.step_mul,
    agent_interface_format=sc2_env.AgentInterfaceFormat(
        feature_dimensions=sc2_env.Dimensions(
            screen=64,
            minimap=64)),
    visualize=visualize) as env:
    env = available_actions_printer.AvailableActionsPrinter(env)

    # Only for a single player!
    replay_buffer = []
    for recorder, is_done in run_loop([agent], env, MAX_AGENT_STEPS):
      if FLAGS.training:
        replay_buffer.append(recorder)
        if is_done:
          counter = 0
          with LOCK:
            global COUNTER
            COUNTER += 1
            counter = COUNTER
          # Learning rate schedule
          learning_rate = FLAGS.learning_rate * (1 - 0.9 * counter / FLAGS.max_steps)
          agent.update(replay_buffer, FLAGS.discount, learning_rate, counter)
          replay_buffer = []
          if counter % FLAGS.snapshot_step == 1:
            agent.save_model(SNAPSHOT, counter)
          if counter >= FLAGS.max_steps:
            break
      elif is_done:
        obs = recorder[-1].observation
        score = obs["score_cumulative"][0]
        print('Your score is '+str(score)+'!')
    if FLAGS.save_replay:
      env.save_replay(agent.name)
Beispiel #11
0
def main(args):
    with tf.Session() as sess:

        with sc2_env.SC2Env(
                map_name=args['map_name'],
                agent_interface_format=sc2_env.AgentInterfaceFormat(
                    feature_dimensions=sc2_env.Dimensions(
                        screen=args['screen_size'],
                        minimap=args['minimap_size'])),
                step_mul=args['step_mul'],
                game_steps_per_episode=args['max_episode_step'],
                visualize=True) as env:
            actor = actorNetwork(
                sess,
                args['actor_lr'],
                args['screen_size'],
                args['action_size'],
            )
            critic = criticNetwork(sess, args['critic_lr'],
                                   args['screen_size'], args['gamma'])

            train(sess, env, actor, critic, args)
Beispiel #12
0
def make_sc2env(render=False, screen_size=RGB_SCREEN_SIZE, map_size=MAP_SIZE):
    rgb_dimensions = False
    if render:
        rgb_dimensions = sc2_env.Dimensions(screen=(screen_size, screen_size),
                                            minimap=(screen_size, screen_size))
    env_args = {
        'agent_interface_format':
        sc2_env.AgentInterfaceFormat(
            feature_dimensions=sc2_env.Dimensions(screen=(map_size, map_size),
                                                  minimap=(map_size,
                                                           map_size)),
            rgb_dimensions=rgb_dimensions,
            action_space=actions.ActionSpace.FEATURES,
        ),
        'map_name':
        MAP_NAME,
        'step_mul':
        SIMULATION_STEP_MUL,
    }
    maps_dir = os.path.join(os.path.dirname(__file__), '..', 'maps')
    register_map(maps_dir, env_args['map_name'])
    return sc2_env.SC2Env(**env_args)
Beispiel #13
0
def default_env_maker(kwargs):
    '''
    :param kwargs: map_name, players, ... almost same as SC2Env
    :return: env_maker
    '''
    assert kwargs.get('map_name') is not None

    screen_sz = kwargs.pop('screen_size', DEFAULT_SCREEN_SIZE)
    minimap_sz = kwargs.pop('minimap_size', DEFAULT_MINIMAP_SIZE)
    assert screen_sz == minimap_sz
    if 'agent_interface_format' not in kwargs:
        kwargs['agent_interface_format'] = sc2_env.AgentInterfaceFormat(
            use_feature_units=True,
            use_raw_units=True,
            feature_dimensions=sc2_env.Dimensions(screen=(screen_sz,
                                                          screen_sz),
                                                  minimap=(minimap_sz,
                                                           minimap_sz)))
    if 'visualize' not in kwargs:
        kwargs['visualize'] = False
    if 'step_mul' not in kwargs:
        kwargs['step_mul'] = DEFAULT_STEP_MUL
    return sc2_env.SC2Env(**kwargs)
Beispiel #14
0
    def test_observation_matches_obs_spec(self):
        with sc2_env.SC2Env(
                map_name="Simple64",
                players=[
                    sc2_env.Agent(sc2_env.Race.random),
                    sc2_env.Bot(sc2_env.Race.random, sc2_env.Difficulty.easy)
                ],
                agent_interface_format=sc2_env.AgentInterfaceFormat(
                    feature_dimensions=sc2_env.Dimensions(
                        screen=(84, 87), minimap=(64, 67)))) as env:

            multiplayer_obs_spec = env.observation_spec()
            self.assertIsInstance(multiplayer_obs_spec, tuple)
            self.assertLen(multiplayer_obs_spec, 1)
            obs_spec = multiplayer_obs_spec[0]

            multiplayer_action_spec = env.action_spec()
            self.assertIsInstance(multiplayer_action_spec, tuple)
            self.assertLen(multiplayer_action_spec, 1)
            action_spec = multiplayer_action_spec[0]

            agent = random_agent.RandomAgent()
            agent.setup(obs_spec, action_spec)

            multiplayer_obs = env.reset()
            agent.reset()
            for _ in range(100):
                self.assertIsInstance(multiplayer_obs, tuple)
                self.assertLen(multiplayer_obs, 1)
                raw_obs = multiplayer_obs[0]
                obs = raw_obs.observation
                self.check_observation_matches_spec(obs, obs_spec)

                act = agent.step(raw_obs)
                multiplayer_act = (act, )
                multiplayer_obs = env.step(multiplayer_act)
Beispiel #15
0
def main():
    FLAGS(sys.argv)

    logdir = "tensorboard"
    if (FLAGS.algorithm == "deepq"):
        logdir = "tensorboard/zergling/%s/%s_%s_prio%s_duel%s_lr%s/%s" % (
            FLAGS.algorithm, FLAGS.timesteps, FLAGS.exploration_fraction,
            FLAGS.prioritized, FLAGS.dueling, FLAGS.lr, start_time)
    elif (FLAGS.algorithm == "acktr"):
        logdir = "tensorboard/zergling/%s/%s_num%s_lr%s/%s" % (
            FLAGS.algorithm, FLAGS.timesteps, FLAGS.num_cpu, FLAGS.lr,
            start_time)

    if (FLAGS.log == "tensorboard"):
        Logger.DEFAULT \
          = Logger.CURRENT \
          = Logger(dir=None,
                   output_formats=[TensorBoardOutputFormat(logdir)])

    elif (FLAGS.log == "stdout"):
        Logger.DEFAULT \
          = Logger.CURRENT \
          = Logger(dir=None,
                   output_formats=[HumanOutputFormat(sys.stdout)])

    with sc2_env.SC2Env(
            map_name="DefeatZerglingsAndBanelings",
            step_mul=step_mul,
            visualize=True,
            agent_interface_format=sc2_env.AgentInterfaceFormat(
                feature_dimensions=sc2_env.Dimensions(screen=32, minimap=32)),
            game_steps_per_episode=steps * step_mul) as env:

        print(env.observation_spec())
        screen_dim = env.observation_spec()[0]['feature_screen'][1:3]
        print(screen_dim)
Beispiel #16
0
from pysc2.env import run_loop

from pysc2.env import remote_sc2_env

from raw_agents import ZerglingRush, MacroZerg
from settings import RESOLUTION, STEP_MUL

AGENT = ZerglingRush()
RACE = sc2_env.Race.zerg
AGENT_INTERFACE_FORMAT = sc2_env.AgentInterfaceFormat(
    feature_dimensions=sc2_env.Dimensions(screen=RESOLUTION, minimap=RESOLUTION),
    raw_resolution=RESOLUTION,
    use_feature_units=True,
    use_raw_units=True,
    use_raw_actions=AGENT.raw_interface,
    show_cloaked=True,
    show_burrowed_shadows=True,
    show_placeholders=True,
    add_cargo_to_units=True,
    crop_to_playable_area=True,
    raw_crop_to_playable_area=True,
    send_observation_proto=True)

# Flags
FLAGS = flags.FLAGS
flags.DEFINE_integer("GamePort", None, "GamePort")
flags.DEFINE_integer("StartPort", None, "StartPort")
flags.DEFINE_string("LadderServer", "127.0.0.1", "LadderServer")
flags.DEFINE_string("OpponentId", None, "OpponentId")

# Run ladder game
def main():
    FLAGS(sys.argv)

    steps = 0  #Test steps

    print("algorithm : %s" % FLAGS.algorithm)
    print("timesteps : %s" % FLAGS.timesteps)
    print("exploration_fraction : %s" % FLAGS.exploration_fraction)
    print("prioritized : %s" % FLAGS.prioritized)
    print("dueling : %s" % FLAGS.dueling)
    print("num_agents : %s" % FLAGS.num_agents)
    print("lr : %s" % FLAGS.lr)

    if FLAGS.lr == 0:
        FLAGS.lr = random.uniform(0.00001, 0.001)

    print("random lr : %s" % FLAGS.lr)

    lr_round = round(FLAGS.lr, 8)

    logdir = "tensorboard"

    if FLAGS.algorithm == "deepq-4way":
        logdir = "tensorboard/mineral/%s/%s_%s_prio%s_duel%s_lr%s/%s" % (
            FLAGS.algorithm, FLAGS.timesteps, FLAGS.exploration_fraction,
            FLAGS.prioritized, FLAGS.dueling, lr_round, start_time)
    elif FLAGS.algorithm == "deepq":
        logdir = "tensorboard/mineral/%s/%s_%s_prio%s_duel%s_lr%s/%s" % (
            FLAGS.algorithm, FLAGS.timesteps, FLAGS.exploration_fraction,
            FLAGS.prioritized, FLAGS.dueling, lr_round, start_time)
    elif FLAGS.algorithm == "a2c":
        logdir = "tensorboard/mineral/%s/%s_n%s_s%s_nsteps%s/lr%s/%s" % (
            FLAGS.algorithm, FLAGS.timesteps,
            FLAGS.num_agents + FLAGS.num_scripts, FLAGS.num_scripts,
            FLAGS.nsteps, lr_round, start_time)

    if FLAGS.log == "tensorboard":
        Logger.DEFAULT \
          = Logger.CURRENT \
          = Logger(dir=None,
                   output_formats=[TensorBoardOutputFormat(logdir)])

    elif FLAGS.log == "stdout":
        Logger.DEFAULT \
          = Logger.CURRENT \
          = Logger(dir=None,
                   output_formats=[HumanOutputFormat(sys.stdout)])

    if FLAGS.algorithm == "deepq":

        AGENT_INTERFACE_FORMAT = sc2_env.AgentInterfaceFormat(
            feature_dimensions=sc2_env.Dimensions(screen=16, minimap=16))
        # temp solution - sc2_env.Agent(sc2_env.Race.terran) might be too restricting
        # We need this change because sc2 now requires specifying players.
        with sc2_env.SC2Env(
                map_name="Simple64",
                players=[
                    sc2_env.Agent(race=sc2_env.Race.terran),
                    sc2_env.Agent(race=sc2_env.Race.terran)
                ],
                #players=[sc2_env.Agent(sc2_env.Race.terran),sc2_env.Agent(sc2_env.Race.terran)],
                step_mul=step_mul,
                visualize=True,
                agent_interface_format=AGENT_INTERFACE_FORMAT) as env:

            model = cnn_to_mlp(convs=[(16, 8, 4), (32, 4, 2)],
                               hiddens=[256],
                               dueling=True)

            acts = deepq_nexus_wars.learn(
                env,
                q_func=model,
                num_actions=16,
                lr=FLAGS.lr,
                max_timesteps=FLAGS.timesteps,
                buffer_size=10000,
                exploration_fraction=FLAGS.exploration_fraction,
                exploration_final_eps=0.01,
                train_freq=4,
                learning_starts=10000,
                target_network_update_freq=1000,
                gamma=0.99,
                prioritized_replay=True,
                callback=deepq_callback)

            agent = random_agent.RandomAgent()
            run_loop.run_loop([agent], env, steps)

            acts[0].save("mineral_shards_x.pkl")
            acts[1].save("mineral_shards_y.pkl")

    elif FLAGS.algorithm == "deepq-4way":

        AGENT_INTERFACE_FORMAT = sc2_env.AgentInterfaceFormat(
            feature_dimensions=sc2_env.Dimensions(screen=32, minimap=32))
        with sc2_env.SC2Env(map_name="Simple64",
                            players=[
                                sc2_env.Agent(race=sc2_env.Race.terran),
                                sc2_env.Agent(race=sc2_env.Race.terran)
                            ],
                            step_mul=step_mul,
                            agent_interface_format=AGENT_INTERFACE_FORMAT,
                            visualize=True) as env:

            model = cnn_to_mlp(convs=[(16, 8, 4), (32, 4, 2)],
                               hiddens=[256],
                               dueling=True)

            act = deepq_mineral_4way.learn(
                env,
                q_func=model,
                num_actions=4,
                lr=FLAGS.lr,
                max_timesteps=FLAGS.timesteps,
                buffer_size=10000,
                exploration_fraction=FLAGS.exploration_fraction,
                exploration_final_eps=0.01,
                train_freq=4,
                learning_starts=10000,
                target_network_update_freq=1000,
                gamma=0.99,
                prioritized_replay=True,
                callback=deepq_4way_callback)

            act.save("mineral_shards.pkl")

    elif FLAGS.algorithm == "a2c":

        num_timesteps = int(40e6)

        num_timesteps //= 4

        seed = 0

        env = SubprocVecEnv(FLAGS.num_agents + FLAGS.num_scripts,
                            FLAGS.num_scripts, FLAGS.map)

        policy_fn = CnnPolicy
        a2c.learn(policy_fn,
                  env,
                  seed,
                  total_timesteps=num_timesteps,
                  nprocs=FLAGS.num_agents + FLAGS.num_scripts,
                  nscripts=FLAGS.num_scripts,
                  ent_coef=0.5,
                  nsteps=FLAGS.nsteps,
                  max_grad_norm=0.01,
                  callback=a2c_callback)
Beispiel #18
0
def main(unused_argv):
    rs = FLAGS.random_seed
    if FLAGS.random_seed is None:
        rs = int((time.time() % 1) * 1000000)

    logger.configure(dir=FLAGS.train_log_dir, format_strs=['log'])

    players = []
    players.append(sc2_env.Agent(races[FLAGS.agent_race]))
    players.append(sc2_env.Agent(races[FLAGS.oppo_race]))

    screen_res = (int(FLAGS.screen_ratio * FLAGS.screen_resolution) // 4 * 4,
                  FLAGS.screen_resolution)
    if FLAGS.agent_interface_format == 'feature':
        agent_interface_format = sc2_env.AgentInterfaceFormat(
            feature_dimensions=sc2_env.Dimensions(
                screen=screen_res, minimap=FLAGS.minimap_resolution))
    elif FLAGS.agent_interface_format == 'rgb':
        agent_interface_format = sc2_env.AgentInterfaceFormat(
            rgb_dimensions=sc2_env.Dimensions(
                screen=screen_res, minimap=FLAGS.minimap_resolution))
    else:
        raise NotImplementedError

    agents = [ZergBotAgent()]

    env = ZergScoutSelfplayEnv(
        agents,
        map_name=FLAGS.map,
        players=players,
        step_mul=FLAGS.step_mul,
        random_seed=rs,
        game_steps_per_episode=FLAGS.max_step,
        agent_interface_format=agent_interface_format,
        score_index=-1,  # this indicates the outcome is reward
        disable_fog=FLAGS.disable_fog,
        visualize=FLAGS.render)

    env = make(FLAGS.wrapper, env)

    network = model(FLAGS.wrapper)  #deepq.models.mlp([64, 32])

    print('params, lr={} bf={} ef={} ef_eps={}'.format(FLAGS.param_lr,
                                                       FLAGS.param_bf,
                                                       FLAGS.param_ef,
                                                       FLAGS.param_efps))

    random_support = False
    total_rwd = 0.0
    act_val = 1
    try:
        obs = env.reset()
        n_step = 0
        # run this episode
        while True:
            n_step += 1
            #print('observation=', obs, 'observation_none=', obs[None])
            action = act_val  #act(obs[None])[0]
            obs, rwd, done, other = env.step(action)
            print('action=', action, '; rwd=', rwd, '; step=', n_step)
            total_rwd += rwd
            if other:
                act_val = 7
            if random_support:
                act_val = random.randint(0, 8)

            if n_step == 50:
                act_val = 3
            '''
            if n_step == 20:
                act_val = 0
            elif n_step == 94:
                act_val = 1
            '''
            #print('step rwd=', rwd, ',action=', action, "obs=", obs)
            if done:
                print("game over, total_rwd=", total_rwd)
                break
    except KeyboardInterrupt:
        pass
    finally:
        print("evaluation over")
    env.unwrapped.save_replay('evaluate')
    env.close()
Beispiel #19
0
    def __init__(self,
                 level: LevelSelection,
                 frame_skip: int,
                 visualization_parameters: VisualizationParameters,
                 target_success_rate: float = 1.0,
                 seed: Union[None, int] = None,
                 human_control: bool = False,
                 custom_reward_threshold: Union[int, float] = None,
                 screen_size: int = 84,
                 minimap_size: int = 64,
                 feature_minimap_maps_to_use: List = range(7),
                 feature_screen_maps_to_use: List = range(17),
                 observation_type:
                 StarcraftObservationType = StarcraftObservationType.Features,
                 disable_fog: bool = False,
                 auto_select_all_army: bool = True,
                 use_full_action_space: bool = False,
                 **kwargs):
        super().__init__(level, seed, frame_skip, human_control,
                         custom_reward_threshold, visualization_parameters,
                         target_success_rate)

        self.screen_size = screen_size
        self.minimap_size = minimap_size
        self.feature_minimap_maps_to_use = feature_minimap_maps_to_use
        self.feature_screen_maps_to_use = feature_screen_maps_to_use
        self.observation_type = observation_type
        self.features_screen_size = None
        self.feature_minimap_size = None
        self.rgb_screen_size = None
        self.rgb_minimap_size = None
        if self.observation_type == StarcraftObservationType.Features:
            self.features_screen_size = screen_size
            self.feature_minimap_size = minimap_size
        elif self.observation_type == StarcraftObservationType.RGB:
            self.rgb_screen_size = screen_size
            self.rgb_minimap_size = minimap_size
        self.disable_fog = disable_fog
        self.auto_select_all_army = auto_select_all_army
        self.use_full_action_space = use_full_action_space

        # step_mul is the equivalent to frame skipping. Not sure if it repeats actions in between or not though.
        self.env = sc2_env.SC2Env(
            map_name=self.env_id,
            step_mul=frame_skip,
            visualize=self.is_rendered,
            agent_interface_format=sc2_env.AgentInterfaceFormat(
                feature_dimensions=sc2_env.Dimensions(
                    screen=self.features_screen_size,
                    minimap=self.feature_minimap_size)
                # rgb_dimensions=sc2_env.Dimensions(
                #     screen=self.rgb_screen_size,
                #     minimap=self.rgb_screen_size
                # )
            ),
            # feature_screen_size=self.features_screen_size,
            # feature_minimap_size=self.feature_minimap_size,
            # rgb_screen_size=self.rgb_screen_size,
            # rgb_minimap_size=self.rgb_screen_size,
            disable_fog=disable_fog,
            random_seed=self.seed)

        # print all the available actions
        # self.env = available_actions_printer.AvailableActionsPrinter(self.env)

        self.reset_internal_state(True)
        """
        feature_screen:  [height_map, visibility_map, creep, power, player_id, player_relative, unit_type, selected,
                          unit_hit_points, unit_hit_points_ratio, unit_energy, unit_energy_ratio, unit_shields,
                          unit_shields_ratio, unit_density, unit_density_aa, effects]
        feature_minimap: [height_map, visibility_map, creep, camera, player_id, player_relative, selecte
        d]
        player:          [player_id, minerals, vespene, food_cap, food_army, food_workers, idle_worker_dount,
                          army_count, warp_gate_count, larva_count]
        """
        self.screen_shape = np.array(
            self.env.observation_spec()[0]['feature_screen'])
        self.screen_shape[0] = len(self.feature_screen_maps_to_use)
        self.minimap_shape = np.array(
            self.env.observation_spec()[0]['feature_minimap'])
        self.minimap_shape[0] = len(self.feature_minimap_maps_to_use)
        self.state_space = StateSpace({
            "screen":
            PlanarMapsObservationSpace(shape=self.screen_shape,
                                       low=0,
                                       high=255,
                                       channels_axis=0),
            "minimap":
            PlanarMapsObservationSpace(shape=self.minimap_shape,
                                       low=0,
                                       high=255,
                                       channels_axis=0),
            "measurements":
            VectorObservationSpace(self.env.observation_spec()[0]["player"][0])
        })
        if self.use_full_action_space:
            action_identifiers = list(self.env.action_spec()[0].functions)
            num_action_identifiers = len(action_identifiers)
            action_arguments = [(arg.name, arg.sizes)
                                for arg in self.env.action_spec()[0].types]
            sub_action_spaces = [DiscreteActionSpace(num_action_identifiers)]
            for argument in action_arguments:
                for dimension in argument[1]:
                    sub_action_spaces.append(DiscreteActionSpace(dimension))
            self.action_space = CompoundActionSpace(sub_action_spaces)
        else:
            self.action_space = BoxActionSpace(2,
                                               0,
                                               self.screen_size - 1,
                                               ["X-Axis, Y-Axis"],
                                               default_action=np.array([
                                                   self.screen_size / 2,
                                                   self.screen_size / 2
                                               ]))

        self.target_success_rate = target_success_rate
Beispiel #20
0
def main():
    FLAGS(sys.argv)
    AGENT_INTERFACE_FORMAT = sc2_env.AgentInterfaceFormat(
        feature_dimensions=sc2_env.Dimensions(screen=64, minimap=64))
    with sc2_env.SC2Env(map_name="CollectMineralShards",
                        step_mul=step_mul,
                        visualize=True,
                        game_steps_per_episode=steps * step_mul,
                        agent_interface_format=AGENT_INTERFACE_FORMAT) as env:

        model = deepq.models.cnn_to_mlp(convs=[(32, 8, 4), (64, 4, 2),
                                               (64, 3, 1)],
                                        hiddens=[256],
                                        dueling=True)

        def make_obs_ph(name):
            return U_b.BatchInput((64, 64), name=name)  #64 64

        act_params = {
            'make_obs_ph': make_obs_ph,
            'q_func': model,
            'num_actions': 4,
        }

        act = deepq_mineral_shards.load("mineral_shards.pkl",
                                        act_params=act_params)

        while True:

            obs = env.reset()
            episode_rew = 0

            done = False

            step_result = env.step(actions=[
                sc2_actions.FunctionCall(_SELECT_ARMY, [_SELECT_ALL])
            ])

            while not done:

                player_relative = step_result[0].observation["feature_screen"][
                    _PLAYER_RELATIVE]

                obs = player_relative

                player_y, player_x = (
                    player_relative == _PLAYER_FRIENDLY).nonzero()
                player = [int(player_x.mean()), int(player_y.mean())]

                if (player[0] > 32):
                    obs = shift(LEFT, player[0] - 32, obs)
                elif (player[0] < 32):
                    obs = shift(RIGHT, 32 - player[0], obs)

                if (player[1] > 32):
                    obs = shift(UP, player[1] - 32, obs)
                elif (player[1] < 32):
                    obs = shift(DOWN, 32 - player[1], obs)
                action = act(np.array(obs)[None])[0]
                print('action=', action)
                coord = [player[0], player[1]]

                if (action == 0):  #UP

                    if (player[1] >= 16):
                        coord = [player[0], player[1] - 16]
                    elif (player[1] > 0):
                        coord = [player[0], 0]

                elif (action == 1):  #DOWN

                    if (player[1] <= 47):
                        coord = [player[0], player[1] + 16]
                    elif (player[1] > 47):
                        coord = [player[0], 63]

                elif (action == 2):  #LEFT

                    if (player[0] >= 16):
                        coord = [player[0] - 16, player[1]]
                    elif (player[0] < 16):
                        coord = [0, player[1]]

                elif (action == 3):  #RIGHT

                    if (player[0] <= 47):
                        coord = [player[0] + 16, player[1]]
                    elif (player[0] > 47):
                        coord = [63, player[1]]

                new_action = [
                    sc2_actions.FunctionCall(_MOVE_SCREEN,
                                             [_NOT_QUEUED, coord])
                ]

                step_result = env.step(actions=new_action)

                rew = step_result[0].reward
                done = step_result[0].step_type == environment.StepType.LAST

                episode_rew += rew
            print("Episode reward", episode_rew)
Beispiel #21
0
env_kwargs = {
    # "map_name": "EconomicRLTraining",
    # "map_name": "StalkersVsRoaches",
    # "map_name": "MoveToBeacon",
    "map_name":
    "CollectMineralShards",
    "visualize":
    False,
    "step_mul":
    8,
    'game_steps_per_episode':
    None,
    "agent_interface_format":
    sc2_env.AgentInterfaceFormat(feature_dimensions=sc2_env.Dimensions(
        screen=84, minimap=84),
                                 action_space=actions.ActionSpace.FEATURES,
                                 use_feature_units=True)
}


def main(_):
    env_interface = EmbeddingInterfaceWrapper(BeaconEnvironmentInterface())
    # env_interface = EmbeddingInterfaceWrapper(TrainMarines())
    learner = Learner(10,
                      env_kwargs,
                      env_interface,
                      run_name="MineralWithBeacon2",
                      load_name="Beacon2",
                      load_model=True)
    learner.train()
Beispiel #22
0
def main(unused_argv):
    rs = FLAGS.random_seed
    if FLAGS.random_seed is None:
        rs = int((time.time() % 1) * 1000000)

    players = []
    players.append(sc2_env.Agent(races[FLAGS.agent_race]))
    players.append(sc2_env.Agent(races[FLAGS.oppo_race]))

    screen_res = (int(FLAGS.screen_ratio * FLAGS.screen_resolution) // 4 * 4,
                  FLAGS.screen_resolution)
    if FLAGS.agent_interface_format == 'feature':
        agent_interface_format = sc2_env.AgentInterfaceFormat(
        feature_dimensions = sc2_env.Dimensions(screen=screen_res,
            minimap=FLAGS.minimap_resolution))
    elif FLAGS.agent_interface_format == 'rgb':
        agent_interface_format = sc2_env.AgentInterfaceFormat(
        rgb_dimensions=sc2_env.Dimensions(screen=screen_res,
            minimap=FLAGS.minimap_resolution))
    else:
        raise NotImplementedError

    agents = [ZergBotAgent()]

    ncpu = 1
    if sys.platform == 'darwin': ncpu //= 2
    config = tf.ConfigProto(allow_soft_placement=True,
                            intra_op_parallelism_threads=ncpu,
                            inter_op_parallelism_threads=ncpu)
    config.gpu_options.allow_growth = True  # pylint: disable=E1101
    tf.Session(config=config).__enter__()

    # env = make_sc2_dis_env(num_env=1, seed=rs, players=players, agent_interface_format=agent_interface_format)
    model_dir = FLAGS.model_dir

    total_rwd = 0

    env = ZergScoutSelfplayEnv(
        agents,
        map_name=FLAGS.map,
        players=players,
        step_mul=FLAGS.step_mul,
        random_seed=rs,
        game_steps_per_episode=FLAGS.max_step,
        agent_interface_format=agent_interface_format,
        score_index=-1,  # this indicates the outcome is reward
        disable_fog=FLAGS.disable_fog,
        visualize=FLAGS.render
    )

    env = make(FLAGS.wrapper, env)
    agent = ppo2.load_model(CnnPolicy,env,model_dir)

    try:
        obs = env.reset()
        state = agent.initial_state
        n_step = 0
        done = False

        # run this episode
        while True:
            n_step += 1
            obs = np.reshape(obs, (1,) + obs.shape)  # convert shape (32,32,20) to (1,32,32,20)
            action, value, state, _ = agent.step(obs, state, done)
            obs, rwd, done, info = env.step(action)
            print('action=', action, '; rwd=', rwd)
            # print('step rwd=', rwd, ',action=', action, "obs=", obs)
            total_rwd += rwd
            if done:
                print("game over, total_rwd=", total_rwd)
                break
    except KeyboardInterrupt:
        pass
    finally:
        print("evaluation over")
    env.unwrapped.save_replay('evaluate')
def main():
    FLAGS(sys.argv)
    AGENT_INTERFACE_FORMAT = sc2_env.AgentInterfaceFormat(
        feature_dimensions=sc2_env.Dimensions(screen=16, minimap=16))
    with sc2_env.SC2Env(map_name="CollectMineralShards",
                        players=[sc2_env.Agent(sc2_env.Race.terran)],
                        step_mul=step_mul,
                        visualize=True,
                        agent_interface_format=AGENT_INTERFACE_FORMAT) as env:

        # model = cnn_to_mlp(
        #   convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        #   hiddens=[256],
        #   dueling=True)

        # def make_obs_ph(name):
        #   return BatchInput((1, 64, 64), name=name)

        model = cnn_to_mlp(convs=[(16, 8, 4), (32, 4, 2)],
                           hiddens=[256],
                           dueling=True)

        def make_obs_ph(name):
            return BatchInput((1, 16, 16), name=name)

        # Using deepq_x here instead of deepq for agent x
        act_params = {
            'make_obs_ph': make_obs_ph,
            'q_func': model,
            'num_actions': 16,
            'scope': "deepq_x"
        }

        # This needs to be the saved model for deepq_x
        # You can change the scope to deepq_y for agent y
        act = deepq_mineral_shards.load("mineral_shards.pkl",
                                        act_params=act_params)

        while True:

            obs = env.reset()
            episode_rew = 0

            done = False

            step_result = env.step(actions=[
                sc2_actions.FunctionCall(_SELECT_ARMY, [_SELECT_ALL])
            ])

            while not done:

                player_relative = step_result[0].observation["feature_screen"][
                    _PLAYER_RELATIVE]

                obs = player_relative

                player_y, player_x = (
                    player_relative == _PLAYER_FRIENDLY).nonzero()
                player = [int(player_x.mean()), int(player_y.mean())]

                if (player[0] > 32):
                    obs = shift(LEFT, player[0] - 32, obs)
                elif (player[0] < 32):
                    obs = shift(RIGHT, 32 - player[0], obs)

                if (player[1] > 32):
                    obs = shift(UP, player[1] - 32, obs)
                elif (player[1] < 32):
                    obs = shift(DOWN, 32 - player[1], obs)

                action = act(np.expand_dims(obs[None], axis=0))[0]
                coord = [player[0], player[1]]

                if (action == 0):  #UP

                    if (player[1] >= 16):
                        coord = [player[0], player[1] - 16]
                    elif (player[1] > 0):
                        coord = [player[0], 0]

                elif (action == 1):  #DOWN

                    if (player[1] <= 47):
                        coord = [player[0], player[1] + 16]
                    elif (player[1] > 47):
                        coord = [player[0], 63]

                elif (action == 2):  #LEFT

                    if (player[0] >= 16):
                        coord = [player[0] - 16, player[1]]
                    elif (player[0] < 16):
                        coord = [0, player[1]]

                elif (action == 3):  #RIGHT

                    if (player[0] <= 47):
                        coord = [player[0] + 16, player[1]]
                    elif (player[0] > 47):
                        coord = [63, player[1]]

                new_action = [
                    sc2_actions.FunctionCall(_MOVE_SCREEN,
                                             [_NOT_QUEUED, coord])
                ]

                step_result = env.step(actions=new_action)

                rew = step_result[0].reward
                done = step_result[0].step_type == environment.StepType.LAST

                episode_rew += rew
            print("Episode reward", episode_rew)
Beispiel #24
0
from Agent import Agent
from pysc2.env import sc2_env
from Util import Global
from pysc2.env.environment import StepType

import sys
from absl import flags
FLAGS = flags.FLAGS
FLAGS(sys.argv)

env = sc2_env.SC2Env(
    map_name="BuildMarines",
    step_mul=8,
    visualize=False,
    agent_interface_format=sc2_env.AgentInterfaceFormat(
        feature_dimensions=sc2_env.Dimensions(
            screen=32,
            minimap=32))
)


agent = Agent(450, "checkpoints_marines_9/")


for episode in range(1000):

    episode_reward = 0
    step = 0
    done = False
    obs = env.reset()[0]
    agent.reset(episode)
Beispiel #25
0
from absl import flags
import sys

import numpy as np

nb_episodes = 10
nb_max_steps = 2000

FLAGS = flags.FLAGS
FLAGS(sys.argv)
flags.DEFINE_bool("render", False, "Whether to render with pygame.")
flags.DEFINE_float("fps", 1, "Frames per second to runs the game.")

agent_format = sc2_env.AgentInterfaceFormat(
    feature_dimensions=sc2_env.Dimensions(
        screen=(32, 32),
        minimap=(32, 32),
    ))

env_names = [
    "DefeatZerglingsAndBanelings", "DefeatRoaches", "CollectMineralShards",
    "MoveToBeacon", "FindAndDefeatZerglings", "BuildMarines",
    "CollectMineralsAndGas"
]


def run(env_name):
    env = sc2_env.SC2Env(
        map_name=env_name,  # "BuildMarines",
        step_mul=16,
        visualize=False,
Beispiel #26
0
# limitations under the License.
"""Test that stepping without observing works correctly for multiple players."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl.testing import absltest

from pysc2.env import sc2_env
from pysc2.lib import actions
from pysc2.tests import utils


AGENT_INTERFACE_FORMAT = sc2_env.AgentInterfaceFormat(
    feature_dimensions=sc2_env.Dimensions(screen=32, minimap=32)
)


class StepWithoutObserveTest(utils.TestCase):

  def test_returns_observation_on_first_step_despite_no_observe(self):
    with sc2_env.SC2Env(
        map_name="DefeatRoaches",
        players=[sc2_env.Agent(sc2_env.Race.random)],
        step_mul=1,
        agent_interface_format=AGENT_INTERFACE_FORMAT) as env:
      timestep = env.step(
          actions=[actions.FUNCTIONS.no_op()],
          update_observation=[False])
Beispiel #27
0
from pysc2.env import sc2_env, run_loop
from DQN import DQNModel

import sys
from absl import flags

import Agent

FLAGS = flags.FLAGS
FLAGS(sys.argv)

env = sc2_env.SC2Env(map_name='Simple64',
                     players=[
                         sc2_env.Agent(race=sc2_env.Race.terran),
                         sc2_env.Bot(difficulty=sc2_env.Difficulty.very_easy,
                                     race=sc2_env.Race.random)
                     ],
                     step_mul=16,
                     agent_interface_format=sc2_env.AgentInterfaceFormat(
                         use_feature_units=True,
                         hide_specific_actions=False,
                         feature_dimensions=sc2_env.Dimensions(screen=64,
                                                               minimap=64)),
                     game_steps_per_episode=0,
                     visualize=True)
agent = Agent.Agent()
observation_spec = env.observation_spec()
action_spec = env.action_spec()
agent.setup(observation_spec[0], action_spec[0])
run_loop.run_loop([agent], env)
def main(unused_argv):
    rs = FLAGS.random_seed
    if FLAGS.random_seed is None:
        rs = int((time.time() % 1) * 1000000)

    logger.configure(dir=FLAGS.train_log_dir, format_strs=['log'])

    players = []
    players.append(sc2_env.Agent(races[FLAGS.agent_race]))
    players.append(sc2_env.Agent(races[FLAGS.oppo_race]))

    screen_res = (int(FLAGS.screen_ratio * FLAGS.screen_resolution) // 4 * 4,
                  FLAGS.screen_resolution)
    if FLAGS.agent_interface_format == 'feature':
        agent_interface_format = sc2_env.AgentInterfaceFormat(
            feature_dimensions=sc2_env.Dimensions(
                screen=screen_res, minimap=FLAGS.minimap_resolution))
    elif FLAGS.agent_interface_format == 'rgb':
        agent_interface_format = sc2_env.AgentInterfaceFormat(
            rgb_dimensions=sc2_env.Dimensions(
                screen=screen_res, minimap=FLAGS.minimap_resolution))
    else:
        raise NotImplementedError

    ncpu = multiprocessing.cpu_count()
    if sys.platform == 'darwin': ncpu //= 2
    config = tf.ConfigProto(
        allow_soft_placement=True,  #log_device_placement=True,
        intra_op_parallelism_threads=ncpu,
        inter_op_parallelism_threads=ncpu)
    config.gpu_options.allow_growth = True  # pylint: disable=E1101
    tf.Session(config=config).__enter__()

    param_lam = FLAGS.param_lam
    param_gamma = FLAGS.param_gamma
    param_concurrent = FLAGS.param_concurrent
    param_lr = FLAGS.param_lr
    param_cr = FLAGS.param_cr
    param_tstep = FLAGS.param_tstep
    print('params, lam={} gamma={} concurrent={} lr={} tstep={}'.format(
        param_lam, param_gamma, param_concurrent, param_lr, param_tstep))

    policy_dict = {
        'twoStreamCNNPolicy': twoStreamCNNPolicy,
        'CnnPolicy': CnnPolicy,
        'LnLstmPolicy': LnLstmPolicy,
        'LstmPolicy': LstmPolicy,
        'MlpPolicy': MlpPolicy,
        'CnnVecSplitPolicy': CnnVecSplitPolicy
    }

    obs_dict = {
        'img_obs_shape': (32, 32, 3),
        'vec_obs_shape': (2, ),
        'num_frame_stack': 4
    }

    env = VecFrameStack(
        make_sc2_dis_env(num_env=FLAGS.param_concurrent,
                         seed=rs,
                         players=players,
                         agent_interface_format=agent_interface_format),
        obs_dict['num_frame_stack'])

    if FLAGS.policy == 'twoStreamCNNPolicy':
        ppo2_twoStreamModel.learn(
            policy=policy_dict[FLAGS.policy],
            env=env,
            nsteps=64,
            nminibatches=1,
            lam=param_lam,
            gamma=param_gamma,
            noptepochs=4,
            log_interval=1,
            ent_coef=0.01,
            lr=lambda f: f * param_lr,
            cliprange=lambda f: f * param_cr,
            total_timesteps=param_tstep,
            save_interval=50,
            load_path=None  #FLAGS.model_dir
        )
    elif FLAGS.policy == 'CnnVecSplitPolicy':
        ppo2_v1.learn(
            policy=policy_dict[FLAGS.policy],
            env=env,
            nsteps=64,
            nminibatches=1,
            obs_dict=obs_dict,
            lam=param_lam,
            gamma=param_gamma,
            noptepochs=4,
            log_interval=1,
            ent_coef=0.01,
            lr=lambda f: f * param_lr,
            cliprange=lambda f: f * param_cr,
            total_timesteps=param_tstep,
            save_interval=50,
            load_path=None  #FLAGS.model_dir
        )
    else:
        ppo2.learn(policy=policy_dict[FLAGS.policy],
                   env=env,
                   nsteps=128,
                   nminibatches=1,
                   lam=param_lam,
                   gamma=param_gamma,
                   noptepochs=4,
                   log_interval=1,
                   ent_coef=0.01,
                   lr=lambda f: f * param_lr,
                   cliprange=lambda f: f * param_cr,
                   total_timesteps=param_tstep,
                   save_interval=50)
Beispiel #29
0
import time
import os
from pysc2.env import sc2_env
from utils import arglist
from copy import deepcopy
import torch.multiprocessing as mp
import numpy as np
import torch
from torch.distributions import Categorical
from pysc2.lib import actions

agent_format = sc2_env.AgentInterfaceFormat(
    feature_dimensions=sc2_env.Dimensions(
        screen=(arglist.FEAT2DSIZE, arglist.FEAT2DSIZE),
        minimap=(arglist.FEAT2DSIZE, arglist.FEAT2DSIZE),
    ))

from networks.acnetwork_a3c import A3CNet


def record(global_ep, global_ep_r, ep_r, res_queue, name):
    with global_ep.get_lock():
        global_ep.value += 1

    with global_ep_r.get_lock():
        if global_ep_r.value == 0.:
            global_ep_r.value = ep_r
        else:
            global_ep_r.value = global_ep_r.value * 0.99 + ep_r * 0.01
        res_queue.put(global_ep_r.value)
def main():
    FLAGS(sys.argv)

    print("algorithm : %s" % FLAGS.algorithm)
    print("timesteps : %s" % FLAGS.timesteps)
    print("exploration_fraction : %s" % FLAGS.exploration_fraction)
    print("prioritized : %s" % FLAGS.prioritized)
    print("dueling : %s" % FLAGS.dueling)
    print("num_agents : %s" % FLAGS.num_agents)
    print("lr : %s" % FLAGS.lr)

    if (FLAGS.lr == 0):
        FLAGS.lr = random.uniform(0.00001, 0.001)

    print("random lr : %s" % FLAGS.lr)

    lr_round = round(FLAGS.lr, 8)

    logdir = "tensorboard"

    if (FLAGS.algorithm == "deepq-4way"):
        logdir = "tensorboard/mineral/%s/%s_%s_prio%s_duel%s_lr%s/%s" % (
            FLAGS.algorithm, FLAGS.timesteps, FLAGS.exploration_fraction,
            FLAGS.prioritized, FLAGS.dueling, lr_round, start_time)
    elif (FLAGS.algorithm == "deepq"):
        logdir = "tensorboard/mineral/%s/%s_%s_prio%s_duel%s_lr%s/%s" % (
            FLAGS.algorithm, FLAGS.timesteps, FLAGS.exploration_fraction,
            FLAGS.prioritized, FLAGS.dueling, lr_round, start_time)
    elif (FLAGS.algorithm == "a2c"):
        logdir = "tensorboard/mineral/%s/%s_n%s_s%s_nsteps%s/lr%s/%s" % (
            FLAGS.algorithm, FLAGS.timesteps,
            FLAGS.num_agents + FLAGS.num_scripts, FLAGS.num_scripts,
            FLAGS.nsteps, lr_round, start_time)

    if (FLAGS.log == "tensorboard"):
        Logger.DEFAULT \
          = Logger.CURRENT \
          = Logger(dir=None,
                   output_formats=[TensorBoardOutputFormat(logdir)])

    elif (FLAGS.log == "stdout"):
        Logger.DEFAULT \
          = Logger.CURRENT \
          = Logger(dir=None,
                   output_formats=[HumanOutputFormat(sys.stdout)])

    if (FLAGS.algorithm == "deepq"):

        with sc2_env.SC2Env(
                map_name="CollectMineralShards",
                step_mul=step_mul,
                #screen_size_px=(16, 16),
                #minimap_size_px=(16, 16),
                agent_interface_format=sc2_env.AgentInterfaceFormat(
                    feature_dimensions=sc2_env.Dimensions(screen=16,
                                                          minimap=16)),
                visualize=True) as env:

            model = deepq.models.cnn_to_mlp(convs=[(16, 8, 4), (32, 4, 2)],
                                            hiddens=[256],
                                            dueling=True)

            act = deepq_mineral_shards.learn(
                env,
                q_func=model,
                num_actions=16,
                lr=FLAGS.lr,
                max_timesteps=FLAGS.timesteps,
                buffer_size=10000,
                exploration_fraction=FLAGS.exploration_fraction,
                exploration_final_eps=0.01,
                train_freq=4,
                learning_starts=10000,
                target_network_update_freq=1000,
                gamma=0.99,
                prioritized_replay=True,
                callback=deepq_callback)
            act.save("mineral_shards.pkl")

    elif (FLAGS.algorithm == "deepq-4way"):

        with sc2_env.SC2Env(
                map_name="CollectMineralShards",
                step_mul=step_mul,
                #screen_size_px=(32, 32),
                #minimap_size_px=(32, 32),
                agent_interface_format=sc2_env.AgentInterfaceFormat(
                    feature_dimensions=sc2_env.Dimensions(screen=32,
                                                          minimap=32)),
                visualize=True) as env:

            model = deepq.models.cnn_to_mlp(convs=[(16, 8, 4), (32, 4, 2)],
                                            hiddens=[256],
                                            dueling=True)

            act = deepq_mineral_4way.learn(
                env,
                q_func=model,
                num_actions=4,
                lr=FLAGS.lr,
                max_timesteps=FLAGS.timesteps,
                buffer_size=10000,
                exploration_fraction=FLAGS.exploration_fraction,
                exploration_final_eps=0.01,
                train_freq=4,
                learning_starts=10000,
                target_network_update_freq=1000,
                gamma=0.99,
                prioritized_replay=True,
                callback=deepq_4way_callback)

            act.save("mineral_shards.pkl")

    elif (FLAGS.algorithm == "a2c"):

        num_timesteps = int(40e6)

        num_timesteps //= 4

        seed = 0

        env = SubprocVecEnv(FLAGS.num_agents + FLAGS.num_scripts,
                            FLAGS.num_scripts, FLAGS.map)

        policy_fn = CnnPolicy
        a2c.learn(policy_fn,
                  env,
                  seed,
                  total_timesteps=num_timesteps,
                  nprocs=FLAGS.num_agents + FLAGS.num_scripts,
                  nscripts=FLAGS.num_scripts,
                  ent_coef=0.5,
                  nsteps=FLAGS.nsteps,
                  max_grad_norm=0.01,
                  callback=a2c_callback)