Beispiel #1
0
  def testMaxFramePooling(self):
    frame_skip = 2
    env = MockEnvironment()
    env = atari_lib.AtariPreprocessing(env, frame_skip=frame_skip)
    env.reset()

    # The first observation is 2, the second 0; max is 2.
    observation, _, _, _ = env.step(0)
    self.assertTrue((observation == 8).all())
Beispiel #2
0
  def testFrameSkipAccumulatesReward(self):
    frame_skip = 2
    env = MockEnvironment()
    env = atari_lib.AtariPreprocessing(env, frame_skip=frame_skip)
    env.reset()

    # Make sure we get the right number of steps. Reward is 1 when we
    # pass in action 0.
    _, reward, _, _ = env.step(0)
    self.assertEqual(reward, frame_skip)
Beispiel #3
0
  def testTerminalPassedThrough(self):
    max_steps = 10
    env = MockEnvironment(max_steps=max_steps)
    env = atari_lib.AtariPreprocessing(env, frame_skip=1)
    env.reset()

    # Make sure we get the right number of steps.
    for _ in range(max_steps - 1):
      _, _, is_terminal, _ = env.step(0)
      self.assertFalse(is_terminal)

    _, _, is_terminal, _ = env.step(0)
    self.assertTrue(is_terminal)
Beispiel #4
0
  def testResetPassesObservation(self):
    env = MockEnvironment()
    env = atari_lib.AtariPreprocessing(env, frame_skip=1, screen_size=16)
    observation = env.reset()

    self.assertEqual(observation.shape, (16, 16, 1))
Beispiel #5
0
def create_atari_environment(game_name=None, sticky_actions=False):
    assert game_name is not None
    game_version = 'v0' if sticky_actions else 'v4'
    full_game_name = '{}NoFrameskip-{}'.format(game_name, game_version)
    gym_env = gym.make(full_game_name)
    return atari_lib.AtariPreprocessing(gym_env.env)