Пример #1
0
def collect_trajectories(actor, critic, port, game_num, p2, rnn, n_frame):
    logger.info(f'start fight with {p2}')
    logger.info(f'game_num value {game_num}')
    error = True
    while error:
        gateway = JavaGateway(
            gateway_parameters=GatewayParameters(port=port),
            callback_server_parameters=CallbackServerParameters())
        try:
            manager = gateway.entry_point
            current_time = int(time.time() * 1000)
            # register AIs
            collect_data_helper = CollectDataHelper(logger)
            agent = SoundAgent(gateway,
                               actor=actor,
                               critic=critic,
                               collect_data_helper=collect_data_helper,
                               logger=logger,
                               n_frame=n_frame,
                               rnn=rnn)
            sandbox_agent = SandboxAgent(gateway)
            manager.registerAI(f'SoundAgent', agent)
            manager.registerAI('Sandbox', sandbox_agent)
            game = manager.createGame('ZEN', 'ZEN', 'SoundAgent', p2, game_num)

            # start game
            manager.runGame(game)

            # finish game
            logger.info('Finish game')
            sys.stdout.flush()

            # close gateway
            gateway.close_callback_server()
            gateway.close()
            error = False
        except Exception as ex:
            print(ex)
            logger.info('There is an error with the gateway, restarting')
            gateway.close_callback_server()
            gateway.close()
            error = True

    # return agent.get_trajectories_data()
    agent_data = process_game_agent_data(actor, critic,
                                         agent.collect_data_helper, rnn)
    # try:
    #     kill_proc_tree(java_env.pid, False)
    # except:
    #     print('kill process')
    # agent.reset()
    return agent_data
