예제 #1
0
def worker_process(remote: multiprocessing.connection.Connection, parameters,
                   worker_id, env):
    """
    This function is used as target by each of the threads in the multiprocess
    to build environment instances and define the commands that can be executed
    by each of the workers.
    """
    # The Atari wrappers are now imported from openAI baselines
    # https://github.com/openai/baselines
    # log_dir = './log'
    # env = make_atari(parameters['scene'])
    # env = bench.Monitor(
    #             env,
    #             os.path.join(log_dir, str(worker_id)),
    #             allow_early_resets=False)
    # env = wrap_deepmind(env)
    while True:
        cmd, data = remote.recv()
        if cmd == 'step':
            obs, reward, done, info = env.step(data)
            if done is True:
                obs = env.reset()
            remote.send((obs, reward, done, info))
        elif cmd == 'reset':
            remote.send(env.reset())
        elif cmd == 'action_space':
            remote.send(env.action_space)
        elif cmd == 'close':
            remote.close()
            break
        else:
            raise NotImplementedError
예제 #2
0
def timeout_child(conn: multiprocessing.connection.Connection,
                  func: types.FunctionType, *args, **kwargs):
    """
    wrapper function for timeout_func
    :param conn:
    :param func:
    :param args:
    :param kwargs:
    :return:
    """
    verbose = kwargs.get('verbose', 0)
    if verbose:
        tplog(f"func={func}")
    if not func:
        # https://stackoverflow.com/questions/43369648/cant-get-attribute-function-inner-on-module-mp-main-from-e-python
        # when the multiprocessing library copies your main module, it won't run it as the __main__ script and
        # therefore anything defined inside the if __name__ == '__main__' is not defined in the child process
        # namespace. Hence, the AttributeError
        message = f"both func={func} is not initialized. note: func cannot be defined in __main__"
        tb = pformat(
            traceback.format_stack())  # outside exception use format_stack()
        conn.send(RuntimeError(f"{message}\n{tb}"))
    else:
        result = None
        try:
            result = func(*args, **kwargs)
            conn.send(result)
        except Exception as e:
            tb = pformat(
                traceback.format_exc())  # within exception use format_exc()
            conn.send(
                RuntimeError(
                    f"child process exception, pid={os.getpid()}\n{tb}"))
    conn.close()
예제 #3
0
def worker_process(remote: multiprocessing.connection.Connection, env_config,
                   worker_id: int):
    """Initializes the environment and executes its interface.

    Arguments:
        remote {multiprocessing.connection.Connection} -- Parent thread
        env_config {dict} -- The configuration data of the desired environment
        worker_id {int} -- Id for the environment's process. This is necessary for Unity ML-Agents environments, because these operate on different ports.
    """

    # Initialize and wrap the environment
    try:
        env = wrap_environment(env_config, worker_id)
    except KeyboardInterrupt:
        pass

    # Communication interface of the environment thread
    while True:
        try:
            cmd, data = remote.recv()
            if cmd == "step":
                remote.send(env.step(data))
            elif cmd == "reset":
                remote.send(env.reset(data))
            elif cmd == "close":
                remote.send(env.close())
                remote.close()
                break
            else:
                raise NotImplementedError
        except:
            break
