示例#1
0
    def simulated(self,
                  state_now,
                  action_now,
                  v_preds_now,
                  dyna_steps,
                  append_to_buffer=True):
        game_state = GameState(dynamic_net=self.dynamic_net, state=state_now)
        sim_buffer = Buffer()
        for _ in range(dyna_steps):
            # simulate next state
            next_game_state = game_state.play(action_now, verbose=False)

            state_last = state_now
            action_last = action_now
            state_now = next_game_state.obs()

            v_preds_last = v_preds_now
            v_preds_now = self.net.policy.get_values(state_now)
            v_preds_now = self.get_values(v_preds_now)

            reward = state_now[1] - state_last[1]

            if append_to_buffer:
                sim_buffer.append(state_last, action_last, state_now, reward,
                                  v_preds_last, v_preds_now)

            action_now, v_preds_now = self.net.policy.get_action(state_now,
                                                                 verbose=False)
            game_state = next_game_state

        #print('sim_buffer:', sim_buffer)
        self.global_buffer.add(sim_buffer, add_return=False)
    def __init__(self,
                 index=0,
                 rl_training=False,
                 restore_model=False,
                 global_buffer=None,
                 net=None,
                 strategy_agent=None,
                 greedy_action=False,
                 extract_save_dir=None,
                 ob_space_add=0,
                 image_debug=False,
                 action_prob_debug=False,
                 act_space_add=0):
        super(MiniSourceAgent, self).__init__()
        self.net = net
        self.index = index
        self.global_buffer = global_buffer
        self.restore_model = restore_model

        # model in brain
        self.strategy_agent = strategy_agent
        self.strategy_act = None

        # count num
        self.step = 0

        self.strategy_wait_secs = 2
        self.strategy_flag = False
        self.policy_wait_secs = 2
        self.policy_flag = True

        # every 30 seconds to show the image
        self.image_wait_secs = 1

        # every 30 seconds to show the image
        self.prob_show_wait_seconds = (np.array([0, 0.5, 1, 3, 5, 10]) *
                                       60).astype(int).tolist()

        self.env = None
        self.obs = None

        # buffer
        self.local_buffer = Buffer()

        self.num_players = 2
        self.on_select = None
        self._result = None
        self._gases = None
        self.is_end = False

        self.greedy_action = greedy_action
        self.rl_training = rl_training

        self.extract_save_dir = extract_save_dir
        self.ob_space_add = ob_space_add
        self.act_space_add = act_space_add

        self.image_debug = image_debug
        self.action_prob_debug = action_prob_debug
        self.action_num = 10 + act_space_add
示例#3
0
    def __init__(self, index=0, rl_training=False, restore_model=False, global_buffer=None, net=None, strategy_agent=None, greedy_action=False):
        super(MiniSourceAgent, self).__init__()
        self.net = net
        self.index = index
        self.global_buffer = global_buffer
        self.restore_model = restore_model

        # model in brain
        self.strategy_agent = strategy_agent
        self.strategy_act = None

        # count num
        self.step = 0

        self.strategy_wait_secs = 2
        self.strategy_flag = False
        self.policy_wait_secs = 2
        self.policy_flag = True

        self.env = None
        self.obs = None

        # buffer
        self.local_buffer = Buffer()

        self.num_players = 2
        self.on_select = None
        self._result = None
        self._gases = None
        self.is_end = False

        self.greedy_action = greedy_action
        self.rl_training = rl_training
示例#4
0
    def __init__(self, agent_id=0, global_buffer=None, net=None, restore_model=False):
        self.agent_id = agent_id
        self.net = net
        self.global_buffer = global_buffer
        self.greedy_action = False
        self.local_buffer = Buffer()
        self.env = None
        self.restore_model = restore_model

        self.reset()
示例#5
0
    def __init__(self,
                 index=0,
                 rl_training=False,
                 restore_model=False,
                 restore_internal_model=False,
                 global_buffer=None,
                 net=None,
                 use_mcts=False,
                 num_reads=0,
                 policy_in_mcts=None,
                 dynamic_net=None,
                 use_dyna=False,
                 dyna_steps_fisrt=0,
                 dyna_decrese_counter=0):
        super(MultiAgent, self).__init__()
        self.net = net
        self.index = index
        self.global_buffer = global_buffer
        self.restore_model = restore_model

        self.restore_dynamic = restore_internal_model

        # count num
        self.step = 0

        self.policy_wait_secs = 2
        self.policy_flag = True

        self.env = None
        self.obs = None

        # buffer
        self.local_buffer = Buffer()

        self.num_players = 2
        self.on_select = None
        self._result = None
        self.is_end = False

        self.rl_training = rl_training

        self.reward_type = 0

        # mcts about
        self.use_mcts = use_mcts
        self.num_reads = num_reads
        self.policy_in_mcts = policy_in_mcts
        self.dynamic_net = dynamic_net

        # dyna about
        self.use_dyna = use_dyna
        self.dyna_steps_fisrt = dyna_steps_fisrt
        self.dyna_decrese_counter = dyna_decrese_counter
        self.dyna_steps = dyna_steps_fisrt
示例#6
0
def Parameter_Server(Synchronizer, cluster, log_path, model_path, procs):
    config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False,
    )
    config.gpu_options.allow_growth = True
    server = tf.train.Server(cluster,
                             job_name="ps",
                             task_index=0,
                             config=config)
    #config.gpu_options.per_process_gpu_memory_fraction = 0.2
    sess = tf.Session(target=server.target, config=config)
    summary_writer = tf.summary.FileWriter(log_path)
    mini_net = MiniNetwork(sess,
                           index=0,
                           summary_writer=summary_writer,
                           rl_training=True,
                           cluster=cluster,
                           ppo_load_path=FLAGS.restore_model_path,
                           ppo_save_path=model_path)
    agent = MiniAgent(agent_id=-1,
                      global_buffer=Buffer(),
                      net=mini_net,
                      restore_model=FLAGS.restore_model)

    print("Parameter server: waiting for cluster connection...")
    sess.run(tf.report_uninitialized_variables())
    print("Parameter server: cluster ready!")

    print("Parameter server: initializing variables...")
    agent.init_network()
    print("Parameter server: variables initialized")

    last_win_rate = 0.

    update_counter = 0
    while update_counter <= TRAIN_ITERS:
        agent.reset_old_network()

        # wait for update
        Synchronizer.wait()
        logging("Update Network!")
        # TODO count the time , compare cpu and gpu
        time.sleep(1)

        # update finish
        Synchronizer.wait()
        logging("Update Network finished!")

        steps, win_rate = agent.update_summary(update_counter)
        logging("Steps: %d, win rate: %f" % (steps, win_rate))

        update_counter += 1
        if win_rate >= last_win_rate:
            agent.save_model()

        last_win_rate = win_rate
    for p in procs:
        print('Process terminate')
        p.terminate()
示例#7
0
def Worker(index, update_game_num, Synchronizer, cluster, model_path):
    config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False,
    )
    config.gpu_options.allow_growth = True
    worker = tf.train.Server(cluster,
                             job_name="worker",
                             task_index=index,
                             config=config)
    #config.gpu_options.per_process_gpu_memory_fraction = 0.2
    sess = tf.Session(target=worker.target, config=config)

    mini_net = MiniNetwork(sess,
                           index=index,
                           summary_writer=None,
                           rl_training=True,
                           cluster=cluster,
                           ppo_load_path=FLAGS.restore_model_path,
                           ppo_save_path=model_path)
    global_buffer = Buffer()
    agents = []
    for i in range(THREAD_NUM):
        agent = MiniAgent(agent_id=i,
                          global_buffer=global_buffer,
                          net=mini_net,
                          restore_model=FLAGS.restore_model)
        agents.append(agent)

    print("Worker %d: waiting for cluster connection..." % index)
    sess.run(tf.report_uninitialized_variables())
    print("Worker %d: cluster ready!" % index)

    while len(sess.run(tf.report_uninitialized_variables())):
        print("Worker %d: waiting for variable initialization..." % index)
        time.sleep(1)
    print("Worker %d: variables initialized" % index)

    game_num = np.ceil(update_game_num // THREAD_NUM)

    UPDATE_EVENT.clear()
    ROLLING_EVENT.set()
    difficulty = INITIAL_DIFF

    # Run threads
    threads = []
    for i in range(THREAD_NUM - 1):
        t = threading.Thread(target=run_thread,
                             args=(agents[i], game_num, Synchronizer,
                                   difficulty))
        threads.append(t)
        t.daemon = True
        t.start()
        time.sleep(3)

    run_thread(agents[-1], game_num, Synchronizer, difficulty)

    for t in threads:
        t.join()
示例#8
0
    def __init__(self,
                 agent_id=0,
                 global_buffer=None,
                 net=None,
                 restore_model=False):
        self.env = None
        self.mpc = None
        self.net = net

        self.agent_id = agent_id
        self.player_id = 0

        self.global_buffer = global_buffer
        self.restore_model = restore_model

        self.local_buffer = Buffer()
        self.restart_game()
示例#9
0
def Worker(index, update_game_num, Synchronizer, cluster, model_path):
    config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False,
    )
    config.gpu_options.allow_growth = True
    worker = tf.train.Server(cluster,
                             job_name="worker",
                             task_index=index,
                             config=config)
    sess = tf.Session(target=worker.target, config=config)
    Net = MiniNetwork(sess=sess,
                      summary_writer=None,
                      rl_training=FLAGS.training,
                      cluster=cluster,
                      index=index,
                      device=DEVICE[index % len(DEVICE)],
                      ppo_load_path=FLAGS.restore_model_path,
                      ppo_save_path=model_path)

    global_buffer = Buffer()
    agents = []
    for i in range(THREAD_NUM):
        agent = mini_source_agent.SourceAgent(
            index=i,
            global_buffer=global_buffer,
            net=Net,
            restore_model=FLAGS.restore_model,
            rl_training=FLAGS.training,
            strategy_agent=None,
            greedy_action=True)
        agents.append(agent)

    print("Worker %d: waiting for cluster connection..." % index)
    sess.run(tf.report_uninitialized_variables())
    print("Worker %d: cluster ready!" % index)

    while len(sess.run(tf.report_uninitialized_variables())):
        print("Worker %d: waiting for variable initialization..." % index)
        time.sleep(1)
    print("Worker %d: variables initialized" % index)

    UPDATE_EVENT.clear()
    ROLLING_EVENT.set()

    # Run threads
    threads = []
    for i in range(THREAD_NUM - 1):
        t = threading.Thread(target=run_thread, args=(agents[i], Synchronizer))
        threads.append(t)
        t.daemon = True
        t.start()
        time.sleep(3)

    run_thread(agents[-1], Synchronizer)

    for t in threads:
        t.join()
示例#10
0
    def __init__(self,
                 index=0,
                 rl_training=False,
                 restore_model=False,
                 global_buffer=None,
                 net=None,
                 strategy_agent=None):
        super(SourceAgent, self).__init__()
        self.net = net
        self.index = index
        self.global_buffer = global_buffer
        self.restore_model = restore_model

        # model in brain
        self.strategy_agent = strategy_agent
        self.strategy_act = None

        # count num
        self.step = 0

        self.strategy_wait_secs = 4
        self.strategy_flag = False
        self.policy_wait_secs = 2
        self.policy_flag = True

        self.env = None
        self.obs = None

        # buffer
        self.local_buffer = Buffer()
        self.mini_state = []
        self.mini_state_mapping = []

        self.num_players = 2
        self.on_select = None
        self._result = None
        self.is_end = False
        self.is_attack = False
        self._gases = None

        self.rl_training = rl_training

        self.reward_type = 0
