Exemplo n.º 1
0
    def testResetPassesObservation(self):
        env = MockEnvironment()
        env = preprocessing.AtariPreprocessing(env,
                                               frame_skip=1,
                                               screen_size=16)
        observation = env.reset()

        self.assertEqual(observation.shape, (16, 16, 1))
Exemplo n.º 2
0
    def testMaxFramePooling(self):
        frame_skip = 2
        env = MockEnvironment()
        env = preprocessing.AtariPreprocessing(env, frame_skip=frame_skip)
        env.reset()

        # The first observation is 2, the second 0; max is 2.
        observation, _, _, _ = env.step(0)
        self.assertTrue((observation == 8).all())
Exemplo n.º 3
0
    def testFrameSkipAccumulatesReward(self):
        frame_skip = 2
        env = MockEnvironment()
        env = preprocessing.AtariPreprocessing(env, frame_skip=frame_skip)
        env.reset()

        # Make sure we get the right number of steps. Reward is 1 when we
        # pass in action 0.
        _, reward, _, _ = env.step(0)
        self.assertEqual(reward, frame_skip)
Exemplo n.º 4
0
    def testTerminalPassedThrough(self):
        max_steps = 10
        env = MockEnvironment(max_steps=max_steps)
        env = preprocessing.AtariPreprocessing(env, frame_skip=1)
        env.reset()

        # Make sure we get the right number of steps.
        for _ in range(max_steps - 1):
            _, _, is_terminal, _ = env.step(0)
            self.assertFalse(is_terminal)

        _, _, is_terminal, _ = env.step(0)
        self.assertTrue(is_terminal)