def process_replay(self, controller, replay_data, map_data, player_id, race, replay_path): controller.start_replay( sc_pb.RequestStartReplay( # Start the replay replay_data=replay_data, map_data=map_data, options=interface, observed_player_id=player_id)) save_folder = os.path.join(FLAGS.save_path, race) actions = [] while True: controller.step(FLAGS.step_mul) obs = controller.observe() # Get observation from current frame actions.append( [MessageToJson(a) for a in obs.actions] ) # Save all actions observed between previous and current observed frame if obs.player_result: # Player result obtained means game has ended with open( os.path.join( save_folder, '{}@{}'.format(player_id, os.path.basename(replay_path))), 'w') as f: json.dump(actions, f) return
def process_replay(self, controller, replay_data, map_data, player_id, race, replay_path): controller.start_replay( sc_pb.RequestStartReplay(replay_data=replay_data, map_data=map_data, options=interface, observed_player_id=player_id)) save_folder = os.path.join(FLAGS.save_path, race) actions = [] controller.step() while True: obs = controller.observe() actions.append([MessageToJson(a) for a in obs.actions]) if obs.player_result: with open( os.path.join( save_folder, '{}@{}'.format(player_id, os.path.basename(replay_path))), 'w') as f: json.dump(actions, f) return controller.step(FLAGS.step_mul)
async def start_replay(self, replay_path: str, realtime: bool, observed_id: int = 0): ifopts = sc_pb.InterfaceOptions(raw=True, score=True, show_cloaked=True, raw_affects_selection=True, raw_crop_to_playable_area=False) if platform.system() == "Linux": replay_name = Path(replay_path).name home_replay_folder = Path.home( ) / "Documents" / "StarCraft II" / "Replays" if str(home_replay_folder / replay_name) != replay_path: logger.warning( f"Linux detected, please put your replay in your home directory at {home_replay_folder}. It was detected at {replay_path}" ) raise FileNotFoundError replay_path = replay_name req = sc_pb.RequestStartReplay(replay_path=replay_path, observed_player_id=observed_id, realtime=realtime, options=ifopts) result = await self._execute(start_replay=req) assert result.status == 4, f"{result.start_replay.error} - {result.start_replay.error_details}" return result
def __init__(self, replay_file_path, agent, player_id=1, screen_size_px=(64, 64), # (60, 60) minimap_size_px=(64, 64), # (60, 60) discount=1., frames_per_game=1): print("Parsing " + replay_file_path) self.replay_file_name = replay_file_path.split("/")[-1].split(".")[0] self.agent = agent self.discount = discount self.frames_per_game = frames_per_game self.run_config = run_configs.get() self.sc2_proc = self.run_config.start() self.controller = self.sc2_proc.controller replay_data = self.run_config.replay_data(self.replay_file_name + '.SC2Replay') ping = self.controller.ping() self.info = self.controller.replay_info(replay_data) # print(self.info) if not self._valid_replay(self.info, ping): self.sc2_proc.close() # print(self.info) raise Exception("{} is not a valid replay file!".format(self.replay_file_name + '.SC2Replay')) # global FILE_OP # FILE_OP.write(self.replay_file_name + '.SC2Replay') # self.replay_file_name = self.info.map_name+'_'+self.replay_file_name # for player_info in self.info.player_info: # race = sc_common.Race.Name(player_info.player_info.race_actual) # self.replay_file_name = race + '_' + self.replay_file_name screen_size_px = point.Point(*screen_size_px) minimap_size_px = point.Point(*minimap_size_px) interface = sc_pb.InterfaceOptions( raw=False, score=True, feature_layer=sc_pb.SpatialCameraSetup(width=24)) screen_size_px.assign_to(interface.feature_layer.resolution) minimap_size_px.assign_to(interface.feature_layer.minimap_resolution) map_data = None if self.info.local_map_path: map_data = self.run_config.map_data(self.info.local_map_path) self._episode_length = self.info.game_duration_loops self._episode_steps = 0 self.controller.start_replay(sc_pb.RequestStartReplay( replay_data=replay_data, map_data=map_data, options=interface, observed_player_id=player_id)) self._state = StepType.FIRST
def openReplay(self, replay_file_path, player_id=1, screen_size_px=(84, 84), minimap_size_px=(64, 64)): replay_data = self.run_config.replay_data(replay_file_path) ping = self.controller.ping() info = self.controller.replay_info(replay_data) if not self._valid_replay(info, ping): raise Exception( "{} is not a valid replay file!".format(replay_file_path)) screen_size_px = point.Point(*screen_size_px) minimap_size_px = point.Point(*minimap_size_px) interface = sc_pb.InterfaceOptions( raw=False, score=True, feature_layer=sc_pb.SpatialCameraSetup(width=24)) screen_size_px.assign_to(interface.feature_layer.resolution) minimap_size_px.assign_to(interface.feature_layer.minimap_resolution) map_data = None if info.local_map_path: map_data = self.run_config.map_data(info.local_map_path) self._episode_length = info.game_duration_loops self._episode_steps = 0 self.controller.start_replay( sc_pb.RequestStartReplay(replay_data=replay_data, map_data=map_data, options=interface, observed_player_id=player_id)) self._state = StepType.FIRST
def start_replay(self, replay_data): start_replay = sc_pb.RequestStartReplay( replay_data=replay_data, map_data=self._map_data, options=self._config.interface, disable_fog=False, observed_player_id=self._config.player_id) self._controller.start_replay(start_replay)
def main(argv): if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") stopwatch.sw.enable() interface = sc_pb.InterfaceOptions() interface.raw = FLAGS.use_feature_units or FLAGS.use_raw_units interface.score = True interface.feature_layer.width = 24 if FLAGS.feature_screen_size and FLAGS.feature_minimap_size: FLAGS.feature_screen_size.assign_to(interface.feature_layer.resolution) FLAGS.feature_minimap_size.assign_to( interface.feature_layer.minimap_resolution) if FLAGS.rgb_screen_size and FLAGS.rgb_minimap_size: FLAGS.rgb_screen_size.assign_to(interface.render.resolution) FLAGS.rgb_minimap_size.assign_to(interface.render.minimap_resolution) run_config = run_configs.get() replay_data = run_config.replay_data(FLAGS.replay) start_replay = sc_pb.RequestStartReplay( replay_data=replay_data, options=interface, observed_player_id=1) version = replay.get_replay_version(replay_data) run_config = run_configs.get(version=version) # Replace the run config. try: with run_config.start( want_rgb=interface.HasField("render")) as controller: info = controller.replay_info(replay_data) print(" Replay info ".center(60, "-")) print(info) print("-" * 60) map_path = FLAGS.map_path or info.local_map_path if map_path: start_replay.map_data = run_config.map_data(map_path) controller.start_replay(start_replay) feats = features.features_from_game_info( game_info=controller.game_info(), use_feature_units=FLAGS.use_feature_units, use_raw_units=FLAGS.use_raw_units, use_unit_counts=interface.raw, use_camera_position=False, action_space=actions.ActionSpace.FEATURES) while True: controller.step(FLAGS.step_mul) obs = controller.observe() feats.transform_obs(obs) if obs.player_result: break except KeyboardInterrupt: pass print(stopwatch.sw)
def _extract_zstat(replay_name, player_id, game_version): logging.info('enter _extract_zstat') logging.info('replay_name: {}, player_id: {}, game_version: {}'.format( replay_name, player_id, game_version)) # skip if the zstat exists in db if db.get(rep_info_to_unique_key(replay_name, player_id)) is not None: logging.info('leave _extract_z_stat, skipped.') return MyResult(replay_name, player_id, game_version, None, 'skipped writing') # set the correct sc2 bin path according to the game_version if game_version != '4.7.1' or 'SC2PATH' in os.environ: # os.environ['SC2PATH'] = '/root/{}'.format(game_version) os.environ['SC2PATH'] = path.join(FLAGS.sc2mv_bin_root, game_version) # what converter cvt_cls = getattr(pb2zstat_lib, FLAGS.zstat_converter) pb2zstat_cvt = cvt_cls() # reset analyzers pb2zstat_cvt.reset() # reset env/process replay from the beginning run_config = run_configs.get() replay_path = path.join(FLAGS.replay_dir, '%s.SC2Replay' % replay_name) replay_data = run_config.replay_data(replay_path) # step each frame w. step_mul with run_config.start(version=game_version) as controller: replay_info = controller.replay_info(replay_data) map_data = None if replay_info.local_map_path: map_data = run_config.map_data(replay_info.local_map_path) print('using local_map_path {}'.format(replay_info.local_map_path)) assert player_id in set(p.player_info.player_id for p in replay_info.player_info) controller.start_replay(sc_pb.RequestStartReplay( replay_data=replay_data, map_data=map_data, options=get_dft_sc2_interface(), observed_player_id=player_id, disable_fog=False)) controller.step() while True: pb_obs = controller.observe() zstat = pb2zstat_cvt.convert(pb_obs) if pb_obs.player_result: # episode end, the zstat to this extent is what we need break # step the replay controller.step(FLAGS.step_mul) logging.info('writing to db...') db.put(rep_info_to_unique_key(replay_name, player_id), zstat) logging.info('leave _extract_z_stat') return MyResult(replay_name, player_id, game_version, zstat, 'successful writing to db')
def process_replay(self, controller, replay_data, map_data, player_id): """Process a single replay, updating the stats.""" self._update_stage("start_replay") controller.start_replay( sc_pb.RequestStartReplay(replay_data=replay_data, map_data=map_data, options=interface, observed_player_id=player_id)) feat = features.Features(controller.game_info()) self.stats.replay_stats.replays += 1 self._update_stage("step") controller.step() while True: self.stats.replay_stats.steps += 1 self._update_stage("observe") obs = controller.observe() for action in obs.actions: act_fl = action.action_feature_layer if act_fl.HasField("unit_command"): self.stats.replay_stats.made_abilities[ act_fl.unit_command.ability_id] += 1 if act_fl.HasField("camera_move"): self.stats.replay_stats.camera_move += 1 if act_fl.HasField("unit_selection_point"): self.stats.replay_stats.select_pt += 1 if act_fl.HasField("unit_selection_rect"): self.stats.replay_stats.select_rect += 1 if action.action_ui.HasField("control_group"): self.stats.replay_stats.control_group += 1 try: func = feat.reverse_action(action).function except ValueError: func = -1 self.stats.replay_stats.made_actions[func] += 1 for valid in obs.observation.abilities: self.stats.replay_stats.valid_abilities[valid.ability_id] += 1 for u in obs.observation.raw_data.units: self.stats.replay_stats.unit_ids[u.unit_type] += 1 for ability_id in feat.available_actions(obs.observation): self.stats.replay_stats.valid_actions[ability_id] += 1 if obs.player_result: break self._update_stage("step") controller.step(FLAGS.step_mul)
def __init__(self, replay_file_path, agent, discount=1.): #frames_per_game=1): print("Parsing " + replay_file_path) self.replay_file_name = replay_file_path.split("\\")[-1].split(".")[0] self.agent = agent self.discount = discount #self.frames_per_game = frames_per_game self.run_config = run_configs.get() versions = self.run_config.get_versions() self.sc2_proc = self.run_config.start(version=versions['4.10.1']) self.controller = self.sc2_proc.controller ping = self.controller.ping() replay_data = self.run_config.replay_data(replay_file_path) try: self.info = self.controller.replay_info(replay_data) except Exception as e: raise Exception(e) if not self._valid_replay(self, self.info, ping): # #os.remove(replay_file_path) raise Exception("{} Was a loser".format(replay_file_path)) _screen_size_px = point.Point(*self.screen_size_px) _minimap_size_px = point.Point(*self.minimap_size_px) interface = sc_pb.InterfaceOptions( feature_layer=sc_pb.SpatialCameraSetup( width=self.camera_width), # crop_to_playable_area=True), show_cloaked=True, raw=True ) #, raw_affects_selection=True,raw_crop_to_playable_area=True) _screen_size_px.assign_to(interface.feature_layer.resolution) _minimap_size_px.assign_to(interface.feature_layer.minimap_resolution) map_data = None if self.info.local_map_path: map_data = self.run_config.map_data(self.info.local_map_path) self._episode_length = self.info.game_duration_loops self._episode_steps = 0 self.controller.start_replay( sc_pb.RequestStartReplay(replay_data=replay_data, map_data=map_data, options=interface, observed_player_id=self.player_id)) self._state = StepType.FIRST
def start_replay(self, replay_path): replay_data = self._run_config.replay_data(replay_path) replay_info = self._controller.replay_info(replay_data) map_data = self._run_config.map_data(replay_info.local_map_path) start_replay = sc_pb.RequestStartReplay( replay_data=replay_data, map_data=map_data, options=self._interface, disable_fog=False, observed_player_id=self._player_id) self._controller.start_replay(start_replay) return replay_info
def __init__(self, replay_file_path, agent, player_id=1, screen_size_px=(64, 64), minimap_size_px=(64, 64), discount=1., step_mul=1): self.agent = agent self.discount = discount self.step_mul = step_mul self.skip = 10 self.replay_file_name = replay_file_path.split("/")[-1].split(".")[0] self.run_config = run_configs.get() self.sc2_proc = self.run_config.start() self.controller = self.sc2_proc.controller replay_data = self.run_config.replay_data(replay_file_path) ping = self.controller.ping() self.info = self.controller.replay_info(replay_data) if not self._valid_replay(self.info, ping): raise Exception( "{} is not a valid replay file!".format(replay_file_path)) screen_size_px = point.Point(*screen_size_px) minimap_size_px = point.Point(*minimap_size_px) interface = sc_pb.InterfaceOptions( raw=False, score=True, feature_layer=sc_pb.SpatialCameraSetup(width=24)) screen_size_px.assign_to(interface.feature_layer.resolution) minimap_size_px.assign_to(interface.feature_layer.minimap_resolution) map_data = None if self.info.local_map_path: map_data = self.run_config.map_data(self.info.local_map_path) self._episode_length = self.info.game_duration_loops self._episode_steps = 0 self.controller.start_replay( sc_pb.RequestStartReplay(replay_data=replay_data, map_data=map_data, options=interface, observed_player_id=player_id)) self._state = StepType.FIRST
def start_replay(self): """Switch from the game to a replay.""" self.step(300) replay_data = self._controllers[0].save_replay() self._parallel.run(c.leave for c in self._controllers) for player_id, controller in enumerate(self._controllers): controller.start_replay( sc_pb.RequestStartReplay(replay_data=replay_data, map_data=self._map_data, options=self._interface, disable_fog=self._disable_fog, observed_player_id=player_id + 1)) self.in_game = False self.step() # Get into the game properly.
def extract_macro_actions( controller, replay_data, map_data, player_id, macro_action_frames ): """ This function takes macro_action_frames (given by extract_action_frames) and moves through the replay only considering the places in which macro actions took place. """ controller.start_replay( sc_pb.RequestStartReplay( replay_data=replay_data, map_data=map_data, options=interface, observed_player_id=player_id, ) ) obs = controller.observe() abilities = controller.data_raw().abilities # a dict of action dics which is to be merged to the other no-ops actions. actions = {} states = {} scores = {} past_frame = obs.observation.game_loop for frame in macro_action_frames: if past_frame == 0: controller.step(frame - past_frame) else: controller.step(frame - past_frame - 1) obs = controller.observe() assert obs.observation.game_loop == frame for _ in range(STEP_MULT): obs = controller.observe() frame_id = obs.observation.game_loop new_actions = get_actions(obs.actions, abilities) if len(new_actions) > 0: # i.e. if they're not no-ops: states[str(frame_id)] = get_state(obs.observation) actions[str(frame_id)] = new_actions scores[str(frame_id)] = get_score(obs.observation) controller.step(1) past_frame = obs.observation.game_loop return states, actions, scores
def __init__(self, replay_file_path, agent, player_id=1, discount=1., step_mul=1): self.agent = agent self.discount = discount self.step_mul = step_mul # lib.version_dict self.run_config = run_configs.get() # self.run_config.lib. self.sc2_proc = self.run_config.start() self.controller = self.sc2_proc.controller # self.sc2_proc.version = sc_pb.RequestReplayInfo.download_data replay_data = self.run_config.replay_data(replay_file_path) ping = self.controller.ping() # sc_process. self.info = self.controller.replay_info(replay_data) if not self._valid_replay(self.info, ping): raise Exception( "{} is not a valid replay file!".format(replay_file_path)) _screen_size_px = point.Point(*self.screen_size_px) _minimap_size_px = point.Point(*self.minimap_size_px) interface = sc_pb.InterfaceOptions( raw=False, score=True, feature_layer=sc_pb.SpatialCameraSetup(width=self.camera_width, crop_to_playable_area=True)) _screen_size_px.assign_to(interface.feature_layer.resolution) _minimap_size_px.assign_to(interface.feature_layer.minimap_resolution) map_data = None if self.info.local_map_path: map_data = self.run_config.map_data(self.info.local_map_path) self._episode_length = self.info.game_duration_loops self._episode_steps = 0 self.controller.start_replay( sc_pb.RequestStartReplay(replay_data=replay_data, map_data=map_data, options=interface, observed_player_id=player_id)) self._state = StepType.FIRST
def debug_pb2all_converter(): pb2all = PB2AllConverter(zstat_data_src=FLAGS.zstat_data_src, dict_space=True, game_version=FLAGS.game_version, delete_dup_action='v2', sort_executors='v2', inj_larv_rule=True) pb2all.reset(replay_name=FLAGS.replay_path.split('/')[-1], player_id=FLAGS.player_id, mmr=6000, map_name=FLAGS.map_name) # reset env/process replay from the beginning run_config = run_configs.get() replay_path = path.join(FLAGS.replay_path) replay_data = run_config.replay_data(replay_path) # step each frame w. step_mul with run_config.start(version=FLAGS.game_version) as controller: replay_info = controller.replay_info(replay_data) print(replay_info) controller.start_replay( sc_pb.RequestStartReplay(replay_data=replay_data, map_data=None, options=get_replay_actor_interface( FLAGS.map_name), observed_player_id=FLAGS.player_id, disable_fog=False)) controller.step() last_pb = None last_game_info = None while True: pb_obs = controller.observe() game_info = controller.game_info() if last_pb is None: last_pb = pb_obs last_game_info = game_info continue if pb_obs.player_result: # episode end, the zstat to this extent is what we need break # pb2all data = pb2all.convert(pb=(last_pb, last_game_info), next_pb=(pb_obs, game_info)) last_pb = pb_obs last_game_info = game_info # step the replay controller.step(1) # step_mul
async def start_replay(self, replay_path, realtime, observed_id=0): # Added ifopts = sc_pb.InterfaceOptions(raw=True, score=True, show_cloaked=True, raw_affects_selection=False, raw_crop_to_playable_area=False) req = sc_pb.RequestStartReplay(replay_path=replay_path, observed_player_id=observed_id, options=ifopts) result = await self._execute(start_replay=req) return result
def main(_): run_config = run_configs.get() replay_list = sorted(run_config.replay_paths(FLAGS.input_dir)) print(len(replay_list), "replays found.\n") version = replay.get_replay_version(run_config.replay_data(replay_list[0])) run_config = run_configs.get(version=version) # Replace the run config. with run_config.start(want_rgb=False) as controller: for replay_path in replay_list: replay_data = run_config.replay_data(replay_path) info = controller.replay_info(replay_data) print(" Starting replay: ".center(60, "-")) print("Path:", replay_path) print("Size:", len(replay_data), "bytes") print(" Replay info: ".center(60, "-")) print(info) print("-" * 60) start_replay = sc_pb.RequestStartReplay( replay_data=replay_data, options=sc_pb.InterfaceOptions(score=True), record_replay=True, observed_player_id=1) if info.local_map_path: start_replay.map_data = run_config.map_data( info.local_map_path, len(info.player_info)) controller.start_replay(start_replay) while True: controller.step(1000) obs = controller.observe() if obs.player_result: print("Stepped", obs.observation.game_loop, "game loops") break replay_data = controller.save_replay() replay_save_loc = os.path.join(FLAGS.output_dir, os.path.basename(replay_path)) with open(replay_save_loc, "wb") as f: f.write(replay_data) print("Wrote replay, ", len(replay_data), " bytes to:", replay_save_loc)
def process_replay(self, controller, replay_data, map_data, player_id, actions, ostream, global_info_path): controller.start_replay(sc_pb.RequestStartReplay( replay_data=replay_data, map_data=map_data, options=interface, observed_player_id=player_id)) global_info = {'game_info': controller.game_info(), 'data_raw': controller.data_raw()} with open(global_info_path, 'w') as f: json.dump({k:MessageToJson(v) for k, v in global_info.items()}, f) controller.step() for pre_id, id in zip(actions[:-1], actions[1:]): controller.step(id - pre_id) obs = controller.observe() ostream.write(obs)
def process_replay(self, controller, replay_data, map_data, player_id, actions, ostream, global_info_path): controller.start_replay( sc_pb.RequestStartReplay(replay_data=replay_data, map_data=map_data, options=interface, observed_player_id=player_id)) global_info = { 'game_info': controller.game_info(), # Get the basic information about the game 'data_raw': controller.data_raw() } # Get the raw static data for the current game with open(global_info_path, 'w') as f: json.dump({k: MessageToJson(v) for k, v in global_info.items()}, f) for pre_id, id in zip( actions[:-1], actions[1:] ): # Loop through all the steps, zip creates pairs of previous and current frame ids controller.step(id - pre_id) obs = controller.observe() ostream.write(obs) # Save observations
def process_replay(self, exporter, controller, replay_data, map_data, replay_path, replay_info, player_id): """Process a single replay, updating the stats.""" self._update_stage("start_replay") controller.start_replay( sc_pb.RequestStartReplay(replay_data=replay_data, map_data=map_data, options=self.interface(), observed_player_id=player_id)) feat = features.features_from_game_info( controller.game_info(), use_feature_units=False, action_space=actions.ActionSpace[FLAGS.action_space.upper()]) steps = 0 self.stats.replay_stats.replays += 1 self._update_stage("step") controller.step() while True: self.stats.replay_stats.steps += 1 steps += 1 self._update_stage("observe") o = controller.observe() try: obs = feat.transform_obs(o) encode(exporter, obs, player_id, replay_path, replay_info, steps) except ValueError: self.stats.replay_stats.invalid_states += 1 if o.player_result: break self._update_stage("step") controller.step(self.step_multplier)
def process_replay(self, controller, replay_data, map_data, player_id): """Process a single replay, updating the stats.""" self._update_stage("start_replay") controller.start_replay( sc_pb.RequestStartReplay(replay_data=replay_data, map_data=map_data, options=interface, observed_player_id=player_id)) feat = features.Features(controller.game_info()) self.stats.replay_stats.replays += 1 self._update_stage("step") controller.step() last_group_id = 0 last_reward = 0 last_x = 0 last_y = 0 player0 = [0, 0] player1 = [0, 0] replay_step = 0 while True: try: done = False self.stats.replay_stats.steps += 1 self._update_stage("observe") obs = controller.observe() obs_trans = feat.transform_obs(obs.observation) #print("obs_trans :", obs_trans) mineral_map = ( obs_trans['screen'][5] == 3).astype(int).tolist() # selected = np.array(obs_trans['screen'][7]) # for (y,x), v in np.ndenumerate(selected): # if(v == 1 and last_group_id == 0): # player0 = [x,y] # elif(v == 1 and last_group_id == 1): # player1 = [x,y] # remain_minerals = np.sum(mineral_map) #print("remain_minerals :", remain_minerals) episode_reward = int(obs_trans['player'][1] / 100) if (episode_reward == 0 and last_reward > 0): done = True reward = 0 else: reward = episode_reward - last_reward last_reward = episode_reward #print("episode_reward :", episode_reward, " reward :", reward, " done :", done) for action in obs.actions: act_fl = action.action_feature_layer #func = self.run_config.actions[obs.observation.game_loop] #func = self.run_config.actions[obs.game_loop] # print(str(func)) if act_fl.HasField("unit_command"): ability_id = act_fl.unit_command.ability_id # 16 move screen target_coord = act_fl.unit_command.target_screen_coord if (ability_id == 16 and last_x != target_coord.x and last_y != target_coord.y): replay_step += 1 last_x = target_coord.x last_y = target_coord.y #print("base_action :", last_group_id, "x,y :", target_coord.x, target_coord.y) td_map = { "step": replay_step, "done": done, "obs": mineral_map, "base_action": last_group_id, "x": target_coord.x, "y": target_coord.y } last_td_map = td_map last_td_map["reward"] = reward # td_map_str = json.dumps(td_map) # f.write(td_map_str) self.stats.replay_stats.td_maps.append(last_td_map) self.stats.replay_stats.made_abilities[ act_fl.unit_command.ability_id] += 1 if act_fl.HasField("camera_move"): self.stats.replay_stats.camera_move += 1 if act_fl.HasField("unit_selection_point"): self.stats.replay_stats.select_pt += 1 if act_fl.HasField("unit_selection_rect"): self.stats.replay_stats.select_rect += 1 if action.action_ui.HasField("control_group"): control_group_action = action.action_ui.control_group.action control_group_id = action.action_ui.control_group.control_group_index #print("control_group_action ", control_group_action ," id :", control_group_id) last_group_id = control_group_id self.stats.replay_stats.control_group += 1 try: func = feat.reverse_action(action).function except ValueError: func = -1 self.stats.replay_stats.made_actions[func] += 1 for valid in obs.observation.abilities: self.stats.replay_stats.valid_abilities[ valid.ability_id] += 1 for u in obs.observation.raw_data.units: self.stats.replay_stats.unit_ids[u.unit_type] += 1 for ability_id in feat.available_actions(obs.observation): self.stats.replay_stats.valid_actions[ability_id] += 1 if obs.player_result: break except Exception as e: print("e:", e) self._update_stage("step") controller.step(FLAGS.step_mul)
def get_random_trajectory(self): function_dict = {} for _FUNCTION in actions._FUNCTIONS: #print(_FUNCTION) function_dict[_FUNCTION.ability_id] = _FUNCTION.name race_list = ['Terran', 'Zerg', 'Protoss'] """How many agent steps the agent has been trained for.""" run_config = run_configs.get() sc2_proc = run_config.start() controller = sc2_proc.controller #print ("source: {}".format(source)) #root_path = '/media/kimbring2/Steam/StarCraftII/Replays/4.8.2.71663-20190123_035823-1' root_path = self.source file_list = glob.glob(root_path + '*.*') #print ("file_list: {}".format(file_list)) for i in range(0, 500): #print("i: " + str(i)) replay_file_path = random.choice(file_list) #print ("replay_file_path: {}".format(replay_file_path)) #replay_file_path = root_path + '0a0f62052fe4311368910ad38c662bf979e292b86ad02b49b41a87013e58c432.SC2Replay' #replay_file_path = root_path + '/0a1b09abc9e98f4e0c3921ae0a427c27e97c2bbdcf34f50df18dc41cea3f3249.SC2Replay' #replay_file_path_2 = root_path + '/0a01d32e9a98e1596b88bc2cdec7752249b22aca774e3305dae2e93efef34be3.SC2Replay' #replay_file_path_0 = human_data #print ("replay_file_path: {}".format(replay_file_path)) try: replay_data = run_config.replay_data(replay_file_path) ping = controller.ping() info = controller.replay_info(replay_data) print("ping: " + str(ping)) print("replay_info: " + str(info)) player0_race = info.player_info[0].player_info.race_actual player0_mmr = info.player_info[0].player_mmr player0_apm = info.player_info[0].player_apm player0_result = info.player_info[0].player_result.result print("player0_race: " + str(player0_race)) print("player0_mmr: " + str(player0_mmr)) print("player0_apm: " + str(player0_apm)) print("player0_result: " + str(player0_result)) home_race = race_list.index(self.home_race_name) + 1 if (home_race == player0_race): print("player0_race pass") else: print("player0_race fail") continue if (player0_mmr >= self.replay_filter): print("player0_mmr pass ") else: print("player0_mmr fail") continue player1_race = info.player_info[0].player_info.race_actual player1_mmr = info.player_info[0].player_mmr player1_apm = info.player_info[0].player_apm player1_result = info.player_info[0].player_result.result print("player1_race: " + str(player1_race)) print("player1_mmr: " + str(player1_mmr)) print("player1_apm: " + str(player1_apm)) print("player1_result: " + str(player1_result)) away_race = race_list.index(self.away_race_name) + 1 if (away_race == player1_race): print("player1_race pass ") else: print("player1_race fail ") continue if (player1_mmr >= self.replay_filter): print("player1_mmr pass ") else: print("player1_mmr fail") continue screen_size_px = (128, 128) minimap_size_px = (64, 64) player_id = 1 discount = 1. step_mul = 8 screen_size_px = point.Point(*screen_size_px) minimap_size_px = point.Point(*minimap_size_px) interface = sc_pb.InterfaceOptions( raw=False, score=True, feature_layer=sc_pb.SpatialCameraSetup(width=24)) screen_size_px.assign_to(interface.feature_layer.resolution) minimap_size_px.assign_to( interface.feature_layer.minimap_resolution) map_data = None if info.local_map_path: map_data = run_config.map_data(info.local_map_path) _episode_length = info.game_duration_loops _episode_steps = 0 controller.start_replay( sc_pb.RequestStartReplay(replay_data=replay_data, map_data=map_data, options=interface, observed_player_id=player_id)) _state = StepType.FIRST if (info.HasField("error") or info.base_build != ping.base_build or # different game version info.game_duration_loops < 1000 or len(info.player_info) != 2): # Probably corrupt, or just not interesting. print("error") continue feature_screen_size = 128 feature_minimap_size = 64 rgb_screen_size = None rgb_minimap_size = None action_space = None use_feature_units = True agent_interface_format = sc2_env.parse_agent_interface_format( feature_screen=feature_screen_size, feature_minimap=feature_minimap_size, rgb_screen=rgb_screen_size, rgb_minimap=rgb_minimap_size, action_space=action_space, use_feature_units=use_feature_units) _features = features.features_from_game_info( controller.game_info()) build_info = [] build_name = [] replay_step = 0 print("True loop") while True: replay_step += 1 print("replay_step: " + str(replay_step)) controller.step(step_mul) obs = controller.observe() self.home_trajectory.append(obs) if (len(obs.actions) != 0): action = (obs.actions)[0] action_spatial = action.action_feature_layer unit_command = action_spatial.unit_command ability_id = unit_command.ability_id function_name = function_dict[ability_id] if (function_name != 'build_queue'): function_name_parse = function_name.split('_') function_name_first = function_name_parse[0] #print("function_name_first: " + str(function_name_first)) if (function_name_first == 'Build' or function_name_first == 'Train'): unit_name = function_name_parse[1] unit_info = int( units_new.get_unit_type( self.home_race_name, unit_name)) #print("unit_name: " + str(unit_name)) #print("unit_info: " + str(unit_info)) #print("function_name_parse[1]: " + str(function_name_parse[1])) build_name.append(unit_name) build_info.append(unit_info) if obs.player_result: # Episide over. _state = StepType.LAST discount = 0 else: discount = discount _episode_steps += step_mul agent_obs = _features.transform_obs(obs) step = TimeStep(step_type=_state, reward=0, discount=discount, observation=agent_obs) score_cumulative = agent_obs['score_cumulative'] score_cumulative_dict = {} score_cumulative_dict['score'] = score_cumulative.score score_cumulative_dict[ 'idle_production_time'] = score_cumulative.idle_production_time score_cumulative_dict[ 'idle_worker_time'] = score_cumulative.idle_worker_time score_cumulative_dict[ 'total_value_units'] = score_cumulative.total_value_units score_cumulative_dict[ 'total_value_structures'] = score_cumulative.total_value_structures score_cumulative_dict[ 'killed_value_units'] = score_cumulative.killed_value_units score_cumulative_dict[ 'killed_value_structures'] = score_cumulative.killed_value_structures score_cumulative_dict[ 'collected_minerals'] = score_cumulative.collected_minerals score_cumulative_dict[ 'collected_vespene'] = score_cumulative.collected_vespene score_cumulative_dict[ 'collection_rate_minerals'] = score_cumulative.collection_rate_minerals score_cumulative_dict[ 'collection_rate_vespene'] = score_cumulative.collection_rate_vespene score_cumulative_dict[ 'spent_minerals'] = score_cumulative.spent_minerals score_cumulative_dict[ 'spent_vespene'] = score_cumulative.spent_vespene if obs.player_result: break _state = StepType.MID self.home_BO = build_info self.away_BU = score_cumulative_dict break except: continue
def __init__(self, replay_file_path, parser_objects, player_id=1, screen_size=(64, 64), minimap_size=(64, 64), discount=1., step_mul=1, override=False): self.replay_file_path = os.path.abspath(replay_file_path) self.replay_name = os.path.split(replay_file_path)[-1].replace( '.SC2Replay', '') self.write_dir = os.path.join(FLAGS.result_dir, FLAGS.race_matchup, self.replay_name) if isinstance(parser_objects, list): self.parsers = [p_obj(self.write_dir) for p_obj in parser_objects] elif issubclass(parser_objects, ParserBase): self.parsers = [parser_objects(self.write_dir)] else: raise ValueError( "Argument 'parsers' expects a single or list of Parser objects." ) self.player_id = player_id self.discount = discount self.step_mul = step_mul self.override = override # Configure screen size if isinstance(screen_size, tuple): self.screen_size = screen_size elif isinstance(screen_size, int): self.screen_size = (screen_size, screen_size) else: raise ValueError( "Argument 'screen_size' requires a tuple of size 2 or a single integer." ) # Configure minimap size if isinstance(minimap_size, tuple): self.minimap_size = minimap_size elif isinstance(minimap_size, int): self.minimap_size = (minimap_size, minimap_size) else: raise ValueError( "Argument 'minimap_size' requires a tuple of size 2 or a single integer." ) assert len(self.screen_size) == 2 assert len(self.minimap_size) == 2 # Arguments for 'sc_process.StarCraftProcess'. Check the following: # https://github.com/deepmind/pysc2/blob/master/pysc2/lib/sc_process.py try: sc2_process_configs = { "full_screen": False, 'timeout_seconds': 300 } self.run_config = run_configs.get() self.sc2_process = self.run_config.start(**sc2_process_configs) self.controller = self.sc2_process.controller except websocket.WebSocketTimeoutException as e: raise ConnectionRefusedError( f'Connection to SC2 process unavailable. ({e})') except protocol.ConnectionError as e: raise ConnectionRefusedError( f'Connection to SC2 process unavailable. ({e})') # Check the following links for usage of run_config and controller. # https://github.com/deepmind/pysc2/blob/master/pysc2/run_configs/platforms.py # https://github.com/deepmind/pysc2/blob/master/pysc2/lib/sc_process.py # https://github.com/deepmind/pysc2/blob/master/pysc2/lib/remote_controller.py # Load replay information & check validity. replay_data = self.run_config.replay_data(self.replay_file_path) info = self.controller.replay_info(replay_data) if not self.check_valid_replay(info, self.controller.ping()): self.safe_escape() raise ValueError('Invalid replay.') # Filter replay by race matchup if FLAGS.race_matchup is not None: if not self.check_valid_matchup(info, matchup=FLAGS.race_matchup): self.safe_escape() raise ValueError('Invalid matchup.') # Map name self.map_name = info.map_name print('...') # 'raw=True' returns enables the use of 'feature_units' # https://github.com/Blizzard/s2client-proto/blob/master/docs/protocol.md#interfaces interface = sc_pb.InterfaceOptions( raw=False, score=True, show_cloaked=False, feature_layer=sc_pb.SpatialCameraSetup(width=24, allow_cheating_layers=True)) self.screen_size = point.Point(*self.screen_size) self.minimap_size = point.Point(*self.minimap_size) self.screen_size.assign_to(interface.feature_layer.resolution) self.minimap_size.assign_to(interface.feature_layer.minimap_resolution) map_data = None if info.local_map_path: map_data = self.run_config.map_data(info.local_map_path) self._episode_length = info.game_duration_loops self._episode_steps = 0 # Request replay self.controller.start_replay(req_start_replay=sc_pb.RequestStartReplay( replay_data=replay_data, map_data=map_data, options=interface, observed_player_id=self.player_id, disable_fog=True, )) self._state = environment.StepType.FIRST self.info = info
def debug_pb2all_converter(): # reset env/process replay from the beginning run_config = run_configs.get() replay_path = path.join(FLAGS.replay_dir, FLAGS.replay_name + '.SC2Replay') replay_data = run_config.replay_data(replay_path) # step each frame w. step_mul with run_config.start(version=FLAGS.game_version) as controller: replay_info = controller.replay_info(replay_data) #print(replay_info) # ***for debugging, VERY dangerous!!*** map_name = replay_info.map_name #map_name = 'Stasis' pb2all = PB2AllConverter( zstat_data_src=FLAGS.zstat_data_src, input_map_size=(128, 128), output_map_size=(128, 128), dict_space=True, game_version=FLAGS.game_version, zmaker_version='v5', zstat_zeroing_prob=0.0, max_bo_count=50, max_bobt_count=20, sort_executors=True ) pb2all.reset( replay_name=FLAGS.replay_name, player_id=FLAGS.player_id, mmr=6000, map_name=map_name ) controller.start_replay(sc_pb.RequestStartReplay( replay_data=replay_data, map_data=None, options=get_replay_actor_interface(map_name), observed_player_id=FLAGS.player_id, disable_fog=False)) controller.step() last_pb = None last_game_info = None step = 0 while True: #print(step) pb_obs = controller.observe() game_info = controller.game_info() if last_pb is None: last_pb = pb_obs last_game_info = game_info continue if pb_obs.player_result: # episode end, the zstat to this extent is what we need break # pb2all data = pb2all.convert( pb=(last_pb, last_game_info), next_pb=(pb_obs, game_info)) last_pb = pb_obs last_game_info = game_info # step the replay controller.step(1) # step_mul step += 1
def main(unused_argv): """Run SC2 to play a game or a replay.""" stopwatch.sw.enabled = FLAGS.profile or FLAGS.trace stopwatch.sw.trace = FLAGS.trace if (FLAGS.map and FLAGS.replay) or (not FLAGS.map and not FLAGS.replay): sys.exit("Must supply either a map or replay.") if FLAGS.replay and not FLAGS.replay.lower().endswith("sc2replay"): sys.exit("Replay must end in .SC2Replay.") if FLAGS.realtime and FLAGS.replay: # TODO(tewalds): Support realtime in replays once the game supports it. sys.exit("realtime isn't possible for replays yet.") if FLAGS.render and (FLAGS.realtime or FLAGS.full_screen): sys.exit("disable pygame rendering if you want realtime or full_screen.") if platform.system() == "Linux" and (FLAGS.realtime or FLAGS.full_screen): sys.exit("realtime and full_screen only make sense on Windows/MacOS.") if not FLAGS.render and FLAGS.render_sync: sys.exit("render_sync only makes sense with pygame rendering on.") run_config = run_configs.get() interface = sc_pb.InterfaceOptions() interface.raw = FLAGS.render interface.score = True interface.feature_layer.width = 24 interface.feature_layer.resolution.x = FLAGS.screen_resolution interface.feature_layer.resolution.y = FLAGS.screen_resolution interface.feature_layer.minimap_resolution.x = FLAGS.minimap_resolution interface.feature_layer.minimap_resolution.y = FLAGS.minimap_resolution max_episode_steps = FLAGS.max_episode_steps if FLAGS.map: map_inst = maps.get(FLAGS.map) if map_inst.game_steps_per_episode: max_episode_steps = map_inst.game_steps_per_episode create = sc_pb.RequestCreateGame( realtime=FLAGS.realtime, disable_fog=FLAGS.disable_fog, local_map=sc_pb.LocalMap(map_path=map_inst.path, map_data=map_inst.data(run_config))) create.player_setup.add(type=sc_pb.Participant) create.player_setup.add(type=sc_pb.Computer, race=sc2_env.races[FLAGS.bot_race], difficulty=sc2_env.difficulties[FLAGS.difficulty]) join = sc_pb.RequestJoinGame(race=sc2_env.races[FLAGS.user_race], options=interface) game_version = None else: replay_data = run_config.replay_data(FLAGS.replay) start_replay = sc_pb.RequestStartReplay( replay_data=replay_data, options=interface, disable_fog=FLAGS.disable_fog, observed_player_id=FLAGS.observed_player) game_version = get_game_version(replay_data) with run_config.start(game_version=game_version, full_screen=FLAGS.full_screen) as controller: if FLAGS.map: controller.create_game(create) controller.join_game(join) else: info = controller.replay_info(replay_data) print(" Replay info ".center(60, "-")) print(info) print("-" * 60) map_path = FLAGS.map_path or info.local_map_path if map_path: start_replay.map_data = run_config.map_data(map_path) controller.start_replay(start_replay) if FLAGS.render: renderer = renderer_human.RendererHuman( fps=FLAGS.fps, step_mul=FLAGS.step_mul, render_sync=FLAGS.render_sync) renderer.run( run_config, controller, max_game_steps=FLAGS.max_game_steps, game_steps_per_episode=max_episode_steps, save_replay=FLAGS.save_replay) else: # Still step forward so the Mac/Windows renderer works. try: while True: frame_start_time = time.time() if not FLAGS.realtime: controller.step(FLAGS.step_mul) obs = controller.observe() if obs.player_result: break time.sleep(max(0, frame_start_time + 1 / FLAGS.fps - time.time())) except KeyboardInterrupt: pass print("Score: ", obs.observation.score.score) print("Result: ", obs.player_result) if FLAGS.map and FLAGS.save_replay: replay_save_loc = run_config.save_replay( controller.save_replay(), "local", FLAGS.map) print("Replay saved to:", replay_save_loc) # Save scores so we know how the human player did. with open(replay_save_loc.replace("SC2Replay", "txt"), "w") as f: f.write("{}\n".format(obs.observation.score.score)) if FLAGS.profile: print(stopwatch.sw)
def process_replay(self, controller, replay_data, map_data, player_id, dump_callbacks, info): """Process a single replay, updating the stats.""" match_result = info.player_info[player_id-1].player_result.result assert match_result in [1, 2] uscore = {1 : 'Victory', 2 : 'Defeat'}[match_result] self._update_stage("start_replay") controller.start_replay(sc_pb.RequestStartReplay( replay_data=replay_data, map_data=map_data, options=interface, observed_player_id=player_id)) feat = features.Features(controller.game_info()) self.stats.replay_stats.replays += 1 self._update_stage("step") replay_stepper = stepper(controller.step) replay_stepper.step() while True: self.stats.replay_stats.steps += 1 self._update_stage("observe") obs = controller.observe() actions_performed = [] for action in obs.actions: act_fl = action.action_feature_layer if act_fl.HasField("unit_command"): self.stats.replay_stats.made_abilities[ act_fl.unit_command.ability_id] += 1 if act_fl.HasField("camera_move"): self.stats.replay_stats.camera_move += 1 if act_fl.HasField("unit_selection_point"): self.stats.replay_stats.select_pt += 1 if act_fl.HasField("unit_selection_rect"): self.stats.replay_stats.select_rect += 1 if action.action_ui.HasField("control_group"): self.stats.replay_stats.control_group += 1 try: func = feat.reverse_action(action) except ValueError: func = actions.FunctionCall(-1, []) self.stats.replay_stats.made_actions[func.function] += 1 actions_performed.append(func) for valid in obs.observation.abilities: self.stats.replay_stats.valid_abilities[valid.ability_id] += 1 for u in obs.observation.raw_data.units: self.stats.replay_stats.unit_ids[u.unit_type] += 1 for ability_id in feat.available_actions(obs.observation): self.stats.replay_stats.valid_actions[ability_id] += 1 observation = feat.transform_obs(obs.observation) call_dump_callbacks(dump_callbacks, replay_stepper.step_, observation, uscore, actions_performed) if obs.player_result: break self._update_stage("step") replay_stepper.step(FLAGS.step_mul)
def debug_pb2all_converter(): # reset env/process replay from the beginning run_config = run_configs.get() replay_path = path.join(FLAGS.replay_dir, FLAGS.replay_name + '.SC2Replay') replay_data = run_config.replay_data(replay_path) game_core_config = { 'show_placeholders': True, 'show_burrowed_shadows': True } # step each frame w. step_mul with run_config.start(version=FLAGS.game_version) as controller: replay_info = controller.replay_info(replay_data) #print(replay_info) # ***for debugging, VERY dangerous!!*** map_name = replay_info.map_name #map_name = 'Stasis' pb2all = PB2AllConverter( zstat_data_src=FLAGS.zstat_data_src, input_map_size=(128, 128), output_map_size=(128, 128), dict_space=True, game_version=FLAGS.game_version, zmaker_version='v5', zstat_zeroing_prob=0.0, max_bo_count=50, max_bobt_count=20, delete_dup_action='v2', sort_executors='v2', inj_larv_rule=True, add_lurker_spine_to_units=True, add_cargo_to_units=True, use_display_type=True, distinguish_effect_camp=True, lurker_effect_decay=0.999, ) pb2all.reset(replay_name=FLAGS.replay_name, player_id=FLAGS.player_id, mmr=6000, map_name=map_name) controller.start_replay( sc_pb.RequestStartReplay(replay_data=replay_data, map_data=None, options=tleague_rep_actor_interface( map_name, game_core_config=game_core_config), observed_player_id=FLAGS.player_id, disable_fog=False)) controller.step() last_pb = None last_game_info = None step = 0 while True: #print(step) pb_obs = controller.observe() game_info = controller.game_info() if last_pb is None: last_pb = pb_obs last_game_info = game_info continue if pb_obs.player_result: # episode end, the zstat to this extent is what we need break # pb2all data = pb2all.convert(pb=(last_pb, last_game_info), next_pb=(pb_obs, game_info)) last_pb = pb_obs last_game_info = game_info # printing if data: # note: data is [((obs, act), w)] act = data[0][0][1] ab_idx = act['A_AB'] print(', '.join([ FLAGS.prefix, 'gal:{}'.format(pb_obs.observation.game_loop), # game loop 'abn:{}'.format(ZERG_ABILITIES[ab_idx][0]), # ability name ])) # step the replay controller.step(1) # step_mul step += 1
def run_loop(replay, player_id, mainDQN): """Run SC2 to play a game or a replay.""" stopwatch.sw.enabled = False stopwatch.sw.trace = False if not replay: sys.exit("Must supply a replay.") if replay and not replay.lower().endswith("sc2replay"): sys.exit("Replay must end in .SC2Replay.") run_config = run_configs.get() interface = sc_pb.InterfaceOptions() interface.raw = False interface.score = True interface.feature_layer.width = 24 interface.feature_layer.resolution.x = FLAGS.screen_size interface.feature_layer.resolution.y = FLAGS.screen_size interface.feature_layer.minimap_resolution.x = FLAGS.minimap_size interface.feature_layer.minimap_resolution.y = FLAGS.minimap_size max_episode_steps = 0 replay_data = run_config.replay_data(replay) start_replay = sc_pb.RequestStartReplay(replay_data=replay_data, options=interface, disable_fog=False, observed_player_id=player_id) with run_config.start(full_screen=False) as controller: info = controller.replay_info(replay_data) infomap = info.map_name inforace = info.player_info[player_id - 1].player_info.race_actual inforesult = info.player_info[player_id - 1].player_result.result if FLAGS.map_name and not mapNameMatch(infomap): print("map doesn't match, continue...") print("map_name:", FLAGS.map_name, "infomap:", infomap) return if FLAGS.agent_race and raceToCode(FLAGS.agent_race) != inforace: print("agent race doesn't match, continue...") print("agent_race:", raceToCode(FLAGS.agent_race), "inforace:", inforace) return if FLAGS.win_only and not inforesult: print("this player was defeated, continue...") print("result:", inforesult) return else: print("condition's satisfied, training starts :", replay) print("map :", infomap) print("player id :", player_id) print("race :", inforace) print("result :", inforesult) map_path = info.local_map_path if map_path: start_replay.map_data = run_config.map_data(map_path) controller.start_replay(start_replay) game_info = controller.game_info() _features = features.Features(game_info) action_spec = _features.action_spec() try: while True: frame_start_time = time.time() controller.step(1) obs = controller.observe() actions = obs.actions real_obs = _features.transform_obs(obs.observation) real_actions = [] for action in actions: try: real_actions.append(_features.reverse_action(action)) except ValueError: real_actions.append( actlib.FunctionCall(function=0, arguments=[])) train(mainDQN, real_obs, real_actions, action_spec) if obs.player_result: break #time.sleep(max(0, frame_start_time + 1 / FLAGS.fps - time.time())) except KeyboardInterrupt: pass print("Score: ", obs.observation.score.score) print("Result: ", obs.player_result)
def _extract(self): run_config = run_configs.get() replay_data = run_config.replay_data(self._replay_filepath) with run_config.start(version=self._version) as controller: replay_info = controller.replay_info(replay_data) mmr = None for p in replay_info.player_info: if p.player_info.player_id == self._player_id: mmr = p.player_mmr assert mmr is not None if mmr < 1: logger.log( "Encounter unknown mmr: {}, defaults it to: {}".format( mmr, self._unk_mmr_dft_to)) mmr = self._unk_mmr_dft_to #map_name = 'KairosJunction' map_name = replay_info.map_name # set map_data when available map_data = None if replay_info.local_map_path: map_data = run_config.map_data(replay_info.local_map_path) print('using local_map_path {}'.format( replay_info.local_map_path)) controller.start_replay( sc_pb.RequestStartReplay( replay_data=replay_data, map_data=map_data, options=_get_interface( map_name, game_core_config=self._game_core_config), observed_player_id=self._player_id, disable_fog=False)) controller.step() start_pos = None enable_da = False if self._da_rate > 0: if random.random() < self._da_rate: enable_da = True logger.log("data augmentation: revise") game_info = controller.game_info() start_pos = game_info.start_raw.start_locations[0] pb2pb.replace_loc(start_pos, map_name) else: logger.log("data augmentation: not revise") self._replay_converter.reset(replay_name=self._replay_name, player_id=self._player_id, mmr=mmr, map_name=map_name, start_pos=start_pos) last_obs, last_game_info = None, None while True: obs = controller.observe() if enable_da: obs = pb2pb.make_aug_data(obs, map_name) if obs.player_result: samples = self._replay_converter.convert(pb=(obs, None), next_pb=None) for data in samples: yield data break # game_info = controller.game_info() game_info = None if last_obs is not None: samples = self._replay_converter.convert( pb=(last_obs, last_game_info), next_pb=(obs, game_info)) for data in samples: yield data last_obs, last_game_info = obs, game_info controller.step(self._step_mul)