def load(env_name): """Creates the training and evaluation environment. This method automatically detects whether we are using a subset of the observation for the goal and modifies the observation space to include the full state + partial goal. Args: env_name: (str) Name of the environment. Returns: tf_env, eval_tf_env, obs_dim: The training and evaluation environments. """ if env_name == 'sawyer_reach': tf_env = load_sawyer_reach() eval_tf_env = load_sawyer_reach() elif env_name == 'sawyer_push': tf_env = load_sawyer_push() eval_tf_env = load_sawyer_push() eval_tf_env.envs[0]._env.gym.MODE = 'eval' # pylint: disable=protected-access elif env_name == 'sawyer_drawer': tf_env = load_sawyer_drawer() eval_tf_env = load_sawyer_drawer() elif env_name == 'sawyer_window': tf_env = load_sawyer_window() eval_tf_env = load_sawyer_window() elif env_name == 'sawyer_faucet': tf_env = load_sawyer_faucet() eval_tf_env = load_sawyer_faucet() else: raise NotImplementedError('Unsupported environment: %s' % env_name) assert len(tf_env.envs) == 1 assert len(eval_tf_env.envs) == 1 # By default, the environment observation contains the current state and goal # state. By setting the obs_to_goal parameters, the use can specify that the # agent should only look at certain subsets of the goal state. The following # code modifies the environment observation to include the full state but only # the user-specified dimensions of the goal state. obs_dim = tf_env.observation_spec().shape[0] // 2 try: start_index = gin.query_parameter('obs_to_goal.start_index') except ValueError: start_index = 0 try: end_index = gin.query_parameter('obs_to_goal.end_index') except ValueError: end_index = None if end_index is None: end_index = obs_dim indices = np.concatenate([ np.arange(obs_dim), np.arange(obs_dim + start_index, obs_dim + end_index) ]) tf_env = tf_py_environment.TFPyEnvironment( wrappers.ObservationFilterWrapper(tf_env.envs[0], indices)) eval_tf_env = tf_py_environment.TFPyEnvironment( wrappers.ObservationFilterWrapper(eval_tf_env.envs[0], indices)) return (tf_env, eval_tf_env, obs_dim)
def test_obs_filtered_reset(self): mock_env = self._get_mock_env_step() env = wrappers.ObservationFilterWrapper(mock_env, [0]) time_step = env.reset() self.assertLen(time_step.observation, 1) self.assertEqual([3], time_step.observation)
def test_checks_nested_obs(self): mock_env = self._get_mock_env_step() mock_env.observation_spec.side_effect = [ [array_spec.BoundedArraySpec((2,), np.int32, -10, 10), array_spec.BoundedArraySpec((2,), np.int32, -10, 10)] ] with self.assertRaises(ValueError): _ = wrappers.ObservationFilterWrapper(mock_env, [0])
def test_obs_filtered_step(self): mock_env = self._get_mock_env_step() env = wrappers.ObservationFilterWrapper(mock_env, [0, 2]) env.reset() time_step = env.step(0) self.assertLen(time_step.observation, 2) self.assertAllEqual([1, 3], time_step.observation)
def test_checks_idx_outofbounds(self): mock_env = self._get_mock_env_step() with self.assertRaises(ValueError): _ = wrappers.ObservationFilterWrapper(mock_env, [5])
def test_checks_multidim_idx(self): mock_env = self._get_mock_env_step() with self.assertRaises(ValueError): _ = wrappers.ObservationFilterWrapper(mock_env, [[0]])
def test_filtered_obs_spec(self): mock_env = self._get_mock_env_step() env = wrappers.ObservationFilterWrapper(mock_env, [1]) self.assertEqual((1,), env.observation_spec().shape)