Exemplo n.º 1
0
  def test_observation_nested(self):
    env = test_envs.NestedCountingEnv()
    history_env = wrappers.HistoryWrapper(env, 3)
    time_step = history_env.reset()
    self.assertCountEqual({
        'total_steps': [0, 0, 0],
        'current_steps': [0, 0, 0]
    }, time_step.observation)

    time_step = history_env.step(0)
    self.assertCountEqual({
        'total_steps': [0, 0, 1],
        'current_steps': [0, 0, 1]
    }, time_step.observation)

    time_step = history_env.step(0)
    self.assertCountEqual({
        'total_steps': [0, 1, 2],
        'current_steps': [0, 1, 2]
    }, time_step.observation)

    time_step = history_env.step(0)
    self.assertCountEqual({
        'total_steps': [1, 2, 3],
        'current_steps': [1, 2, 3]
    }, time_step.observation)
Exemplo n.º 2
0
  def test_observation_spec_changed(self):
    cartpole_env = gym.spec('CartPole-v1').make()
    env = gym_wrapper.GymWrapper(cartpole_env)
    obs_shape = env.observation_spec().shape

    history_env = wrappers.HistoryWrapper(env, 3)
    self.assertEqual((3,) + obs_shape, history_env.observation_spec().shape)
Exemplo n.º 3
0
  def test_observation_spec_changed_with_action(self):
    cartpole_env = gym.spec('CartPole-v1').make()
    env = gym_wrapper.GymWrapper(cartpole_env)
    obs_shape = env.observation_spec().shape
    action_shape = env.action_spec().shape

    history_env = wrappers.HistoryWrapper(env, 3, include_actions=True)
    self.assertEqual((3,) + obs_shape,
                     history_env.observation_spec()['observation'].shape)
    self.assertEqual((3,) + action_shape,
                     history_env.observation_spec()['action'].shape)
Exemplo n.º 4
0
  def test_observation_and_action_nested(self):
    env = test_envs.NestedCountingEnv(nested_action=True)
    history_env = wrappers.HistoryWrapper(env, 3, include_actions=True)
    time_step = history_env.reset()
    self.assertCountEqual({
        'total_steps': [0, 0, 0],
        'current_steps': [0, 0, 0]
    }, time_step.observation['observation'])
    self.assertCountEqual({
        'foo': [0, 0, 0],
        'bar': [0, 0, 0]
    }, time_step.observation['action'])

    time_step = history_env.step({
        'foo': 5,
        'bar': 5
    })
    self.assertCountEqual({
        'total_steps': [0, 0, 1],
        'current_steps': [0, 0, 1]
    }, time_step.observation['observation'])
    self.assertCountEqual({
        'foo': [0, 0, 5],
        'bar': [0, 0, 5]
    }, time_step.observation['action'])

    time_step = history_env.step({
        'foo': 6,
        'bar': 6
    })
    self.assertCountEqual({
        'total_steps': [0, 1, 2],
        'current_steps': [0, 1, 2]
    }, time_step.observation['observation'])
    self.assertCountEqual({
        'foo': [0, 5, 6],
        'bar': [0, 5, 6]
    }, time_step.observation['action'])

    time_step = history_env.step({
        'foo': 7,
        'bar': 7
    })
    self.assertCountEqual({
        'total_steps': [1, 2, 3],
        'current_steps': [1, 2, 3]
    }, time_step.observation['observation'])
    self.assertCountEqual({
        'foo': [5, 6, 7],
        'bar': [5, 6, 7]
    }, time_step.observation['action'])
Exemplo n.º 5
0
  def test_observation_stacked(self):
    env = test_envs.CountingEnv()
    history_env = wrappers.HistoryWrapper(env, 3)
    time_step = history_env.reset()
    self.assertEqual([0, 0, 0], time_step.observation.tolist())

    time_step = history_env.step(0)
    self.assertEqual([0, 0, 1], time_step.observation.tolist())

    time_step = history_env.step(0)
    self.assertEqual([0, 1, 2], time_step.observation.tolist())

    time_step = history_env.step(0)
    self.assertEqual([1, 2, 3], time_step.observation.tolist())
Exemplo n.º 6
0
  def test_observation_tiled(self):
    env = test_envs.CountingEnv()
    # Force observations to be non zero for the test
    env._episodes = 2
    history_env = wrappers.HistoryWrapper(env, 3, tile_first_step_obs=True)
    # Extra reset to make observations in base env not 0.
    time_step = history_env.reset()
    self.assertEqual([20, 20, 20], time_step.observation.tolist())

    time_step = history_env.step(0)
    self.assertEqual([20, 20, 21], time_step.observation.tolist())

    time_step = history_env.step(0)
    self.assertEqual([20, 21, 22], time_step.observation.tolist())

    time_step = history_env.step(0)
    self.assertEqual([21, 22, 23], time_step.observation.tolist())
Exemplo n.º 7
0
  def test_observation_and_action_stacked(self):
    env = test_envs.CountingEnv()
    history_env = wrappers.HistoryWrapper(env, 3, include_actions=True)
    time_step = history_env.reset()
    self.assertEqual([0, 0, 0], time_step.observation['observation'].tolist())
    self.assertEqual([0, 0, 0], time_step.observation['action'].tolist())

    time_step = history_env.step(5)
    self.assertEqual([0, 0, 1], time_step.observation['observation'].tolist())
    self.assertEqual([0, 0, 5], time_step.observation['action'].tolist())

    time_step = history_env.step(6)
    self.assertEqual([0, 1, 2], time_step.observation['observation'].tolist())
    self.assertEqual([0, 5, 6], time_step.observation['action'].tolist())

    time_step = history_env.step(7)
    self.assertEqual([1, 2, 3], time_step.observation['observation'].tolist())
    self.assertEqual([5, 6, 7], time_step.observation['action'].tolist())