def test_observation_nested(self): env = test_envs.NestedCountingEnv() history_env = wrappers.HistoryWrapper(env, 3) time_step = history_env.reset() self.assertCountEqual({ 'total_steps': [0, 0, 0], 'current_steps': [0, 0, 0] }, time_step.observation) time_step = history_env.step(0) self.assertCountEqual({ 'total_steps': [0, 0, 1], 'current_steps': [0, 0, 1] }, time_step.observation) time_step = history_env.step(0) self.assertCountEqual({ 'total_steps': [0, 1, 2], 'current_steps': [0, 1, 2] }, time_step.observation) time_step = history_env.step(0) self.assertCountEqual({ 'total_steps': [1, 2, 3], 'current_steps': [1, 2, 3] }, time_step.observation)
def test_observation_spec_changed(self): cartpole_env = gym.spec('CartPole-v1').make() env = gym_wrapper.GymWrapper(cartpole_env) obs_shape = env.observation_spec().shape history_env = wrappers.HistoryWrapper(env, 3) self.assertEqual((3,) + obs_shape, history_env.observation_spec().shape)
def test_observation_spec_changed_with_action(self): cartpole_env = gym.spec('CartPole-v1').make() env = gym_wrapper.GymWrapper(cartpole_env) obs_shape = env.observation_spec().shape action_shape = env.action_spec().shape history_env = wrappers.HistoryWrapper(env, 3, include_actions=True) self.assertEqual((3,) + obs_shape, history_env.observation_spec()['observation'].shape) self.assertEqual((3,) + action_shape, history_env.observation_spec()['action'].shape)
def test_observation_and_action_nested(self): env = test_envs.NestedCountingEnv(nested_action=True) history_env = wrappers.HistoryWrapper(env, 3, include_actions=True) time_step = history_env.reset() self.assertCountEqual({ 'total_steps': [0, 0, 0], 'current_steps': [0, 0, 0] }, time_step.observation['observation']) self.assertCountEqual({ 'foo': [0, 0, 0], 'bar': [0, 0, 0] }, time_step.observation['action']) time_step = history_env.step({ 'foo': 5, 'bar': 5 }) self.assertCountEqual({ 'total_steps': [0, 0, 1], 'current_steps': [0, 0, 1] }, time_step.observation['observation']) self.assertCountEqual({ 'foo': [0, 0, 5], 'bar': [0, 0, 5] }, time_step.observation['action']) time_step = history_env.step({ 'foo': 6, 'bar': 6 }) self.assertCountEqual({ 'total_steps': [0, 1, 2], 'current_steps': [0, 1, 2] }, time_step.observation['observation']) self.assertCountEqual({ 'foo': [0, 5, 6], 'bar': [0, 5, 6] }, time_step.observation['action']) time_step = history_env.step({ 'foo': 7, 'bar': 7 }) self.assertCountEqual({ 'total_steps': [1, 2, 3], 'current_steps': [1, 2, 3] }, time_step.observation['observation']) self.assertCountEqual({ 'foo': [5, 6, 7], 'bar': [5, 6, 7] }, time_step.observation['action'])
def test_observation_stacked(self): env = test_envs.CountingEnv() history_env = wrappers.HistoryWrapper(env, 3) time_step = history_env.reset() self.assertEqual([0, 0, 0], time_step.observation.tolist()) time_step = history_env.step(0) self.assertEqual([0, 0, 1], time_step.observation.tolist()) time_step = history_env.step(0) self.assertEqual([0, 1, 2], time_step.observation.tolist()) time_step = history_env.step(0) self.assertEqual([1, 2, 3], time_step.observation.tolist())
def test_observation_tiled(self): env = test_envs.CountingEnv() # Force observations to be non zero for the test env._episodes = 2 history_env = wrappers.HistoryWrapper(env, 3, tile_first_step_obs=True) # Extra reset to make observations in base env not 0. time_step = history_env.reset() self.assertEqual([20, 20, 20], time_step.observation.tolist()) time_step = history_env.step(0) self.assertEqual([20, 20, 21], time_step.observation.tolist()) time_step = history_env.step(0) self.assertEqual([20, 21, 22], time_step.observation.tolist()) time_step = history_env.step(0) self.assertEqual([21, 22, 23], time_step.observation.tolist())
def test_observation_and_action_stacked(self): env = test_envs.CountingEnv() history_env = wrappers.HistoryWrapper(env, 3, include_actions=True) time_step = history_env.reset() self.assertEqual([0, 0, 0], time_step.observation['observation'].tolist()) self.assertEqual([0, 0, 0], time_step.observation['action'].tolist()) time_step = history_env.step(5) self.assertEqual([0, 0, 1], time_step.observation['observation'].tolist()) self.assertEqual([0, 0, 5], time_step.observation['action'].tolist()) time_step = history_env.step(6) self.assertEqual([0, 1, 2], time_step.observation['observation'].tolist()) self.assertEqual([0, 5, 6], time_step.observation['action'].tolist()) time_step = history_env.step(7) self.assertEqual([1, 2, 3], time_step.observation['observation'].tolist()) self.assertEqual([5, 6, 7], time_step.observation['action'].tolist())