def make_atari(env_id, max_episode_steps=4500, action_space=None): # from environments.gym_super_mario_bros import make from nes_py.wrappers import JoypadSpace from environments.gym_super_mario_bros.actions import RIGHT_ONLY, SIMPLE_MOVEMENT, COMPLEX_MOVEMENT from environments.mario_env import SuperMario_Env # env = gym.make(env_id) if action_space == 'RIGHT_ONLY': mario_action_space = RIGHT_ONLY elif action_space == 'SIMPLE_MOVEMENT': mario_action_space = SIMPLE_MOVEMENT elif action_space == 'COMPLEX_MOVEMENT': mario_action_space = COMPLEX_MOVEMENT elif action_space is None : mario_action_space = RIGHT_ONLY else: mario_action_space = RIGHT_ONLY env = SuperMario_Env(world=2, stage=1, scale=True) #make('SuperMarioBros-2-1-v0') # ('SuperMarioBros-v0') env = JoypadSpace(env, mario_action_space) env._max_episode_steps = max_episode_steps*4 # assert 'NoFrameskip' in env.spec.id env = StickyActionEnv(env) env = MaxAndSkipEnv(env, skip=4) if "Montezuma" in env_id or "Pitfall" in env_id: env = MontezumaInfoWrapper(env, room_address=3 if "Montezuma" in env_id else 1) else: env = DummyMontezumaInfoWrapper(env) # print(f'env id is {env_id}!!!!!!!!!!!!!!!!!!!!!!!!!!!') if 'sparse' in env_id: # print(f'using sparse env!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') env = SparseRewardEnv(env) env = AddRandomStateToInfo(env) return env
def get_marioenv(world=1, stage=1, version=0, movement = RIGHT_ONLY, max_episode_steps=4500): mario_env = JoypadSpace(SuperMario_Env(world, stage, version), movement) mario_env._max_episode_steps = max_episode_steps * 4 mario_env = StickyActionEnv(mario_env) mario_env = MaxAndSkipEnv(mario_env, skip=4) mario_env = DummyMontezumaInfoWrapper(mario_env) mario_env = AddRandomStateToInfo(mario_env) mario_env = wrap_deepmind(mario_env, frame_stack=True) return mario_env
def _thunk(): mario_env = JoypadSpace(SuperMario_Env(world, stage, version), movement) if wrap_atari: mario_env._max_episode_steps = max_episode_steps * 4 mario_env = StickyActionEnv(mario_env) mario_env = MaxAndSkipEnv(mario_env, skip=4) mario_env = DummyMontezumaInfoWrapper(mario_env) mario_env = AddRandomStateToInfo(mario_env) # mario_env.seed(seed + rank) mario_env = Monitor( mario_env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), allow_early_resets=True) if wrap_atari: mario_env = wrap_deepmind(mario_env) mario_env = BlocksWrapper(mario_env) mario_env.seed(seed) return mario_env