コード例 #1
0
def train_army(agent, train_action):
    # select multi building and train
    count = U.get_unit_num(agent.obs, C._GATEWAY_TYPE_INDEX)
    pos = selectGateway(agent)
    camera_on_base = U.check_base_camera(agent.env.game_info, agent.obs)
    # print(camera_on_base)
    if pos and camera_on_base:
        agent.select(C._SELECT_POINT, C._GATEWAY_GROUP_INDEX,
                     [C._DBL_CLICK, pos])
        agent.safe_action(C._CONTROL_GROUP, C._GATEWAY_GROUP_INDEX,
                          [C._SET_GROUP, C._GATEWAY_GROUP_ID])

        agent.select(C._SELECT_POINT, C._GATEWAY_GROUP_INDEX, [C._CLICK, pos])
        # agent.safe_action(C._CONTROL_GROUP, C._GATEWAY_GROUP_INDEX, [C._APPEND_GROUP, C._GATEWAY_GROUP_ID])
    else:
        agent.select(C._CONTROL_GROUP, C._GATEWAY_GROUP_INDEX,
                     [C._RECALL_GROUP, C._GATEWAY_GROUP_ID])

    use_multi = True
    if use_multi:
        for _ in range(count):
            agent.safe_action(train_action, C._GATEWAY_GROUP_INDEX,
                              [C._NOT_QUEUED])
    else:
        agent.safe_action(train_action, C._GATEWAY_GROUP_INDEX,
                          [C._NOT_QUEUED])
コード例 #2
0
def run(replay_name, replay_version, difficulty, run_config, interface, net):
    replay_path = 'D:/sc2/multi_agent/init/data/replays/' + replay_name
    print(replay_path)

    tech_ability_list = [C._A_BUILD_PYLON_S, C._A_BUILD_ASSIMILATOR_S, C._A_BUILD_GATEWAY_S, C._A_BUILD_CYBER_S]
    pop_ability_list = [C._A_SMART_SCREEN, C._A_TRAIN_PROBE, C._A_TRAIN_ZEALOT, C._A_TRAIN_STALKER]
    attack_ability_list = [C._A_ATTACK_ATTACK_MINIMAP_S, C._A_ATTACK_MINIMAP_S]
    all_ability_list = tech_ability_list + pop_ability_list + attack_ability_list

    replay_data = run_config.replay_data(replay_path)
    start_replay = sc_pb.RequestStartReplay(
        replay_data=replay_data,
        options=interface,
        disable_fog=FLAGS.disable_fog,
        observed_player_id=FLAGS.observed_player)

    with run_config.start(full_screen=FLAGS.full_screen, game_version=replay_version) as controller:
        info = controller.replay_info(replay_data)
        print(" Replay info ".center(60, "-"))
        print(info)
        print("-" * 60)
        print(" Replay difficulty: ", difficulty)
        C.difficulty = difficulty

        frame_num = info.game_duration_loops
        step_num = frame_num // FLAGS.step_mul
        sub_goal_frames = getSubGoalFrame(frame_num, replay=replay_path, fps=FLAGS.fps)

        obs_array_count = FLAGS.obs_array_count
        obs_array = [None] * obs_array_count

        controller.start_replay(start_replay)
        feature_layer = features.Features(controller.game_info())
        path = FLAGS.save_path

        high_data = np.array([])
        tech_data = np.array([])
        pop_data = np.array([])
        begin_attack = False
        for i in range(step_num):
            # to play the game in the normal speed
            controller.step(FLAGS.step_mul)
            obs = controller.observe()
            timestep = environment.TimeStep(step_type=None,
                                            reward=None,
                                            discount=None,
                                            observation=None, raw_observation=obs)
            high_goal = net.predict_high(timestep)
            # print('high_goal:', high_goal)

            obs_data = feature_layer.transform_obs(obs.observation)
            frame_idx = obs_data["game_loop"][0]
            subgoals = [1 if start <= frame_idx <= end else 0 for [start, end, subgoal] in sub_goal_frames]
            obs_array[int(i / 2) % obs_array_count] = timestep

            use_rule = True
            if use_rule:
                gateway_count = U.get_unit_num(timestep, C._GATEWAY_TYPE_INDEX)
                cyber_count = U.get_unit_num(timestep, C._CYBER_TYPE_INDEX)
                pylon_count = U.get_unit_num(timestep, C._PYLON_TYPE_INDEX)

                player_common = timestep.raw_observation.observation.player_common
                subgoals[0] = 0 if player_common.food_workers >= 22 else 1
                subgoals[1] = 1 if 1 <= gateway_count else 0
                subgoals[2] = 1 if player_common.army_count >= 10 else 0

                use_no_op = False
                if use_no_op:
                    build_wait = False
                    if gateway_count >= 4 and cyber_count >= 1 and pylon_count >= 8:
                        build_wait = True
                    if gateway_count >= 6 and pylon_count >= 10:
                        build_wait = True
                    if build_wait:
                        tech_record = get_tech_data_array(obs_array, np.array(subgoals), -1)
                        tech_data = np.append(tech_data, tech_record)

            for action in obs.actions:
                act_fl = action.action_feature_layer
                if act_fl.HasField("unit_command"):
                    high_input, tech_cost, pop_num = U.get_input(timestep, difficulty)
                    ability_id = act_fl.unit_command.ability_id
                    if ability_id in tech_ability_list:
                        # [showRawObs(timestep.raw_observation) for timestep in obs_array]
                        tech_record = get_tech_data_array(obs_array, np.array(subgoals), ability_id)
                        tech_data = np.append(tech_data, tech_record)
                    if ability_id in pop_ability_list:
                        pop_record = get_pop_data_array(obs_array, np.array(subgoals), ability_id, act_fl.unit_command)
                        pop_data = np.append(pop_data, pop_record)
                        # print('len of pop_data:', pop_data.shape)
                    if act_fl.unit_command.ability_id in attack_ability_list:
                        begin_attack = True

            print('subgoals:', subgoals)

            if FLAGS.save_data:
                record = np.zeros(C._SIZE_HIGH_NET_INPUT + C._SIZE_HIGH_NET_OUT)
                high_input, tech_cost, pop_num = U.get_input(timestep, difficulty)
                record[0:C._SIZE_HIGH_NET_INPUT] = high_input
                record[C._SIZE_HIGH_NET_INPUT:] = np.array(subgoals)
                high_data = np.append(high_data, record)

        if FLAGS.save_data:
            with open(path + "high.txt", 'ab') as f:
                np.savetxt(f, high_data.reshape(-1, C._SIZE_HIGH_NET_INPUT + C._SIZE_HIGH_NET_OUT))
            with open(path + "tech.txt", 'ab') as f:
                np.savetxt(f, tech_data.reshape(-1, 26 + 1))
            with open(path + "pop.txt", 'ab') as f:
                np.savetxt(f, pop_data.reshape(-1, 30 + 1))