class MiniAtariTask: def __init__(self, env_id, seed=np.random.randint(int(1e5)), sticky_action_prob=0.0): random_seed(seed) # TODO: Allow sticky_action_prob and difficulty_ramping to be set by the configuration file self.env = Environment(env_id, random_seed=seed, sticky_action_prob=0.0, difficulty_ramping=False) self.name = env_id self.state_dim = self.env.state_shape() self.action_set = self.env.minimal_action_set() self.action_dim = len(self.action_set) def reset(self): self.env.reset() return self.env.state().flatten() def step(self, actions): rew, done = self.env.act(self.action_set[actions[0]]) obs = self.reset() if done else self.env.state() return obs.flatten(), np.asarray(rew), np.asarray(done), ""
class BaseEnv(gym.Env): metadata = {'render.modes': ['human', 'rgb_array']} def __init__(self, display_time=50, **kwargs): self.game_name = 'Game Name' self.display_time = display_time self.init(**kwargs) def init(self, **kwargs): self.game = Environment(env_name=self.game_name, **kwargs) self.action_set = self.game.env.action_map self.action_space = spaces.Discrete(self.game.num_actions()) self.observation_space = spaces.Box(0.0, 1.0, shape=self.game.state_shape(), dtype=np.float32) def step(self, action): reward, done = self.game.act(action) return (self.game.state(), reward, done, {}) def reset(self): self.game.reset() return self.game.state() def seed(self, seed=None): self.game = Environment(env_name=self.game_name, random_seed=seed) return seed def render(self, mode='human'): if mode == 'rgb_array': return self.game.state() elif mode == 'human': self.game.display_state(self.display_time) def close(self): if self.game.visualized: self.game.close_display() return 0