示例#11
0
def Worker(index, update_game_num, Synchronizer, cluster, log_path, model_path, dynamic_path):
    config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False,
    )
    config.gpu_options.allow_growth = True
    worker = tf.train.Server(cluster, job_name="worker", task_index=index, config=config)
    #config.gpu_options.per_process_gpu_memory_fraction = 0.2
    sess = tf.Session(target=worker.target, config=config)
    Net = HierNetwork(sess=sess, summary_writer=None, rl_training=FLAGS.training,
                      cluster=cluster, index=index, device=DEVICE[index % len(DEVICE)], ppo_save_path=model_path, 
                      ppo_load_path=FLAGS.restore_model_path, dynamic_load_path=FLAGS.restore_dynamic_path)
    policy_in_mcts = PolicyNetinMCTS(Net)
    dynamic_net = Net.dynamic_net
    dynamic_net.restore_sl_model(FLAGS.restore_dynamic_path + "probe")

    global_buffer = Buffer()
    agents = []
    for i in range(THREAD_NUM):
        agent = MultiAgent(index=i, global_buffer=global_buffer, net=Net,
                                       restore_model=FLAGS.restore_model, rl_training=FLAGS.training,
                                       restore_internal_model=FLAGS.restore_dynamic,
                                       use_mcts=FLAGS.use_MCTS, num_reads=NUM_READS,
                                       policy_in_mcts=policy_in_mcts, dynamic_net=dynamic_net,
                                       use_dyna=FLAGS.use_Dyna, dyna_steps_fisrt=FLAGS.Dyna_steps_fisrt,
                                       dyna_decrese_counter=FLAGS.Dyna_decrese_counter)
        agents.append(agent)

    print("Worker %d: waiting for cluster connection..." % index)
    sess.run(tf.report_uninitialized_variables())
    print("Worker %d: cluster ready!" % index)

    while len(sess.run(tf.report_uninitialized_variables())):
        print("Worker %d: waiting for variable initialization..." % index)
        time.sleep(1)
    print("Worker %d: variables initialized" % index)

    game_num = np.ceil(update_game_num // THREAD_NUM)

    UPDATE_EVENT.clear()
    ROLLING_EVENT.set()

    # Run threads
    threads = []
    for i in range(THREAD_NUM - 1):
        t = threading.Thread(target=run_thread, args=(agents[i], game_num, Synchronizer, FLAGS.difficulty))
        threads.append(t)
        t.daemon = True
        t.start()
        time.sleep(3)

    run_thread(agents[-1], game_num, Synchronizer, FLAGS.difficulty)

    for t in threads:
        t.join()
def Worker(index, update_game_num, Synchronizer, cluster, model_path, log_path):
    config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False,
    )
    config.gpu_options.allow_growth = True
    worker = tf.train.Server(cluster, job_name="worker", task_index=index, config=config)
    sess = tf.Session(target=worker.target, config=config)
    summary_writer = tf.summary.FileWriter(log_path)
    Net = MiniNetwork(sess=sess, summary_writer=summary_writer, rl_training=FLAGS.training,
                      cluster=cluster, index=index, device=DEVICE[index % len(DEVICE)],
                      ppo_load_path=FLAGS.restore_model_path, ppo_save_path=model_path, 
                      ob_space_add=FLAGS.ob_space_add, act_space_add=FLAGS.act_space_add, 
                      freeze_head=FLAGS.freeze_head, use_bn=FLAGS.use_bn, 
                      use_sep_net=FLAGS.use_sep_net, restore_model=FLAGS.restore_model,
                      restore_from=FLAGS.restore_from, restore_to=FLAGS.restore_to,
                      load_latest=FLAGS.load_latest, add_image=FLAGS.add_image, partial_restore=FLAGS.partial_restore,
                      weighted_sum_type=FLAGS.weighted_sum_type, initial_type=FLAGS.initial_type)

    global_buffer = Buffer()
    agents = []
    for i in range(THREAD_NUM):
        agent = mini_source_agent.MiniSourceAgent(index=i, global_buffer=global_buffer, net=Net,
                                                  restore_model=FLAGS.restore_model, rl_training=FLAGS.training,
                                                  strategy_agent=None, ob_space_add=FLAGS.ob_space_add)
        agents.append(agent)

    print("Worker %d: waiting for cluster connection..." % index)
    sess.run(tf.report_uninitialized_variables())
    print("Worker %d: cluster ready!" % index)

    while len(sess.run(tf.report_uninitialized_variables())):
        print("Worker %d: waiting for variable initialization..." % index)
        time.sleep(1)
    print("Worker %d: variables initialized" % index)

    game_num = np.ceil(update_game_num // THREAD_NUM)

    UPDATE_EVENT.clear()
    ROLLING_EVENT.set()

    # Run threads
    threads = []
    for i in range(THREAD_NUM - 1):
        t = threading.Thread(target=run_thread, args=(agents[i], game_num, Synchronizer, FLAGS.difficulty))
        threads.append(t)
        t.daemon = True
        t.start()
        time.sleep(3)

    run_thread(agents[-1], game_num, Synchronizer, FLAGS.difficulty)

    for t in threads:
        t.join()
示例#13
0
class MiniSourceAgent(base_agent.BaseAgent):
    """Agent for source game of starcraft."""

    def __init__(self, index=0, rl_training=False, restore_model=False, global_buffer=None, net=None, strategy_agent=None, greedy_action=False):
        super(MiniSourceAgent, self).__init__()
        self.net = net
        self.index = index
        self.global_buffer = global_buffer
        self.restore_model = restore_model

        # model in brain
        self.strategy_agent = strategy_agent
        self.strategy_act = None

        # count num
        self.step = 0

        self.strategy_wait_secs = 2
        self.strategy_flag = False
        self.policy_wait_secs = 2
        self.policy_flag = True

        self.env = None
        self.obs = None

        # buffer
        self.local_buffer = Buffer()

        self.num_players = 2
        self.on_select = None
        self._result = None
        self._gases = None
        self.is_end = False

        self.greedy_action = greedy_action
        self.rl_training = rl_training

    def reset(self):
        super(MiniSourceAgent, self).reset()
        self.step = 0
        self.obs = None
        self._result = None
        self._gases = None
        self.is_end = False

        self.strategy_flag = False
        self.policy_flag = True

        self.local_buffer.reset()

        if self.strategy_agent is not None:
            self.strategy_agent.reset()

    def set_env(self, env):
        self.env = env

    def init_network(self):
        self.net.initialize()
        if self.restore_model:
            self.net.restore_policy()

    def reset_old_network(self):
        self.net.reset_old_network()

    def save_model(self):
        self.net.save_policy()

    def update_network(self, result_list):
        self.net.Update_policy(self.global_buffer)
        self.net.Update_result(result_list)

    def update_summary(self, counter):
        return self.net.Update_summary(counter)

    def mini_step(self, action):
        if action == ProtossAction.Build_probe.value:
            M.mineral_worker(self)

        elif action == ProtossAction.Build_zealot.value:
            M.train_army(self, C._TRAIN_ZEALOT)

        elif action == ProtossAction.Build_Stalker.value:
            M.train_army(self, C._TRAIN_STALKER)

        elif action == ProtossAction.Build_pylon.value:
            no_unit_index = U.get_unit_mask_screen(self.obs, size=2)
            pos = U.get_pos(no_unit_index)
            M.build_by_idle_worker(self, C._BUILD_PYLON_S, pos)

        elif action == ProtossAction.Build_gateway.value:
            power_index = U.get_power_mask_screen(self.obs, size=5)
            pos = U.get_pos(power_index)
            M.build_by_idle_worker(self, C._BUILD_GATEWAY_S, pos)

        elif action == ProtossAction.Build_Assimilator.value:
            if self._gases is not None:
                #U.find_gas_pos(self.obs, 1)
                gas_1 = self._gases[0]
                gas_2 = self._gases[1]

                if gas_1 is not None and not U.is_assimilator_on_gas(self.obs, gas_1):
                    gas_1_pos = T.world_to_screen_pos(self.env.game_info, gas_1.pos, self.obs)
                    M.build_by_idle_worker(self, C._BUILD_ASSIMILATOR_S, gas_1_pos)

                elif gas_2 is not None and not U.is_assimilator_on_gas(self.obs, gas_2):
                    gas_2_pos = T.world_to_screen_pos(self.env.game_info, gas_2.pos, self.obs)
                    M.build_by_idle_worker(self, C._BUILD_ASSIMILATOR_S, gas_2_pos)

        elif action == ProtossAction.Build_CyberneticsCore.value:
            power_index = U.get_power_mask_screen(self.obs, size=3)
            pos = U.get_pos(power_index)
            M.build_by_idle_worker(self, C._BUILD_CYBER_S, pos)

        elif action == ProtossAction.Attack.value:
            M.attack_step(self)

        elif action == ProtossAction.Retreat.value:
            M.retreat_step(self)

        elif action == ProtossAction.Do_nothing.value:
            self.safe_action(C._NO_OP, 0, [])

    def get_the_input(self):
        high_input, tech_cost, pop_num = U.get_input(self.obs)
        controller_input = np.concatenate([high_input, tech_cost, pop_num], axis=0)
        return controller_input

    def mapping_source_to_mini_by_rule(self, source_state):
        simple_input = np.zeros([20])
        simple_input[0] = 0  # self.time_seconds
        simple_input[1] = source_state[28]  # self.mineral_worker_nums
        simple_input[2] = source_state[30] + source_state[32]  # self.gas_worker_nums
        simple_input[3] = source_state[2]  # self.mineral
        simple_input[4] = source_state[3]  # self.gas
        simple_input[5] = source_state[6]  # self.food_cup
        simple_input[6] = source_state[7]  # self.food_used
        simple_input[7] = source_state[10]  # self.army_nums

        simple_input[8] = source_state[16]  # self.gateway_num
        simple_input[9] = source_state[14]  # self.pylon_num
        simple_input[10] = source_state[15]  # self.Assimilator_num
        simple_input[11] = source_state[17]  # self.CyberneticsCore_num

        simple_input[12] = source_state[12]  # self.zealot_num
        simple_input[13] = source_state[13]  # self.Stalker_num
        simple_input[14] = source_state[11]  # self.probe_num

        simple_input[15] = source_state[4] + source_state[2]  # self.collected_mineral
        simple_input[16] = source_state[4]  # self.spent_mineral
        simple_input[17] = source_state[5] + source_state[3]  # self.collected_gas
        simple_input[18] = source_state[5]  # self.spent_gas
        simple_input[19] = 1  # self.Nexus_num

        return simple_input

    def play(self, verbose=False):
        self.play_train_mini(verbose=verbose)

    def play_train_mini(self, verbose=False):
        is_attack = False
        state_last = None

        self.safe_action(C._NO_OP, 0, [])
        self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
        self._gases = U.find_initial_gases(self.obs)
        while True:

            self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
            if self.policy_flag and (not self.is_end):

                state_now = self.mapping_source_to_mini_by_rule(self.get_the_input())
                if self.greedy_action:
                    action_prob, v_preds = self.net.policy.get_action_probs(state_now, verbose=False)
                    action = np.argmax(action_prob)
                else:
                    action, v_preds = self.net.policy.get_action(state_now, verbose=False)

                # print(ProtossAction(action).name)
                self.mini_step(action)

                if state_last is not None:
                    if 0:
                        print('state_last:', state_last, ', action_last:', action_last, ', state_now:', state_now)
                    v_preds_next = self.net.policy.get_values(state_now)
                    v_preds_next = self.get_values(v_preds_next)
                    reward = 0
                    self.local_buffer.append(state_last, action_last, state_now, reward, v_preds, v_preds_next)

                # continuous attack, consistent with mind-game
                if action == ProtossAction.Attack.value:
                    is_attack = True
                if is_attack:
                    self.mini_step(ProtossAction.Attack.value)

                state_last = state_now
                action_last = action
                self.policy_flag = False

            if self.is_end:
                if self.rl_training:
                    self.local_buffer.rewards[-1] += 1 * self.result['reward']  # self.result['win']
                    print(self.local_buffer.rewards)
                    self.global_buffer.add(self.local_buffer)
                    print("add %d buffer!" % (len(self.local_buffer.rewards)))
                break

    def set_flag(self):
        if self.step % C.time_wait(self.strategy_wait_secs) == 1:
            self.strategy_flag = True

        if self.step % C.time_wait(self.policy_wait_secs) == 1:
            self.policy_flag = True

    def safe_action(self, action, unit_type, args):
        if M.check_params(self, action, unit_type, args, 1):
            obs = self.env.step([sc2_actions.FunctionCall(action, args)])[0]
            self.obs = obs
            self.step += 1
            self.update_result()
            self.set_flag()

    def select(self, action, unit_type, args):
        # safe select
        if M.check_params(self, action, unit_type, args, 0):
            self.obs = self.env.step([sc2_actions.FunctionCall(action, args)])[0]
            self.on_select = unit_type
            self.update_result()
            self.step += 1
            self.set_flag()

    @property
    def result(self):
        return self._result

    def update_result(self):
        if self.obs is None:
            return
        if self.obs.last() or self.env.state == environment.StepType.LAST:
            self.is_end = True
            outcome = 0
            o = self.obs.raw_observation
            player_id = o.observation.player_common.player_id
            for r in o.player_result:
                if r.player_id == player_id:
                    outcome = sc2_env._possible_results.get(r.result, 0)
            frames = o.observation.game_loop
            result = {}
            result['outcome'] = outcome
            result['reward'] = self.obs.reward
            result['frames'] = frames

            self._result = result
            print('play end, total return', self.obs.reward)
            self.step = 0

    def get_values(self, values):
        # check if the game is end
        if self.is_end and self.result['reward'] != 0:
            return 0
        else:
            return values
示例#14
0
class Agent:
    def __init__(self,
                 agent_id=0,
                 global_buffer=None,
                 net=None,
                 restore_model=False):
        self.env = None
        self.mpc = None
        self.net = net

        self.agent_id = agent_id
        self.player_id = 0

        self.global_buffer = global_buffer
        self.restore_model = restore_model

        self.local_buffer = Buffer()
        self.restart_game()

    def reset(self, pos):
        self.set_pos(pos)
        self.set_army()

        self.local_buffer.reset()
        self.restart_game()

    def restart_game(self):
        self.is_end = False
        self._result = 0

        self.time_seconds = 0
        self.mineral_worker_nums = 12
        self.gas_worker_nums = 0
        self.mineral = 50
        self.gas = 0
        self.food_cap = 14
        self.food_used = 12
        self.army_nums = 0
        self.enemy_army_nums = 0
        self.building_nums = 1
        self.enemy_building_nums = 1
        self.defender_nums = 0
        self.enemy_defender_nums = 0
        self.strategy = StrategyforSC2.RUSH
        self.enemy_strategy = StrategyforSC2.ECONOMY
        self.workers_list = {}
        self.army_list = {}
        self.building_list = {}
        self.remain_buildings_hp = 0

        self.time_per_step = 9

    def obs():
        return None

    def init(self, env, player_id, pos):
        self.set_env(env)
        self.set_id(player_id)
        self.set_pos(pos)
        self.set_army()

    def init_network(self):
        self.net.initialize()
        if self.restore_model:
            self.net.restore_policy()

    def update_network(self, result_list):
        self.net.Update_policy(self.global_buffer)
        self.net.Update_result(result_list)

    def reset_old_network(self):
        self.net.reset_old_network()

    def save_model(self):
        self.net.save_policy()

    def update_summary(self, counter):
        return self.net.Update_summary(counter)

    def set_buffer(self, global_buffer):
        self.global_buffer = global_buffer

    def set_env(self, env):
        self.env = env

    def set_net(self, net):
        self.net = net

    def set_mpc(self, mpc):
        self.mpc = mpc

    def set_id(self, player_id):
        self.player_id = player_id

    def set_pos(self, pos):
        self.pos = pos

    def set_army(self):
        army = Army(self.player_id)
        army.pos = self.pos
        self.env.army[self.player_id] = army

    def add_unit(self, unit, num=1, u_type='army'):
        unit_type = type(unit)
        if u_type == 'worker':
            if unit_type in self.workers_list:
                self.workers_list[unit_type] += num
            else:
                self.workers_list[unit_type] = num
        elif u_type == 'army':
            if unit_type in self.army_list:
                self.army_list[unit_type] += num
            else:
                self.army_list[unit_type] = num

    def add_building(self, building, num=1):
        building_type = type(building)
        hp = building.hp
        self.remain_buildings_hp += hp * num
        if building_type in self.building_list.keys():
            self.building_list[building_type] += num
        else:
            self.building_list[building_type] = num

    def building_hp(self):
        return self.remain_buildings_hp

    def under_attack(self, attack_hp):
        self.remain_buildings_hp -= attack_hp

    def military_force(self):
        return self.army_list

    def military_num(self):
        return sum(self.army_list.values())

    def reset_military(self, remain_hp):
        all_hp = 0
        remain_creatures_list = {}
        for key, value in self.army_list.items():
            unit_type = key
            number = value
            unit = unit_type()
            count = 0
            for _ in range(number):
                all_hp += unit.hp
                if all_hp <= remain_hp:
                    count += 1
                else:
                    break
            remain_creatures_list[unit_type] = count
            if all_hp >= remain_hp:
                break
        # print(remain_hp)
        # print(self.army_list)
        # print(remain_creatures_list)
        self.army_list = remain_creatures_list
示例#15
0
class MultiAgent(base_agent.BaseAgent):
    """My first agent for starcraft."""
    def __init__(self,
                 index=0,
                 rl_training=False,
                 restore_model=False,
                 restore_internal_model=False,
                 global_buffer=None,
                 net=None,
                 use_mcts=False,
                 num_reads=0,
                 policy_in_mcts=None,
                 dynamic_net=None,
                 use_dyna=False,
                 dyna_steps_fisrt=0,
                 dyna_decrese_counter=0):
        super(MultiAgent, self).__init__()
        self.net = net
        self.index = index
        self.global_buffer = global_buffer
        self.restore_model = restore_model

        self.restore_dynamic = restore_internal_model

        # count num
        self.step = 0

        self.policy_wait_secs = 2
        self.policy_flag = True

        self.env = None
        self.obs = None

        # buffer
        self.local_buffer = Buffer()

        self.num_players = 2
        self.on_select = None
        self._result = None
        self.is_end = False

        self.rl_training = rl_training

        self.reward_type = 0

        # mcts about
        self.use_mcts = use_mcts
        self.num_reads = num_reads
        self.policy_in_mcts = policy_in_mcts
        self.dynamic_net = dynamic_net

        # dyna about
        self.use_dyna = use_dyna
        self.dyna_steps_fisrt = dyna_steps_fisrt
        self.dyna_decrese_counter = dyna_decrese_counter
        self.dyna_steps = dyna_steps_fisrt

    def reset(self):
        super(MultiAgent, self).reset()
        self.step = 0
        self.obs = None
        self._result = None
        self.is_end = False

        self.policy_flag = True

        self.local_buffer.reset()

    def set_env(self, env):
        self.env = env

    def init_network(self):
        self.net.initialize()
        if self.restore_model:
            self.net.restore_policy()
        if self.restore_dynamic:
            # print('self.net.restore_dynamic()')
            self.net.restore_dynamic("")
            # self.dynamic_net.restore_sl_model("")

    def reset_old_network(self):
        self.net.reset_old_network()

    def save_model(self):
        self.net.save_policy()

    def update_network(self, result_list):
        self.net.Update_policy(self.global_buffer)
        # self.net.Update_internal_model(self.global_buffer)
        self.net.Update_result(result_list)
        # self.update_policy_in_mcts()

    def update_policy_in_mcts(self):
        values = self.global_buffer.values
        values_array = np.array(values).astype(dtype=np.float32).reshape(-1)
        print('values_array:', values_array)
        min_v = np.min(values_array)
        print('min_v:', min_v)
        max_v = np.max(values_array)
        print('max_v:', max_v)
        self.policy_in_mcts.update_min_max_v(min_v, max_v)

        mean_v = np.mean(values_array)
        print('mean_v:', mean_v)
        std_v = np.std(values_array)
        print('std_v:', std_v)
        self.policy_in_mcts.update_mean_std_v(mean_v, std_v)

    def update_summary(self, counter):
        self.net.Update_summary(counter)
        #self.global_update_count = counter
        # every some global_update_count dyna_step-1
        # if self.use_dyna:
        #    self.dyna_steps = 5 - self.global_update_count // 20
        #logging("global_update_count: %d, dyna_steps: %d" % (self.global_update_count, self.dyna_steps))

    def get_policy_input(self, obs):
        high_input, tech_cost, pop_num = U.get_input(obs)
        policy_input = np.concatenate([high_input, tech_cost, pop_num], axis=0)
        return policy_input

    def tech_step(self, tech_action):
        if tech_action == 0:  # nothing
            self.safe_action(C._NO_OP, 0, [])
        elif tech_action == 1:  # worker
            M.mineral_worker(self)
        elif tech_action == 2:  # pylon
            no_unit_index = U.get_unit_mask_screen(self.obs, size=2)
            pos = U.get_pos(no_unit_index)
            M.build_by_idle_worker(self, C._BUILD_PYLON_S, pos)

    def get_simple_state(self, obs):
        simple_state = U.get_simple_state(obs)
        return simple_state

    def set_dyna_steps(self):
        global_steps = self.net.get_global_steps()
        # every some global_update_count dyna_step-1
        self.dyna_steps = max(
            self.dyna_steps_fisrt - global_steps // self.dyna_decrese_counter,
            0)
        logging("global_update_count: %d, dyna_steps: %d" %
                (global_steps, self.dyna_steps))

    def play(self, verbose=False):
        M.set_source(self)

        if self.use_dyna:
            self.set_dyna_steps()

        tech_act, v_preds = np.zeros(2)
        last_obs, state_last = None, None
        action_last, state_now = None, None
        step = 0

        while True:
            self.safe_action(C._NO_OP, 0, [])

            # only one second do one thing
            if self.policy_flag:
                now_obs = self.obs
                state_now = self.get_simple_state(now_obs)

                # (s_last, action) -> s_now,
                if last_obs:
                    #rule_state_diff = self.predict_state_diff_by_rule(state_last, action_last)
                    #print('state_last:', state_last, ', action_last:', action_last)
                    #print('rule_state_diff:', rule_state_diff, 'state_diff:', state_now - state_last)
                    if verbose:
                        print('state_last:', state_last, ', action_last:',
                              action_last, ', state_now:', state_now)
                    # add data to buffer
                    reward = self.get_mineral_reward(last_obs, now_obs)
                    if self.reward_type == 0:
                        reward = 0
                    if verbose:
                        print("reward: ", reward)
                    v_preds_next = self.net.policy.get_values(state_now)
                    v_preds_next = self.get_values(v_preds_next)
                    self.local_buffer.append(state_last, action_last,
                                             state_now, reward, v_preds,
                                             v_preds_next)

                # predict action
                tech_act, v_preds = self.net.policy.get_action(state_now,
                                                               verbose=False)

                # print mcts choose action
                if self.use_mcts:
                    game_state = GameState(dynamic_net=self.dynamic_net,
                                           state=state_now)
                    mcts_act = UCT_search(game_state=game_state,
                                          num_reads=self.num_reads,
                                          policy_in_mcts=self.policy_in_mcts)
                    if 1:
                        #print('state_now:', state_now)
                        print('mcts_act: ', mcts_act)
                        print('\n')
                    tech_act = mcts_act[0]

                # use dyna to add predicted trace
                if self.use_dyna:
                    self.simulated(state_now, tech_act, v_preds,
                                   self.dyna_steps)

                    # do action
                self.tech_step(tech_act)
                # finish
                step += 1
                last_obs = now_obs
                state_last = state_now
                action_last = tech_act
                self.policy_flag = False

            if self.is_end:
                if self.rl_training:
                    if self.reward_type == 0:
                        final_mineral = now_obs.raw_observation.observation.player_common.minerals
                        self.local_buffer.rewards[-1] += final_mineral
                        print('final_mineral:', final_mineral)
                        if verbose:
                            print('final_reward:',
                                  self.local_buffer.rewards[-1])
                    self.global_buffer.add(self.local_buffer)
                break

    def simulated(self,
                  state_now,
                  action_now,
                  v_preds_now,
                  dyna_steps,
                  append_to_buffer=True):
        game_state = GameState(dynamic_net=self.dynamic_net, state=state_now)
        sim_buffer = Buffer()
        for _ in range(dyna_steps):
            # simulate next state
            next_game_state = game_state.play(action_now, verbose=False)

            state_last = state_now
            action_last = action_now
            state_now = next_game_state.obs()

            v_preds_last = v_preds_now
            v_preds_now = self.net.policy.get_values(state_now)
            v_preds_now = self.get_values(v_preds_now)

            reward = state_now[1] - state_last[1]

            if append_to_buffer:
                sim_buffer.append(state_last, action_last, state_now, reward,
                                  v_preds_last, v_preds_now)

            action_now, v_preds_now = self.net.policy.get_action(state_now,
                                                                 verbose=False)
            game_state = next_game_state

        #print('sim_buffer:', sim_buffer)
        self.global_buffer.add(sim_buffer, add_return=False)

    def get_mineral_reward(self, old_obs, now_obs):
        state_last = self.get_simple_state(old_obs)
        state_now = self.get_simple_state(now_obs)

        mineral_reward = state_now[1] - state_last[1]
        return mineral_reward

    def set_flag(self):
        if self.step % C.time_wait(self.policy_wait_secs) == 1:
            self.policy_flag = True

    def safe_action(self, action, unit_type, args):
        if M.check_params(self, action, unit_type, args, 1):
            obs = self.env.step([sc2_actions.FunctionCall(action, args)])[0]
            self.obs = obs
            self.step += 1
            self.update_result()
            self.set_flag()

    def select(self, action, unit_type, args):
        # safe select
        if M.check_params(self, action, unit_type, args, 0):
            self.obs = self.env.step([sc2_actions.FunctionCall(action,
                                                               args)])[0]
            self.on_select = unit_type
            self.update_result()
            self.step += 1
            self.set_flag()

        # else:
        # print('Unavailable_actions id:', action, ' and type:', unit_type, ' and args:', args)

    @property
    def result(self):
        return self._result

    def update_result(self):
        if self.obs is None:
            return
        if self.obs.last() or self.env.state == environment.StepType.LAST:
            self.is_end = True
            outcome = 0
            o = self.obs.raw_observation
            player_id = o.observation.player_common.player_id
            for r in o.player_result:
                if r.player_id == player_id:
                    outcome = sc2_env._possible_results.get(r.result, 0)
            frames = o.observation.game_loop
            result = {}
            result['outcome'] = outcome
            result['reward'] = self.obs.reward
            result['frames'] = frames
            self._result = result
            # print('play end, total return', self.obs.reward)

    def get_values(self, values):
        # check if the game is end
        if self.is_end and self.result['reward'] != 0:
            return 0
        else:
            return values
class MiniSourceAgent(base_agent.BaseAgent):
    """Agent for source game of starcraft."""
    def __init__(self,
                 index=0,
                 rl_training=False,
                 restore_model=False,
                 global_buffer=None,
                 net=None,
                 sec_net=None,
                 strategy_agent=None,
                 greedy_action=False,
                 extract_save_dir=None):
        super(MiniSourceAgent, self).__init__()
        self.net = net
        self.sec_net = sec_net
        self.index = index
        self.global_buffer = global_buffer
        self.restore_model = restore_model

        # count num
        self.step = 0

        self.strategy_wait_secs = 2
        self.strategy_flag = False
        self.policy_wait_secs = 2
        self.policy_flag = True

        self.env = None
        self.obs = None

        # buffer
        self.local_buffer = Buffer()

        self.num_players = 2
        self.on_select = None
        self._result = None
        self._gases = None
        self.is_end = False

        self.rl_training = rl_training

        self.rnn_state = None
        self.zero_state = self.sec_net.rnn_init_state()

        self.extract_save_dir = extract_save_dir
        self.reset()

    def reset(self):
        super(MiniSourceAgent, self).reset()
        self.step = 0
        self.obs = None
        self._result = None
        self._gases = None
        self.is_end = False

        self.strategy_flag = False
        self.policy_flag = True

        self.local_buffer.reset()

        self.rnn_state = self.zero_state

    def set_env(self, env):
        self.env = env

    def init_network(self):
        self.net.initialize()
        if self.restore_model:
            self.net.restore_policy()

    def reset_old_network(self):
        self.net.reset_old_network()

    def save_model(self):
        self.net.save_policy()

    def update_policy(self):
        self.net.Update_policy(self.global_buffer)

    def update_result(self, result_list):
        self.net.update_result(result_list)

    def update_network(self, result_list):
        self.net.Update_policy(self.global_buffer)
        self.net.Update_result(result_list)

    def update_summary(self, counter):
        return self.net.Update_summary(counter)

    def mini_step(self, action):
        if action == ProtossAction.Build_probe.value:
            M.mineral_worker(self)

        elif action == ProtossAction.Build_zealot.value:
            M.train_army(self, C._TRAIN_ZEALOT)

        elif action == ProtossAction.Build_Stalker.value:
            M.train_army(self, C._TRAIN_STALKER)

        elif action == ProtossAction.Build_pylon.value:
            no_unit_index = U.get_unit_mask_screen(self.obs, size=2)
            pos = U.get_pos(no_unit_index)
            M.build_by_idle_worker(self, C._BUILD_PYLON_S, pos)

        elif action == ProtossAction.Build_gateway.value:
            power_index = U.get_power_mask_screen(self.obs, size=5)
            pos = U.get_pos(power_index)
            M.build_by_idle_worker(self, C._BUILD_GATEWAY_S, pos)

        elif action == ProtossAction.Build_Assimilator.value:
            if self._gases is not None:
                #U.find_gas_pos(self.obs, 1)
                gas_1 = self._gases[0]
                gas_2 = self._gases[1]

                if gas_1 is not None and not U.is_assimilator_on_gas(
                        self.obs, gas_1):
                    gas_1_pos = T.world_to_screen_pos(self.env.game_info,
                                                      gas_1.pos, self.obs)
                    M.build_by_idle_worker(self, C._BUILD_ASSIMILATOR_S,
                                           gas_1_pos)

                elif gas_2 is not None and not U.is_assimilator_on_gas(
                        self.obs, gas_2):
                    gas_2_pos = T.world_to_screen_pos(self.env.game_info,
                                                      gas_2.pos, self.obs)
                    M.build_by_idle_worker(self, C._BUILD_ASSIMILATOR_S,
                                           gas_2_pos)

        elif action == ProtossAction.Build_CyberneticsCore.value:
            power_index = U.get_power_mask_screen(self.obs, size=3)
            pos = U.get_pos(power_index)
            M.build_by_idle_worker(self, C._BUILD_CYBER_S, pos)

        elif action == ProtossAction.Attack.value:
            M.attack_step(self)

        elif action == ProtossAction.Retreat.value:
            M.retreat_step(self)

        elif action == ProtossAction.Do_nothing.value:
            self.safe_action(C._NO_OP, 0, [])

    def get_the_input(self):
        high_input, tech_cost, pop_num = U.get_input(self.obs)
        controller_input = np.concatenate([high_input, tech_cost, pop_num],
                                          axis=0)
        return controller_input

    def mapping_source_to_mini_by_rule(self, source_state):
        simple_input = np.zeros([20], dtype=np.int16)
        simple_input[0] = 0  # self.time_seconds
        simple_input[1] = source_state[28]  # self.mineral_worker_nums
        simple_input[2] = source_state[30] + source_state[
            32]  # self.gas_worker_nums
        simple_input[3] = source_state[2]  # self.mineral
        simple_input[4] = source_state[3]  # self.gas
        simple_input[5] = source_state[6]  # self.food_cup
        simple_input[6] = source_state[7]  # self.food_used
        simple_input[7] = source_state[10]  # self.army_nums

        simple_input[8] = source_state[16]  # self.gateway_num
        simple_input[9] = source_state[14]  # self.pylon_num
        simple_input[10] = source_state[15]  # self.Assimilator_num
        simple_input[11] = source_state[17]  # self.CyberneticsCore_num

        simple_input[12] = source_state[12]  # self.zealot_num
        simple_input[13] = source_state[13]  # self.Stalker_num
        simple_input[14] = source_state[11]  # self.probe_num

        simple_input[15] = source_state[4] + source_state[
            2]  # self.collected_mineral
        simple_input[16] = source_state[4]  # self.spent_mineral
        simple_input[17] = source_state[5] + source_state[
            3]  # self.collected_gas
        simple_input[18] = source_state[5]  # self.spent_gas
        simple_input[19] = 1  # self.Nexus_num

        return simple_input

    def play(self, verbose=False):
        self.play_train(verbose=verbose)

    def sample(self, verbose=False, use_image=True):
        is_attack = False
        state_last = None

        random_generated_int = random.randint(0, 2**31 - 1)
        filename = self.extract_save_dir + "/" + str(
            random_generated_int) + ".npz"

        recording_obs = []
        recording_img = []
        recording_action = []
        recording_reward = []

        np.random.seed(random_generated_int)
        tf.set_random_seed(random_generated_int)

        self.safe_action(C._NO_OP, 0, [])
        self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
        self._gases = U.find_initial_gases(self.obs)
        while True:

            self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
            if self.policy_flag and (not self.is_end):

                non_image_feature = self.mapping_source_to_mini_by_rule(
                    self.get_the_input())
                #print('non_image_feature.shape:', non_image_feature.shape)
                #print('non_image_feature:', non_image_feature)

                image_feature = U.get_simple_map_data(self.obs)
                #print('image_feature.shape:', image_feature.shape)
                #print('image_feature:', image_feature)

                latent_image_feature, mu, logvar = self.encode_obs(
                    image_feature)
                #print('latent_image_feature.shape:', latent_image_feature.shape)
                #print('latent_image_feature:', latent_image_feature)

                feature = np.concatenate(
                    [non_image_feature, latent_image_feature], axis=-1)
                #print('feature.shape:', feature.shape)
                #print('feature:', feature)

                #state_now = feature
                reward_last = 0
                state_now, action, v_preds = self.get_action(
                    feature, reward_last)

                # print(ProtossAction(action).name)
                self.mini_step(action)

                if state_last is not None:
                    if False:
                        print('state_last:', state_last, ', action_last:',
                              action_last, ', state_now:', state_now)
                    v_preds_next = self.net.policy.get_values(state_now)
                    v_preds_next = self.get_values(v_preds_next)
                    reward = 0

                    recording_obs.append(non_image_feature)
                    recording_img.append(image_feature)
                    recording_action.append(action)
                    recording_reward.append(reward)

                    #self.local_buffer.append(state_last, action_last, state_now, reward, v_preds, v_preds_next)

                state_last = state_now
                action_last = action
                self.policy_flag = False

            if self.is_end:
                if True:
                    # consider the win/loss, to 0(not end), 1(loss), 2(draw), 3(win)
                    recording_reward[-1] = (1 * self.result['reward'] + 2)
                    if recording_reward[-1] != 0:
                        print("result is:", recording_reward[-1])

                    recording_obs = np.array(recording_obs, dtype=np.uint16)
                    recording_action = np.array(recording_action,
                                                dtype=np.uint8)
                    recording_reward = np.array(recording_reward,
                                                dtype=np.uint8)
                    recording_img = np.array(recording_img, dtype=np.float16)

                    np.savez_compressed(filename,
                                        obs=recording_obs,
                                        img=recording_img,
                                        action=recording_action,
                                        reward=recording_reward)
                break

    def play_train(self, continues_attack=False, verbose=False):
        is_attack = False
        state_last = None

        self.safe_action(C._NO_OP, 0, [])
        self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
        self._gases = U.find_initial_gases(self.obs)

        while True:

            self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
            if self.policy_flag and (not self.is_end):

                non_image_feature = self.mapping_source_to_mini_by_rule(
                    self.get_the_input())
                #print('non_image_feature.shape:', non_image_feature.shape)
                #print('non_image_feature:', non_image_feature)

                image_feature = U.get_simple_map_data(self.obs)
                #print('image_feature.shape:', image_feature.shape)
                #print('image_feature:', image_feature)

                latent_image_feature, mu, logvar = self.encode_obs(
                    image_feature)
                #print('latent_image_feature.shape:', latent_image_feature.shape)
                #print('latent_image_feature:', latent_image_feature)

                feature = np.concatenate(
                    [non_image_feature, latent_image_feature], axis=-1)
                #print('feature.shape:', feature.shape)
                #print('feature:', feature)

                #state_now = feature
                reward_last = 0
                state_now, action, v_preds = self.get_action(
                    feature, reward_last)

                # print(ProtossAction(action).name)
                self.mini_step(action)

                if state_last is not None:
                    if 0:
                        print('state_last:', state_last, ', action_last:',
                              action_last, ', state_now:', state_now)
                    v_preds_next = self.net.policy.get_values(state_now)
                    v_preds_next = self.get_values(v_preds_next)
                    reward = 0
                    self.local_buffer.append(state_last, action_last,
                                             state_now, reward, v_preds,
                                             v_preds_next)

                state_last = state_now
                action_last = action
                self.policy_flag = False

            if self.is_end:
                if self.rl_training:
                    self.local_buffer.rewards[-1] += 1 * self.result[
                        'reward']  # self.result['win']
                    #print(self.local_buffer.rewards)
                    self.global_buffer.add(self.local_buffer)
                    #print("add %d buffer!" % (len(self.local_buffer.rewards)))
                break

    def encode_obs(self, obs):
        # convert raw obs to z, mu, logvar
        result = np.copy(obs)
        result = result.reshape(1, 64, 64, 12)
        mu, logvar = self.sec_net.vae.encode_mu_logvar(result)
        mu = mu[0]
        logvar = logvar[0]
        s = logvar.shape
        z = mu + np.exp(logvar / 2.0) * np.random.randn(*s)
        return z, mu, logvar

    def get_action(self, feature, reward):
        input_h = self.sec_net.rnn_output(self.rnn_state, feature)

        action, v_preds = self.net.policy.get_action(input_h, verbose=False)

        self.rnn_state = self.sec_net.rnn_next_state(feature, action, reward,
                                                     self.rnn_state)
        return input_h, action, v_preds

    def set_flag(self):
        if self.step % C.time_wait(self.strategy_wait_secs) == 1:
            self.strategy_flag = True

        if self.step % C.time_wait(self.policy_wait_secs) == 1:
            self.policy_flag = True

    def safe_action(self, action, unit_type, args):
        if M.check_params(self, action, unit_type, args, 1):
            obs = self.env.step([sc2_actions.FunctionCall(action, args)])[0]
            self.obs = obs
            self.step += 1
            self.update_result()
            self.set_flag()

    def select(self, action, unit_type, args):
        # safe select
        if M.check_params(self, action, unit_type, args, 0):
            self.obs = self.env.step([sc2_actions.FunctionCall(action,
                                                               args)])[0]
            self.on_select = unit_type
            self.update_result()
            self.step += 1
            self.set_flag()

    @property
    def result(self):
        return self._result

    def update_result(self):
        if self.obs is None:
            return
        if self.obs.last() or self.env.state == environment.StepType.LAST:
            self.is_end = True
            outcome = 0
            o = self.obs.raw_observation
            player_id = o.observation.player_common.player_id
            for r in o.player_result:
                if r.player_id == player_id:
                    outcome = sc2_env._possible_results.get(r.result, 0)
            frames = o.observation.game_loop
            result = {}
            result['outcome'] = outcome
            result['reward'] = self.obs.reward
            result['frames'] = frames

            self._result = result
            print('play end, total return', self.obs.reward)
            self.step = 0

    def get_values(self, values):
        # check if the game is end
        if self.is_end and self.result['reward'] != 0:
            return 0
        else:
            return values
class MiniSourceAgent(base_agent.BaseAgent):
    """Agent for source game of starcraft."""
    def __init__(self,
                 index=0,
                 rl_training=False,
                 restore_model=False,
                 global_buffer=None,
                 net=None,
                 strategy_agent=None,
                 greedy_action=False,
                 extract_save_dir=None,
                 ob_space_add=0,
                 image_debug=False,
                 action_prob_debug=False,
                 act_space_add=0):
        super(MiniSourceAgent, self).__init__()
        self.net = net
        self.index = index
        self.global_buffer = global_buffer
        self.restore_model = restore_model

        # model in brain
        self.strategy_agent = strategy_agent
        self.strategy_act = None

        # count num
        self.step = 0

        self.strategy_wait_secs = 2
        self.strategy_flag = False
        self.policy_wait_secs = 2
        self.policy_flag = True

        # every 30 seconds to show the image
        self.image_wait_secs = 1

        # every 30 seconds to show the image
        self.prob_show_wait_seconds = (np.array([0, 0.5, 1, 3, 5, 10]) *
                                       60).astype(int).tolist()

        self.env = None
        self.obs = None

        # buffer
        self.local_buffer = Buffer()

        self.num_players = 2
        self.on_select = None
        self._result = None
        self._gases = None
        self.is_end = False

        self.greedy_action = greedy_action
        self.rl_training = rl_training

        self.extract_save_dir = extract_save_dir
        self.ob_space_add = ob_space_add
        self.act_space_add = act_space_add

        self.image_debug = image_debug
        self.action_prob_debug = action_prob_debug
        self.action_num = 10 + act_space_add

    def reset(self):
        super(MiniSourceAgent, self).reset()
        self.step = 0
        self.obs = None
        self._result = None
        self._gases = None
        self.is_end = False

        self.strategy_flag = False
        self.policy_flag = True

        self.local_buffer.reset()

        if self.strategy_agent is not None:
            self.strategy_agent.reset()

    def set_env(self, env):
        self.env = env

    def init_network(self):
        self.net.initialize()
        if self.restore_model:
            self.net.restore_policy()

    def reset_old_network(self):
        self.net.reset_old_network()

    def save_model(self):
        self.net.save_policy()

    def update_policy(self):
        self.net.Update_policy(self.global_buffer)

    def update_result_list(self, result_list):
        self.net.Update_result(result_list)

    def update_network(self, result_list):
        self.net.Update_policy(self.global_buffer)
        self.net.Update_result(result_list)

    def update_summary(self, counter):
        return self.net.Update_summary(counter)

    def mini_step(self, action):
        if action == ProtossAction.Build_probe.value:
            M.mineral_worker(self)

        elif action == ProtossAction.Build_zealot.value:
            M.train_army(self, C._TRAIN_ZEALOT)

        elif action == ProtossAction.Build_Stalker.value:
            M.train_army(self, C._TRAIN_STALKER)

        elif action == ProtossAction.Build_pylon.value:
            no_unit_index = U.get_unit_mask_screen(self.obs, size=2)
            pos = U.get_pos(no_unit_index)
            M.build_by_idle_worker(self, C._BUILD_PYLON_S, pos)

        elif action == ProtossAction.Build_gateway.value:
            power_index = U.get_power_mask_screen(self.obs, size=5)
            pos = U.get_pos(power_index)
            M.build_by_idle_worker(self, C._BUILD_GATEWAY_S, pos)

        elif action == ProtossAction.Build_Assimilator.value:
            if self._gases is not None:
                #U.find_gas_pos(self.obs, 1)
                gas_1 = self._gases[0]
                gas_2 = self._gases[1]

                if gas_1 is not None and not U.is_assimilator_on_gas(
                        self.obs, gas_1):
                    gas_1_pos = T.world_to_screen_pos(self.env.game_info,
                                                      gas_1.pos, self.obs)
                    M.build_by_idle_worker(self, C._BUILD_ASSIMILATOR_S,
                                           gas_1_pos)

                elif gas_2 is not None and not U.is_assimilator_on_gas(
                        self.obs, gas_2):
                    gas_2_pos = T.world_to_screen_pos(self.env.game_info,
                                                      gas_2.pos, self.obs)
                    M.build_by_idle_worker(self, C._BUILD_ASSIMILATOR_S,
                                           gas_2_pos)

        elif action == ProtossAction.Build_CyberneticsCore.value:
            power_index = U.get_power_mask_screen(self.obs, size=3)
            pos = U.get_pos(power_index)
            M.build_by_idle_worker(self, C._BUILD_CYBER_S, pos)

        elif action == ProtossAction.Attack.value:
            M.attack_step(self)

        elif action == ProtossAction.Retreat.value:
            M.retreat_step(self)

        elif action == ProtossAction.Do_nothing.value:
            self.safe_action(C._NO_OP, 0, [])

        # added action
        elif action == ProtossAction.All.value + 0:
            M.attack_queued(self)

        elif action == ProtossAction.All.value + 1:
            M.retreat_queued(self)

        elif action == ProtossAction.All.value + 2:
            M.gas_worker_only(self)

        elif action == ProtossAction.All.value + 3:
            M.attack_main_base(self)

        elif action == ProtossAction.All.value + 4:
            M.attack_sub_base(self)

    def get_the_input(self):
        high_input, tech_cost, pop_num = U.get_input(self.obs)
        controller_input = np.concatenate([high_input, tech_cost, pop_num],
                                          axis=0)
        return controller_input

    def get_the_input_right(self, obs):
        high_input, tech_cost, pop_num = U.get_input(obs)
        controller_input = np.concatenate([high_input, tech_cost, pop_num],
                                          axis=0)
        return controller_input

    def mapping_source_to_mini_by_rule(self, source_state):
        simple_input = np.zeros([20], dtype=np.int16)
        simple_input[0] = 0  # self.time_seconds
        simple_input[1] = source_state[28]  # self.mineral_worker_nums
        simple_input[2] = source_state[30] + source_state[
            32]  # self.gas_worker_nums
        simple_input[3] = source_state[2]  # self.mineral
        simple_input[4] = source_state[3]  # self.gas
        simple_input[5] = source_state[6]  # self.food_cup
        simple_input[6] = source_state[7]  # self.food_used
        simple_input[7] = source_state[10]  # self.army_nums

        simple_input[8] = source_state[16]  # self.gateway_num
        simple_input[9] = source_state[14]  # self.pylon_num
        simple_input[10] = source_state[15]  # self.Assimilator_num
        simple_input[11] = source_state[17]  # self.CyberneticsCore_num

        simple_input[12] = source_state[12]  # self.zealot_num
        simple_input[13] = source_state[13]  # self.Stalker_num
        simple_input[14] = source_state[11]  # self.probe_num

        simple_input[15] = source_state[4] + source_state[
            2]  # self.collected_mineral
        simple_input[16] = source_state[4]  # self.spent_mineral
        simple_input[17] = source_state[5] + source_state[
            3]  # self.collected_gas
        simple_input[18] = source_state[5]  # self.spent_gas
        simple_input[19] = 1  # self.Nexus_num

        return simple_input

    def get_add_state(self, source_state):
        add_input = np.zeros([4], dtype=np.int16)
        add_input[0] = source_state[0]  # self.difficulty
        add_input[1] = source_state[1]  # self.game_loop
        add_input[2] = source_state[8]  # self.food_army
        add_input[3] = source_state[9]  # self.food_workers
        return add_input

    def play_right_add(self, verbose=False):
        # note this is a right version of game play, which also add input and action
        prev_state = None
        prev_action = None
        prev_value = None
        prev_add_state = None
        prev_map_state = None
        show_image = False

        self.safe_action(C._NO_OP, 0, [])
        self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
        self._gases = U.find_initial_gases(self.obs)

        self.dummy_add_state = np.zeros(1)
        self.dummy_map_state = np.zeros([1, 1, 1])

        simulate_seconds = 0
        feature_dict = U.edge_state()
        previous_match = -1

        while True:

            self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
            if self.policy_flag and (not self.is_end):
                # get the state
                state = self.mapping_source_to_mini_by_rule(
                    self.get_the_input_right(self.obs))

                if self.image_debug and verbose and self.step % C.time_wait(
                        self.image_wait_secs) == 0:
                    show_image = True
                else:
                    show_image = False

                if verbose:
                    print('show_image:', show_image)
                map_state = U.get_small_simple_map_data(
                    self.obs, show_image, show_image)

                if verbose:
                    print('map_state.shape:', map_state.shape)
                add_state = self.get_add_state(
                    self.get_the_input_right(self.obs))

                # get the action and value accoding to state
                #print("add_state:", add_state)
                if self.ob_space_add == 0:
                    add_state = self.dummy_add_state
                    map_state = self.dummy_map_state = np.zeros([1, 1, 1])
                action, action_probs, value = self.net.policy.get_act_action_probs(
                    state, add_state, map_state, verbose=verbose)

                # if this is not the fisrt state, store things to buffer
                if prev_state is not None:
                    # try reward = self.obs.reward
                    reward = self.obs.reward
                    if verbose:
                        print(prev_state, prev_add_state, prev_action, state,
                              reward, prev_value, value)
                    self.local_buffer.append_more_more(
                        prev_state, prev_add_state, prev_map_state,
                        prev_action, state, reward, prev_value, value)

                self.mini_step(action)
                simulate_seconds += self.policy_wait_secs
                # the evn step to new states

                prev_state = state
                prev_action = action
                prev_value = value
                prev_add_state = add_state
                prev_map_state = map_state

                self.policy_flag = False

            if self.is_end:
                # get the last state and reward
                # get the state
                state = self.mapping_source_to_mini_by_rule(
                    self.get_the_input_right(self.obs))
                map_state = U.get_small_simple_map_data(self.obs)
                add_state = self.get_add_state(
                    self.get_the_input_right(self.obs))

                if self.ob_space_add == 0:
                    add_state = self.dummy_add_state
                    map_state = self.dummy_map_state = np.zeros([1, 1, 1])

                value = self.net.policy.get_values(state, add_state, map_state)
                # the value of the last state is defined somewhat different
                value = self.get_values_right(value)

                # if this is not the fisrt state, store things to buffer
                if prev_state is not None:
                    reward = self.obs.reward
                    if verbose:
                        print(prev_state, prev_add_state, prev_action, state,
                              reward, prev_value, value)
                    self.local_buffer.append_more_more(
                        prev_state, prev_add_state, prev_map_state,
                        prev_action, state, reward, prev_value, value)
                break

        if self.rl_training:
            #print(self.local_buffer.values)
            #print(self.local_buffer.values_next)
            #print(self.local_buffer.rewards)
            self.global_buffer.add(self.local_buffer)
            print("add map bn:")
            print("add %d buffer!" % (len(self.local_buffer.rewards)))
            #print("returns:", self.global_buffer.returns)
            #print("gaes:", self.global_buffer.gaes)

    def set_flag(self):
        if self.step % C.time_wait(self.strategy_wait_secs) == 1:
            self.strategy_flag = True

        if self.step % C.time_wait(self.policy_wait_secs) == 1:
            self.policy_flag = True

    def safe_action(self, action, unit_type, args):
        if M.check_params(self, action, unit_type, args, 1):
            obs = self.env.step([sc2_actions.FunctionCall(action, args)])[0]
            self.obs = obs
            self.step += 1
            self.update_result()
            self.set_flag()

    def select(self, action, unit_type, args):
        # safe select
        if M.check_params(self, action, unit_type, args, 0):
            self.obs = self.env.step([sc2_actions.FunctionCall(action,
                                                               args)])[0]
            self.on_select = unit_type
            self.update_result()
            self.step += 1
            self.set_flag()

    @property
    def result(self):
        return self._result

    def update_result(self):
        if self.obs is None:
            return
        if self.obs.last() or self.env.state == environment.StepType.LAST:
            self.is_end = True
            outcome = 0
            o = self.obs.raw_observation
            player_id = o.observation.player_common.player_id
            for r in o.player_result:
                if r.player_id == player_id:
                    outcome = sc2_env._possible_results.get(r.result, 0)
            frames = o.observation.game_loop
            result = {}
            result['outcome'] = outcome
            result['reward'] = self.obs.reward
            result['frames'] = frames

            self._result = result
            print('play end, total return', self.obs.reward)
            self.step = 0

    def get_values(self, values):
        # check if the game is end
        if self.is_end and self.result['reward'] != 0:
            return 0
        else:
            return values

    def get_values_right(self, values):
        # if the game ends with a win or loss (the result reward is 1 or -1), the value is set to 0
        # else if the game ends without a result (the result reward is 1 or -1), the value is set to asbefore
        if self.is_end and self.result['reward'] != 0:
            return 0
        else:
            return values
示例#18
0
class MiniSourceAgent(base_agent.BaseAgent):
    """Agent for source game of starcraft."""
    def __init__(self,
                 index=0,
                 rl_training=False,
                 restore_model=False,
                 global_buffer=None,
                 net=None,
                 strategy_agent=None,
                 greedy_action=False):
        super(MiniSourceAgent, self).__init__()
        self.net = net
        self.index = index
        self.global_buffer = global_buffer
        self.restore_model = restore_model

        # model in brain
        self.strategy_agent = strategy_agent
        self.strategy_act = None

        # count num
        self.step = 0

        self.strategy_wait_secs = 3
        self.strategy_flag = False
        self.policy_wait_secs = 2
        self.policy_flag = True

        self.env = None
        self.obs = None

        # buffer
        self.local_buffer = Buffer()

        self.num_players = 2
        self.on_select = None
        self._result = None
        self._gases = None
        self.is_end = False

        self.greedy_action = greedy_action
        self.rl_training = rl_training

        self.reset_tc()

    def reset(self):
        super(MiniSourceAgent, self).reset()
        self.step = 0
        self.obs = None
        self._result = None
        self._gases = None
        self.is_end = False

        self.strategy_flag = False
        self.policy_flag = True

        self.local_buffer.reset()

        if self.strategy_agent is not None:
            self.strategy_agent.reset()

        self.reset_tc()

    def reset_tc(self):
        self.num_pylon = 0
        self.num_gateway = 0
        self.num_cyber = 0

        self.max_pylon = 20
        self.max_gateway = 5
        self.max_cyber = 3

        self.enemy_pos = None
        self.retreat_pos = None
        self.rally_pos = [455, 165]

        self.base = None
        self.resourceUnits = []
        self.vespeneUnits = []

        self.initial_scout = False
        self.initial = False

    def set_env(self, env):
        self.env = env

    def set_obs(self, state):
        self.obs = state

    def init_network(self):
        self.net.initialize()
        if self.restore_model:
            self.net.restore_policy()

    def reset_old_network(self):
        self.net.reset_old_network()

    def save_model(self):
        self.net.save_policy()

    def update_network(self, result_list):
        self.net.Update_policy(self.global_buffer)
        self.net.Update_result(result_list)

    def update_summary(self, counter):
        return self.net.Update_summary(counter)

    def mini_step(self, action):
        if action == ProtossAction.Build_probe.value:
            M.train_unit(self, T.Protoss_Nexus, T.Protoss_Probe)

        elif action == ProtossAction.Build_zealot.value:
            M.train_unit(self, T.Protoss_Gateway, T.Protoss_Zealot)

        elif action == ProtossAction.Build_Stalker.value:
            M.train_unit(self, T.Protoss_Gateway, T.Protoss_Dragoon)

        elif action == ProtossAction.Build_pylon.value:
            if not self.initial_scout:
                M.scout_manager(self)
                self.initial_scout = True

            pos = self.placer_manager(self.base, T.Protoss_Pylon)
            M.build_by_worker(self, T.Protoss_Probe, T.Protoss_Pylon, pos)

        elif action == ProtossAction.Build_gateway.value:
            pos = self.placer_manager(self.base, T.Protoss_Gateway)
            M.build_by_worker(self, T.Protoss_Probe, T.Protoss_Gateway, pos)

        elif action == ProtossAction.Build_Assimilator.value:
            pos = self.placer_manager(self.base, T.Protoss_Assimilator)
            M.build_by_worker(self, T.Protoss_Probe, T.Protoss_Assimilator,
                              pos)

        elif action == ProtossAction.Build_CyberneticsCore.value:
            pos = self.placer_manager(self.base, T.Protoss_Cybernetics_Core)
            M.build_by_worker(self, T.Protoss_Probe,
                              T.Protoss_Cybernetics_Core, pos)

        elif action == ProtossAction.Attack.value:
            M.attack_step(self, [T.Protoss_Zealot, T.Protoss_Dragoon],
                          self.enemy_pos)

        elif action == ProtossAction.Retreat.value:
            M.retreat_step(self, [T.Protoss_Zealot, T.Protoss_Dragoon],
                           self.retreat_pos)
            #pass

        elif action == ProtossAction.Do_nothing.value:
            M.no_op(self)

    def calculate_features(self):
        state = self.obs
        myunits = state.units[state.player_id]

        self.mineral_worker_nums = 0
        self.gas_worker_nums = 0

        self.spent_mineral = 0
        self.spent_gas = 0

        self.probe_num = 0
        self.zealot_num = 0
        self.Stalker_num = 0
        self.army_nums = 0

        self.gateway_num = 0
        self.pylon_num = 0
        self.Assimilator_num = 0
        self.CyberneticsCore_num = 0

        for unit in myunits:
            if unit.type == T.Protoss_Probe:
                self.spent_mineral += P.Probe().mineral_price
                if unit.completed:
                    self.probe_num += 1
                    if unit.gathering_minerals:
                        self.mineral_worker_nums += 1
                    if unit.gathering_gas:
                        self.gas_worker_nums += 1
            if unit.type == T.Protoss_Zealot:
                self.spent_mineral += P.Zealot().mineral_price
                if unit.completed:
                    if unit.visible:
                        self.zealot_num += 1
                        self.army_nums += 1
                    else:
                        print('find invisible Zealot')
            if unit.type == T.Protoss_Dragoon:
                self.spent_mineral += P.Stalker().mineral_price
                self.spent_gas += P.Stalker().gas_price
                if unit.completed:
                    if unit.visible:
                        self.Stalker_num += 1
                        self.army_nums += 1
                    else:
                        print('find invisible Dragoon')
            if unit.type == T.Protoss_Pylon:
                self.spent_mineral += P.Pylon().mineral_price
                if unit.completed:
                    self.pylon_num += 1
            if unit.type == T.Protoss_Gateway:
                self.spent_mineral += P.Gateway().mineral_price
                if unit.completed:
                    self.gateway_num += 1
            if unit.type == T.Protoss_Assimilator:
                self.spent_mineral += P.Assimilator().mineral_price
                if unit.completed:
                    self.Assimilator_num += 1
            if unit.type == T.Protoss_Cybernetics_Core:
                self.spent_mineral += P.CyberneticsCore().mineral_price
                if unit.completed:
                    self.CyberneticsCore_num += 1

        self.mineral = state.frame.resources[state.player_id].ore
        self.gas = state.frame.resources[state.player_id].gas
        self.food_used = state.frame.resources[state.player_id].used_psi
        self.food_cup = state.frame.resources[state.player_id].total_psi

    def mapping_source_to_mini_by_rule(self, state):
        simple_input = np.zeros([20])
        simple_input[0] = 0  # self.time_seconds
        simple_input[1] = self.mineral_worker_nums  # self.mineral_worker_nums
        simple_input[2] = self.gas_worker_nums  # self.gas_worker_nums
        simple_input[3] = self.mineral  # self.mineral
        simple_input[4] = self.gas  # self.gas
        simple_input[5] = self.food_cup  # self.food_cup
        simple_input[6] = self.food_used  # self.food_used
        simple_input[7] = self.army_nums  # self.army_nums

        simple_input[8] = self.gateway_num  # self.gateway_num
        simple_input[9] = self.pylon_num  # self.pylon_num
        simple_input[10] = self.Assimilator_num  # self.Assimilator_num
        simple_input[11] = self.CyberneticsCore_num  # self.CyberneticsCore_num

        simple_input[12] = self.zealot_num  # self.zealot_num
        simple_input[13] = self.Stalker_num  # self.Stalker_num
        simple_input[14] = self.probe_num  # self.probe_num

        simple_input[
            15] = self.mineral + self.spent_mineral  # self.collected_mineral
        simple_input[16] = self.spent_mineral  # self.spent_mineral
        simple_input[17] = self.gas + self.spent_gas  # self.collected_gas
        simple_input[18] = self.spent_gas  # self.spent_gas
        simple_input[19] = 1  # self.Nexus_num

        return simple_input

    def play(self, verbose=False):
        self.play_train_mini(verbose=verbose)

    def play_train_mini(self, verbose=False):
        is_attack = False
        state_last = None

        while not self.obs.game_ended:
            #print('self.step:', self.step)
            #print('self.frame_from_bwapi:', self.obs.frame_from_bwapi)

            if self.obs.game_ended:
                break

            if self.step >= C.time_wait_sc1(900):
                self.env.send([[tcc.restart]])
                self.obs = self.env.recv()
                self.update_result(time_out=True)
                continue

            if not self.initial:
                self.initial_manager()
                self.initial = True

            self.safe_action([])

            if self.policy_flag and (not self.is_end):
                self.calculate_features()
                state_now = self.mapping_source_to_mini_by_rule(self.obs)
                if self.greedy_action:
                    action_prob, v_preds = self.net.policy.get_action_probs(
                        state_now, verbose=False)
                    action = np.argmax(action_prob)
                else:
                    action, v_preds = self.net.policy.get_action(state_now,
                                                                 verbose=False)

                #print(ProtossAction(action).name)
                self.mini_step(action)
                if self.is_end:
                    #self.env.send([[tcc.restart]])
                    #self.obs = self.env.recv()
                    break

                if state_last is not None:
                    if 0:
                        print('state_last:', state_last, ', action_last:',
                              action_last, ', state_now:', state_now)
                    v_preds_next = self.net.policy.get_values(state_now)
                    v_preds_next = self.get_values(v_preds_next)
                    reward = 0
                    self.local_buffer.append(state_last, action_last,
                                             state_now, reward, v_preds,
                                             v_preds_next)

                # continuous attack, consistent with mind-game
                if action == ProtossAction.Attack.value:
                    is_attack = True

                if is_attack:
                    self.mini_step(ProtossAction.Attack.value)
                    if self.is_end:
                        break

                state_last = state_now
                action_last = action
                self.policy_flag = False

            if self.strategy_flag and (not self.is_end):
                #print('self.strategy_flag:', self.strategy_flag)
                M.worker_manager(self)
                self.strategy_flag = False

        if self.rl_training:
            self.local_buffer.rewards[-1] += 1 * self.result[
                'reward']  # self.result['win']
            #print(self.local_buffer.rewards)
            self.global_buffer.add(self.local_buffer)
            #print("add %d buffer!" % (len(self.local_buffer.rewards)))

    def set_flag(self):
        if self.step % C.time_wait_sc1(self.strategy_wait_secs) == 1:
            self.strategy_flag = True

        if self.step % C.time_wait_sc1(self.policy_wait_secs) == 1:
            self.policy_flag = True

    def safe_action(self, actions):
        if len(actions) > 0:
            pass
            #print("Sending actions: " + str(actions))
        self.env.send(actions)
        self.obs = self.env.recv()
        self.step += 1
        self.update_result()
        self.set_flag()

    @property
    def result(self):
        return self._result

    def judge_if_we_win_or_draw(self):
        enemy_untis = None
        for i in self.obs.units:
            if i != self.obs.player_id and i != self.obs.neutral_id:
                enemy_untis = self.obs.units[i]

        if enemy_untis is None:
            return True
        else:
            if len(enemy_untis) <= 4:
                army = M.selectArmy(self,
                                    [T.Protoss_Zealot, T.Protoss_Dragoon])
                if army is not None:
                    if len(army) >= 6:
                        return True
                    else:
                        print('len(army):', len(army))
        print('len(enemy_untis):', len(enemy_untis))
        return False

    def update_result(self, time_out=False):
        if self.obs is None:
            return
        if self.obs.waiting_for_restart:
            print("WAITING FOR RESTART...")
        if self.obs.game_ended:
            self.is_end = True
            frames = self.obs.frame_from_bwapi
            outcome = self.obs.game_won
            reward = self.obs.game_won

            if time_out:
                if self.judge_if_we_win_or_draw():
                    reward = 1
                    outcome = 1
                else:
                    reward = 0
                    outcome = 0
            elif not self.obs.game_won:
                reward = -1
                outcome = -1

            result = {}
            result['outcome'] = outcome
            result['reward'] = reward
            result['frames'] = frames

            self._result = result
            #print('play end, total result', self._result)
            self.step = 0

    def get_values(self, values):
        # check if the game is end
        if self.is_end and self.result['reward'] != 0:
            return 0
        else:
            return values

    def placer_manager(self, base, unit_type):
        pylon_size = 8
        gateway_size = 16
        cyber_size = 16

        if unit_type == T.Protoss_Pylon:
            initial_polyon_x = base.x - 16
            initial_polyon_y = base.y + 8
            #print(initial_polyon_x, initial_polyon_y)
            colum_index = 0
            row_index = self.num_pylon
            if self.num_pylon > 0.5 * self.max_pylon:
                row_index = self.num_pylon - 0.5 * self.max_pylon
                colum_index = 1

            target_x = initial_polyon_x - int(colum_index * 8)
            target_y = initial_polyon_y + int(row_index * pylon_size)
            self.num_pylon = (self.num_pylon + 1) % self.max_pylon
            return [target_x, target_y]
        elif unit_type == T.Protoss_Gateway:
            initial_gateway_x = base.x - 8
            initial_gateway_y = base.y + 10
            #print(initial_gateway_x, initial_gateway_y)
            target_x = initial_gateway_x
            target_y = initial_gateway_y + self.num_gateway * gateway_size
            self.num_gateway = (self.num_gateway + 1) % self.max_gateway
            return [target_x, target_y]
        elif unit_type == T.Protoss_Cybernetics_Core:
            initial_cyber_x = base.x - 36
            initial_cyber_y = base.y + 10
            #print(initial_cyber_x, initial_cyber_y)
            target_x = initial_cyber_x
            target_y = initial_cyber_y + self.num_cyber * cyber_size
            self.num_cyber = (self.num_cyber + 1) % self.max_cyber
            return [target_x, target_y]
        elif unit_type == T.Protoss_Assimilator:
            if len(self.vespeneUnits) > 0:
                vespene = self.vespeneUnits[0]
                target_x = vespene.x - 8
                target_y = vespene.y - 4
                #print(target_x, target_y)
                return [target_x, target_y]
        return [-1, -1]

    def initial_manager(self):
        self.obs = self.env.recv()
        state = self.obs

        frame_no = state.frame_from_bwapi
        #print('begin frame_no:', frame_no)

        myunits = state.units[state.player_id]

        # initial base
        for unit in myunits:
            if unit.type == T.Protoss_Nexus:
                self.base = unit
                break

        # initial mineral and gas
        neutralUnits = state.units[state.neutral_id]
        for u in neutralUnits:
            if u.type == T.Resource_Mineral_Field or u.type == T.Resource_Mineral_Field_Type_2 \
                    or u.type == T.Resource_Mineral_Field_Type_3:
                if u.visible:
                    self.resourceUnits.append(u)
            if u.type == T.Resource_Vespene_Geyser:
                if u.visible:
                    self.vespeneUnits.append(u)

        #print('resourceUnits:', len(self.resourceUnits))
        #print('vespeneUnits:', len(self.vespeneUnits))

        # initial workers
        actions = []
        for unit in myunits:
            if unit.type == T.Protoss_Probe and unit.completed:
                if unit.idle:
                    target = M.get_closest(unit.x, unit.y, self.resourceUnits)
                    actions.append([
                        tcc.command_unit,
                        unit.id,
                        tcc.unitcommandtypes.Right_Click_Unit,
                        target.id,
                    ])

        self.safe_action(actions)
示例#19
0
class Agent:
    def __init__(self,
                 agent_id=0,
                 global_buffer=None,
                 net=None,
                 restore_model=False):
        self.env = None
        self.mpc = None
        self.net = net

        self.agent_id = agent_id
        self.player_id = 0

        self.global_buffer = global_buffer
        self.restore_model = restore_model

        self.local_buffer = Buffer()
        self.restart_game()

    def __deepcopy__(self, memo):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
            # print(k)
            if k == 'net' or k == 'env' or k == 'mpc':
                #print(k, 'is not deepcopy')
                setattr(result, k, v)
                continue
            setattr(result, k, deepcopy(v, memo))
        return result

    def reset(self, pos):
        self.set_pos(pos)
        self.set_army()

        self.local_buffer.reset()
        self.restart_game()

    def restart_game(self):
        self.is_end = False
        self._result = None

        self.time_seconds = 0
        self.mineral_worker_nums = 12
        self.gas_worker_nums = 0
        self.mineral = 50
        self.gas = 0
        self.food_cup = 15
        self.food_used = 12
        self.army_nums = 0
        self.enemy_army_nums = 0
        self.building_nums = 1
        self.enemy_building_nums = 1
        self.defender_nums = 0
        self.enemy_defender_nums = 0
        self.strategy = StrategyforSC2.RUSH
        self.enemy_strategy = StrategyforSC2.ECONOMY
        self.creatures_list = {}
        self.building_list = {}
        self.remain_buildings_hp = 1500

        self.on_building_list = []
        self.production_building_list = []

        self.time_for_one_step = 9

    def obs():
        return None

    def init(self, env, player_id, pos):
        self.set_env(env)
        self.set_id(player_id)
        self.set_pos(pos)
        self.set_army()

    def init_network(self):
        self.net.initialize()
        if self.restore_model:
            self.net.restore_policy()

    def update_network(self, result_list):
        self.net.Update_policy(self.global_buffer)
        self.net.Update_result(result_list)

    def reset_old_network(self):
        self.net.reset_old_network()

    def save_model(self):
        self.net.save_policy()

    def update_summary(self, counter):
        return self.net.Update_summary(counter)

    def set_buffer(self, global_buffer):
        self.global_buffer = global_buffer

    def set_env(self, env):
        self.env = env

    def set_net(self, net):
        self.net = net

    def set_mpc(self, mpc):
        self.mpc = mpc

    def set_id(self, player_id):
        self.player_id = player_id

    def set_pos(self, pos):
        self.pos = pos

    def set_army(self):
        army = Army(self.player_id)
        army.pos = self.pos
        self.env.army[self.player_id] = army

    def init_building(self, building_obj):
        self.on_building_list.append(building_obj)

    def add_unit(self, unit_obj, num=1):
        unit_type = type(unit_obj)
        if unit_type in self.creatures_list:
            self.creatures_list[unit_type] += num
        else:
            self.creatures_list[unit_type] = num

    def add_building(self, building_obj, num=1):
        building_type = type(building_obj)

        # TODO: add shied to hp
        hp = building_type().hp
        self.remain_buildings_hp += hp * num

        if building_type in self.building_list:
            self.building_list[building_type] += num
        else:
            self.building_list[building_type] = num

        queue = getattr(building_obj, "queue", None)
        if queue is not None:
            self.production_building_list.append(building_obj)

    def building_hp(self):
        return self.remain_buildings_hp

    def under_attack(self, attack_hp):
        self.remain_buildings_hp -= attack_hp

    def military_force(self):
        return self.creatures_list

    def buildings(self):
        return self.building_list

    def building_num(self):
        count = 0
        for key, value in self.building_list.items():
            count += value
        return count

    def military_num(self):
        count = 0
        for key, value in self.creatures_list.items():
            count += value
        return count

    def reset_military(self, remain_hp):
        all_hp = 0
        remain_creatures_list = {}
        for key, value in self.creatures_list.items():
            unit_type = key
            number = value
            unit = unit_type()
            count = 0
            for _ in range(number):
                all_hp += unit.hp
                if all_hp <= remain_hp:
                    count += 1
                else:
                    break
            remain_creatures_list[unit_type] = count
            if all_hp >= remain_hp:
                break

        self.creatures_list = remain_creatures_list
示例#20
0
class SourceAgent(base_agent.BaseAgent):
    """Agent for source game of starcraft."""

    def __init__(self, index=0, rl_training=False, restore_model=False, global_buffer=None, net=None,
                 strategy_agent=None):
        super(SourceAgent, self).__init__()
        self.net = net
        self.index = index
        self.global_buffer = global_buffer
        self.restore_model = restore_model

        # model in brain
        self.strategy_agent = strategy_agent
        self.strategy_act = None

        # count num
        self.step = 0

        self.strategy_wait_secs = 4
        self.strategy_flag = False
        self.policy_wait_secs = 2
        self.policy_flag = True

        self.env = None
        self.obs = None

        # buffer
        self.local_buffer = Buffer()
        self.mini_state = []
        self.mini_state_mapping = []

        self.num_players = 2
        self.on_select = None
        self._result = None
        self.is_end = False
        self.is_attack = False
        self._gases = None

        self.rl_training = rl_training

        self.reward_type = 0

    def reset(self):
        super(SourceAgent, self).reset()
        self.step = 0
        self.obs = None
        self._result = None
        self.is_end = False
        self.is_attack = False
        self._gases = None

        self.policy_flag = True

        self.local_buffer.reset()

        if self.strategy_agent is not None:
            self.strategy_agent.reset()

    def set_env(self, env):
        self.env = env

    def init_network(self):
        self.net.initialize()
        if self.restore_model:
            self.net.restore_policy()

    def reset_old_network(self):
        self.net.reset_old_network()

    def save_model(self):
        self.net.save_policy()

    def update_network(self, result_list):
        self.net.Update_policy(self.global_buffer)
        self.net.Update_result(result_list)

    def update_summary(self, counter):
        return self.net.Update_summary(counter)

    def get_policy_input(self, obs):
        high_input, tech_cost, pop_num = U.get_input(obs)
        policy_input = np.concatenate([high_input, tech_cost, pop_num], axis=0)
        return policy_input


    def mini_step(self, action):

        if action == TerranAction.Build_SCV.value:
            M.mineral_worker(self)

        elif action == TerranAction.Build_Refinery.value:
            if self._gases is not None:
                # U.find_gas_pos(self.obs, 1)
                gas_1 = self._gases[0]
                gas_2 = self._gases[1]

                if gas_1 is not None and not U.is_Refinery_on_gas(self.obs, gas_1):
                    gas_1_pos = T.world_to_screen_pos(self.env.game_info, gas_1.pos, self.obs)
                    M.build_by_idle_worker(self, C._BUILD_REFINERY_S, gas_1_pos)

                elif gas_2 is not None and not U.is_Refinery_on_gas(self.obs, gas_2):
                    gas_2_pos = T.world_to_screen_pos(self.env.game_info, gas_2.pos, self.obs)
                    M.build_by_idle_worker(self, C._BUILD_REFINERY_S, gas_2_pos)

        elif action == TerranAction.Gather_gas.value:
            M.gather_resource(self, 'gas')

        elif action == TerranAction.Gather_mineral.value:
            M.gather_resource(self, 'mineral')

        elif action == TerranAction.Build_Marine.value:
            M.train_army(self, C._TRAIN_MARINE)

        elif action == TerranAction.Build_Reaper.value:
            M.train_army(self, C._TRAIN_REAPER)

        elif action == TerranAction.Build_SupplyDepot.value:
            valid_index = U.get_valid_mask_screen(self.obs, size=3)
            pos = U.get_pos(valid_index)
            M.build_by_idle_worker(self, C._BUILD_SUPPLYDEPOT_S, pos)

        elif action == TerranAction.Build_Barracks.value:
            valid_index = U.get_valid_mask_screen(self.obs, size=3)
            pos = U.get_pos(valid_index)
            M.build_by_idle_worker(self, C._BUILD_BARRACKS_S, pos)

        elif action == TerranAction.Attack.value:
            self.is_attack = True

        elif action == TerranAction.Defend.value:
            M.retreat_step(self)

        else:
            self.safe_action(C._NO_OP, 0, [])

        if self.is_attack:
            M.attack_step(self)
        # if any queen exists, try to inject lavra to hatchery.
        # M.inject_larva(self) TODO

    def get_the_input(self):
        high_input, tech_cost, pop_num = U.get_input(self.obs)
        controller_input = np.concatenate([high_input, tech_cost, pop_num], axis=0)
        return controller_input

    def combine_state_and_mini_action(self, state, strategy_act):
        act = np.zeros((1, 1))
        act[0, 0] = strategy_act
        action_array = self.one_hot_label(act, C._SIZE_MINI_ACTIONS)[0]
        combined_state = np.concatenate([state, action_array], axis=0)
        return combined_state

    def mapping_source_to_mini(self, source_state):
        mini_state = self.net.mapping.predict_func(source_state, use_transform=False)
        return mini_state

    def mapping_source_to_mini_by_rule(self, source_state):
        simple_input = np.zeros([17])
        simple_input[0] = 0  # self.time_seconds
        simple_input[1] = source_state[27]  # self.mineral_worker_nums
        simple_input[2] = source_state[29] + source_state[31]  # self.gas_worker_nums
        simple_input[3] = source_state[2]  # self.mineral
        simple_input[4] = source_state[3]  # self.gas
        simple_input[5] = source_state[6]  # self.food_cap
        simple_input[6] = source_state[7]  # self.food_used
        simple_input[7] = source_state[10]  # self.army_nums

        simple_input[8] = source_state[11]  # self.Refinery_num
        simple_input[9] = source_state[12]  # self.SupplyDepot_num
        simple_input[10] = source_state[13]  # self.Barracks_num

        simple_input[11] = source_state[15]  # self.Marine_num
        simple_input[12] = source_state[16]  # self.Reaper_num

        simple_input[13] = source_state[4] + source_state[2]
        simple_input[14] = source_state[5] + source_state[3]

        simple_input[15] = source_state[4]
        simple_input[16] = source_state[5]

        return simple_input

    def play(self, verbose=False):
        is_attack = False
        state_last = None

        self.safe_action(C._NO_OP, 0, [])
        self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
        self._gases = U.find_initial_gases(self.obs)

        while True:
            # self.mini_step(3)
            self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
            #self.safe_action(C._NO_OP, 0, [])
            if self.policy_flag and (not self.is_end):

                state_now = self.mapping_source_to_mini_by_rule(self.get_the_input())

                # action_prob, v_preds = self.net.policy.get_action_probs(state_now, verbose=False)
                # action = np.argmax(action_prob)
                # print(action)

                action, v_preds = self.net.policy.get_action(state_now, verbose=False)
                # print(action)

                self.mini_step(action)
                if state_last is not None:
                    # print(state_now)
                    # print(TerranAction(int(action)).name)
                    # time.sleep(0.5)
                    v_preds_next = self.net.policy.get_values(state_now)
                    v_preds_next = self.get_values(v_preds_next)
                    reward = 0
                    self.local_buffer.append(state_last, action_last, state_now, reward, v_preds, v_preds_next)


                state_last = state_now
                action_last = action
                self.policy_flag = False

            if self.is_end:
                if self.rl_training:
                    self.local_buffer.rewards[-1] += 1 * self.result['reward']  # self.result['win']
                    # print(self.local_buffer.rewards)
                    self.global_buffer.add(self.local_buffer)
                    # print("add %d buffer!" % (len(self.local_buffer.rewards)))
                break

    # def play(self, verbose=False):
    #     is_attack = False
    #     while True:
    #         #self.safe_action(C._NO_OP, 0, [])
    #         self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
    #         if self.policy_flag and (not self.is_end):
    #             mini_state_mapping = self.mapping_source_to_mini_by_rule(self.get_the_input())
    #             #print('state:', mini_state_mapping)
    #             mini_act = self.strategy_agent.get_action_by_policy(mini_state_mapping)[0]
    #             print("Action: ", ZergAction(int(mini_act)).name)
    #             self.mini_step(mini_act)

    #             if mini_act == ZergAction.Attack.value:
    #                 is_attack = True
    #             if is_attack:
    #                 self.mini_step(ZergAction.Attack.value)

    #             self.policy_flag = False

    #         if self.is_end:
    #             break

    def set_flag(self):
        if self.step % C.time_wait(self.strategy_wait_secs) == 1:
            self.strategy_flag = True

        if self.step % C.time_wait(self.policy_wait_secs) == 1:
            self.policy_flag = True

    def safe_action(self, action, unit_type, args):
        if M.check_params(self, action, unit_type, args, 1):
            self.obs = self.env.step([sc2_actions.FunctionCall(action, args)])[0]
            self.step += 1
            self.update_result()
            self.set_flag()

    def select(self, action, unit_type, args):
        # safe select
        if M.check_params(self, action, unit_type, args, 0):
            self.obs = self.env.step([sc2_actions.FunctionCall(action, args)])[0]
            self.on_select = unit_type
            self.update_result()
            self.step += 1
            self.set_flag()

    @property
    def result(self):
        return self._result

    def update_result(self):
        if self.obs is None:
            return
        if self.obs.last() or self.env.state == environment.StepType.LAST:
            self.is_end = True
            outcome = 0
            o = self.obs.raw_observation
            player_id = o.observation.player_common.player_id
            for r in o.player_result:
                if r.player_id == player_id:
                    outcome = sc2_env._possible_results.get(r.result, 0)
            frames = o.observation.game_loop
            result = {}
            result['outcome'] = outcome
            result['reward'] = self.obs.reward
            result['frames'] = frames

            result['win'] = 0
            if result['reward'] == 1:
                result['win'] = 1

            self._result = result
            print('play end, total return', self.obs.reward)
            self.step = 0

    def one_hot_label(self, action_type_array, action_max_num):
        rows = action_type_array.shape[0]
        cols = action_max_num
        data = np.zeros((rows, cols))

        for i in range(rows):
            data[i, int(action_type_array[i])] = 1

        return data

    def get_values(self, values):
        # check if the game is end
        if self.is_end and self.result['reward'] != 0:
            return 0
        else:
            return values
示例#21
0
class MiniSourceAgent(base_agent.BaseAgent):
    """Agent for source game of starcraft."""
    def __init__(self,
                 index=0,
                 rl_training=False,
                 restore_model=False,
                 global_buffer=None,
                 net=None,
                 strategy_agent=None,
                 greedy_action=False,
                 extract_save_dir=None):
        super(MiniSourceAgent, self).__init__()
        self.net = net
        self.index = index
        self.global_buffer = global_buffer
        self.restore_model = restore_model

        # model in brain
        self.strategy_agent = strategy_agent
        self.strategy_act = None

        # count num
        self.step = 0

        self.strategy_wait_secs = 2
        self.strategy_flag = False
        self.policy_wait_secs = 2
        self.policy_flag = True

        self.env = None
        self.obs = None

        # buffer
        self.local_buffer = Buffer()

        self.num_players = 2
        self.on_select = None
        self._result = None
        self._gases = None
        self.is_end = False

        self.greedy_action = greedy_action
        self.rl_training = rl_training

        self.extract_save_dir = extract_save_dir

    def reset(self):
        super(MiniSourceAgent, self).reset()
        self.step = 0
        self.obs = None
        self._result = None
        self._gases = None
        self.is_end = False

        self.strategy_flag = False
        self.policy_flag = True

        self.local_buffer.reset()

        if self.strategy_agent is not None:
            self.strategy_agent.reset()

    def set_env(self, env):
        self.env = env

    def init_network(self):
        self.net.initialize()
        if self.restore_model:
            self.net.restore_policy()

    def reset_old_network(self):
        self.net.reset_old_network()

    def save_model(self):
        self.net.save_policy()

    def update_policy(self):
        self.net.Update_policy(self.global_buffer)

    def update_result(self, result_list):
        self.net.update_result(result_list)

    def update_network(self, result_list):
        self.net.Update_policy(self.global_buffer)
        self.net.Update_result(result_list)

    def update_summary(self, counter):
        return self.net.Update_summary(counter)

    def mini_step(self, action):
        if action == ProtossAction.Build_probe.value:
            M.mineral_worker(self)

        elif action == ProtossAction.Build_zealot.value:
            M.train_army(self, C._TRAIN_ZEALOT)

        elif action == ProtossAction.Build_Stalker.value:
            M.train_army(self, C._TRAIN_STALKER)

        elif action == ProtossAction.Build_pylon.value:
            no_unit_index = U.get_unit_mask_screen(self.obs, size=2)
            pos = U.get_pos(no_unit_index)
            M.build_by_idle_worker(self, C._BUILD_PYLON_S, pos)

        elif action == ProtossAction.Build_gateway.value:
            power_index = U.get_power_mask_screen(self.obs, size=5)
            pos = U.get_pos(power_index)
            M.build_by_idle_worker(self, C._BUILD_GATEWAY_S, pos)

        elif action == ProtossAction.Build_Assimilator.value:
            if self._gases is not None:
                #U.find_gas_pos(self.obs, 1)
                gas_1 = self._gases[0]
                gas_2 = self._gases[1]

                if gas_1 is not None and not U.is_assimilator_on_gas(
                        self.obs, gas_1):
                    gas_1_pos = T.world_to_screen_pos(self.env.game_info,
                                                      gas_1.pos, self.obs)
                    M.build_by_idle_worker(self, C._BUILD_ASSIMILATOR_S,
                                           gas_1_pos)

                elif gas_2 is not None and not U.is_assimilator_on_gas(
                        self.obs, gas_2):
                    gas_2_pos = T.world_to_screen_pos(self.env.game_info,
                                                      gas_2.pos, self.obs)
                    M.build_by_idle_worker(self, C._BUILD_ASSIMILATOR_S,
                                           gas_2_pos)

        elif action == ProtossAction.Build_CyberneticsCore.value:
            power_index = U.get_power_mask_screen(self.obs, size=3)
            pos = U.get_pos(power_index)
            M.build_by_idle_worker(self, C._BUILD_CYBER_S, pos)

        elif action == ProtossAction.Attack.value:
            M.attack_step(self)

        elif action == ProtossAction.Retreat.value:
            M.retreat_step(self)

        elif action == ProtossAction.Do_nothing.value:
            self.safe_action(C._NO_OP, 0, [])

    def get_the_input(self):
        high_input, tech_cost, pop_num = U.get_input(self.obs)
        controller_input = np.concatenate([high_input, tech_cost, pop_num],
                                          axis=0)
        return controller_input

    def get_the_input_right(self, obs):
        high_input, tech_cost, pop_num = U.get_input(obs)
        controller_input = np.concatenate([high_input, tech_cost, pop_num],
                                          axis=0)
        return controller_input

    def mapping_source_to_mini_by_rule(self, source_state):
        simple_input = np.zeros([20], dtype=np.int16)
        simple_input[0] = 0  # self.time_seconds
        simple_input[1] = source_state[28]  # self.mineral_worker_nums
        simple_input[2] = source_state[30] + source_state[
            32]  # self.gas_worker_nums
        simple_input[3] = source_state[2]  # self.mineral
        simple_input[4] = source_state[3]  # self.gas
        simple_input[5] = source_state[6]  # self.food_cup
        simple_input[6] = source_state[7]  # self.food_used
        simple_input[7] = source_state[10]  # self.army_nums

        simple_input[8] = source_state[16]  # self.gateway_num
        simple_input[9] = source_state[14]  # self.pylon_num
        simple_input[10] = source_state[15]  # self.Assimilator_num
        simple_input[11] = source_state[17]  # self.CyberneticsCore_num

        simple_input[12] = source_state[12]  # self.zealot_num
        simple_input[13] = source_state[13]  # self.Stalker_num
        simple_input[14] = source_state[11]  # self.probe_num

        simple_input[15] = source_state[4] + source_state[
            2]  # self.collected_mineral
        simple_input[16] = source_state[4]  # self.spent_mineral
        simple_input[17] = source_state[5] + source_state[
            3]  # self.collected_gas
        simple_input[18] = source_state[5]  # self.spent_gas
        simple_input[19] = 1  # self.Nexus_num

        return simple_input

    def play(self, verbose=False):
        self.play_train_mini(verbose=verbose)

    def sample(self, verbose=False, use_image=False):
        is_attack = False
        state_last = None

        random_generated_int = random.randint(0, 2**31 - 1)
        filename = self.extract_save_dir + "/" + str(
            random_generated_int) + ".npz"
        recording_obs = []
        recording_img = []
        recording_action = []

        np.random.seed(random_generated_int)
        tf.set_random_seed(random_generated_int)

        self.safe_action(C._NO_OP, 0, [])
        self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
        self._gases = U.find_initial_gases(self.obs)
        while True:

            self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
            if self.policy_flag and (not self.is_end):

                state_now = self.mapping_source_to_mini_by_rule(
                    self.get_the_input())
                recording_obs.append(state_now)

                if use_image:
                    recording_img.append(U.get_simple_map_data(self.obs))

                action, v_preds = self.net.policy.get_action(state_now,
                                                             verbose=False)
                recording_action.append(action)

                self.mini_step(action)

                if state_last is not None:
                    if False:
                        print('state_last:', state_last, ', action_last:',
                              action_last, ', state_now:', state_now)
                    v_preds_next = self.net.policy.get_values(state_now)
                    v_preds_next = self.get_values(v_preds_next)
                    reward = 0

                    self.local_buffer.append(state_last, action_last,
                                             state_now, reward, v_preds,
                                             v_preds_next)

                state_last = state_now
                action_last = action
                self.policy_flag = False

            if self.is_end:
                if True:
                    #note this will not consider the minerals larger than 256!
                    recording_obs = np.array(recording_obs, dtype=np.uint16)
                    recording_action = np.array(recording_action,
                                                dtype=np.uint8)
                    if not use_image:
                        np.savez_compressed(filename,
                                            obs=recording_obs,
                                            action=recording_action)
                    else:
                        recording_img = np.array(recording_img,
                                                 dtype=np.float16)
                        np.savez_compressed(filename,
                                            obs=recording_obs,
                                            img=recording_img,
                                            action=recording_action)
                break

    def play_train_mini(self, continues_attack=False, verbose=False):
        is_attack = False
        state_last = None

        self.safe_action(C._NO_OP, 0, [])
        self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
        self._gases = U.find_initial_gases(self.obs)
        while True:

            self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
            if self.policy_flag and (not self.is_end):

                state_now = self.mapping_source_to_mini_by_rule(
                    self.get_the_input())
                if self.greedy_action:
                    action_prob, v_preds = self.net.policy.get_action_probs(
                        state_now, verbose=False)
                    action = np.argmax(action_prob)
                else:
                    action, v_preds = self.net.policy.get_action(state_now,
                                                                 verbose=False)

                # print(ProtossAction(action).name)
                self.mini_step(action)

                if state_last is not None:
                    if False:
                        print('state_last:', state_last, ', action_last:',
                              action_last, ', state_now:', state_now)
                    v_preds_next = self.net.policy.get_values(state_now)
                    v_preds_next = self.get_values(v_preds_next)
                    reward = 0
                    if False:
                        print(state_last, action_last, state_now, reward,
                              v_preds, v_preds_next)
                    self.local_buffer.append(state_last, action_last,
                                             state_now, reward, v_preds,
                                             v_preds_next)

                # continuous attack, consistent with mind-game
                if continues_attack:
                    if action == ProtossAction.Attack.value:
                        is_attack = True
                    if is_attack:
                        self.mini_step(ProtossAction.Attack.value)

                state_last = state_now
                action_last = action
                self.policy_flag = False

            if self.is_end:
                if self.rl_training:
                    self.local_buffer.rewards[-1] += 1 * self.result[
                        'reward']  # self.result['win']
                    #print(self.local_buffer.values)
                    #print(self.local_buffer.values_next)
                    #print(self.local_buffer.rewards)

                    self.global_buffer.add(self.local_buffer)
                    print("add %d buffer!" % (len(self.local_buffer.rewards)))
                    #print("returns:", self.global_buffer.returns)
                    #print("gaes:", self.global_buffer.gaes)
                break

    def play_right(self, verbose=False):
        # note this is a right version of game play
        prev_state = None
        prev_action = None
        prev_value = None

        self.safe_action(C._NO_OP, 0, [])
        self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
        self._gases = U.find_initial_gases(self.obs)

        while True:

            self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
            if self.policy_flag and (not self.is_end):
                # get the state
                state = self.mapping_source_to_mini_by_rule(
                    self.get_the_input_right(self.obs))

                # get the action and value accoding to state
                action, value = self.net.policy.get_action(state,
                                                           verbose=verbose)

                # if this is not the fisrt state, store things to buffer
                if prev_state is not None:
                    # try reward = self.obs.reward
                    reward = self.obs.reward
                    if verbose:
                        print(prev_state, prev_action, state, reward,
                              prev_value, value)
                    self.local_buffer.append(prev_state, prev_action, state,
                                             reward, prev_value, value)

                self.mini_step(action)
                # the evn step to new states

                prev_state = state
                prev_action = action
                prev_value = value

                self.policy_flag = False

            if self.is_end:
                # get the last state and reward
                # get the state
                state = self.mapping_source_to_mini_by_rule(
                    self.get_the_input_right(self.obs))

                value = self.net.policy.get_values(state)
                # the value of the last state is defined somewhat different
                value = self.get_values_right(value)

                # if this is not the fisrt state, store things to buffer
                if prev_state is not None:
                    reward = self.obs.reward
                    if verbose:
                        print(prev_state, prev_action, state, reward,
                              prev_value, value)
                    self.local_buffer.append(prev_state, prev_action, state,
                                             reward, prev_value, value)
                break

        if self.rl_training:
            if verbose:
                print(self.local_buffer.values)
                print(self.local_buffer.values_next)
            #print(self.local_buffer.rewards)
            self.global_buffer.add(self.local_buffer)
            #print("add %d buffer!" % (len(self.local_buffer.rewards)))
            #print("returns:", self.global_buffer.returns)
            #print("gaes:", self.global_buffer.gaes)

    def set_flag(self):
        if self.step % C.time_wait(self.strategy_wait_secs) == 1:
            self.strategy_flag = True

        if self.step % C.time_wait(self.policy_wait_secs) == 1:
            self.policy_flag = True

    def safe_action(self, action, unit_type, args):
        if M.check_params(self, action, unit_type, args, 1):
            obs = self.env.step([sc2_actions.FunctionCall(action, args)])[0]
            self.obs = obs
            self.step += 1
            self.update_result()
            self.set_flag()

    def select(self, action, unit_type, args):
        # safe select
        if M.check_params(self, action, unit_type, args, 0):
            self.obs = self.env.step([sc2_actions.FunctionCall(action,
                                                               args)])[0]
            self.on_select = unit_type
            self.update_result()
            self.step += 1
            self.set_flag()

    @property
    def result(self):
        return self._result

    def update_result(self):
        if self.obs is None:
            return
        if self.obs.last() or self.env.state == environment.StepType.LAST:
            self.is_end = True
            outcome = 0
            o = self.obs.raw_observation
            player_id = o.observation.player_common.player_id
            for r in o.player_result:
                if r.player_id == player_id:
                    outcome = sc2_env._possible_results.get(r.result, 0)
            frames = o.observation.game_loop
            result = {}
            result['outcome'] = outcome
            result['reward'] = self.obs.reward
            result['frames'] = frames

            self._result = result
            print('play end, total return', self.obs.reward)
            self.step = 0

    def get_values(self, values):
        # check if the game is end
        if self.is_end and self.result['reward'] != 0:
            return 0
        else:
            return values

    def get_values_right(self, values):
        # if the game ends with a win or loss (the result reward is 1 or -1), the value is set to 0
        # else if the game ends without a result (the result reward is 1 or -1), the value is set to asbefore
        if self.is_end and self.result['reward'] != 0:
            return 0
        else:
            return values
示例#22
0
class SourceAgent(base_agent.BaseAgent):
    """Agent for source game of starcraft."""
    def __init__(self,
                 index=0,
                 rl_training=False,
                 restore_model=False,
                 global_buffer=None,
                 net=None,
                 strategy_agent=None):
        super(SourceAgent, self).__init__()
        self.net = net
        self.index = index
        self.global_buffer = global_buffer
        self.restore_model = restore_model

        # model in brain
        self.strategy_agent = strategy_agent
        self.strategy_act = None

        # count num
        self.step = 0

        self.strategy_wait_secs = 4
        self.strategy_flag = False
        self.policy_wait_secs = 2
        self.policy_flag = True

        self.env = None
        self.obs = None

        # buffer
        self.local_buffer = Buffer()
        self.mini_state = []
        self.mini_state_mapping = []

        self.num_players = 2
        self.on_select = None
        self._result = None
        self.is_end = False

        self.rl_training = rl_training

        self.reward_type = 0

    def reset(self):
        super(SourceAgent, self).reset()
        self.step = 0
        self.obs = None
        self._result = None
        self.is_end = False

        self.policy_flag = True

        self.local_buffer.reset()
        self.strategy_agent.reset()

    def set_env(self, env):
        self.env = env

    def init_network(self):
        self.net.initialize()
        if self.restore_model:
            self.net.restore_policy()

    def reset_old_network(self):
        self.net.reset_old_network()

    def save_model(self):
        self.net.save_policy()

    def update_network(self, result_list):
        self.net.Update_policy(self.global_buffer)
        self.net.Update_result(result_list)

    def update_summary(self, counter):
        self.net.Update_summary(counter)

    def get_policy_input(self, obs):
        high_input, tech_cost, pop_num = U.get_input(obs)
        policy_input = np.concatenate([high_input, tech_cost, pop_num], axis=0)
        return policy_input

    def tech_step(self, tech_action):
        # to execute a tech_action
        # [pylon, gas1, gas2, gateway, cyber]

        if tech_action == 0:  # pylon
            no_unit_index = U.get_unit_mask_screen(self.obs, size=2)
            pos = U.get_pos(no_unit_index)
            M.build_by_idle_worker(self, C._BUILD_PYLON_S, pos)

        elif tech_action == 1 and not U.find_gas(self.obs, 1):  # gas_1
            gas_1 = U.find_gas_pos(self.obs, 1)
            gas_1_pos = T.world_to_screen_pos(self.env.game_info, gas_1.pos,
                                              self.obs)
            M.build_by_idle_worker(self, C._BUILD_ASSIMILATOR_S, gas_1_pos)

        elif tech_action == 1 and not U.find_gas(self.obs, 2):  # gas_2
            gas_2 = U.find_gas_pos(self.obs, 2)
            gas_2_pos = T.world_to_screen_pos(self.env.game_info, gas_2.pos,
                                              self.obs)
            M.build_by_idle_worker(self, C._BUILD_ASSIMILATOR_S, gas_2_pos)

        elif tech_action == 2:  # gateway
            power_index = U.get_power_mask_screen(self.obs, size=5)
            pos = U.get_pos(power_index)
            M.build_by_idle_worker(self, C._BUILD_GATEWAY_S, pos)

        elif tech_action == 3:  # cyber
            power_index = U.get_power_mask_screen(self.obs, size=3)
            pos = U.get_pos(power_index)
            M.build_by_idle_worker(self, C._BUILD_CYBER_S, pos)

        else:
            self.safe_action(C._NO_OP, 0, [])

    def pop_step(self, pop_action):
        # to execute a pop_action
        # [ mineral_probe, zealot, stalker]
        #print('pop_action', pop_action)
        if pop_action == 0:  # mineral_probe
            M.mineral_worker(self)
            # print('mineral_worker')
        elif pop_action == 1:  # zealot
            M.train_army(self, C._TRAIN_ZEALOT)
            # print('_TRAIN_ZEALOT')
        elif pop_action == 2:  # stalker
            M.train_army(self, C._TRAIN_STALKER)
            # print('_TRAIN_STALKER')
        else:
            self.safe_action(C._NO_OP, 0, [])

    def battle_step(self, battle_action):
        if battle_action == 0:  # attack
            M.attack_step(self)

        elif battle_action == 1:  # retreat
            M.retreat_step(self)

        else:
            self.safe_action(C._NO_OP, 0, [])

    def mini_step(self, action):
        if action == ProtossAction.Build_worker.value:
            M.mineral_worker(self)
        elif action == ProtossAction.Build_zealot.value:
            M.train_army(self, C._TRAIN_ZEALOT)
        elif action == ProtossAction.Build_pylon.value:
            no_unit_index = U.get_unit_mask_screen(self.obs, size=2)
            pos = U.get_pos(no_unit_index)
            M.build_by_idle_worker(self, C._BUILD_PYLON_S, pos)
        elif action == ProtossAction.Build_gateway.value:
            power_index = U.get_power_mask_screen(self.obs, size=5)
            pos = U.get_pos(power_index)
            M.build_by_idle_worker(self, C._BUILD_GATEWAY_S, pos)
        elif action == ProtossAction.Attack.value:
            M.attack_step(self)
        elif action == ProtossAction.Defend.value:
            M.retreat_step(self)
        elif action == ProtossAction.Build_sub_base.value:
            self.safe_action(C._NO_OP, 0, [])
        elif action == ProtossAction.Build_cannon.value:
            self.safe_action(C._NO_OP, 0, [])
        else:
            self.safe_action(C._NO_OP, 0, [])

    def get_the_input(self):
        high_input, tech_cost, pop_num = U.get_input(self.obs)
        controller_input = np.concatenate([high_input, tech_cost, pop_num],
                                          axis=0)
        return controller_input

    def combine_state_and_mini_action(self, state, strategy_act):
        act = np.zeros((1, 1))
        act[0, 0] = strategy_act
        action_array = self.one_hot_label(act, C._SIZE_MINI_ACTIONS)[0]
        combined_state = np.concatenate([state, action_array], axis=0)
        return combined_state

    def mapping_source_to_mini(self, source_state):
        mini_state = self.net.mapping.predict_func(source_state,
                                                   use_transform=False)
        return mini_state

    def mapping_source_to_mini_by_rule(self, source_state):
        simple_input = np.zeros([11])
        simple_input[0] = 0  # self.time_seconds
        simple_input[1] = source_state[28]  # self.mineral_worker_nums
        simple_input[2] = source_state[30] + source_state[
            32]  # self.gas_worker_nums
        simple_input[3] = source_state[2]  # self.mineral
        simple_input[4] = source_state[3]  # self.gas
        simple_input[5] = source_state[6]  # self.food_cup
        simple_input[6] = source_state[7]  # self.food_used
        simple_input[7] = source_state[10]  # self.army_nums
        simple_input[8] = source_state[16]  # self.gateway_num
        simple_input[9] = source_state[14]  # self.pylon_num
        simple_input[10] = source_state[12]  # self.zealot_num

        return simple_input

    def play_bak(self, verbose=False):
        # self.safe_action(C._NO_OP, 0, [])
        state_last = None
        mini_state = self.strategy_agent.obs()
        while True:
            self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])

            source_state = self.get_the_input()
            mini_state_mapping = self.mapping_source_to_mini_by_rule(
                source_state)
            if 0:
                print('source_state:', source_state)
                print('mini_state_mapping:', mini_state_mapping)

            # test use mini_state_mapping
            strategy_state = mini_state_mapping

            mini_act = self.strategy_agent.get_action_by_policy(
                strategy_state)[0]
            #print('strategy_act:', mini_act)

            self.strategy_agent.set_obs(strategy_state)
            mini_state = self.strategy_agent.get_next_state(mini_act)

            self.strategy_act = mini_act
            self.strategy_flag = False

            while (not self.strategy_flag) and (not self.is_end):
                self.safe_action(C._NO_OP, 0, [])

                if self.policy_flag and (not self.is_end):
                    state_now = self.combine_state_and_mini_action(
                        self.get_the_input(), self.strategy_act)
                    #print('state_now:', state_now)
                    action, v_preds = self.net.policy.get_action(state_now,
                                                                 verbose=False)
                    #print('action:', action)

                    print('action:', self.strategy_act)
                    self.mini_step(self.strategy_act)
                    '''
                    if action < C._SIZE_TECH_NET_OUT:
                        reward = self.tech_step(action)
                    elif action < C._SIZE_TECH_NET_OUT + C._SIZE_POP_NET_OUT:
                        reward = self.pop_step(action - C._SIZE_TECH_NET_OUT)
                    elif action < C._SIZE_TECH_NET_OUT + C._SIZE_POP_NET_OUT + C._SIZE_BATTLE_NET_OUT:
                        reward = self.battle_step(action - C._SIZE_TECH_NET_OUT - C._SIZE_POP_NET_OUT)
                    else:
                        self.safe_action(C._NO_OP, 0, [])
                        reward = 0
                    '''

                    if state_last is not None:
                        if 0:
                            print('state_last:', state_last, ', action_last:',
                                  action_last, ', state_now:', state_now)
                        v_preds_next = self.net.policy.get_values(state_now)
                        v_preds_next = self.get_values(v_preds_next)
                        reward = 0
                        self.local_buffer.append(state_last, action_last,
                                                 state_now, reward, v_preds,
                                                 v_preds_next)

                    state_last = state_now
                    action_last = action
                    self.policy_flag = False

            if self.is_end:
                if self.rl_training:
                    self.local_buffer.rewards[-1] += 1 * self.result['reward']
                    print(self.local_buffer.rewards)
                    self.global_buffer.add(self.local_buffer)
                    print("add %d buffer!" % (len(self.local_buffer.rewards)))
                break

    def play(self, verbose=False):
        is_attack = False
        while True:
            #self.safe_action(C._NO_OP, 0, [])
            self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
            if self.policy_flag and (not self.is_end):
                mini_state_mapping = self.mapping_source_to_mini_by_rule(
                    self.get_the_input())
                #print('state:', mini_state_mapping)
                mini_act = self.strategy_agent.get_action_by_policy(
                    mini_state_mapping)[0]
                print('action:', mini_act)
                self.mini_step(mini_act)

                if mini_act == ProtossAction.Attack.value:
                    is_attack = True
                if is_attack:
                    self.mini_step(ProtossAction.Attack.value)

                self.policy_flag = False

            if self.is_end:
                break

    def set_flag(self):
        if self.step % C.time_wait(self.strategy_wait_secs) == 1:
            self.strategy_flag = True

        if self.step % C.time_wait(self.policy_wait_secs) == 1:
            self.policy_flag = True

    def safe_action(self, action, unit_type, args):
        if M.check_params(self, action, unit_type, args, 1):
            obs = self.env.step([sc2_actions.FunctionCall(action, args)])[0]
            self.obs = obs
            self.step += 1
            self.update_result()
            self.set_flag()

    def select(self, action, unit_type, args):
        # safe select
        if M.check_params(self, action, unit_type, args, 0):
            self.obs = self.env.step([sc2_actions.FunctionCall(action,
                                                               args)])[0]
            self.on_select = unit_type
            self.update_result()
            self.step += 1
            self.set_flag()

    @property
    def result(self):
        return self._result

    def update_result(self):
        if self.obs is None:
            return
        if self.obs.last() or self.env.state == environment.StepType.LAST:
            self.is_end = True
            outcome = 0
            o = self.obs.raw_observation
            player_id = o.observation.player_common.player_id
            for r in o.player_result:
                if r.player_id == player_id:
                    outcome = sc2_env._possible_results.get(r.result, 0)
            frames = o.observation.game_loop
            result = {}
            result['outcome'] = outcome
            result['reward'] = self.obs.reward
            result['frames'] = frames

            result['win'] = 0
            if result['reward'] == 1:
                result['win'] = 1

            self._result = result
            print('play end, total return', self.obs.reward)
            self.step = 0

    def one_hot_label(self, action_type_array, action_max_num):
        rows = action_type_array.shape[0]
        cols = action_max_num
        data = np.zeros((rows, cols))

        for i in range(rows):
            data[i, int(action_type_array[i])] = 1

        return data

    def get_values(self, values):
        # check if the game is end
        if self.is_end and self.result['reward'] != 0:
            return 0
        else:
            return values
