def collect_trajectories(actor, critic, port, game_num, p2, rnn, n_frame): logger.info(f'start fight with {p2}') logger.info(f'game_num value {game_num}') error = True while error: gateway = JavaGateway( gateway_parameters=GatewayParameters(port=port), callback_server_parameters=CallbackServerParameters()) try: manager = gateway.entry_point current_time = int(time.time() * 1000) # register AIs collect_data_helper = CollectDataHelper(logger) agent = SoundAgent(gateway, actor=actor, critic=critic, collect_data_helper=collect_data_helper, logger=logger, n_frame=n_frame, rnn=rnn) sandbox_agent = SandboxAgent(gateway) manager.registerAI(f'SoundAgent', agent) manager.registerAI('Sandbox', sandbox_agent) game = manager.createGame('ZEN', 'ZEN', 'SoundAgent', p2, game_num) # start game manager.runGame(game) # finish game logger.info('Finish game') sys.stdout.flush() # close gateway gateway.close_callback_server() gateway.close() error = False except Exception as ex: print(ex) logger.info('There is an error with the gateway, restarting') gateway.close_callback_server() gateway.close() error = True # return agent.get_trajectories_data() agent_data = process_game_agent_data(actor, critic, agent.collect_data_helper, rnn) # try: # kill_proc_tree(java_env.pid, False) # except: # print('kill process') # agent.reset() return agent_data
class FightingiceEnv_Display_NoFrameskip(gym.Env): metadata = {'render.modes': ['human']} def __init__(self, **kwargs): self.freq_restart_java = 3 self.java_env_path = os.getcwd() if "java_env_path" in kwargs.keys(): self.java_env_path = kwargs["java_env_path"] if "freq_restart_java" in kwargs.keys(): self.freq_restart_java = kwargs["freq_restart_java"] if "port" in kwargs.keys(): self.port = kwargs["port"] else: try: import port_for self.port = port_for.select_random( ) # select one random port for java env except: raise ImportError( "Pass port=[your_port] when make env, or install port_for to set startup port automatically, maybe pip install port_for can help" ) _actions = "AIR AIR_A AIR_B AIR_D_DB_BA AIR_D_DB_BB AIR_D_DF_FA AIR_D_DF_FB AIR_DA AIR_DB AIR_F_D_DFA AIR_F_D_DFB AIR_FA AIR_FB AIR_GUARD AIR_GUARD_RECOV AIR_RECOV AIR_UA AIR_UB BACK_JUMP BACK_STEP CHANGE_DOWN CROUCH CROUCH_A CROUCH_B CROUCH_FA CROUCH_FB CROUCH_GUARD CROUCH_GUARD_RECOV CROUCH_RECOV DASH DOWN FOR_JUMP FORWARD_WALK JUMP LANDING NEUTRAL RISE STAND STAND_A STAND_B STAND_D_DB_BA STAND_D_DB_BB STAND_D_DF_FA STAND_D_DF_FB STAND_D_DF_FC STAND_F_D_DFA STAND_F_D_DFB STAND_FA STAND_FB STAND_GUARD STAND_GUARD_RECOV STAND_RECOV THROW_A THROW_B THROW_HIT THROW_SUFFER" action_strs = _actions.split(" ") self.observation_space = spaces.Box(low=0, high=1, shape=(96, 64, 1)) self.action_space = spaces.Discrete(len(action_strs)) os_name = platform.system() if os_name.startswith("Linux"): self.system_name = "linux" elif os_name.startswith("Darwin"): self.system_name = "macos" else: self.system_name = "windows" if self.system_name == "linux": # first check java can be run, can only be used on Linux java_version = subprocess.check_output( 'java -version 2>&1 | awk -F[\\\"_] \'NR==1{print $2}\'', shell=True) if java_version == b"\n": raise ModuleNotFoundError("Java is not installed") else: print("Please make sure you can run java if you see some error") # second check if FightingIce is installed correct start_jar_path = os.path.join(self.java_env_path, "FightingICE.jar") start_data_path = os.path.join(self.java_env_path, "data") start_lib_path = os.path.join(self.java_env_path, "lib") lwjgl_path = os.path.join(start_lib_path, "lwjgl", "*") lib_path = os.path.join(start_lib_path, "*") start_system_lib_path = os.path.join(self.java_env_path, "lib", "natives", self.system_name) natives_path = os.path.join(start_system_lib_path, "*") if os.path.exists(start_jar_path) and os.path.exists( start_data_path) and os.path.exists( start_lib_path) and os.path.exists(start_system_lib_path): pass else: error_message = "FightingICE is not installed in your script launched path {}, set path when make() or start script in FightingICE path".format( self.java_env_path) raise FileExistsError(error_message) self.java_ai_path = os.path.join(self.java_env_path, "data", "ai") ai_path = os.path.join(self.java_ai_path, "*") if self.system_name == "windows": self.start_up_str = "{};{};{};{};{}".format( start_jar_path, lwjgl_path, natives_path, lib_path, ai_path) self.need_set_memory_when_start = True else: self.start_up_str = "{}:{}:{}:{}:{}".format( start_jar_path, lwjgl_path, natives_path, lib_path, ai_path) self.need_set_memory_when_start = False self.game_started = False self.round_num = 0 def _start_java_game(self): # start game print("Start java env in {} and port {}".format( self.java_env_path, self.port)) devnull = open(os.devnull, 'w') if self.system_name == "windows": # -Xms1024m -Xmx1024m we need set this in windows self.java_env = subprocess.Popen([ "java", "-Xms1024m", "-Xmx1024m", "-cp", self.start_up_str, "Main", "--port", str(self.port), "--py4j", "--fastmode", "--grey-bg", "--inverted-player", "1", "--mute", "--limithp", "400", "400" ], stdout=devnull, stderr=devnull) elif self.system_name == "linux": self.java_env = subprocess.Popen([ "java", "-cp", self.start_up_str, "Main", "--port", str(self.port), "--py4j", "--fastmode", "--grey-bg", "--inverted-player", "1", "--mute", "--limithp", "400", "400" ], stdout=devnull, stderr=devnull) elif self.system_name == "macos": self.java_env = subprocess.Popen([ "java", "-XstartOnFirstThread", "-cp", self.start_up_str, "Main", "--port", str(self.port), "--py4j", "--fastmode", "--grey-bg", "--inverted-player", "1", "--mute", "--limithp", "400", "400" ], stdout=devnull, stderr=devnull) # self.java_env = subprocess.Popen(["java", "-cp", "/home/myt/gym-fightingice/gym_fightingice/FightingICE.jar:/home/myt/gym-fightingice/gym_fightingice/lib/lwjgl/*:/home/myt/gym-fightingice/gym_fightingice/lib/natives/linux/*:/home/myt/gym-fightingice/gym_fightingice/lib/*", "Main", "--port", str(self.free_port), "--py4j", "--c1", "ZEN", "--c2", "ZEN","--fastmode", "--grey-bg", "--inverted-player", "1", "--mute"]) # sleep 3s for java starting, if your machine is slow, make it longer time.sleep(3) def _start_gateway(self, p2=Machete): # auto select callback server port and reset it in java env self.gateway = JavaGateway( gateway_parameters=GatewayParameters(port=self.port), callback_server_parameters=CallbackServerParameters(port=0)) python_port = self.gateway.get_callback_server().get_listening_port() self.gateway.java_gateway_server.resetCallbackClient( self.gateway.java_gateway_server.getCallbackClient().getAddress(), python_port) self.manager = self.gateway.entry_point # create pipe between gym_env_api and python_ai for java env server, client = Pipe() self.pipe = server self.p1 = GymAIDisplay(self.gateway, client, False) self.manager.registerAI(self.p1.__class__.__name__, self.p1) if isinstance(p2, str): # p2 is a java class name self.p2 = p2 self.game_to_start = self.manager.createGame( "ZEN", "ZEN", self.p1.__class__.__name__, self.p2, self.freq_restart_java) else: # p2 is a python class self.p2 = p2(self.gateway) self.manager.registerAI(self.p2.__class__.__name__, self.p2) self.game_to_start = self.manager.createGame( "ZEN", "ZEN", self.p1.__class__.__name__, self.p2.__class__.__name__, self.freq_restart_java) self.game = Thread(target=game_thread, name="game_thread", args=(self, )) self.game.start() self.game_started = True self.round_num = 0 def _close_gateway(self): self.gateway.close_callback_server() self.gateway.close() del self.gateway def _close_java_game(self): self.java_env.kill() del self.java_env self.pipe.close() del self.pipe self.game_started = False def reset(self, p2=Machete): # start java game if game is not started if self.game_started is False: try: self._close_gateway() self._close_java_game() except: pass self._start_java_game() self._start_gateway(p2) # to provide crash, restart java game in some freq if self.round_num == self.freq_restart_java * 3: # 3 is for round in one game try: self._close_gateway() self._close_java_game() self._start_java_game() except: raise SystemExit("Can not restart game") self._start_gateway(p2) # just reset is anything ok self.pipe.send("reset") self.round_num += 1 obs = self.pipe.recv() return obs def step(self, action): # check if game is running, if not try restart # when restart, dict will contain crash info, agent should do something, it is a BUG in this version if self.game_started is False: dict = {} dict["pre_game_crashed"] = True return self.reset(), 0, None, dict self.pipe.send(["step", action]) if self.pipe.poll(5): message = self.pipe.recv() new_obs, reward, done, dict = message else: new_obs, reward = self.p1.get_obs(), self.p1.get_reward() dict = {} dict["no_data_receive"] = True logging.warning( "server can not receive, request to reset the game") return new_obs, reward, True, dict return new_obs, reward, done, dict def render(self, mode='human'): # no need pass def close(self): if self.game_started: self._close_java_game()
def run(self, parent_data_objs): # Run the java py4j entry point comp_dir = self._dag_node.comp_root_path() self._logger.info("comp_dir: {}".format(comp_dir)) jar_files = glob.glob(os.path.join(comp_dir, "*.jar")) self._logger.info("Java classpath files: {}".format(jar_files)) component_class = self._dag_node.comp_class() java_jars = [self._mlcomp_jar] + jar_files class_path = ":".join(java_jars) java_gateway = None all_ok = False monitor_proc = None try: total_phys_mem_size_mb = ByteConv.from_bytes( psutil.virtual_memory().total).mbytes jvm_heap_size_option = "-Xmx{}m".format( int(math.ceil(total_phys_mem_size_mb))) java_opts = [jvm_heap_size_option] self._logger.info("JVM options: {}".format(java_opts)) # Note: the jarpath is set to be the path to the mlcomp jar since the launch_gateway code is checking # for the existence of the jar. The py4j jar is packed inside the mlcomp jar. java_port = launch_gateway(port=0, javaopts=java_opts, die_on_exit=True, jarpath=self._mlcomp_jar, classpath=class_path, redirect_stdout=sys.stdout, redirect_stderr=sys.stderr) java_gateway = JavaGateway( gateway_parameters=GatewayParameters(port=java_port), callback_server_parameters=CallbackServerParameters(port=0), python_server_entry_point=MLOpsPY4JWrapper()) python_port = java_gateway.get_callback_server( ).get_listening_port() self._logger.debug("Python port: {}".format(python_port)) java_gateway.java_gateway_server.resetCallbackClient( java_gateway.java_gateway_server.getCallbackClient(). getAddress(), python_port) mlops_wrapper = MLOpsPY4JWrapper() entry_point = java_gateway.jvm.com.parallelm.mlcomp.ComponentEntryPoint( component_class) component_via_py4j = entry_point.getComponent() component_via_py4j.setMLOps(mlops_wrapper) # Configure m = java_gateway.jvm.java.util.HashMap() for key in self._params.keys(): # py4j does not handle nested structures. So the configs which is a dict will not be passed to the java # layer now. if isinstance(self._params[key], dict): continue m[key] = self._params[key] component_via_py4j.configure(m) # Materialized l = java_gateway.jvm.java.util.ArrayList() for obj in parent_data_objs: l.append(obj) self._logger.info("Parent obj: {} type {}".format( obj, type(obj))) self._logger.info("Parent objs: {}".format(l)) if mlops_loaded: monitor_proc = ProcessMonitor(mlops, self._ml_engine) monitor_proc.start() py4j_out_objs = component_via_py4j.materialize(l) self._logger.debug(type(py4j_out_objs)) self._logger.debug(len(py4j_out_objs)) python_out_objs = [] for obj in py4j_out_objs: self._logger.debug("Obj:") self._logger.debug(obj) python_out_objs.append(obj) self._logger.info( "Done running of materialize and getting output objects") all_ok = True except Py4JJavaError as e: self._logger.error("Error in java code: {}".format(e)) raise MLCompException(str(e)) except Exception as e: self._logger.error("General error: {}".format(e)) raise MLCompException(str(e)) finally: self._logger.info("In finally block: all_ok {}".format(all_ok)) if java_gateway: java_gateway.close_callback_server() java_gateway.shutdown() if mlops_loaded and monitor_proc: monitor_proc.stop_gracefully() return python_out_objs
print("Replay: Loading") replay = manager.loadReplay("HPMode_KickAIPython_RandomAI_2017.12.07-15.51.44") # Load replay data print("Replay: Init") replay.init() # Main process for i in range(1000): # Simulate 100 frames print("Replay: Run frame", i) if i % 10 == 0 and replay.getState().name() == "UPDATE": framedata = replay.getFrameData() print("Replay: Infos") print("Replay: Round:", framedata.getRound()) print("Replay: Frame:", framedata.getFramesNumber()) print("Replay: P1 HP:", framedata.getCharacter(True).getHp()) print("Replay: P2 HP:", framedata.getCharacter(False).getHp()) sys.stdout.flush() replay.updateState() print("Replay: Close") replay.close() sys.stdout.flush() gateway.close_callback_server() gateway.close()
class RAlphTrainable(tune.Trainable): def setup(self, config): self.risk_threshold = self.config["tuning_config"]["risk_threshold"] self.max_concurrent_trials = self.config["tuning_config"][ "max_concurrent_trials"] self.payoff_empirical_min = self.config["tuning_config"][ "obj_fnc_params"]["payoff_empirical_min"] self.payoff_empirical_max = self.config["tuning_config"][ "obj_fnc_params"]["payoff_empirical_max"] self.min_eval_eps = self.config["tuning_config"]["eval_eps_min"] self.max_eval_eps = self.config["tuning_config"]["eval_eps_max"] self.lock = threading.Lock() self.train_results = queue.Queue() self.eval_results = queue.Queue() self.java_gateway = None self.my_id = self.trial_id self.start_time = time.time() self.am_stopped = False self.cur_max_score = self.payoff_empirical_min self.cur_max_train_score = self.payoff_empirical_min default_gateway = JavaGateway() while self.java_gateway is None: for i in range(self.max_concurrent_trials): try: self.my_port = 25334 + 2 * i java_port = self.my_port - 1 # if addJavaPort fails, it means it is being used by some other trial running # in a different process if not default_gateway.addJavaPort(java_port): if i == self.max_concurrent_trials - 1: time.sleep(0.5) continue gateway_params = GatewayParameters( port=java_port, enable_memory_management=False) callback_params = CallbackServerParameters( port=self.my_port) self.java_gateway = JavaGateway( gateway_parameters=gateway_params, callback_server_parameters=callback_params) break except: default_gateway.removeJavaPort(self.my_port - 1) continue default_gateway.close(keep_callback_server=True, close_callback_server_connections=False) encoded_trial_params = TuningParser.encode_trial_params(self.config) self.results_receiver = _TrialResultsReceiver(self) self.java_gateway.entry_point.runTrial(self.trial_id, encoded_trial_params, self.results_receiver) self.additional_setup() def additional_setup(self): # To be overriden by subclasses. pass def compute_stage_value(self, stage_result): stage_payoff_avg = stage_result["payoff_avg"] stage_risk_avg = stage_result["risk_avg"] if stage_risk_avg <= self.risk_threshold: return stage_payoff_avg - self.payoff_empirical_min base = math.floor(self.risk_threshold * 100) - math.ceil( stage_risk_avg * 100) dist_from_min_payoff = stage_payoff_avg - self.payoff_empirical_min payoff_range = self.payoff_empirical_max - self.payoff_empirical_min return base + dist_from_min_payoff / payoff_range def get_first_eval_batch_size(self): # Default implementation; run all available eval episodes return self.max_eval_eps def get_next_eval_batch_size(self, eval_result): # Default implementation; all eval eps have been done, report and announce # the end of evaluation by returning 0 self.ralph_trainable.eval_results.put((eval_result, self.max_eval_eps)) return 0 def step(self): ret = {} if 'configuration_timeout' in self.config['tuning_config']: timeout_time = self.config['tuning_config'][ 'configuration_timeout'] remaining_time = (self.start_time + timeout_time) - time.time() if remaining_time <= 0: self.java_gateway.entry_point.stopTrial(self.trial_id) self.am_stopped = True ret['done'] = True ret['timed_out'] = True ret['score'] = self.cur_max_score return ret try: train_result = self.train_results.get(timeout=remaining_time) except: self.java_gateway.entry_point.stopTrial(self.trial_id) self.am_stopped = True ret['done'] = True ret['timed_out'] = True ret['score'] = self.cur_max_score return ret else: train_result = self.train_results.get() ret['timed_out'] = False ret["train_risk"] = train_result["risk_avg"] ret["train_payoff"] = train_result["payoff_avg"] ret["train_value"] = self.compute_stage_value(train_result) self.cur_max_train_score = max(self.cur_max_train_score, ret["train_value"]) ret["train_score"] = self.cur_max_train_score eval_result = self.eval_results.get( ) # tuple (stage_result, eval_eps_done) ret["eval_value"] = self.compute_stage_value(eval_result[0]) self.cur_max_score = max(self.cur_max_score, ret["eval_value"]) ret["score"] = self.cur_max_score ret["eval_eps_done"] = eval_result[1] ret["eval_payoff"] = eval_result[0]["payoff_avg"] ret["eval_risk"] = eval_result[0]["risk_avg"] ret["eval_feasible"] = 1 if ret[ "eval_risk"] <= self.risk_threshold else 0 ret["eval_time"] = eval_result[0][ "decision_avg_ms"] # currently unused if self.iteration + 1 == self.config["stages"] or self.am_stopped: ret["done"] = True self.additional_step(ret) return ret def additional_step(self, ret): pass def reset_config(self, new_config): self.lock.acquire() self.java_gateway.entry_point.stopTrial(self.my_id) self.my_id = self.trial_id self.am_stopped = False self.start_time = time.time() self.train_results = queue.Queue() self.eval_results = queue.Queue() self.cur_max_score = self.payoff_empirical_min self.cur_max_train_score = self.payoff_empirical_min self.additional_reset_config() encoded_trial_params = TuningParser.encode_trial_params(self.config) self.java_gateway.entry_point.runTrial(self.trial_id, encoded_trial_params, self.results_receiver) self.lock.release() return True def additional_reset_config(self): pass def cleanup(self): self.java_gateway.entry_point.stopTrial(self.trial_id) self.java_gateway.close_callback_server() self.java_gateway.entry_point.resetTrialCallback(self.my_port) self.java_gateway.removeJavaPort(self.my_port - 1) self.java_gateway.close(keep_callback_server=True, close_callback_server_connections=False)
class FightingiceEnv_TwoPlayer(gym.Env): metadata = {'render.modes': ['human']} def __init__(self, freq_restart_java=3, env_config=None, java_env_path=None, port=None, auto_start_up=False, frameskip=False, display=False, p2_server=None): _actions = "AIR AIR_A AIR_B AIR_D_DB_BA AIR_D_DB_BB AIR_D_DF_FA AIR_D_DF_FB AIR_DA AIR_DB AIR_F_D_DFA AIR_F_D_DFB AIR_FA AIR_FB AIR_GUARD AIR_GUARD_RECOV AIR_RECOV AIR_UA AIR_UB BACK_JUMP BACK_STEP CHANGE_DOWN CROUCH CROUCH_A CROUCH_B CROUCH_FA CROUCH_FB CROUCH_GUARD CROUCH_GUARD_RECOV CROUCH_RECOV DASH DOWN FOR_JUMP FORWARD_WALK JUMP LANDING NEUTRAL RISE STAND STAND_A STAND_B STAND_D_DB_BA STAND_D_DB_BB STAND_D_DF_FA STAND_D_DF_FB STAND_D_DF_FC STAND_F_D_DFA STAND_F_D_DFB STAND_FA STAND_FB STAND_GUARD STAND_GUARD_RECOV STAND_RECOV THROW_A THROW_B THROW_HIT THROW_SUFFER" action_strs = _actions.split(" ") self.observation_space = spaces.Box(low=0, high=1, shape=(143, )) self.action_space = spaces.Discrete(len(action_strs)) os_name = platform.system() if os_name.startswith("Linux"): system_name = "linux" elif os_name.startswith("Darwin"): system_name = "macos" else: system_name = "windows" if system_name == "linux": # first check java can be run, can only be used on Linux java_version = subprocess.check_output( 'java -version 2>&1 | awk -F[\\\"_] \'NR==1{print $2}\'', shell=True) if java_version == b"\n": raise ModuleNotFoundError("Java is not installed") else: print("Please make sure you can run java if you see some error") # second check if FightingIce is installed correct if java_env_path == None: self.java_env_path = os.getcwd() else: self.java_env_path = java_env_path start_jar_path = os.path.join(self.java_env_path, "FightingICE.jar") start_data_path = os.path.join(self.java_env_path, "data") start_lib_path = os.path.join(self.java_env_path, "lib") lwjgl_path = os.path.join(start_lib_path, "lwjgl", "*") lib_path = os.path.join(start_lib_path, "*") start_system_lib_path = os.path.join(self.java_env_path, "lib", "natives", system_name) natives_path = os.path.join(start_system_lib_path, "*") if os.path.exists(start_jar_path) and os.path.exists( start_data_path) and os.path.exists( start_lib_path) and os.path.exists(start_system_lib_path): pass else: if auto_start_up is False: error_message = "FightingICE is not installed in {}".format( self.java_env_path) raise FileExistsError(error_message) else: start_up() if port: self.port = port else: try: import port_for self.port = port_for.select_random( ) # select one random port for java env except: raise ImportError( "Pass port=[your_port] when make env, or install port_for to set startup port automatically, maybe pip install port_for can help" ) self.java_ai_path = os.path.join(self.java_env_path, "data", "ai") ai_path = os.path.join(self.java_ai_path, "*") if system_name == "windows": self.start_up_str = "{};{};{};{};{}".format( start_jar_path, lwjgl_path, natives_path, lib_path, ai_path) self.need_set_memory_when_start = True else: self.start_up_str = "{}:{}:{}:{}:{}".format( start_jar_path, lwjgl_path, natives_path, lib_path, ai_path) self.need_set_memory_when_start = False self.game_started = False self.round_num = 0 self.freq_restart_java = freq_restart_java self.frameskip = frameskip self.display = display self.p2_server = p2_server def _start_java_game(self): # start game print("Start java env in {} and port {}".format( self.java_env_path, self.port)) devnull = open(os.devnull, 'w') if self.need_set_memory_when_start: # -Xms1024m -Xmx1024m # we need set this in windows self.java_env = subprocess.Popen([ "java", "-Xms1024m", "-Xmx1024m", "-cp", self.start_up_str, "Main", "--port", str(self.port), "--py4j", "--fastmode", "--grey-bg", "--inverted-player", "1", "--mute", "--limithp", "400", "400" ], stdout=devnull, stderr=devnull) else: self.java_env = subprocess.Popen( [ "java", "-cp", self.start_up_str, "Main", "--port", str(self.port), "--py4j", "--fastmode", "--grey-bg", "--inverted-player", "1", "--mute", "--limithp", "400", "400" ], stdout=devnull, stderr=devnull ) # self.java_env = subprocess.Popen(["java", "-cp", "/home/myt/gym-fightingice/gym_fightingice/FightingICE.jar:/home/myt/gym-fightingice/gym_fightingice/lib/lwjgl/*:/home/myt/gym-fightingice/gym_fightingice/lib/natives/linux/*:/home/myt/gym-fightingice/gym_fightingice/lib/*", "Main", "--port", str(self.free_port), "--py4j", "--c1", "ZEN", "--c2", "ZEN","--fastmode", "--grey-bg", "--inverted-player", "1", "--mute"]) # sleep 3s for java starting, if your machine is slow, make it longer time.sleep(3) def _start_gateway(self, p1=GymAI, p2=GymAI): # auto select callback server port and reset it in java env self.gateway = JavaGateway( gateway_parameters=GatewayParameters(port=self.port), callback_server_parameters=CallbackServerParameters(port=0)) python_port = self.gateway.get_callback_server().get_listening_port() self.gateway.java_gateway_server.resetCallbackClient( self.gateway.java_gateway_server.getCallbackClient().getAddress(), python_port) self.manager = self.gateway.entry_point # check if pipe built if self.p1_server is None: raise Exception( "Must call build_pipe_and_return_p2 and also make p2 env after gym.make() but before env.reset()" ) self.pipe = self.p1_server if self.display: self.p1 = GymAIDisplay(self.gateway, self.p1_client, self.frameskip) self.p2 = GymAIDisplay(self.gateway, self.p2_client, self.frameskip) else: self.p1 = p1(self.gateway, self.p1_client, self.frameskip) self.p2 = p2(self.gateway, self.p2_client, self.frameskip) self.manager.registerAI("P1", self.p1) self.manager.registerAI("P2", self.p2) self.game_to_start = self.manager.createGame("ZEN", "ZEN", "P1", "P2", self.freq_restart_java) self.game = Thread(target=game_thread, name="game_thread", args=(self, )) self.game.start() self.game_started = True self.round_num = 0 # Must call this function after "gym.make()" but before "env.reset()" def build_pipe_and_return_p2(self): # create pipe between gym_env_api and python_ai for java env if self.p2_server is not None: raise Exception( "Can not build pipe again if env is used as p2 (p2_server set)" ) self.p1_server, self.p1_client = Pipe() self._p2_server, self.p2_client = Pipe( ) # p2_server should not be used in this env but another return self._p2_server # p2_server is returned to build a gym env for p2 def _close_gateway(self): self.gateway.close_callback_server() self.gateway.close() del self.gateway def _close_java_game(self): self.java_env.kill() del self.java_env #self.pipe.close() #del self.pipe self.game_started = False def reset(self, p1=GymAI, p2=GymAI): if self.p2_server is None: # start java game if game is not started if self.game_started is False: try: self._close_gateway() self._close_java_game() except: pass self._start_java_game() self._start_gateway(p1, p2) # to provide crash, restart java game in some freq if self.round_num == self.freq_restart_java * 3: # 3 is for round in one game try: self._close_gateway() self._close_java_game() self._start_java_game() except: raise SystemExit("Can not restart game") self._start_gateway() else: self.pipe = self.p2_server if self.round_num == 0 or self.round_num == self.freq_restart_java * 3: time.sleep(10) # p2 wait 10s self.round_num = 0 self.game_started = True # just reset is anything ok self.pipe.send("reset") self.round_num += 1 obs = self.pipe.recv() return obs def step(self, action): # check if game is running, if not try restart # when restart, dict will contain crash info, agent should do something, it is a BUG in this version if self.game_started is False: dict = {} dict["pre_game_crashed"] = True return self.reset(), 0, None, dict self.pipe.send(["step", action]) new_obs, reward, done, info = self.pipe.recv() return new_obs, reward, done, {} def render(self, mode='human'): # no need pass def close(self): if self.game_started and self.p2_server is None: self._close_java_game()
class Environment(): actions = ["↖", "↑", "↗", "←", "→", "↙", "↓", "↘", "A", "B", "C", "_"] # Create the connection to the server def __init__(self): self.gateway = JavaGateway( gateway_parameters=GatewayParameters(port=4242), callback_server_parameters=CallbackServerParameters()) self.manager = self.gateway.entry_point logo = ''' Welcome to ███████╗██╗ ██████╗ ██╗ ██╗████████╗██╗███╗ ██╗ ██████╗ ██╗ ██████╗███████╗ ██╔════╝██║██╔════╝ ██║ ██║╚══██╔══╝██║████╗ ██║██╔════╝ ██║██╔════╝██╔════╝ █████╗ ██║██║ ███╗███████║ ██║ ██║██╔██╗ ██║██║ ███╗ ██║██║ █████╗ ██╔══╝ ██║██║ ██║██╔══██║ ██║ ██║██║╚██╗██║██║ ██║ ██║██║ ██╔══╝ ██║ ██║╚██████╔╝██║ ██║ ██║ ██║██║ ╚████║╚██████╔╝ ██║╚██████╗███████╗ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝╚═╝ ╚═══╝ ╚═════╝ ╚═╝ ╚═════╝╚══════╝ ...the competitive deep reinforcement learning environment since 2013. Powered by Prof. Ruck Thawonmas (Ritsumeikan University) ''' print(logo) def play(self, cls, opponent_cls=None, player_character="ZEN", opponent_character="GARNET", opponent_agent="GigaThunder"): self.reset() # Register custom AI with the engine. The name is needed later to select that agent. self.manager.registerAI(cls.__name__, cls(self)) if opponent_cls is not None: if cls.__name__ != opponent_cls.__name__: self.manager.registerAI(opponent_cls.__name__, opponent_cls(self)) opponent_agent = opponent_cls.__name__ print("GAME START ▶") print(" (P1) %s/%s vs. (P2) %s/%s" % \ (cls.__name__, player_character, opponent_agent, opponent_character)) GAME_NUM = 1 # Create a game with the selected characters and AI # Possible default AIs are: MctsAi, GigaThunder # Characters: ZEN, GARNET game = self.manager.createGame(player_character, opponent_character, cls.__name__, opponent_agent, GAME_NUM) # Run a 3 match game try: self.manager.runGame(game) except KeyboardInterrupt: pass finally: print("STOP ▮") # Close connection self.finalize() def finalize(self): self.gateway.close_callback_server() self.gateway.close() subprocess.call( 'ps aux | grep "java Main" | awk \'{print $2}\' | xargs kill', shell=True) def reset(self): print("Resetting game...") self.finalize() del self.gateway time.sleep(3) self.gateway = JavaGateway( gateway_parameters=GatewayParameters(port=4242), callback_server_parameters=CallbackServerParameters()) self.manager = self.gateway.entry_point