コード例 #1
0
def parse_replay(replay_player_path, sampled_action_path, reward):
    if os.path.isfile(os.path.join(FLAGS.parsed_replay_path, 'GlobalFeatures', replay_player_path)):
        return

    # Global Info
    with open(os.path.join(FLAGS.parsed_replay_path, 'GlobalInfos', replay_player_path)) as f:
        global_info = json.load(f)
    units_info = static_data.StaticData(Parse(global_info['data_raw'], sc_pb.ResponseData())).units
    feat = features.Features(Parse(global_info['game_info'], sc_pb.ResponseGameInfo()))

    # Sampled Actions
    with open(sampled_action_path) as f:
        sampled_action = json.load(f)
    sampled_action_id = [id // FLAGS.step_mul + 1 for id in sampled_action]

    # Actions
    with open(os.path.join(FLAGS.parsed_replay_path, 'Actions', replay_player_path)) as f:
        actions = json.load(f)
    actions = [None if len(actions[idx]) == 0 else Parse(actions[idx][0], sc_pb.Action())
                for idx in sampled_action_id]

    # Observations
    observations =  [obs for obs in stream.parse(os.path.join(FLAGS.parsed_replay_path,
                            'SampledObservations', replay_player_path), sc_pb.ResponseObservation)]

    assert len(sampled_action) == len(sampled_action_id) == len(actions) == len(observations)

    states = process_replay(sampled_action, actions, observations, feat, units_info, reward)

    with open(os.path.join(FLAGS.parsed_replay_path, 'GlobalFeatures', replay_player_path), 'w') as f:
        json.dump(states, f)
コード例 #2
0
    def from_vg_graph_file(cls,
                           vg_graph_file_name,
                           only_read_nodes=False,
                           use_cache_if_available=False):
        nodes = {}
        paths = []
        edges = []

        i = 0
        for line in stream.parse(vg_graph_file_name, vg_pb2.Graph):
            i += 1
            if hasattr(line, "node"):
                for node in line.node:

                    nodes[node.id] = node.sequence

            if only_read_nodes:
                continue

            if hasattr(line, "path"):
                for path in line.path:
                    paths.append(path)

            if hasattr(line, "edge"):
                for edge in line.edge:
                    assert edge.overlap == 0

        graph = cls(nodes, edges, paths)
        graph._cache_nodes()
        return graph
コード例 #3
0
def parse_replay(replay_player_path, sampled_frame_path, reward):
    if os.path.isfile(os.path.join(FLAGS.parsed_replay_path, 'GlobalFeatures', replay_player_path)):
        return

    # Global Info
    with open(os.path.join(FLAGS.parsed_replay_path, 'GlobalInfos', replay_player_path)) as f:
        global_info = json.load(f)
    units_info = static_data.StaticData(Parse(global_info['data_raw'], sc_pb.ResponseData())).units
    feat = features.features_from_game_info(Parse(global_info['game_info'], sc_pb.ResponseGameInfo()))

    # Sampled Frames
    with open(sampled_frame_path) as f:
        sampled_frames = json.load(f)
    sampled_actions_idx = [frame // FLAGS.step_mul - 1 for frame in sampled_frames] # Create index to retrieve actions corresponding to sampled frames

    # Actions
    with open(os.path.join(FLAGS.parsed_replay_path, 'Actions', replay_player_path)) as f:
        actions = json.load(f)
    sampled_actions = [None if len(actions[idx]) == 0 else Parse(actions[idx][0], sc_pb.Action()) 
                for idx in sampled_actions_idx] # Get first action executed after each sampled frame

    # Observations
    observations =  [obs for obs in stream.parse(os.path.join(FLAGS.parsed_replay_path,
                            'SampledObservations', replay_player_path), sc_pb.ResponseObservation)]

    assert len(sampled_frames) == len(sampled_actions_idx) == len(sampled_actions) == len(observations)

    states = process_replay(sampled_frames, sampled_actions, observations, feat, units_info, reward)

    with open(os.path.join(FLAGS.parsed_replay_path, 'GlobalFeatures', replay_player_path), 'w') as f:
        json.dump(states, f)
コード例 #4
0
def load_single_ticker_market_event_data_bundle(
        universe,
        atype,
        market,
        ticker,
        start_dt_str,
        end_dt_str,
        db_path,
        pb_cls=kfinformat_pb2.kfinevent):

    start_date, end_date = start_dt_str.split('T')[0], end_dt_str.split('T')[0]
    to_do_file_path_list, t_init = list(), True

    if atype in ('STOCK', 'ETF'):
        file_name = '{}_{}_RM0_{}_*-*-*.gz'.format(atype, market, ticker)
        file_path = '{}\\{}\\TICK\\{}\\{}\\{}'.format(db_path, universe, atype,
                                                      ticker, file_name)
        file_list = glob.glob(file_path)

        for fp in file_list:
            file_date = fp.split('_')[-1][:-3]
            if (file_date >= start_date) & (file_date <= end_date):
                to_do_file_path_list.append(fp)
        to_do_file_path_list.sort(reverse=False)

        for fp_inx, fp_val in enumerate(to_do_file_path_list):
            if t_init:
                t, t_init = stream.parse(ifp=fp_val, pb_cls=pb_cls), False
                continue
            t = chain(t, stream.parse(ifp=fp_val, pb_cls=pb_cls))

    elif atype in ('FUTURES', ):
        for inx, date in enumerate(
                generate_day_list(start_date=start_date, end_date=end_date)):
            file_name = '{}_{}_*_{}.gz'.format(atype, ticker, date)
            fp_val = '{}\\{}\\TICK\\{}\\{}\\{}\\{}'.format(
                db_path, universe, atype,
                ticker.split('_')[0], date, file_name)
            file_list = glob.glob(fp_val)
            if len(file_list) != 0:
                if t_init:
                    t, t_init = stream.parse(ifp=file_list[0],
                                             pb_cls=pb_cls), False
                    continue
                t = chain(t, stream.parse(ifp=file_list[0], pb_cls=pb_cls))
    return t
コード例 #5
0
def parse_replay(replay_player_path, sampled_action_path, reward, race, enemy_race, stat):
    with open(os.path.join(FLAGS.parsed_replay_path, 'GlobalInfos', replay_player_path)) as f:
        global_info = json.load(f)
    feat = SpatialFeatures(Parse(global_info['game_info'], sc_pb.ResponseGameInfo()))

    states = [obs for obs in stream.parse(os.path.join(FLAGS.parsed_replay_path,
                    'SampledObservations', replay_player_path), sc_pb.ResponseObservation)]

    # Sampled Actions
    with open(sampled_action_path) as f:
        sampled_action = json.load(f)
    sampled_action_id = [id // FLAGS.step_mul + 1 for id in sampled_action]
    # Actions
    with open(os.path.join(FLAGS.parsed_replay_path, 'Actions', replay_player_path)) as f:
        actions = json.load(f)
    actions = [None if len(actions[idx]) == 0 else Parse(actions[idx][0], sc_pb.Action())
               for idx in sampled_action_id]

    assert len(states) == len(actions)

    spatial_states_np, global_states_np = [], []
    for state, action in zip(states, actions):
        action_id = -1
        if action is not None:
            try:
                func_id = feat.reverse_action(action).function
                func_name = FUNCTIONS[func_id].name
                if func_name.split('_')[0] in {'Build', 'Train', 'Research', 'Morph', 'Cancel', 'Halt', 'Stop'}:
                    action_id = func_id
            except:
                pass

        obs = feat.transform_obs(state.observation)
        spatial_states_np.append(np.concatenate([obs['screen'], obs['minimap']], axis=0))

        global_states_np.append(np.hstack([obs['player']/(stat['max']+1e-5), obs['score'], [reward],
                                           [stat['action_id'][action_id]]]))

    spatial_states_np = np.asarray(spatial_states_np)
    global_states_np = np.asarray(global_states_np)

    spatial_states_np = spatial_states_np.reshape([len(states), -1])
    sparse.save_npz(os.path.join(FLAGS.parsed_replay_path, 'SpatialFeatureTensor',
                                 replay_player_path+'@S'), sparse.csc_matrix(spatial_states_np))
    sparse.save_npz(os.path.join(FLAGS.parsed_replay_path, 'SpatialFeatureTensor',
                                 replay_player_path+'@G'), sparse.csc_matrix(global_states_np))
コード例 #6
0
 def from_vg_snarls_file(cls, vg_snarls_file_name):
     snarls = (snarl
               for snarl in stream.parse(vg_snarls_file_name, vg_pb2.Snarl))
     return cls(snarls)
コード例 #7
0
 def get_protobuf_message(self, path, doc_id):
     """ Retrieve protocol buffer message matching 'doc_id' from binary fire """
     return [
         d for d in stream.parse(path, document_pb2.Document)
         if d.doc_id == doc_id
     ][0]
コード例 #8
0
 def get_list_protobuf_messages(self, path):
     """ Retrieve list of protocol buffer messages from binary fire """
     return [d for d in stream.parse(path, document_pb2.Document)]
コード例 #9
0
import stream
from s2clientprotocol import sc2api_pb2 as sc_pb

SAMPLED_OBSERVATION_PATH = 'parsed_replays/SampledObservations/Terran_vs_Terran/Terran/1@0008edd10656cc66b14b367400bfe5bd50ddd6bb6f5941c2921e144ef0c01f18.SC2Replay'

OBS = [
    obs for obs in stream.parse(SAMPLED_OBSERVATION_PATH,
                                sc_pb.ResponseObservation)
]
print(OBS)
コード例 #10
0
import stream
from s2clientprotocol import sc2api_pb2 as sc_pb

PATH = 'parsed_replays/SampledActions/Terran_vs_Terran/Terran/7cc6fe85694768dbab7987344196ab44615842bd896f9172ee1177a2b899ba58.SC2Replay'

# OBS =  [obs for obs in stream.parse(PATH), sc_pb.ResponseObservation)]
# print(OBS)

for obs in stream.parse(PATH):
    print(sc_pb.ResponseObservation(obs))