Ejemplo n.º 1
0
def load(domain_name,
         task_name,
         task_kwargs=None,
         visualize_reward=False,
         render_kwargs=None,
         env_wrappers=()):
  """Returns an environment from a domain name, task name and optional settings.

  Args:
    domain_name: A string containing the name of a domain.
    task_name: A string containing the name of a task.
    task_kwargs: Optional `dict` of keyword arguments for the task.
    visualize_reward: Optional `bool`. If `True`, object colours in rendered
      frames are set to indicate the reward at each step. Default `False`.
    render_kwargs: Optional `dict` of keyword arguments for rendering.
    env_wrappers: Iterable with references to wrapper classes to use on the
      gym_wrapped environment.

  Returns:
    The requested environment.

  Raises:
    ImportError: if dm_control module was not available.
  """
  if not is_available():
    raise ImportError("dm_control module is not available.")
  dm_env = suite.load(domain_name, task_name, task_kwargs, visualize_reward)
  env = dm_control_wrapper.DmControlWrapper(dm_env, render_kwargs)

  for wrapper in env_wrappers:
    env = wrapper(env)

  return env
Ejemplo n.º 2
0
def load_pixels(
    domain_name: Text,
    task_name: Text,
    observation_key: Text = 'pixels',
    pixels_only: bool = True,
    task_kwargs=None,
    environment_kwargs=None,
    visualize_reward: bool = False,
    render_kwargs=None,
    env_wrappers: Sequence[types.PyEnvWrapper] = ()
) -> py_environment.PyEnvironment:
    """Returns an environment from a domain name, task name and optional settings.

  Args:
    domain_name: A string containing the name of a domain.
    task_name: A string containing the name of a task.
    observation_key: Optional custom string specifying the pixel observation's
      key in the `OrderedDict` of observations. Defaults to 'pixels'.
    pixels_only: If True (default), the original set of 'state' observations
      returned by the wrapped environment will be discarded, and the
      `OrderedDict` of observations will only contain pixels. If False, the
      `OrderedDict` will contain the original observations as well as the pixel
      observations.
    task_kwargs: Optional `dict` of keyword arguments for the task.
    environment_kwargs: Optional `dict` specifying keyword arguments for the
      environment.
    visualize_reward: Optional `bool`. If `True`, object colours in rendered
      frames are set to indicate the reward at each step. Default `False`.
    render_kwargs: Optional `dict` of keyword arguments for rendering.
    env_wrappers: Iterable with references to wrapper classes to use on the
      wrapped environment.

  Returns:
    The requested environment.

  Raises:
    ImportError: if dm_control module was not available.
  """
    dm_env = _load_env(domain_name,
                       task_name,
                       task_kwargs=task_kwargs,
                       environment_kwargs=environment_kwargs,
                       visualize_reward=visualize_reward)

    dm_env = pixels.Wrapper(dm_env,
                            pixels_only=pixels_only,
                            render_kwargs=render_kwargs,
                            observation_key=observation_key)
    env = dm_control_wrapper.DmControlWrapper(dm_env, render_kwargs)

    for wrapper in env_wrappers:
        env = wrapper(env)

    return env
Ejemplo n.º 3
0
def load(
        domain_name,
        task_name,
        task_kwargs=None,
        environment_kwargs=None,
        env_load_fn=suite.load,  # use custom_suite.load for customized env
        action_repeat_wrapper=wrappers.ActionRepeat,
        action_repeat=1,
        frame_stack=4,
        episode_length=1000,
        actions_in_obs=True,
        rewards_in_obs=False,
        pixels_obs=True,
        # Render params
        grayscale=False,
        visualize_reward=False,
        render_kwargs=None):
    """Returns an environment from a domain name, task name."""
    env = env_load_fn(domain_name,
                      task_name,
                      task_kwargs=task_kwargs,
                      environment_kwargs=environment_kwargs,
                      visualize_reward=visualize_reward)
    if pixels_obs:
        env = pixel_wrapper.Wrapper(env,
                                    pixels_only=False,
                                    render_kwargs=render_kwargs)

    env = dm_control_wrapper.DmControlWrapper(env, render_kwargs)

    if pixels_obs and grayscale:
        env = GrayscaleWrapper(env)
    if action_repeat > 1:
        env = action_repeat_wrapper(env, action_repeat)
    if pixels_obs:
        env = FrameStack(env, frame_stack, actions_in_obs, rewards_in_obs)
    else:
        env = FlattenState(env)

    # Adjust episode length based on action_repeat
    max_episode_steps = (episode_length + action_repeat - 1) // action_repeat

    # Apply a time limit wrapper at the end to properly trigger all reset()
    env = wrappers.TimeLimit(env, max_episode_steps)
    return env
Ejemplo n.º 4
0
def load(
    domain_name: Text,
    task_name: Text,
    task_kwargs=None,
    environment_kwargs=None,
    visualize_reward: bool = False,
    render_kwargs=None,
    env_wrappers: Sequence[types.PyEnvWrapper] = ()
) -> py_environment.PyEnvironment:
    """Returns an environment from a domain name, task name and optional settings.

  Args:
    domain_name: A string containing the name of a domain.
    task_name: A string containing the name of a task.
    task_kwargs: Optional `dict` of keyword arguments for the task.
    environment_kwargs: Optional `dict` specifying keyword arguments for the
      environment.
    visualize_reward: Optional `bool`. If `True`, object colours in rendered
      frames are set to indicate the reward at each step. Default `False`.
    render_kwargs: Optional `dict` of keyword arguments for rendering.
    env_wrappers: Iterable with references to wrapper classes to use on the
      wrapped environment.

  Returns:
    The requested environment.

  Raises:
    ImportError: if dm_control module was not available.
  """
    dmc_env = _load_env(domain_name,
                        task_name,
                        task_kwargs=task_kwargs,
                        environment_kwargs=environment_kwargs,
                        visualize_reward=visualize_reward)

    env = dm_control_wrapper.DmControlWrapper(dmc_env, render_kwargs)

    for wrapper in env_wrappers:
        env = wrapper(env)

    return env