예제 #4
0
def worker_process(remote: multiprocessing.connection.Connection, seed: int,
                   num: int):
    """Each worker process runs this method"""
    # create game
    Seed = lambda x: int.from_bytes(
        hashlib.sha256(
            int.to_bytes(seed, 4, 'little') + int.to_bytes(x, 4, 'little')).
        digest(), 'little')
    games = [Game(Seed(i)) for i in range(num)]
    # wait for instructions from the connection and execute them
    while True:
        result = []
        cmd, data = remote.recv()
        if cmd == "step":
            for i in range(num):
                result.append(games[i].step((data[i] // kW, data[i] % kW)))
            obs, rew, over, info = zip(*result)
            remote.send(
                (np.stack(obs), np.stack(rew), np.array(over), list(info)))
        elif cmd == "reset":
            for i in range(num):
                result.append(games[i].reset())
            remote.send(np.stack(result))
        elif cmd == "close":
            remote.close()
            break
        else:
            raise NotImplementedError
예제 #5
0
def worker_process(remote: multiprocessing.connection.Connection,
                   player_types: list, trained_model_path: str,
                   game_to_play: str):
    """
    ##Worker Process
    Each worker process runs this method
    """

    # create game
    game = Game(player_types=player_types,
                trained_model_path=trained_model_path,
                game_to_play=game_to_play)

    # wait for instructions from the connection and execute them
    while True:
        cmd, data = remote.recv()
        if cmd == "step":
            remote.send(game.step(data))
        elif cmd == "reset":
            remote.send(game.reset())
        elif cmd == "close":
            remote.close()
            break
        else:
            raise NotImplementedError("received {}".format(cmd))
def _worker(remote: mp.connection.Connection,
            parent_remote: mp.connection.Connection,
            manager_fn_wrapper: CloudpickleWrapper) -> None:
    parent_remote.close()

    torch.set_num_threads(1)
    evaluation_manager = manager_fn_wrapper.var()

    while True:
        try:
            cmd, data = remote.recv()
            if cmd == "run_eval_episodes":
                num_episodes = data
                winners = []
                num_game_steps = []
                victory_points_all = []
                policy_steps = []
                for ep in range(num_episodes):
                    winner, victory_points, total_steps, policy_decisions = evaluation_manager.run_evaluation_game(
                    )
                    winners.append(winner)
                    num_game_steps.append(total_steps)
                    victory_points_all.append(victory_points)
                    policy_steps.append(policy_decisions)
                remote.send((winners, num_game_steps, victory_points_all,
                             policy_steps))
            elif cmd == "update_policies":
                state_dicts = data.var
                evaluation_manager._update_policies(state_dicts)
                remote.send(True)
        except EOFError:
            break
예제 #7
0
def worker_process(remote: multiprocessing.connection.Connection, seed: int):
    while True:
        cmd, data = remote.recv()
        if cmd == 'reset':
            remote.send('Resetted!')
        elif cmd == 'close':
            remote.close()
            break
        else:
            raise NotImplementedError
def _worker(
    remote: mp.connection.Connection,
    parent_remote: mp.connection.Connection,
    env_fn_wrapper: CloudpickleWrapper,
    render: bool,
    render_mode: str,
) -> None:
    # Import here to avoid a circular import
    from stable_baselines3.common.env_util import is_wrapped

    parent_remote.close()
    env = env_fn_wrapper.var()
    while True:
        try:
            cmd, data = remote.recv()
            if cmd == "step":
                observation, reward, done, info = env.step(data)
                if render:
                    env.render(mode=render_mode)
                if done:
                    # save final observation where user can get it, then reset
                    info["terminal_observation"] = observation
                    observation = env.reset()
                    if render:
                        env.render(mode=render_mode)
                remote.send((observation, reward, done, info))
            elif cmd == "seed":
                remote.send(env.seed(data))
            elif cmd == "reset":
                observation = env.reset()
                if render:
                    env.render(mode=render_mode)
                remote.send(observation)
            elif cmd == "render":
                remote.send(env.render(data))
            elif cmd == "close":
                env.close()
                remote.close()
                break
            elif cmd == "get_spaces":
                remote.send((env.observation_space, env.action_space))
            elif cmd == "env_method":
                method = getattr(env, data[0])
                remote.send(method(*data[1], **data[2]))
            elif cmd == "get_attr":
                remote.send(getattr(env, data))
            elif cmd == "set_attr":
                remote.send(setattr(env, data[0], data[1]))
            elif cmd == "is_wrapped":
                remote.send(is_wrapped(env, data))
            else:
                raise NotImplementedError(
                    f"`{cmd}` is not implemented in the worker")
        except EOFError:
            break
    def monitor_mproc(self, proc: multiprocessing.Process,
                      pipe: multiprocessing.connection.Connection):
        piped_message = ""
        while piped_message != "Done" and proc.is_alive() and not pipe.closed:
            piped_message = pipe.recv()
            print("PIPED THROUGH: " + str(piped_message))

        print("Updating Checksum...")
        tool_man.update_main_pak_checksum()
        print("Checksum updated.")
        pipe.close()
        self.unpack_mproc_monitor_thread = None
예제 #10
0
def worker_process(remote: multiprocessing.connection.Connection, seed: int):
    game = Game(seed)
    while True:
        cmd, data = remote.recv()
        if cmd == "step":
            remote.send(game.step(data))
        elif cmd == "reset":
            remote.send(game.reset())
        elif cmd == "close":
            remote.close()
            break
        else:
            raise NotImplementedError
def worker_process(remote: multiprocessing.connection.Connection, seed: int):
    game = MyGame(seed)

    while True:
        cmd, data = remote.recv()
        if cmd == "gen_dataX":
            remote.send(game.gen_data(*data))  # MyGame.gen_data

        elif cmd == "close":
            remote.close()
            break
        else:
            raise NotImplementedError
예제 #12
0
def worker_process(remote: multiprocessing.connection.Connection, parameters,
                   worker_id, seed):
    """
    This function is used as target by each of the threads in the multiprocess
    to build environment instances and define the commands that can be executed
    by each of the workers.
    """
    # The Atari wrappers are now imported from openAI baselines
    # https://github.com/openai/baselines
    log_dir = './log'
    if parameters['env_type'] == 'atari':
        env = make_atari(parameters['scene'])
        env = bench.Monitor(
                    env,
                    os.path.join(log_dir, str(worker_id)),
                    allow_early_resets=False)
        env = wrap_deepmind(env, True)
    if parameters['env_type'] == 'warehouse':
        env = Warehouse(seed, parameters)
    if parameters['env_type'] == 'sumo':
        env = LoopNetwork(parameters, seed)
    if parameters['env_type'] == 'minigrid':
        env = gym.make(parameters['scene'])
        # env = RGBImgPartialObsWrapper(env, tile_size=12) # Get pixel observations
        env = ImgObsWrapper(env) # Get rid of the 'mission' field
        env = wrappers.GrayScaleObservation(env, keep_dim=True) # Gray scale
        env = FeatureVectorWrapper(env)
        env.seed(seed)
        
    while True:
        cmd, data = remote.recv()
        if cmd == 'step':
            obs, reward, done, info = env.step(data)
            if done:
                obs = env.reset()
            remote.send((obs, reward, done, info))
        elif cmd == 'reset':
            remote.send(env.reset())
        elif cmd == 'action_space':
            remote.send(env.action_space.n)
        elif cmd == 'close':
            remote.close()
            break
        else:
            raise NotImplementedError
예제 #13
0
    def _poll_pipe(
        self,
        actor_idx: int,
        pipe: mp.connection.Connection,
        replay_buffer_lock: mp.synchronize.Lock,
        exception_event: mp.synchronize.Event,
    ) -> None:
        if pipe.closed:
            return
        try:
            while pipe.poll() and not exception_event.is_set():
                cmd, data = pipe.recv()
                if cmd == "get_statistics":
                    assert data is None
                    with replay_buffer_lock:
                        stats = self.get_statistics()
                    pipe.send(stats)
                elif cmd == "load":
                    self.load(data)
                    pipe.send(None)
                elif cmd == "save":
                    self.save(data)
                    pipe.send(None)
                elif cmd == "transition":
                    with replay_buffer_lock:
                        if "env_id" not in data:
                            data["env_id"] = actor_idx
                        self.replay_buffer.append(**data)
                        self._cumulative_steps += 1
                elif cmd == "stop_episode":
                    idx = actor_idx if data is None else data
                    with replay_buffer_lock:
                        self.replay_buffer.stop_current_episode(env_id=idx)
                        stats = self.get_statistics()
                    pipe.send(stats)

                else:
                    raise RuntimeError(
                        "Unknown command from actor: {}".format(cmd))
        except EOFError:
            pipe.close()
        except Exception:
            self.logger.exception("Poller loop failed. Exiting")
            exception_event.set()
예제 #14
0
def repack_process(p_tool_man: ToolMan, filename: str,
                   pipe: multiprocessing.connection.Connection):
    # sys.stdout = open(str(os.getpid()) + ".out", 'w')
    def smart_print(x):
        print(x)
        pipe.send(x)
        sys.stdout.flush()

    smart_print("Rebuilding CastleDB")
    proc1 = p_tool_man.run_rebuild_cdb()

    while proc1.poll() is None:
        smart_print(proc1.communicate())

    smart_print("Rebuilding res.pak")
    proc2 = p_tool_man.run_rebuild_pak(filename)
    while proc2.poll() is None:
        smart_print(proc2.communicate())

    smart_print("Done")
    pipe.close()
예제 #15
0
def unpack_process(p_tool_man: ToolMan, filename: str,
                   pipe: multiprocessing.connection.Connection):
    # sys.stdout = open(str(os.getpid()) + ".out", 'w')
    def smart_print(x):
        print(x)
        pipe.send(x)
        sys.stdout.flush()

    smart_print("Unpacking res.pak")
    proc1 = p_tool_man.run_extract_pak(filename)

    while proc1.poll() is None:
        smart_print(proc1.communicate())

    smart_print("Expanding data.cdb")
    proc2 = p_tool_man.run_extract_cdb()
    while proc2.poll() is None:
        smart_print(proc2.communicate())

    smart_print("Done")
    pipe.close()
예제 #16
0
def _worker(
        remote: mp.connection.Connection, parent_remote: mp.connection.Connection, manager_fn_wrapper: CloudpickleWrapper
) -> None:
    parent_remote.close()

    torch.set_num_threads(1)
    game_manager = manager_fn_wrapper.var()

    while True:
        try:
            cmd, data = remote.recv()
            if cmd == "gather_rollouts":
                observations, hidden_states, rewards, actions, action_masks, \
                    action_log_probs, dones = game_manager.gather_rollouts()
                game_manager._after_rollouts()
                remote.send(
                    CloudpickleWrapper((observations, hidden_states, rewards, actions, action_masks,
                                        action_log_probs, dones))
                )
            elif cmd == "reset":
                game_manager.reset()
                remote.send(True)
            elif cmd == "close":
                remote.close()
                break
            elif cmd == "update_policy":
                state_dict = data[0].var
                policy_id = data[1]
                game_manager._update_policy(state_dict, policy_id)

                remote.send(True)
            elif cmd == "seed":
                np.random.seed(data)
                random.seed(data)
                torch.manual_seed(data)
                remote.send(True)
            else:
                raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
        except EOFError:
            break
예제 #17
0
def worker_process(remote: multiprocessing.connection.Connection,
                   env_name: str, crop_size: int, kwargs: Dict):

    game = wrappers.CroppedImagePCGRLWrapper(env_name, crop_size, **kwargs)

    while True:
        cmd, data = remote.recv()
        if cmd == "step":
            # print('stepping')
            temp = game.step(data)
            # print(temp)
            remote.send(temp)
        elif cmd == "reset":
            # print('resetting')
            temp = game.reset()
            # print(temp)
            remote.send(temp)
        elif cmd == "close":
            remote.close()
            break
        else:
            raise NotImplementedError
예제 #18
0
파일: game.py 프로젝트: weihaoxie/nn
def worker_process(remote: multiprocessing.connection.Connection, seed: int):
    """
    ##Worker Process

    Each worker process runs this method
    """

    # create game
    game = Game(seed)

    # wait for instructions from the connection and execute them
    while True:
        cmd, data = remote.recv()
        if cmd == "step":
            remote.send(game.step(data))
        elif cmd == "reset":
            remote.send(game.reset())
        elif cmd == "close":
            remote.close()
            break
        else:
            raise NotImplementedError
예제 #19
0
def _worker(
    env_constructor: bytes,
    auto_reset: bool,
    pipe: mp.connection.Connection,
    polling_period: float = 0.1,
):
    """Process to build and run an environment. Using a pipe to
    communicate with parent, the process receives action, steps
    the environment, and returns the observations.

    Args:
        env_constructor (bytes): Cloudpickled callable which constructs the environment.
        auto_reset (bool): If True, auto resets environment when episode ends.
        pipe (mp.connection.Connection): Child's end of the pipe.
        polling_period (float, optional): Time to wait for keyboard interrupts. Defaults to 0.1.

    Raises:
        KeyError: If unknown message type is received.
    """
    env = cloudpickle.loads(env_constructor)()
    pipe.send((_Message.RESULT, None))

    try:
        while True:
            if not pipe.poll(polling_period):
                continue
            message, payload = pipe.recv()
            if message == _Message.SEED:
                env_seed = env.seed(payload)
                pipe.send((_Message.RESULT, env_seed))
            elif message == _Message.ACCESS:
                result = getattr(env, payload, None)
                pipe.send((_Message.RESULT, result))
            elif message == _Message.RESET:
                observation = env.reset()
                pipe.send((_Message.RESULT, observation))
            elif message == _Message.STEP:
                observation, reward, done, info = env.step(payload)
                if done["__all__"] and auto_reset:
                    # Final observation can be obtained from `info` as follows:
                    # `final_obs = info[agent_id]["env_obs"]`
                    observation = env.reset()
                pipe.send((_Message.RESULT, (observation, reward, done, info)))
            elif message == _Message.CLOSE:
                break
            else:
                raise KeyError(
                    f"Expected message from {_Message.__members__}, but got unknown message `{message}`."
                )
    except (Exception, KeyboardInterrupt):
        etype, evalue, tb = sys.exc_info()
        if etype == KeyboardInterrupt:
            stacktrace = "".join(
                traceback.format_exception(etype, evalue, None))
        else:
            stacktrace = "".join(traceback.format_exception(etype, evalue, tb))
        payload = (mp.current_process().name, stacktrace)
        pipe.send((_Message.EXCEPTION, payload))
    finally:
        env.close()
        pipe.close()