class FightingiceEnv_Display_NoFrameskip(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self, **kwargs):

        self.freq_restart_java = 3
        self.java_env_path = os.getcwd()

        if "java_env_path" in kwargs.keys():
            self.java_env_path = kwargs["java_env_path"]
        if "freq_restart_java" in kwargs.keys():
            self.freq_restart_java = kwargs["freq_restart_java"]
        if "port" in kwargs.keys():
            self.port = kwargs["port"]
        else:
            try:
                import port_for
                self.port = port_for.select_random(
                )  # select one random port for java env
            except:
                raise ImportError(
                    "Pass port=[your_port] when make env, or install port_for to set startup port automatically, maybe pip install port_for can help"
                )

        _actions = "AIR AIR_A AIR_B AIR_D_DB_BA AIR_D_DB_BB AIR_D_DF_FA AIR_D_DF_FB AIR_DA AIR_DB AIR_F_D_DFA AIR_F_D_DFB AIR_FA AIR_FB AIR_GUARD AIR_GUARD_RECOV AIR_RECOV AIR_UA AIR_UB BACK_JUMP BACK_STEP CHANGE_DOWN CROUCH CROUCH_A CROUCH_B CROUCH_FA CROUCH_FB CROUCH_GUARD CROUCH_GUARD_RECOV CROUCH_RECOV DASH DOWN FOR_JUMP FORWARD_WALK JUMP LANDING NEUTRAL RISE STAND STAND_A STAND_B STAND_D_DB_BA STAND_D_DB_BB STAND_D_DF_FA STAND_D_DF_FB STAND_D_DF_FC STAND_F_D_DFA STAND_F_D_DFB STAND_FA STAND_FB STAND_GUARD STAND_GUARD_RECOV STAND_RECOV THROW_A THROW_B THROW_HIT THROW_SUFFER"
        action_strs = _actions.split(" ")

        self.observation_space = spaces.Box(low=0, high=1, shape=(96, 64, 1))
        self.action_space = spaces.Discrete(len(action_strs))

        os_name = platform.system()
        if os_name.startswith("Linux"):
            self.system_name = "linux"
        elif os_name.startswith("Darwin"):
            self.system_name = "macos"
        else:
            self.system_name = "windows"

        if self.system_name == "linux":
            # first check java can be run, can only be used on Linux
            java_version = subprocess.check_output(
                'java -version 2>&1 | awk -F[\\\"_] \'NR==1{print $2}\'',
                shell=True)
            if java_version == b"\n":
                raise ModuleNotFoundError("Java is not installed")
        else:
            print("Please make sure you can run java if you see some error")

        # second check if FightingIce is installed correct
        start_jar_path = os.path.join(self.java_env_path, "FightingICE.jar")
        start_data_path = os.path.join(self.java_env_path, "data")
        start_lib_path = os.path.join(self.java_env_path, "lib")
        lwjgl_path = os.path.join(start_lib_path, "lwjgl", "*")
        lib_path = os.path.join(start_lib_path, "*")
        start_system_lib_path = os.path.join(self.java_env_path, "lib",
                                             "natives", self.system_name)
        natives_path = os.path.join(start_system_lib_path, "*")
        if os.path.exists(start_jar_path) and os.path.exists(
                start_data_path) and os.path.exists(
                    start_lib_path) and os.path.exists(start_system_lib_path):
            pass
        else:
            error_message = "FightingICE is not installed in your script launched path {}, set path when make() or start script in FightingICE path".format(
                self.java_env_path)
            raise FileExistsError(error_message)

        self.java_ai_path = os.path.join(self.java_env_path, "data", "ai")
        ai_path = os.path.join(self.java_ai_path, "*")
        if self.system_name == "windows":
            self.start_up_str = "{};{};{};{};{}".format(
                start_jar_path, lwjgl_path, natives_path, lib_path, ai_path)
            self.need_set_memory_when_start = True
        else:
            self.start_up_str = "{}:{}:{}:{}:{}".format(
                start_jar_path, lwjgl_path, natives_path, lib_path, ai_path)
            self.need_set_memory_when_start = False

        self.game_started = False
        self.round_num = 0

    def _start_java_game(self):
        # start game
        print("Start java env in {} and port {}".format(
            self.java_env_path, self.port))
        devnull = open(os.devnull, 'w')

        if self.system_name == "windows":
            # -Xms1024m -Xmx1024m we need set this in windows
            self.java_env = subprocess.Popen([
                "java", "-Xms1024m", "-Xmx1024m", "-cp", self.start_up_str,
                "Main", "--port",
                str(self.port), "--py4j", "--fastmode", "--grey-bg",
                "--inverted-player", "1", "--mute", "--limithp", "400", "400"
            ],
                                             stdout=devnull,
                                             stderr=devnull)
        elif self.system_name == "linux":
            self.java_env = subprocess.Popen([
                "java", "-cp", self.start_up_str, "Main", "--port",
                str(self.port), "--py4j", "--fastmode", "--grey-bg",
                "--inverted-player", "1", "--mute", "--limithp", "400", "400"
            ],
                                             stdout=devnull,
                                             stderr=devnull)
        elif self.system_name == "macos":
            self.java_env = subprocess.Popen([
                "java", "-XstartOnFirstThread", "-cp", self.start_up_str,
                "Main", "--port",
                str(self.port), "--py4j", "--fastmode", "--grey-bg",
                "--inverted-player", "1", "--mute", "--limithp", "400", "400"
            ],
                                             stdout=devnull,
                                             stderr=devnull)
        # self.java_env = subprocess.Popen(["java", "-cp", "/home/myt/gym-fightingice/gym_fightingice/FightingICE.jar:/home/myt/gym-fightingice/gym_fightingice/lib/lwjgl/*:/home/myt/gym-fightingice/gym_fightingice/lib/natives/linux/*:/home/myt/gym-fightingice/gym_fightingice/lib/*", "Main", "--port", str(self.free_port), "--py4j", "--c1", "ZEN", "--c2", "ZEN","--fastmode", "--grey-bg", "--inverted-player", "1", "--mute"])
        # sleep 3s for java starting, if your machine is slow, make it longer
        time.sleep(3)

    def _start_gateway(self, p2=Machete):
        # auto select callback server port and reset it in java env
        self.gateway = JavaGateway(
            gateway_parameters=GatewayParameters(port=self.port),
            callback_server_parameters=CallbackServerParameters(port=0))
        python_port = self.gateway.get_callback_server().get_listening_port()
        self.gateway.java_gateway_server.resetCallbackClient(
            self.gateway.java_gateway_server.getCallbackClient().getAddress(),
            python_port)
        self.manager = self.gateway.entry_point

        # create pipe between gym_env_api and python_ai for java env
        server, client = Pipe()
        self.pipe = server
        self.p1 = GymAIDisplay(self.gateway, client, False)
        self.manager.registerAI(self.p1.__class__.__name__, self.p1)

        if isinstance(p2, str):
            # p2 is a java class name
            self.p2 = p2
            self.game_to_start = self.manager.createGame(
                "ZEN", "ZEN", self.p1.__class__.__name__, self.p2,
                self.freq_restart_java)
        else:
            # p2 is a python class
            self.p2 = p2(self.gateway)
            self.manager.registerAI(self.p2.__class__.__name__, self.p2)
            self.game_to_start = self.manager.createGame(
                "ZEN", "ZEN", self.p1.__class__.__name__,
                self.p2.__class__.__name__, self.freq_restart_java)
        self.game = Thread(target=game_thread,
                           name="game_thread",
                           args=(self, ))
        self.game.start()

        self.game_started = True
        self.round_num = 0

    def _close_gateway(self):
        self.gateway.close_callback_server()
        self.gateway.close()
        del self.gateway

    def _close_java_game(self):
        self.java_env.kill()
        del self.java_env
        self.pipe.close()
        del self.pipe
        self.game_started = False

    def reset(self, p2=Machete):
        # start java game if game is not started
        if self.game_started is False:
            try:
                self._close_gateway()
                self._close_java_game()
            except:
                pass
            self._start_java_game()
            self._start_gateway(p2)

        # to provide crash, restart java game in some freq
        if self.round_num == self.freq_restart_java * 3:  # 3 is for round in one game
            try:
                self._close_gateway()
                self._close_java_game()
                self._start_java_game()
            except:
                raise SystemExit("Can not restart game")
            self._start_gateway(p2)

        # just reset is anything ok
        self.pipe.send("reset")
        self.round_num += 1
        obs = self.pipe.recv()
        return obs

    def step(self, action):
        # check if game is running, if not try restart
        # when restart, dict will contain crash info, agent should do something, it is a BUG in this version
        if self.game_started is False:
            dict = {}
            dict["pre_game_crashed"] = True
            return self.reset(), 0, None, dict

        self.pipe.send(["step", action])
        if self.pipe.poll(5):
            message = self.pipe.recv()
            new_obs, reward, done, dict = message
        else:
            new_obs, reward = self.p1.get_obs(), self.p1.get_reward()
            dict = {}
            dict["no_data_receive"] = True
            logging.warning(
                "server can not receive, request to reset the game")
            return new_obs, reward, True, dict
        return new_obs, reward, done, dict

    def render(self, mode='human'):
        # no need
        pass

    def close(self):
        if self.game_started:
            self._close_java_game()
    def run(self, parent_data_objs):

        # Run the java py4j entry point
        comp_dir = self._dag_node.comp_root_path()
        self._logger.info("comp_dir: {}".format(comp_dir))

        jar_files = glob.glob(os.path.join(comp_dir, "*.jar"))
        self._logger.info("Java classpath files: {}".format(jar_files))
        component_class = self._dag_node.comp_class()

        java_jars = [self._mlcomp_jar] + jar_files
        class_path = ":".join(java_jars)
        java_gateway = None
        all_ok = False
        monitor_proc = None

        try:
            total_phys_mem_size_mb = ByteConv.from_bytes(
                psutil.virtual_memory().total).mbytes
            jvm_heap_size_option = "-Xmx{}m".format(
                int(math.ceil(total_phys_mem_size_mb)))

            java_opts = [jvm_heap_size_option]
            self._logger.info("JVM options: {}".format(java_opts))

            # Note: the jarpath is set to be the path to the mlcomp jar since the launch_gateway code is checking
            #       for the existence of the jar. The py4j jar is packed inside the mlcomp jar.
            java_port = launch_gateway(port=0,
                                       javaopts=java_opts,
                                       die_on_exit=True,
                                       jarpath=self._mlcomp_jar,
                                       classpath=class_path,
                                       redirect_stdout=sys.stdout,
                                       redirect_stderr=sys.stderr)

            java_gateway = JavaGateway(
                gateway_parameters=GatewayParameters(port=java_port),
                callback_server_parameters=CallbackServerParameters(port=0),
                python_server_entry_point=MLOpsPY4JWrapper())

            python_port = java_gateway.get_callback_server(
            ).get_listening_port()
            self._logger.debug("Python port: {}".format(python_port))

            java_gateway.java_gateway_server.resetCallbackClient(
                java_gateway.java_gateway_server.getCallbackClient().
                getAddress(), python_port)

            mlops_wrapper = MLOpsPY4JWrapper()
            entry_point = java_gateway.jvm.com.parallelm.mlcomp.ComponentEntryPoint(
                component_class)
            component_via_py4j = entry_point.getComponent()
            component_via_py4j.setMLOps(mlops_wrapper)

            # Configure
            m = java_gateway.jvm.java.util.HashMap()
            for key in self._params.keys():
                # py4j does not handle nested structures. So the configs which is a dict will not be passed to the java
                # layer now.
                if isinstance(self._params[key], dict):
                    continue
                m[key] = self._params[key]

            component_via_py4j.configure(m)

            # Materialized
            l = java_gateway.jvm.java.util.ArrayList()
            for obj in parent_data_objs:
                l.append(obj)
                self._logger.info("Parent obj: {} type {}".format(
                    obj, type(obj)))
            self._logger.info("Parent objs: {}".format(l))

            if mlops_loaded:
                monitor_proc = ProcessMonitor(mlops, self._ml_engine)
                monitor_proc.start()

            py4j_out_objs = component_via_py4j.materialize(l)

            self._logger.debug(type(py4j_out_objs))
            self._logger.debug(len(py4j_out_objs))

            python_out_objs = []
            for obj in py4j_out_objs:
                self._logger.debug("Obj:")
                self._logger.debug(obj)
                python_out_objs.append(obj)
            self._logger.info(
                "Done running of materialize and getting output objects")
            all_ok = True
        except Py4JJavaError as e:
            self._logger.error("Error in java code: {}".format(e))
            raise MLCompException(str(e))
        except Exception as e:
            self._logger.error("General error: {}".format(e))
            raise MLCompException(str(e))
        finally:
            self._logger.info("In finally block: all_ok {}".format(all_ok))
            if java_gateway:
                java_gateway.close_callback_server()
                java_gateway.shutdown()

            if mlops_loaded and monitor_proc:
                monitor_proc.stop_gracefully()

        return python_out_objs
