示例#1
0
class TestStackFrames(unittest.TestCase):
    @overrides
    def setUp(self):
        self.shape = (50, 50)
        self.env = mock.Mock()
        self.env.observation_space = Box(
            low=0, high=255, shape=self.shape, dtype=np.uint8)
        self.env.reset.return_value = np.zeros(self.shape)
        self.env.step.side_effect = self._step

        self._n_frames = 4
        self.env_s = StackFrames(self.env, n_frames=self._n_frames)

        self.obs = self.env.reset()
        self.obs_s = self.env_s.reset()
        self.frame_width = self.env.observation_space.shape[0]
        self.frame_height = self.env.observation_space.shape[1]

    def _step(self, action):
        def generate():
            for i in range(0, 255):
                yield np.full(self.shape, i)

        generator = generate()

        return next(generator), 0, False, dict()

    def test_stack_frames_invalid_environment_type(self):
        with self.assertRaises(ValueError):
            self.env.observation_space = Discrete(64)
            StackFrames(self.env, n_frames=4)

    def test_stack_frames_invalid_environment_shape(self):
        with self.assertRaises(ValueError):
            self.env.observation_space = Box(
                low=0, high=255, shape=(4, ), dtype=np.uint8)
            StackFrames(self.env, n_frames=4)

    def test_stack_frames_output_observation_space(self):
        assert self.env_s.observation_space.shape == (self.frame_width,
                                                      self.frame_height,
                                                      self._n_frames)

    def test_stack_frames_for_reset(self):
        frame_stack = self.obs
        for i in range(self._n_frames - 1):
            frame_stack = np.dstack((frame_stack, self.obs))

        np.testing.assert_array_equal(self.obs_s, frame_stack)

    def test_stack_frames_for_step(self):
        frame_stack = np.empty((self.frame_width, self.frame_height,
                                self._n_frames))
        for i in range(10):
            frame_stack = frame_stack[:, :, 1:]
            obs, _, _, _ = self.env.step(0)
            frame_stack = np.dstack((frame_stack, obs))

        obs_stack, _, _, _ = self.env_s.step(0)
        np.testing.assert_array_equal(obs_stack, frame_stack)
    def test_stack_frames_axis(self):
        env = StackFrames(DummyDiscrete2DEnv(random=False),
                          n_frames=self.n_frames,
                          axis=0)
        env.reset()
        obs, _, _, _ = env.step(1)
        assert obs.shape[0] == self.n_frames

        env = StackFrames(DummyDiscrete2DEnv(random=False),
                          n_frames=self.n_frames,
                          axis=2)
        env.reset()
        obs, _, _, _ = env.step(1)
        assert obs.shape[2] == self.n_frames
示例#3
0
    def setUp(self):
        self.shape = (50, 50)
        self.env = mock.Mock()
        self.env.observation_space = Box(
            low=0, high=255, shape=self.shape, dtype=np.uint8)
        self.env.reset.return_value = np.zeros(self.shape)
        self.env.step.side_effect = self._step

        self._n_frames = 4
        self.env_s = StackFrames(self.env, n_frames=self._n_frames)

        self.obs = self.env.reset()
        self.obs_s = self.env_s.reset()
        self.frame_width = self.env.observation_space.shape[0]
        self.frame_height = self.env.observation_space.shape[1]
示例#4
0
 def setUp(self):
     self.n_frames = 4
     self.env = TfEnv(DummyDiscrete2DEnv(random=False))
     self.env_s = TfEnv(
         StackFrames(
             DummyDiscrete2DEnv(random=False), n_frames=self.n_frames))
     self.width, self.height = self.env.observation_space.shape
 def test_stack_frames_invalid_environment_shape(self):
     with pytest.raises(ValueError):
         self.env.observation_space = gym.spaces.Box(low=0,
                                                     high=255,
                                                     shape=(4, ),
                                                     dtype=np.uint8)
         StackFrames(self.env, n_frames=4)
class TestStackFrames:
    def setup_method(self):
        self.n_frames = 4
        self.env = DummyDiscrete2DEnv(random=False)
        self.env_s = StackFrames(DummyDiscrete2DEnv(random=False),
                                 n_frames=self.n_frames)
        self.width, self.height = self.env.observation_space.shape

    def teardown_method(self):
        self.env.close()
        self.env_s.close()

    def test_stack_frames_invalid_environment_type(self):
        with pytest.raises(ValueError):
            self.env.observation_space = gym.spaces.Discrete(64)
            StackFrames(self.env, n_frames=4)

    def test_stack_frames_invalid_environment_shape(self):
        with pytest.raises(ValueError):
            self.env.observation_space = gym.spaces.Box(low=0,
                                                        high=255,
                                                        shape=(4, ),
                                                        dtype=np.uint8)
            StackFrames(self.env, n_frames=4)

    def test_stack_frames_output_observation_space(self):
        assert self.env_s.observation_space.shape == (self.width, self.height,
                                                      self.n_frames)

    def test_stack_frames_for_reset(self):
        frame_stack = self.env.reset()
        for _ in range(self.n_frames - 1):
            frame_stack = np.dstack((frame_stack, self.env.reset()))

        np.testing.assert_array_equal(self.env_s.reset(), frame_stack)

    def test_stack_frames_for_step(self):
        self.env.reset()
        self.env_s.reset()

        frame_stack = np.empty((self.width, self.height, self.n_frames))
        for _ in range(10):
            frame_stack = frame_stack[:, :, 1:]
            obs, _, _, _ = self.env.step(1)
            frame_stack = np.dstack((frame_stack, obs))
            obs_stack, _, _, _ = self.env_s.step(1)

        np.testing.assert_array_equal(obs_stack, frame_stack)
示例#7
0
 def test_stack_frames_invalid_environment_type(self):
     with self.assertRaises(ValueError):
         self.env.observation_space = Discrete(64)
         StackFrames(self.env, n_frames=4)
 def test_stack_frames_invalid_environment_type(self):
     with pytest.raises(ValueError):
         self.env.observation_space = gym.spaces.Discrete(64)
         StackFrames(self.env, n_frames=4)
 def setup_method(self):
     self.n_frames = 4
     self.env = DummyDiscrete2DEnv(random=False)
     self.env_s = StackFrames(DummyDiscrete2DEnv(random=False),
                              n_frames=self.n_frames)
     self.width, self.height = self.env.observation_space.shape
示例#10
0
 def test_invalid_axis_raises_error(self):
     with pytest.raises(ValueError):
         StackFrames(DummyDiscrete2DEnv(random=False),
                     n_frames=self.n_frames,
                     axis=5)