def discrete_atari_env(): env = AtariEnvironment(name="MsPacman-v0", clone_seeds=True, autoreset=True) env.reset() env = DiscreteEnv(env) return env
def atari_env(): env = AtariEnvironment(name="MsPacman-v0", clone_seeds=True, autoreset=True) env.reset() env = DiscreteEnv(env) params = { "actions": { "dtype": np.int64 }, "critic": { "dtype": np.float32 } } states = States(state_dict=params, batch_size=N_WALKERS) states.update(actions=np.ones(N_WALKERS), critic=np.ones(N_WALKERS)) return env, states
class DistributedRam(DistributedSwarm): def __init__(self, swarm, *args, **kwargs): super(DistributedRam, self).__init__(swarm=swarm, *args, **kwargs) self.local_swarm = swarm() env = self.local_swarm.env env_name = env.name if isinstance( env, ParallelEnvironment) else env._env.name self.local_env = AtariEnvironment(name=env_name, clone_seeds=True) self.local_env.reset() def image_from_state(self, state): self.local_env.set_state(state.astype(np.uint8).copy()) self.local_env.step(0) return np.asarray(self.local_env._env.ale.getScreenRGB(), dtype=np.uint8) def stream_progress(self, state, observation, reward): example = pd.DataFrame({"reward": [reward]}, index=[self.n_iters // self.n_swarms]) self.stream.emit(example) obs = self.image_from_state(state) self.frame_pipe.send(obs)
def qbert_rgb(): env = AtariEnvironment(name="Qbert-v0", clone_seeds=True, autoreset=True) env.reset() env = AtariEnv(env) return env
def pacman_ram(): env = AtariEnvironment(name="MsPacman-ram-v0", clone_seeds=True, autoreset=True) env.reset() env = AtariEnv(env) return env