print("Replay: Loading")
replay = manager.loadReplay("HPMode_KickAIPython_RandomAI_2017.12.07-15.51.44") # Load replay data

print("Replay: Init")
replay.init()

# Main process
for i in range(1000): # Simulate 100 frames
    print("Replay: Run frame", i)

    if i % 10 == 0 and replay.getState().name() == "UPDATE":
        framedata = replay.getFrameData()

        print("Replay: Infos")
        print("Replay:     Round:", framedata.getRound())
        print("Replay:     Frame:", framedata.getFramesNumber())
        print("Replay:     P1 HP:", framedata.getCharacter(True).getHp())
        print("Replay:     P2 HP:", framedata.getCharacter(False).getHp())

    sys.stdout.flush()

    replay.updateState()

print("Replay: Close")
replay.close()

sys.stdout.flush()

gateway.close_callback_server()
gateway.close()
Пример #5
0
class RAlphTrainable(tune.Trainable):
    def setup(self, config):
        self.risk_threshold = self.config["tuning_config"]["risk_threshold"]
        self.max_concurrent_trials = self.config["tuning_config"][
            "max_concurrent_trials"]
        self.payoff_empirical_min = self.config["tuning_config"][
            "obj_fnc_params"]["payoff_empirical_min"]
        self.payoff_empirical_max = self.config["tuning_config"][
            "obj_fnc_params"]["payoff_empirical_max"]
        self.min_eval_eps = self.config["tuning_config"]["eval_eps_min"]
        self.max_eval_eps = self.config["tuning_config"]["eval_eps_max"]

        self.lock = threading.Lock()

        self.train_results = queue.Queue()
        self.eval_results = queue.Queue()
        self.java_gateway = None
        self.my_id = self.trial_id
        self.start_time = time.time()
        self.am_stopped = False
        self.cur_max_score = self.payoff_empirical_min
        self.cur_max_train_score = self.payoff_empirical_min

        default_gateway = JavaGateway()
        while self.java_gateway is None:
            for i in range(self.max_concurrent_trials):
                try:
                    self.my_port = 25334 + 2 * i
                    java_port = self.my_port - 1

                    # if addJavaPort fails, it means it is being used by some other trial running
                    # in a different process
                    if not default_gateway.addJavaPort(java_port):
                        if i == self.max_concurrent_trials - 1:
                            time.sleep(0.5)
                        continue

                    gateway_params = GatewayParameters(
                        port=java_port, enable_memory_management=False)
                    callback_params = CallbackServerParameters(
                        port=self.my_port)
                    self.java_gateway = JavaGateway(
                        gateway_parameters=gateway_params,
                        callback_server_parameters=callback_params)
                    break
                except:
                    default_gateway.removeJavaPort(self.my_port - 1)
                    continue

        default_gateway.close(keep_callback_server=True,
                              close_callback_server_connections=False)

        encoded_trial_params = TuningParser.encode_trial_params(self.config)
        self.results_receiver = _TrialResultsReceiver(self)
        self.java_gateway.entry_point.runTrial(self.trial_id,
                                               encoded_trial_params,
                                               self.results_receiver)

        self.additional_setup()

    def additional_setup(self):
        # To be overriden by subclasses.
        pass

    def compute_stage_value(self, stage_result):
        stage_payoff_avg = stage_result["payoff_avg"]
        stage_risk_avg = stage_result["risk_avg"]

        if stage_risk_avg <= self.risk_threshold:
            return stage_payoff_avg - self.payoff_empirical_min

        base = math.floor(self.risk_threshold * 100) - math.ceil(
            stage_risk_avg * 100)
        dist_from_min_payoff = stage_payoff_avg - self.payoff_empirical_min
        payoff_range = self.payoff_empirical_max - self.payoff_empirical_min

        return base + dist_from_min_payoff / payoff_range

    def get_first_eval_batch_size(self):
        # Default implementation; run all available eval episodes
        return self.max_eval_eps

    def get_next_eval_batch_size(self, eval_result):
        # Default implementation; all eval eps have been done, report and announce
        # the end of evaluation by returning 0
        self.ralph_trainable.eval_results.put((eval_result, self.max_eval_eps))
        return 0

    def step(self):
        ret = {}

        if 'configuration_timeout' in self.config['tuning_config']:
            timeout_time = self.config['tuning_config'][
                'configuration_timeout']
            remaining_time = (self.start_time + timeout_time) - time.time()
            if remaining_time <= 0:
                self.java_gateway.entry_point.stopTrial(self.trial_id)
                self.am_stopped = True
                ret['done'] = True
                ret['timed_out'] = True
                ret['score'] = self.cur_max_score
                return ret
            try:
                train_result = self.train_results.get(timeout=remaining_time)
            except:
                self.java_gateway.entry_point.stopTrial(self.trial_id)
                self.am_stopped = True
                ret['done'] = True
                ret['timed_out'] = True
                ret['score'] = self.cur_max_score
                return ret
        else:
            train_result = self.train_results.get()

        ret['timed_out'] = False
        ret["train_risk"] = train_result["risk_avg"]
        ret["train_payoff"] = train_result["payoff_avg"]
        ret["train_value"] = self.compute_stage_value(train_result)
        self.cur_max_train_score = max(self.cur_max_train_score,
                                       ret["train_value"])
        ret["train_score"] = self.cur_max_train_score

        eval_result = self.eval_results.get(
        )  # tuple (stage_result, eval_eps_done)

        ret["eval_value"] = self.compute_stage_value(eval_result[0])
        self.cur_max_score = max(self.cur_max_score, ret["eval_value"])
        ret["score"] = self.cur_max_score

        ret["eval_eps_done"] = eval_result[1]

        ret["eval_payoff"] = eval_result[0]["payoff_avg"]
        ret["eval_risk"] = eval_result[0]["risk_avg"]
        ret["eval_feasible"] = 1 if ret[
            "eval_risk"] <= self.risk_threshold else 0
        ret["eval_time"] = eval_result[0][
            "decision_avg_ms"]  # currently unused

        if self.iteration + 1 == self.config["stages"] or self.am_stopped:
            ret["done"] = True

        self.additional_step(ret)

        return ret

    def additional_step(self, ret):
        pass

    def reset_config(self, new_config):
        self.lock.acquire()

        self.java_gateway.entry_point.stopTrial(self.my_id)

        self.my_id = self.trial_id
        self.am_stopped = False
        self.start_time = time.time()
        self.train_results = queue.Queue()
        self.eval_results = queue.Queue()
        self.cur_max_score = self.payoff_empirical_min
        self.cur_max_train_score = self.payoff_empirical_min

        self.additional_reset_config()

        encoded_trial_params = TuningParser.encode_trial_params(self.config)
        self.java_gateway.entry_point.runTrial(self.trial_id,
                                               encoded_trial_params,
                                               self.results_receiver)

        self.lock.release()
        return True

    def additional_reset_config(self):
        pass

    def cleanup(self):
        self.java_gateway.entry_point.stopTrial(self.trial_id)
        self.java_gateway.close_callback_server()
        self.java_gateway.entry_point.resetTrialCallback(self.my_port)
        self.java_gateway.removeJavaPort(self.my_port - 1)
        self.java_gateway.close(keep_callback_server=True,
                                close_callback_server_connections=False)
