예제 #1
0
def _worker(
    parent: connection.Connection,
    p: connection.Connection,
    env_fn_wrapper: CloudpickleWrapper,
    obs_bufs: Optional[Union[dict, tuple, ShArray]] = None,
) -> None:

    def _encode_obs(
        obs: Union[dict, tuple, np.ndarray], buffer: Union[dict, tuple, ShArray]
    ) -> None:
        if isinstance(obs, np.ndarray) and isinstance(buffer, ShArray):
            buffer.save(obs)
        elif isinstance(obs, tuple) and isinstance(buffer, tuple):
            for o, b in zip(obs, buffer):
                _encode_obs(o, b)
        elif isinstance(obs, dict) and isinstance(buffer, dict):
            for k in obs.keys():
                _encode_obs(obs[k], buffer[k])
        return None

    parent.close()
    env = env_fn_wrapper.data()
    try:
        while True:
            try:
                cmd, data = p.recv()
            except EOFError:  # the pipe has been closed
                p.close()
                break
            if cmd == "step":
                if data is None:  # reset
                    obs = env.reset()
                else:
                    obs, reward, done, info = env.step(data)
                if obs_bufs is not None:
                    _encode_obs(obs, obs_bufs)
                    obs = None
                if data is None:
                    p.send(obs)
                else:
                    p.send((obs, reward, done, info))
            elif cmd == "close":
                p.send(env.close())
                p.close()
                break
            elif cmd == "render":
                p.send(env.render(**data) if hasattr(env, "render") else None)
            elif cmd == "seed":
                p.send(env.seed(data) if hasattr(env, "seed") else None)
            elif cmd == "getattr":
                p.send(getattr(env, data) if hasattr(env, data) else None)
            elif cmd == "setattr":
                setattr(env, data["key"], data["value"])
            else:
                p.close()
                raise NotImplementedError
    except KeyboardInterrupt:
        p.close()
예제 #2
0
 def __init__(self, env_fns):
     super().__init__(env_fns)
     self.closed = False
     self.parent_remote, self.child_remote = \
         zip(*[Pipe() for _ in range(self.env_num)])
     self.processes = [
         Process(target=worker, args=(
             parent, child, CloudpickleWrapper(env_fn)), daemon=True)
         for (parent, child, env_fn) in zip(
             self.parent_remote, self.child_remote, env_fns)
     ]
     for p in self.processes:
         p.start()
     for c in self.child_remote:
         c.close()
예제 #3
0
 def __init__(self, env_fn: Callable[[], gym.Env],
              share_memory=False) -> None:
     super().__init__(env_fn)
     self.parent_remote, self.child_remote = Pipe()
     self.share_memory = share_memory
     self.buffer = None
     if self.share_memory:
         dummy = env_fn()
         obs_space = dummy.observation_space
         dummy.close()
         del dummy
         self.buffer = _setup_buf(obs_space)
     args = (self.parent_remote, self.child_remote,
             CloudpickleWrapper(env_fn), self.buffer)
     self.process = Process(target=_worker, args=args, daemon=True)
     self.process.start()
     self.child_remote.close()