def environment(game): """Atari environment.""" env = atari_lib.create_atari_environment(game_name=game, sticky_actions=True) env = AtariDopamineWrapper(env) env = wrappers.FrameStackingWrapper(env, num_frames=4) return wrappers.SinglePrecisionWrapper(env)
def test_second_reset(self): original_env = FakeNonZeroObservationEnvironment() env = wrappers.FrameStackingWrapper(original_env, 2) action_spec = env.action_spec() env.reset() env.step(action_spec.generate_value()) timestep = env.reset() self.assertTrue(np.all(timestep.observation[..., 0] == 0))
def make_environment(task, end_on_success, max_episode_steps, distance_fn, goal_image, baseline_distance=None, eval_mode=False, logdir=None, counter=None, record_every=100, num_episodes_to_record=3): """Create the environment and its wrappers.""" env = gym.make(task) env = gym_wrapper.GymWrapper(env) if end_on_success: env = env_wrappers.EndOnSuccessWrapper(env) env = wrappers.StepLimitWrapper(env, max_episode_steps) env = env_wrappers.ReshapeImageWrapper(env) if distance_fn.history_length > 1: env = wrappers.FrameStackingWrapper(env, distance_fn.history_length) env = env_wrappers.GoalConditionedWrapper(env, goal_image) env = env_wrappers.DistanceModelWrapper( env, distance_fn, max_episode_steps, baseline_distance, distance_reward_weight=FLAGS.distance_reward_weight, environment_reward_weight=FLAGS.environment_reward_weight) if FLAGS.use_true_distance: env = env_wrappers.RewardWrapper(env) if logdir: env = env_wrappers.RecordEpisodesWrapper( env, counter, logdir, record_every=record_every, num_to_record=num_episodes_to_record, eval_mode=eval_mode) env = env_wrappers.VisibleStateWrapper(env, eval_mode) return single_precision.SinglePrecisionWrapper(env)
def test_specs(self): original_env = FakeNonZeroObservationEnvironment() env = wrappers.FrameStackingWrapper(original_env, 2) original_observation_spec = original_env.observation_spec() expected_shape = original_observation_spec.shape + (2,) observation_spec = env.observation_spec() self.assertEqual(expected_shape, observation_spec.shape) expected_action_spec = original_env.action_spec() action_spec = env.action_spec() self.assertEqual(expected_action_spec, action_spec) expected_reward_spec = original_env.reward_spec() reward_spec = env.reward_spec() self.assertEqual(expected_reward_spec, reward_spec) expected_discount_spec = original_env.discount_spec() discount_spec = env.discount_spec() self.assertEqual(expected_discount_spec, discount_spec)
def make_environment( evaluation: bool = False, domain_name: str = 'cartpole', task_name: str = 'balance', from_pixels: bool = False, frames_to_stack: int = 3, flatten_stack: bool = False, num_action_repeats: Optional[int] = None, ) -> dm_env.Environment: """Implements a control suite environment factory.""" # Load dm_suite lazily not require Mujoco license when not using it. from dm_control import suite # pylint: disable=g-import-not-at-top from acme.wrappers import mujoco as mujoco_wrappers # pylint: disable=g-import-not-at-top # Load raw control suite environment. environment = suite.load(domain_name, task_name) # Maybe wrap to get pixel observations from environment state. if from_pixels: environment = mujoco_wrappers.MujocoPixelWrapper(environment) environment = wrappers.FrameStackingWrapper(environment, num_frames=frames_to_stack, flatten=flatten_stack) environment = wrappers.CanonicalSpecWrapper(environment, clip=True) if num_action_repeats: environment = wrappers.ActionRepeatWrapper( environment, num_repeats=num_action_repeats) environment = wrappers.SinglePrecisionWrapper(environment) if evaluation: # The evaluator in the distributed agent will set this to True so you can # use this clause to, e.g., set up video recording by the evaluator. pass return environment