Ejemplo n.º 5
0
def _load_dm_env(domain_name,
                 task_name,
                 pixels,
                 action_repeat,
                 max_episode_steps=None,
                 obs_type='pixels',
                 distractor=False):
    """Load a Deepmind control suite environment."""
    try:
        if not pixels:
            env = suite_dm_control.load(domain_name=domain_name,
                                        task_name=task_name)
            if action_repeat > 1:
                env = wrappers.ActionRepeat(env, action_repeat)

        else:

            def wrap_repeat(env):
                return ActionRepeatDMWrapper(env, action_repeat)

            camera_id = 2 if domain_name == 'quadruped' else 0

            pixels_only = obs_type == 'pixels'
            if distractor:
                render_kwargs = dict(width=84, height=84, camera_id=camera_id)

                env = distractor_suite.load(
                    domain_name,
                    task_name,
                    difficulty='hard',
                    dynamic=False,
                    background_dataset_path='DAVIS/JPEGImages/480p/',
                    task_kwargs={},
                    environment_kwargs={},
                    render_kwargs=render_kwargs,
                    visualize_reward=False,
                    env_state_wrappers=[wrap_repeat])

                # env = wrap_repeat(env)

                # env = suite.wrappers.pixels.Wrapper(
                #     env,
                #     pixels_only=pixels_only,
                #     render_kwargs=render_kwargs,
                #     observation_key=obs_type)

                env = dm_control_wrapper.DmControlWrapper(env, render_kwargs)

            else:
                env = suite_dm_control.load_pixels(
                    domain_name=domain_name,
                    task_name=task_name,
                    render_kwargs=dict(width=84,
                                       height=84,
                                       camera_id=camera_id),
                    env_state_wrappers=[wrap_repeat],
                    observation_key=obs_type,
                    pixels_only=pixels_only)

        if action_repeat > 1 and max_episode_steps is not None:
            # Shorten episode length.
            max_episode_steps = (max_episode_steps + action_repeat -
                                 1) // action_repeat
            env = wrappers.TimeLimit(env, max_episode_steps)

        return env

    except ValueError as e:
        logging.warning(
            'cannot instantiate dm env: domain_name=%s, task_name=%s',
            domain_name, task_name)
        logging.warning('Supported domains and tasks: %s',
                        str({
                            key: list(val.SUITE.keys())
                            for key, val in suite._DOMAINS.items()
                        }))  # pylint: disable=protected-access
        raise e
Ejemplo n.º 6
0
def load_env(env_name,
             seed,
             action_repeat=0,
             frame_stack=1,
             obs_type='pixels'):
    """Loads a learning environment.

  Args:
    env_name: Name of the environment.
    seed: Random seed.
    action_repeat: (optional) action repeat multiplier. Useful for DM control
      suite tasks.
    frame_stack: (optional) frame stack.
    obs_type: `pixels` or `state`
  Returns:
    Learning environment.
  """

    action_repeat_applied = False
    state_env = None

    if env_name.startswith('dm'):
        _, domain_name, task_name = env_name.split('-')
        if 'manipulation' in domain_name:
            env = manipulation.load(task_name)
            env = dm_control_wrapper.DmControlWrapper(env)
        else:
            env = _load_dm_env(domain_name,
                               task_name,
                               pixels=False,
                               action_repeat=action_repeat)
            action_repeat_applied = True
        env = wrappers.FlattenObservationsWrapper(env)

    elif env_name.startswith('pixels-dm'):
        if 'distractor' in env_name:
            _, _, domain_name, task_name, _ = env_name.split('-')
            distractor = True
        else:
            _, _, domain_name, task_name = env_name.split('-')
            distractor = False
        # TODO(tompson): Are there DMC environments that have other
        # max_episode_steps?
        env = _load_dm_env(domain_name,
                           task_name,
                           pixels=True,
                           action_repeat=action_repeat,
                           max_episode_steps=1000,
                           obs_type=obs_type,
                           distractor=distractor)
        action_repeat_applied = True
        if obs_type == 'pixels':
            env = FlattenImageObservationsWrapper(env)
            state_env = None
        else:
            env = JointImageObservationsWrapper(env)
            state_env = tf_py_environment.TFPyEnvironment(
                wrappers.FlattenObservationsWrapper(
                    _load_dm_env(domain_name,
                                 task_name,
                                 pixels=False,
                                 action_repeat=action_repeat)))

    else:
        env = suite_mujoco.load(env_name)
        env.seed(seed)

    if action_repeat > 1 and not action_repeat_applied:
        env = wrappers.ActionRepeat(env, action_repeat)
    if frame_stack > 1:
        env = FrameStackWrapperTfAgents(env, frame_stack)

    env = tf_py_environment.TFPyEnvironment(env)

    return env, state_env