示例#23
0
class SourceAgent(base_agent.BaseAgent):
    """Agent for source game of starcraft."""
    def __init__(self,
                 index=0,
                 rl_training=False,
                 restore_model=False,
                 global_buffer=None,
                 net=None,
                 strategy_agent=None):
        super(SourceAgent, self).__init__()
        self.net = net
        self.index = index
        self.global_buffer = global_buffer
        self.restore_model = restore_model

        # model in brain
        self.strategy_agent = strategy_agent
        self.strategy_act = None

        # count num
        self.step = 0

        self.strategy_wait_secs = 4
        self.strategy_flag = False
        self.policy_wait_secs = 2
        self.policy_flag = True

        self.env = None
        self.obs = None

        # buffer
        self.local_buffer = Buffer()
        self.mini_state = []
        self.mini_state_mapping = []

        self.num_players = 2
        self.on_select = None
        self._result = None
        self.is_end = False
        self.is_attack = False
        self._gases = None

        self.rl_training = rl_training

        self.reward_type = 0

    def reset(self):
        super(SourceAgent, self).reset()
        self.step = 0
        self.obs = None
        self._result = None
        self.is_end = False
        self.is_attack = False
        self._gases = None

        self.policy_flag = True

        self.local_buffer.reset()

        if self.strategy_agent is not None:
            self.strategy_agent.reset()

    def set_env(self, env):
        self.env = env

    def init_network(self):
        self.net.initialize()
        if self.restore_model:
            self.net.restore_policy()

    def reset_old_network(self):
        self.net.reset_old_network()

    def save_model(self):
        self.net.save_policy()

    def update_network(self, result_list):
        self.net.Update_policy(self.global_buffer)
        self.net.Update_result(result_list)

    def update_summary(self, counter):
        self.net.Update_summary(counter)

    def get_policy_input(self, obs):
        high_input, tech_cost, pop_num = U.get_input(obs)
        policy_input = np.concatenate([high_input, tech_cost, pop_num], axis=0)
        return policy_input

    def tech_step(self, tech_action):
        # to execute a tech_action
        # [pylon, gas1, gas2, gateway, cyber]

        if tech_action == 0:  # pylon
            no_unit_index = U.get_unit_mask_screen(self.obs, size=2)
            pos = U.get_pos(no_unit_index)
            M.build_by_idle_worker(self, C._BUILD_PYLON_S, pos)

        elif tech_action == 1 and not U.find_gas(self.obs, 1):  # gas_1
            gas_1 = U.find_geyser_pos(self.obs, 1)
            gas_1_pos = T.world_to_screen_pos(self.env.game_info, gas_1.pos,
                                              self.obs)
            M.build_by_idle_worker(self, C._BUILD_ASSIMILATOR_S, gas_1_pos)

        elif tech_action == 1 and not U.find_gas(self.obs, 2):  # gas_2
            gas_2 = U.find_geyser_pos(self.obs, 2)
            gas_2_pos = T.world_to_screen_pos(self.env.game_info, gas_2.pos,
                                              self.obs)
            M.build_by_idle_worker(self, C._BUILD_ASSIMILATOR_S, gas_2_pos)

        elif tech_action == 2:  # gateway
            power_index = U.get_power_mask_screen(self.obs, size=5)
            pos = U.get_pos(power_index)
            M.build_by_idle_worker(self, C._BUILD_GATEWAY_S, pos)

        elif tech_action == 3:  # cyber
            power_index = U.get_power_mask_screen(self.obs, size=3)
            pos = U.get_pos(power_index)
            M.build_by_idle_worker(self, C._BUILD_CYBER_S, pos)

        else:
            self.safe_action(C._NO_OP, 0, [])

    def pop_step(self, pop_action):
        # to execute a pop_action
        # [ mineral_probe, zealot, stalker]
        #print('pop_action', pop_action)
        if pop_action == 0:  # mineral_probe
            M.mineral_worker(self)
            # print('mineral_worker')
        elif pop_action == 1:  # zealot
            M.train_army(self, C._TRAIN_ZEALOT)
            # print('_TRAIN_ZEALOT')
        elif pop_action == 2:  # stalker
            M.train_army(self, C._TRAIN_STALKER)
            # print('_TRAIN_STALKER')
        else:
            self.safe_action(C._NO_OP, 0, [])

    def battle_step(self, battle_action):
        if battle_action == 0:  # attack
            M.attack_step(self)

        elif battle_action == 1:  # retreat
            M.retreat_step(self)

        else:
            self.safe_action(C._NO_OP, 0, [])

    def mini_step(self, action):

        if action == ZergAction.Build_drone.value:
            M.mineral_worker(self)

        elif action == ZergAction.Build_extractor.value:
            if self._gases is not None:
                #U.find_gas_pos(self.obs, 1)
                gas_1 = self._gases[0]
                gas_2 = self._gases[1]

                if gas_1 is not None and not U.is_extractor_on_gas(
                        self.obs, gas_1):
                    gas_1_pos = T.world_to_screen_pos(self.env.game_info,
                                                      gas_1.pos, self.obs)
                    M.build_by_idle_worker(self, C._BUILD_EXTRACTOR_S,
                                           gas_1_pos)

                elif gas_2 is not None and not U.is_extractor_on_gas(
                        self.obs, gas_2):
                    gas_2_pos = T.world_to_screen_pos(self.env.game_info,
                                                      gas_2.pos, self.obs)
                    M.build_by_idle_worker(self, C._BUILD_EXTRACTOR_S,
                                           gas_2_pos)

        elif action == ZergAction.Gather_gas.value:
            M.gather_resource(self, 'gas')

        elif action == ZergAction.Gather_mineral.value:
            M.gather_resource(self, 'mineral')

        elif action == ZergAction.Build_queen.value:
            M.train_army(self, C._TRAIN_QUEEN)

        elif action == ZergAction.Build_zergling.value:
            M.train_army(self, C._TRAIN_ZERGLING)

        elif action == ZergAction.Build_roach.value:
            M.train_army(self, C._TRAIN_ROACH)

        elif action == ZergAction.Build_overlord.value:
            M.train_army(self, C._TRAIN_OVERLORD)

        elif action == ZergAction.Build_spawningpool.value:
            creep_index = U.get_creep_mask_screen(self.obs, size=2)
            pos = U.get_pos(creep_index)
            M.build_by_idle_worker(self, C._BUILD_SPAWNINGPOOL_S, pos)

        elif action == ZergAction.Build_roachwarren.value:
            creep_index = U.get_creep_mask_screen(self.obs, size=2)
            pos = U.get_pos(creep_index)
            M.build_by_idle_worker(self, C._BUILD_ROACHWARREN_S, pos)

        elif action == ZergAction.Build_evolutionchamber.value:
            creep_index = U.get_creep_mask_screen(self.obs, size=2)
            pos = U.get_pos(creep_index)
            M.build_by_idle_worker(self, C._BUILD_EVOLUTIONCHAMBER_S, pos)

        elif action == ZergAction.Build_spinecrawler.value:
            creep_index = U.get_creep_mask_screen(self.obs, size=2)
            pos = U.get_pos(creep_index)
            M.build_by_idle_worker(self, C._BUILD_SPINECRAWLER_S, pos)

        elif action == ZergAction.Attack.value:
            self.is_attack = True

        elif action == ZergAction.Defend.value:
            M.retreat_step(self)

        else:
            self.safe_action(C._NO_OP, 0, [])

        if self.is_attack:
            M.attack_step(self)
        # if any queen exists, try to inject lavra to hatchery.
        M.inject_larva(self)

    def get_the_input(self):
        high_input, tech_cost, pop_num = U.get_input(self.obs)
        controller_input = np.concatenate([high_input, tech_cost, pop_num],
                                          axis=0)
        return controller_input

    def combine_state_and_mini_action(self, state, strategy_act):
        act = np.zeros((1, 1))
        act[0, 0] = strategy_act
        action_array = self.one_hot_label(act, C._SIZE_MINI_ACTIONS)[0]
        combined_state = np.concatenate([state, action_array], axis=0)
        return combined_state

    def mapping_source_to_mini(self, source_state):
        mini_state = self.net.mapping.predict_func(source_state,
                                                   use_transform=False)
        return mini_state

    def mapping_source_to_mini_by_rule(self, source_state):
        simple_input = np.zeros([17])
        simple_input[0] = 0  # self.time_seconds
        simple_input[1] = source_state[31]  # self.mineral_worker_nums
        simple_input[2] = source_state[33] + source_state[
            35]  # self.gas_worker_nums
        simple_input[3] = source_state[2]  # self.mineral
        simple_input[4] = source_state[3]  # self.gas
        simple_input[5] = source_state[6]  # self.food_cap
        simple_input[6] = source_state[7]  # self.food_used
        simple_input[7] = source_state[10]  # self.army_nums

        simple_input[8] = source_state[11]  # self.larva_num
        simple_input[9] = source_state[15]  # self.overlord_num
        simple_input[10] = source_state[17]  # self.spawningpool_num
        simple_input[11] = source_state[18]  # self.roachwarren_num

        simple_input[12] = source_state[13]  # self.zergling_num
        simple_input[13] = source_state[14]  # self.roach_num
        simple_input[14] = source_state[16]  # self.extractor_num
        simple_input[15] = source_state[19]  # self.evolutionchamber_num
        simple_input[16] = source_state[20]  # self.queen_num

        return simple_input

    def play(self, verbose=False):
        is_attack = False
        state_last = None

        self.safe_action(C._NO_OP, 0, [])
        self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
        self._gases = U.find_initial_gases(self.obs)
        while True:

            #self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
            self.safe_action(C._NO_OP, 0, [])
            if self.policy_flag and (not self.is_end):

                state_now = self.mapping_source_to_mini_by_rule(
                    self.get_the_input())
                action, v_preds = self.net.policy.get_action(state_now,
                                                             verbose=False)
                #action = 3
                self.mini_step(action)
                if state_last is not None:
                    # print(state_now)
                    # time.sleep(0.5)
                    v_preds_next = self.net.policy.get_values(state_now)
                    v_preds_next = self.get_values(v_preds_next)
                    reward = 0
                    self.local_buffer.append(state_last, action_last,
                                             state_now, reward, v_preds,
                                             v_preds_next)

                if action == ZergAction.Attack.value:
                    is_attack = True
                if is_attack:
                    self.mini_step(ZergAction.Attack.value)

                state_last = state_now
                action_last = action
                self.policy_flag = False

            if self.is_end:
                if self.rl_training:
                    self.local_buffer.rewards[-1] += 1 * self.result[
                        'reward']  # self.result['win']
                    print(self.local_buffer.rewards)
                    self.global_buffer.add(self.local_buffer)
                    print("add %d buffer!" % (len(self.local_buffer.rewards)))
                break

    # def play(self, verbose=False):
    #     is_attack = False
    #     while True:
    #         #self.safe_action(C._NO_OP, 0, [])
    #         self.safe_action(C._MOVE_CAMERA, 0, [C.base_camera_pos])
    #         if self.policy_flag and (not self.is_end):
    #             mini_state_mapping = self.mapping_source_to_mini_by_rule(self.get_the_input())
    #             #print('state:', mini_state_mapping)
    #             mini_act = self.strategy_agent.get_action_by_policy(mini_state_mapping)[0]
    #             print("Action: ", ZergAction(int(mini_act)).name)
    #             self.mini_step(mini_act)

    #             if mini_act == ZergAction.Attack.value:
    #                 is_attack = True
    #             if is_attack:
    #                 self.mini_step(ZergAction.Attack.value)

    #             self.policy_flag = False

    #         if self.is_end:
    #             break

    def set_flag(self):
        if self.step % C.time_wait(self.strategy_wait_secs) == 1:
            self.strategy_flag = True

        if self.step % C.time_wait(self.policy_wait_secs) == 1:
            self.policy_flag = True

    def safe_action(self, action, unit_type, args):
        if M.check_params(self, action, unit_type, args, 1):
            self.obs = self.env.step([sc2_actions.FunctionCall(action,
                                                               args)])[0]
            self.step += 1
            self.update_result()
            self.set_flag()

    def select(self, action, unit_type, args):
        # safe select
        if M.check_params(self, action, unit_type, args, 0):
            self.obs = self.env.step([sc2_actions.FunctionCall(action,
                                                               args)])[0]
            self.on_select = unit_type
            self.update_result()
            self.step += 1
            self.set_flag()

    @property
    def result(self):
        return self._result

    def update_result(self):
        if self.obs is None:
            return
        if self.obs.last() or self.env.state == environment.StepType.LAST:
            self.is_end = True
            outcome = 0
            o = self.obs.raw_observation
            player_id = o.observation.player_common.player_id
            for r in o.player_result:
                if r.player_id == player_id:
                    outcome = sc2_env._possible_results.get(r.result, 0)
            frames = o.observation.game_loop
            result = {}
            result['outcome'] = outcome
            result['reward'] = self.obs.reward
            result['frames'] = frames

            result['win'] = 0
            if result['reward'] == 1:
                result['win'] = 1

            self._result = result
            print('play end, total return', self.obs.reward)
            self.step = 0

    def one_hot_label(self, action_type_array, action_max_num):
        rows = action_type_array.shape[0]
        cols = action_max_num
        data = np.zeros((rows, cols))

        for i in range(rows):
            data[i, int(action_type_array[i])] = 1

        return data

    def get_values(self, values):
        # check if the game is end
        if self.is_end and self.result['reward'] != 0:
            return 0
        else:
            return values
