Beispiel #1
0
def load(env_name):
    """Creates the training and evaluation environment.

  This method automatically detects whether we are using a subset of the
  observation for the goal and modifies the observation space to include the
  full state + partial goal.

  Args:
    env_name: (str) Name of the environment.
  Returns:
    tf_env, eval_tf_env, obs_dim: The training and evaluation environments.
  """
    if env_name == 'sawyer_reach':
        tf_env = load_sawyer_reach()
        eval_tf_env = load_sawyer_reach()
    elif env_name == 'sawyer_push':
        tf_env = load_sawyer_push()
        eval_tf_env = load_sawyer_push()
        eval_tf_env.envs[0]._env.gym.MODE = 'eval'  # pylint: disable=protected-access
    elif env_name == 'sawyer_drawer':
        tf_env = load_sawyer_drawer()
        eval_tf_env = load_sawyer_drawer()
    elif env_name == 'sawyer_window':
        tf_env = load_sawyer_window()
        eval_tf_env = load_sawyer_window()
    elif env_name == 'sawyer_faucet':
        tf_env = load_sawyer_faucet()
        eval_tf_env = load_sawyer_faucet()
    else:
        raise NotImplementedError('Unsupported environment: %s' % env_name)
    assert len(tf_env.envs) == 1
    assert len(eval_tf_env.envs) == 1

    # By default, the environment observation contains the current state and goal
    # state. By setting the obs_to_goal parameters, the use can specify that the
    # agent should only look at certain subsets of the goal state. The following
    # code modifies the environment observation to include the full state but only
    # the user-specified dimensions of the goal state.
    obs_dim = tf_env.observation_spec().shape[0] // 2
    try:
        start_index = gin.query_parameter('obs_to_goal.start_index')
    except ValueError:
        start_index = 0
    try:
        end_index = gin.query_parameter('obs_to_goal.end_index')
    except ValueError:
        end_index = None
    if end_index is None:
        end_index = obs_dim

    indices = np.concatenate([
        np.arange(obs_dim),
        np.arange(obs_dim + start_index, obs_dim + end_index)
    ])
    tf_env = tf_py_environment.TFPyEnvironment(
        wrappers.ObservationFilterWrapper(tf_env.envs[0], indices))
    eval_tf_env = tf_py_environment.TFPyEnvironment(
        wrappers.ObservationFilterWrapper(eval_tf_env.envs[0], indices))
    return (tf_env, eval_tf_env, obs_dim)
Beispiel #2
0
  def test_obs_filtered_reset(self):
    mock_env = self._get_mock_env_step()
    env = wrappers.ObservationFilterWrapper(mock_env, [0])
    time_step = env.reset()

    self.assertLen(time_step.observation, 1)
    self.assertEqual([3], time_step.observation)
Beispiel #3
0
 def test_checks_nested_obs(self):
   mock_env = self._get_mock_env_step()
   mock_env.observation_spec.side_effect = [
       [array_spec.BoundedArraySpec((2,), np.int32, -10, 10),
        array_spec.BoundedArraySpec((2,), np.int32, -10, 10)]
   ]
   with self.assertRaises(ValueError):
     _ = wrappers.ObservationFilterWrapper(mock_env, [0])
Beispiel #4
0
  def test_obs_filtered_step(self):
    mock_env = self._get_mock_env_step()
    env = wrappers.ObservationFilterWrapper(mock_env, [0, 2])
    env.reset()
    time_step = env.step(0)

    self.assertLen(time_step.observation, 2)
    self.assertAllEqual([1, 3], time_step.observation)
Beispiel #5
0
 def test_checks_idx_outofbounds(self):
   mock_env = self._get_mock_env_step()
   with self.assertRaises(ValueError):
     _ = wrappers.ObservationFilterWrapper(mock_env, [5])
Beispiel #6
0
 def test_checks_multidim_idx(self):
   mock_env = self._get_mock_env_step()
   with self.assertRaises(ValueError):
     _ = wrappers.ObservationFilterWrapper(mock_env, [[0]])
Beispiel #7
0
  def test_filtered_obs_spec(self):
    mock_env = self._get_mock_env_step()
    env = wrappers.ObservationFilterWrapper(mock_env, [1])

    self.assertEqual((1,), env.observation_spec().shape)