Пример #6
0
class FightingiceEnv_TwoPlayer(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self,
                 freq_restart_java=3,
                 env_config=None,
                 java_env_path=None,
                 port=None,
                 auto_start_up=False,
                 frameskip=False,
                 display=False,
                 p2_server=None):
        _actions = "AIR AIR_A AIR_B AIR_D_DB_BA AIR_D_DB_BB AIR_D_DF_FA AIR_D_DF_FB AIR_DA AIR_DB AIR_F_D_DFA AIR_F_D_DFB AIR_FA AIR_FB AIR_GUARD AIR_GUARD_RECOV AIR_RECOV AIR_UA AIR_UB BACK_JUMP BACK_STEP CHANGE_DOWN CROUCH CROUCH_A CROUCH_B CROUCH_FA CROUCH_FB CROUCH_GUARD CROUCH_GUARD_RECOV CROUCH_RECOV DASH DOWN FOR_JUMP FORWARD_WALK JUMP LANDING NEUTRAL RISE STAND STAND_A STAND_B STAND_D_DB_BA STAND_D_DB_BB STAND_D_DF_FA STAND_D_DF_FB STAND_D_DF_FC STAND_F_D_DFA STAND_F_D_DFB STAND_FA STAND_FB STAND_GUARD STAND_GUARD_RECOV STAND_RECOV THROW_A THROW_B THROW_HIT THROW_SUFFER"
        action_strs = _actions.split(" ")

        self.observation_space = spaces.Box(low=0, high=1, shape=(143, ))
        self.action_space = spaces.Discrete(len(action_strs))

        os_name = platform.system()
        if os_name.startswith("Linux"):
            system_name = "linux"
        elif os_name.startswith("Darwin"):
            system_name = "macos"
        else:
            system_name = "windows"

        if system_name == "linux":
            # first check java can be run, can only be used on Linux
            java_version = subprocess.check_output(
                'java -version 2>&1 | awk -F[\\\"_] \'NR==1{print $2}\'',
                shell=True)
            if java_version == b"\n":
                raise ModuleNotFoundError("Java is not installed")
        else:
            print("Please make sure you can run java if you see some error")

        # second check if FightingIce is installed correct
        if java_env_path == None:
            self.java_env_path = os.getcwd()
        else:
            self.java_env_path = java_env_path
        start_jar_path = os.path.join(self.java_env_path, "FightingICE.jar")
        start_data_path = os.path.join(self.java_env_path, "data")
        start_lib_path = os.path.join(self.java_env_path, "lib")
        lwjgl_path = os.path.join(start_lib_path, "lwjgl", "*")
        lib_path = os.path.join(start_lib_path, "*")
        start_system_lib_path = os.path.join(self.java_env_path, "lib",
                                             "natives", system_name)
        natives_path = os.path.join(start_system_lib_path, "*")
        if os.path.exists(start_jar_path) and os.path.exists(
                start_data_path) and os.path.exists(
                    start_lib_path) and os.path.exists(start_system_lib_path):
            pass
        else:
            if auto_start_up is False:
                error_message = "FightingICE is not installed in {}".format(
                    self.java_env_path)
                raise FileExistsError(error_message)
            else:
                start_up()
        if port:
            self.port = port
        else:
            try:
                import port_for
                self.port = port_for.select_random(
                )  # select one random port for java env
            except:
                raise ImportError(
                    "Pass port=[your_port] when make env, or install port_for to set startup port automatically, maybe pip install port_for can help"
                )

        self.java_ai_path = os.path.join(self.java_env_path, "data", "ai")
        ai_path = os.path.join(self.java_ai_path, "*")
        if system_name == "windows":
            self.start_up_str = "{};{};{};{};{}".format(
                start_jar_path, lwjgl_path, natives_path, lib_path, ai_path)
            self.need_set_memory_when_start = True
        else:
            self.start_up_str = "{}:{}:{}:{}:{}".format(
                start_jar_path, lwjgl_path, natives_path, lib_path, ai_path)
            self.need_set_memory_when_start = False

        self.game_started = False
        self.round_num = 0

        self.freq_restart_java = freq_restart_java

        self.frameskip = frameskip
        self.display = display
        self.p2_server = p2_server

    def _start_java_game(self):
        # start game
        print("Start java env in {} and port {}".format(
            self.java_env_path, self.port))
        devnull = open(os.devnull, 'w')
        if self.need_set_memory_when_start:
            # -Xms1024m -Xmx1024m
            # we need set this in windows
            self.java_env = subprocess.Popen([
                "java", "-Xms1024m", "-Xmx1024m", "-cp", self.start_up_str,
                "Main", "--port",
                str(self.port), "--py4j", "--fastmode", "--grey-bg",
                "--inverted-player", "1", "--mute", "--limithp", "400", "400"
            ],
                                             stdout=devnull,
                                             stderr=devnull)
        else:
            self.java_env = subprocess.Popen(
                [
                    "java", "-cp", self.start_up_str, "Main", "--port",
                    str(self.port), "--py4j", "--fastmode", "--grey-bg",
                    "--inverted-player", "1", "--mute", "--limithp", "400",
                    "400"
                ],
                stdout=devnull,
                stderr=devnull
            )  # self.java_env = subprocess.Popen(["java", "-cp", "/home/myt/gym-fightingice/gym_fightingice/FightingICE.jar:/home/myt/gym-fightingice/gym_fightingice/lib/lwjgl/*:/home/myt/gym-fightingice/gym_fightingice/lib/natives/linux/*:/home/myt/gym-fightingice/gym_fightingice/lib/*", "Main", "--port", str(self.free_port), "--py4j", "--c1", "ZEN", "--c2", "ZEN","--fastmode", "--grey-bg", "--inverted-player", "1", "--mute"])
        # sleep 3s for java starting, if your machine is slow, make it longer
        time.sleep(3)

    def _start_gateway(self, p1=GymAI, p2=GymAI):
        # auto select callback server port and reset it in java env
        self.gateway = JavaGateway(
            gateway_parameters=GatewayParameters(port=self.port),
            callback_server_parameters=CallbackServerParameters(port=0))
        python_port = self.gateway.get_callback_server().get_listening_port()
        self.gateway.java_gateway_server.resetCallbackClient(
            self.gateway.java_gateway_server.getCallbackClient().getAddress(),
            python_port)
        self.manager = self.gateway.entry_point

        # check if pipe built
        if self.p1_server is None:
            raise Exception(
                "Must call build_pipe_and_return_p2 and also make p2 env after gym.make() but before env.reset()"
            )
        self.pipe = self.p1_server

        if self.display:
            self.p1 = GymAIDisplay(self.gateway, self.p1_client,
                                   self.frameskip)
            self.p2 = GymAIDisplay(self.gateway, self.p2_client,
                                   self.frameskip)
        else:
            self.p1 = p1(self.gateway, self.p1_client, self.frameskip)
            self.p2 = p2(self.gateway, self.p2_client, self.frameskip)

        self.manager.registerAI("P1", self.p1)
        self.manager.registerAI("P2", self.p2)

        self.game_to_start = self.manager.createGame("ZEN", "ZEN", "P1", "P2",
                                                     self.freq_restart_java)

        self.game = Thread(target=game_thread,
                           name="game_thread",
                           args=(self, ))
        self.game.start()
        self.game_started = True
        self.round_num = 0

    # Must call this function after "gym.make()" but before "env.reset()"
    def build_pipe_and_return_p2(self):
        # create pipe between gym_env_api and python_ai for java env
        if self.p2_server is not None:
            raise Exception(
                "Can not build pipe again if env is used as p2 (p2_server set)"
            )
        self.p1_server, self.p1_client = Pipe()
        self._p2_server, self.p2_client = Pipe(
        )  # p2_server should not be used in this env but another
        return self._p2_server  # p2_server is returned to build a gym env for p2

    def _close_gateway(self):
        self.gateway.close_callback_server()
        self.gateway.close()
        del self.gateway

    def _close_java_game(self):
        self.java_env.kill()
        del self.java_env
        #self.pipe.close()
        #del self.pipe
        self.game_started = False

    def reset(self, p1=GymAI, p2=GymAI):
        if self.p2_server is None:
            # start java game if game is not started
            if self.game_started is False:
                try:
                    self._close_gateway()
                    self._close_java_game()
                except:
                    pass
                self._start_java_game()
                self._start_gateway(p1, p2)

            # to provide crash, restart java game in some freq
            if self.round_num == self.freq_restart_java * 3:  # 3 is for round in one game
                try:
                    self._close_gateway()
                    self._close_java_game()
                    self._start_java_game()
                except:
                    raise SystemExit("Can not restart game")
                self._start_gateway()
        else:
            self.pipe = self.p2_server
            if self.round_num == 0 or self.round_num == self.freq_restart_java * 3:
                time.sleep(10)  # p2 wait 10s
                self.round_num = 0
                self.game_started = True

        # just reset is anything ok
        self.pipe.send("reset")
        self.round_num += 1
        obs = self.pipe.recv()
        return obs

    def step(self, action):
        # check if game is running, if not try restart
        # when restart, dict will contain crash info, agent should do something, it is a BUG in this version
        if self.game_started is False:
            dict = {}
            dict["pre_game_crashed"] = True
            return self.reset(), 0, None, dict

        self.pipe.send(["step", action])
        new_obs, reward, done, info = self.pipe.recv()
        return new_obs, reward, done, {}

    def render(self, mode='human'):
        # no need
        pass

    def close(self):
        if self.game_started and self.p2_server is None:
            self._close_java_game()