示例#24
0
class MiniAgent():

    def __init__(self, agent_id=0, global_buffer=None, net=None, restore_model=False):
        self.agent_id = agent_id
        self.net = net
        self.global_buffer = global_buffer
        self.greedy_action = False
        self.local_buffer = Buffer()
        self.env = None
        self.restore_model = restore_model

        self.reset()

    def __str__(self):
        return None

    def set_env(self, env):
        self.env = env

    def reset(self):
        self.step = 0
        self.obs = None
        self.reward = 0
        self.done = False
        self.result = 0
        self.local_buffer.reset()

    def play(self, show_details=False):
        #self.reset()
        self.obs = self.env.reset()
        state_last = None

        while True:
            # get the action
            if self.greedy_action:
                action_prob, v_preds = self.net.policy.get_action_probs(self.obs, verbose=False)
                action = np.argmax(action_prob)
            else:
                action, v_preds = self.net.policy.get_action(self.obs, verbose=False)

            # use the action to push the env step
            self.obs, self.reward, self.done, info = self.env.step(action)

            # add info to buffer
            if state_last is not None:
                if show_details:
                    print('state_last:', state_last, ', action_last:', action_last, ', state_now:', self.obs)
                v_preds_next = self.net.policy.get_values(self.obs)
                v_preds_next = self.get_values(v_preds_next)
                self.local_buffer.append(state_last, action_last, self.obs, self.reward, v_preds, v_preds_next)
            
            state_last = self.obs
            action_last = action

            if self.done:
                self.result = self.reward
                print('play end, total return', self.result) if show_details else None
                if len(self.local_buffer.rewards) > 0:
                    self.global_buffer.add(self.local_buffer)
                print("add %d buffer!" % (len(self.local_buffer.rewards))) if 1 else None
                break



    def init_network(self):
        self.net.initialize()
        if self.restore_model:
            self.net.restore_policy()

    def update_network(self, result_list):
        self.net.Update_policy(self.global_buffer)
        self.net.Update_result(result_list)

    def reset_old_network(self):
        self.net.reset_old_network()

    def save_model(self):
        self.net.save_policy()

    def update_summary(self, counter):
        return self.net.Update_summary(counter)

    def get_values(self, values):
        # check if the game is end
        if self.done:
            return 0
        else:
            return values