def environment(game):
    """Atari environment."""
    env = atari_lib.create_atari_environment(game_name=game,
                                             sticky_actions=True)
    env = AtariDopamineWrapper(env)
    env = wrappers.FrameStackingWrapper(env, num_frames=4)
    return wrappers.SinglePrecisionWrapper(env)
Exemplo n.º 2
0
  def test_second_reset(self):
    original_env = FakeNonZeroObservationEnvironment()
    env = wrappers.FrameStackingWrapper(original_env, 2)
    action_spec = env.action_spec()

    env.reset()
    env.step(action_spec.generate_value())
    timestep = env.reset()
    self.assertTrue(np.all(timestep.observation[..., 0] == 0))
Exemplo n.º 3
0
def make_environment(task,
                     end_on_success,
                     max_episode_steps,
                     distance_fn,
                     goal_image,
                     baseline_distance=None,
                     eval_mode=False,
                     logdir=None,
                     counter=None,
                     record_every=100,
                     num_episodes_to_record=3):
    """Create the environment and its wrappers."""
    env = gym.make(task)
    env = gym_wrapper.GymWrapper(env)
    if end_on_success:
        env = env_wrappers.EndOnSuccessWrapper(env)
    env = wrappers.StepLimitWrapper(env, max_episode_steps)

    env = env_wrappers.ReshapeImageWrapper(env)
    if distance_fn.history_length > 1:
        env = wrappers.FrameStackingWrapper(env, distance_fn.history_length)
    env = env_wrappers.GoalConditionedWrapper(env, goal_image)
    env = env_wrappers.DistanceModelWrapper(
        env,
        distance_fn,
        max_episode_steps,
        baseline_distance,
        distance_reward_weight=FLAGS.distance_reward_weight,
        environment_reward_weight=FLAGS.environment_reward_weight)
    if FLAGS.use_true_distance:
        env = env_wrappers.RewardWrapper(env)
    if logdir:
        env = env_wrappers.RecordEpisodesWrapper(
            env,
            counter,
            logdir,
            record_every=record_every,
            num_to_record=num_episodes_to_record,
            eval_mode=eval_mode)
    env = env_wrappers.VisibleStateWrapper(env, eval_mode)

    return single_precision.SinglePrecisionWrapper(env)
Exemplo n.º 4
0
  def test_specs(self):
    original_env = FakeNonZeroObservationEnvironment()
    env = wrappers.FrameStackingWrapper(original_env, 2)

    original_observation_spec = original_env.observation_spec()
    expected_shape = original_observation_spec.shape + (2,)
    observation_spec = env.observation_spec()
    self.assertEqual(expected_shape, observation_spec.shape)

    expected_action_spec = original_env.action_spec()
    action_spec = env.action_spec()
    self.assertEqual(expected_action_spec, action_spec)

    expected_reward_spec = original_env.reward_spec()
    reward_spec = env.reward_spec()
    self.assertEqual(expected_reward_spec, reward_spec)

    expected_discount_spec = original_env.discount_spec()
    discount_spec = env.discount_spec()
    self.assertEqual(expected_discount_spec, discount_spec)
Exemplo n.º 5
0
def make_environment(
    evaluation: bool = False,
    domain_name: str = 'cartpole',
    task_name: str = 'balance',
    from_pixels: bool = False,
    frames_to_stack: int = 3,
    flatten_stack: bool = False,
    num_action_repeats: Optional[int] = None,
) -> dm_env.Environment:
    """Implements a control suite environment factory."""
    # Load dm_suite lazily not require Mujoco license when not using it.
    from dm_control import suite  # pylint: disable=g-import-not-at-top
    from acme.wrappers import mujoco as mujoco_wrappers  # pylint: disable=g-import-not-at-top

    # Load raw control suite environment.
    environment = suite.load(domain_name, task_name)

    # Maybe wrap to get pixel observations from environment state.
    if from_pixels:
        environment = mujoco_wrappers.MujocoPixelWrapper(environment)
        environment = wrappers.FrameStackingWrapper(environment,
                                                    num_frames=frames_to_stack,
                                                    flatten=flatten_stack)
    environment = wrappers.CanonicalSpecWrapper(environment, clip=True)

    if num_action_repeats:
        environment = wrappers.ActionRepeatWrapper(
            environment, num_repeats=num_action_repeats)
    environment = wrappers.SinglePrecisionWrapper(environment)

    if evaluation:
        # The evaluator in the distributed agent will set this to True so you can
        # use this clause to, e.g., set up video recording by the evaluator.
        pass

    return environment