Пример #7
0
class Environment():
    actions = ["↖", "↑", "↗", "←", "→", "↙", "↓", "↘", "A", "B", "C", "_"]

    # Create the connection to the server
    def __init__(self):
        self.gateway = JavaGateway(
            gateway_parameters=GatewayParameters(port=4242),
            callback_server_parameters=CallbackServerParameters())
        self.manager = self.gateway.entry_point

        logo = '''
Welcome to

  ███████╗██╗ ██████╗ ██╗  ██╗████████╗██╗███╗   ██╗ ██████╗     ██╗ ██████╗███████╗
  ██╔════╝██║██╔════╝ ██║  ██║╚══██╔══╝██║████╗  ██║██╔════╝     ██║██╔════╝██╔════╝
  █████╗  ██║██║  ███╗███████║   ██║   ██║██╔██╗ ██║██║  ███╗    ██║██║     █████╗  
  ██╔══╝  ██║██║   ██║██╔══██║   ██║   ██║██║╚██╗██║██║   ██║    ██║██║     ██╔══╝  
  ██║     ██║╚██████╔╝██║  ██║   ██║   ██║██║ ╚████║╚██████╔╝    ██║╚██████╗███████╗
  ╚═╝     ╚═╝ ╚═════╝ ╚═╝  ╚═╝   ╚═╝   ╚═╝╚═╝  ╚═══╝ ╚═════╝     ╚═╝ ╚═════╝╚══════╝                                                                                
...the competitive deep reinforcement learning environment since 2013.
Powered by Prof. Ruck Thawonmas (Ritsumeikan University) 
        '''
        print(logo)

    def play(self,
             cls,
             opponent_cls=None,
             player_character="ZEN",
             opponent_character="GARNET",
             opponent_agent="GigaThunder"):
        self.reset()

        # Register custom AI with the engine. The name is needed later to select that agent.
        self.manager.registerAI(cls.__name__, cls(self))
        if opponent_cls is not None:
            if cls.__name__ != opponent_cls.__name__:
                self.manager.registerAI(opponent_cls.__name__,
                                        opponent_cls(self))
            opponent_agent = opponent_cls.__name__

        print("GAME START ▶")
        print("   (P1) %s/%s vs. (P2) %s/%s" % \
                  (cls.__name__, player_character, opponent_agent, opponent_character))
        GAME_NUM = 1

        # Create a game with the selected characters and AI
        # Possible default AIs are: MctsAi, GigaThunder
        # Characters: ZEN, GARNET
        game = self.manager.createGame(player_character, opponent_character,
                                       cls.__name__, opponent_agent, GAME_NUM)

        # Run a 3 match game
        try:
            self.manager.runGame(game)

        except KeyboardInterrupt:
            pass

        finally:
            print("STOP ▮")
            # Close connection
            self.finalize()

    def finalize(self):
        self.gateway.close_callback_server()
        self.gateway.close()
        subprocess.call(
            'ps aux | grep "java Main" | awk \'{print $2}\' | xargs kill',
            shell=True)

    def reset(self):
        print("Resetting game...")
        self.finalize()
        del self.gateway

        time.sleep(3)

        self.gateway = JavaGateway(
            gateway_parameters=GatewayParameters(port=4242),
            callback_server_parameters=CallbackServerParameters())
        self.manager = self.gateway.entry_point