def make_atari_environment( level: str = 'Pong', sticky_actions: bool = True, zero_discount_on_life_loss: bool = False, oar_wrapper: bool = False, ) -> dm_env.Environment: """Loads the Atari environment.""" # Internal logic. version = 'v0' if sticky_actions else 'v4' level_name = f'{level}NoFrameskip-{version}' env = gym.make(level_name, full_action_space=True) wrapper_list = [ wrappers.GymAtariAdapter, functools.partial( wrappers.AtariWrapper, to_float=True, max_episode_len=108_000, zero_discount_on_life_loss=zero_discount_on_life_loss, ), wrappers.SinglePrecisionWrapper, ] if oar_wrapper: # E.g. IMPALA and R2D2 use this particular variant. wrapper_list.append(wrappers.ObservationActionRewardWrapper) return wrappers.wrap_all(env, wrapper_list)
def make_environmment(): """Make environment. Returns: env (acme.wrappers.single_precision.SinglePrecisionWrapper). """ # Create the environment environment = gym.make("gym_missile_command:missile-command-v0", custom_config=ENV_CONFIG) # Add the necessary ALE function environment.ale = DummyALE() # Acme processing environment = wrappers.wrap_all(environment, [ wrappers.GymAtariAdapter, functools.partial( wrappers.AtariWrapper, to_float=True, max_episode_len=MAX_EPISODE_LEN, zero_discount_on_life_loss=True, ), wrappers.SinglePrecisionWrapper, ]) return environment
def make_dqn_atari_environment(task_and_level: str = 'PongNoFrameskip-v4', evaluation: bool = False) -> dm_env.Environment: env = gym.make(task_and_level, full_action_space=True) max_episode_len = 108_000 if evaluation else 50_000 return wrappers.wrap_all(env, [ wrappers.GymAtariAdapter, functools.partial( wrappers.AtariWrapper, to_float=True, max_episode_len=max_episode_len, zero_discount_on_life_loss=True, ), wrappers.SinglePrecisionWrapper, ])
def make_environment(level: str = 'PongNoFrameskip-v4', oar_wrapper: bool = False) -> dm_env.Environment: """Loads the Atari environment.""" env = gym.make(level, full_action_space=True) # Always use episodes of 108k steps as this is standard, matching the paper. max_episode_len = 108_000 wrapper_list = [ wrappers.GymAtariAdapter, functools.partial( wrappers.AtariWrapper, to_float=True, max_episode_len=max_episode_len, zero_discount_on_life_loss=True, ), ] if oar_wrapper: # E.g. IMPALA and R2D2 use this particular variant. wrapper_list.append(wrappers.ObservationActionRewardWrapper) wrapper_list.append(wrappers.SinglePrecisionWrapper) return wrappers.wrap_all(env, wrapper_list)