Exemplo n.º 1
0
    def test_task_runs(self, task_name):
        """Tests that the environment runs and is coherent with its specs."""
        seed = 99 if _get_fix_seed() else None
        env = manipulation.load(task_name, seed=seed)
        random_state = np.random.RandomState(seed)

        observation_spec = env.observation_spec()
        action_spec = env.action_spec()
        self.assertTrue(np.all(np.isfinite(action_spec.minimum)))
        self.assertTrue(np.all(np.isfinite(action_spec.maximum)))

        # Run a partial episode, check observations, rewards, discount.
        for _ in range(_NUM_EPISODES):
            time_step = env.reset()
            for _ in range(_NUM_STEPS_PER_EPISODE):
                self._validate_observation(time_step.observation,
                                           observation_spec)
                if time_step.first():
                    self.assertIsNone(time_step.reward)
                    self.assertIsNone(time_step.discount)
                else:
                    self._validate_reward_range(time_step.reward)
                    self._validate_discount(time_step.discount)
                action = random_state.uniform(action_spec.minimum,
                                              action_spec.maximum)
                env.step(action)
Exemplo n.º 2
0
def make(env_name, frame_stack, action_repeat, seed):
    domain, task = split_env_name(env_name)

    if domain == 'manip':
        env = manipulation.load(f'{task}_vision', seed=seed)
    else:
        env = suite.load(domain,
                         task,
                         task_kwargs={'random': seed},
                         visualize_reward=False)

    # apply action repeat and scaling
    env = ActionRepeatWrapper(env, action_repeat)
    env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0)
    # flatten features
    env = FlattenObservationWrapper(env)

    if domain != 'manip':
        # per dreamer: https://github.com/danijar/dreamer/blob/02f0210f5991c7710826ca7881f19c64a012290c/wrappers.py#L26
        camera_id = 2 if domain == 'quadruped' else 0
        render_kwargs = {'height': 84, 'width': 84, 'camera_id': camera_id}
        env = pixels.Wrapper(env,
                             pixels_only=False,
                             render_kwargs=render_kwargs)

    env = FrameStackWrapper(env, frame_stack)

    action_spec = env.action_spec()
    assert np.all(action_spec.minimum >= -1.0)
    assert np.all(action_spec.maximum <= +1.0)

    return env
Exemplo n.º 3
0
 def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
   os.environ['MUJOCO_GL'] = 'egl'
   domain, task = name.split('_', 1)
   if domain == 'cup':  # Only domain with multiple words.
     domain = 'ball_in_cup'
   if domain == 'manip':
     from dm_control import manipulation
     self._env = manipulation.load(task + '_vision')
   elif domain == 'locom':
     from dm_control.locomotion.examples import basic_rodent_2020
     self._env = getattr(basic_rodent_2020, task)()
   else:
     from dm_control import suite
     self._env = suite.load(domain, task)
   self._action_repeat = action_repeat
   self._size = size
   if camera in (-1, None):
     camera = dict(
         quadruped_walk=2, quadruped_run=2, quadruped_escape=2,
         quadruped_fetch=2, locom_rodent_maze_forage=1,
         locom_rodent_two_touch=1,
     ).get(name, 0)
   self._camera = camera
   self._ignored_keys = []
   for key, value in self._env.observation_spec().items():
     if value.shape == (0,):
       print(f"Ignoring empty observation key '{key}'.")
       self._ignored_keys.append(key)
Exemplo n.º 4
0
 def __init__(self, name: str, size: Tuple[int, int] = (64, 64), camera=0):
     if 'vision' in name:
         self._env = manipulation.load(environment_name=name)
     else:
         domain, task = name.split('_', 1)
         self._env = suite.load(domain, task)
     self._size = size
     self._camera = camera
Exemplo n.º 5
0
    def __init__(self,
                 domain_name,
                 task_name,
                 task_kwargs=None,
                 visualize_reward={},
                 from_pixels=False,
                 height=84,
                 width=84,
                 camera_id=0,
                 frame_skip=1,
                 environment_kwargs=None,
                 channels_first=True):
        assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour'
        self._from_pixels = from_pixels
        self._height = height
        self._width = width
        self._camera_id = camera_id
        self._frame_skip = frame_skip
        self._channels_first = channels_first

        # create task
        if domain_name == 'jaco':
            self._env, self._mod_tag = manipulation.load(task_name)
        else:
            self._env = suite.load(domain_name=domain_name,
                                   task_name=task_name,
                                   task_kwargs=task_kwargs,
                                   visualize_reward=visualize_reward,
                                   environment_kwargs=environment_kwargs)

        # true and normalized action spaces
        self._true_action_space = _spec_to_box([self._env.action_spec()])
        self._norm_action_space = spaces.Box(
            low=-1.0,
            high=1.0,
            shape=self._true_action_space.shape,
            dtype=np.float32)

        # create observation space
        if from_pixels:
            shape = [3, height, width
                     ] if channels_first else [height, width, 3]
            self._observation_space = spaces.Box(low=0,
                                                 high=255,
                                                 shape=shape,
                                                 dtype=np.uint8)
        else:
            self._observation_space = _spec_to_box(
                self._env.observation_spec().values())

        self._state_space = _spec_to_box(self._env.observation_spec().values())

        self.current_state = None

        # set seed
        self.seed(seed=task_kwargs.get('random', 1))
Exemplo n.º 6
0
def load_env(env_name,
             seed,
             action_repeat=0,
             frame_stack=1,
             obs_type='pixels'):
    """Loads a learning environment.

  Args:
    env_name: Name of the environment.
    seed: Random seed.
    action_repeat: (optional) action repeat multiplier. Useful for DM control
      suite tasks.
    frame_stack: (optional) frame stack.
    obs_type: `pixels` or `state`
  Returns:
    Learning environment.
  """

    action_repeat_applied = False
    state_env = None

    if env_name.startswith('dm'):
        _, domain_name, task_name = env_name.split('-')
        if 'manipulation' in domain_name:
            env = manipulation.load(task_name)
            env = dm_control_wrapper.DmControlWrapper(env)
        else:
            env = _load_dm_env(domain_name,
                               task_name,
                               pixels=False,
                               action_repeat=action_repeat)
            action_repeat_applied = True
        env = wrappers.FlattenObservationsWrapper(env)

    elif env_name.startswith('pixels-dm'):
        if 'distractor' in env_name:
            _, _, domain_name, task_name, _ = env_name.split('-')
            distractor = True
        else:
            _, _, domain_name, task_name = env_name.split('-')
            distractor = False
        # TODO(tompson): Are there DMC environments that have other
        # max_episode_steps?
        env = _load_dm_env(domain_name,
                           task_name,
                           pixels=True,
                           action_repeat=action_repeat,
                           max_episode_steps=1000,
                           obs_type=obs_type,
                           distractor=distractor)
        action_repeat_applied = True
        if obs_type == 'pixels':
            env = FlattenImageObservationsWrapper(env)
            state_env = None
        else:
            env = JointImageObservationsWrapper(env)
            state_env = tf_py_environment.TFPyEnvironment(
                wrappers.FlattenObservationsWrapper(
                    _load_dm_env(domain_name,
                                 task_name,
                                 pixels=False,
                                 action_repeat=action_repeat)))

    else:
        env = suite_mujoco.load(env_name)
        env.seed(seed)

    if action_repeat > 1 and not action_repeat_applied:
        env = wrappers.ActionRepeat(env, action_repeat)
    if frame_stack > 1:
        env = FrameStackWrapperTfAgents(env, frame_stack)

    env = tf_py_environment.TFPyEnvironment(env)

    return env, state_env