class TestMaxAndSkip:
    def setup_method(self):
        self.env = DummyDiscretePixelEnv(random=False)
        self.env_wrap = MaxAndSkip(DummyDiscretePixelEnv(random=False), skip=4)

    def teardown_method(self):
        self.env.close()
        self.env_wrap.close()

    def test_max_and_skip_reset(self):
        np.testing.assert_array_equal(self.env.reset(), self.env_wrap.reset())

    def test_max_and_skip_step(self):
        self.env.reset()
        self.env_wrap.reset()
        obs_wrap, reward_wrap, _, _ = self.env_wrap.step(1)
        reward = 0
        for _ in range(4):
            obs, r, _, _ = self.env.step(1)
            reward += r

        np.testing.assert_array_equal(obs, obs_wrap)
        np.testing.assert_array_equal(reward, reward_wrap)

        # done=True because both env stepped more than 4 times in total
        obs_wrap, _, done_wrap, _ = self.env_wrap.step(1)
        obs, _, done, _ = self.env.step(1)

        assert done
        assert done_wrap
        np.testing.assert_array_equal(obs, obs_wrap)
Exemple #2
0
class TestGrayscale:
    def setup_method(self):
        self.env = DummyDiscretePixelEnv(random=False)
        self.env_g = Grayscale(DummyDiscretePixelEnv(random=False))

    def teardown_method(self):
        self.env.close()
        self.env_g.close()

    def test_grayscale_invalid_environment_type(self):
        with pytest.raises(ValueError):
            self.env.observation_space = gym.spaces.Discrete(64)
            Grayscale(self.env)

    def test_grayscale_invalid_environment_shape(self):
        with pytest.raises(ValueError):
            self.env.observation_space = gym.spaces.Box(low=0,
                                                        high=255,
                                                        shape=(4, ),
                                                        dtype=np.uint8)
            Grayscale(self.env)

    def test_grayscale_observation_space(self):
        assert self.env_g.observation_space.shape == (
            self.env.observation_space.shape[:-1])

    def test_grayscale_reset(self):
        """
        RGB to grayscale conversion using scikit-image.

        Weights used for conversion:
        Y = 0.2125 R + 0.7154 G + 0.0721 B

        Reference:
        http://scikit-image.org/docs/dev/api/skimage.color.html#skimage.color.rgb2grey
        """
        grayscale_output = np.round(
            np.dot(self.env.reset()[:, :, :3],
                   [0.2125, 0.7154, 0.0721])).astype(np.uint8)
        np.testing.assert_array_almost_equal(grayscale_output,
                                             self.env_g.reset())

    def test_grayscale_step(self):
        self.env.reset()
        self.env_g.reset()
        obs, _, _, _ = self.env.step(1)
        obs_g, _, _, _ = self.env_g.step(1)

        grayscale_output = np.round(
            np.dot(obs[:, :, :3], [0.2125, 0.7154, 0.0721])).astype(np.uint8)
        np.testing.assert_array_almost_equal(grayscale_output, obs_g)
Exemple #3
0
 def test_get_action(self, hidden_channels, kernel_sizes, strides,
                     hidden_sizes):
     """Test get_action function."""
     env = DummyDiscretePixelEnv()
     env = self._initialize_obs_env(env)
     policy = CategoricalCNNPolicy(env=env,
                                   kernel_sizes=kernel_sizes,
                                   hidden_channels=hidden_channels,
                                   strides=strides,
                                   hidden_sizes=hidden_sizes)
     env.reset()
     obs, _, _, _ = env.step(1)
     action, _ = policy.get_action(obs)
     assert env.action_space.contains(action)
Exemple #4
0
    def test_get_actions(self, hidden_channels, kernel_sizes, strides,
                         hidden_sizes):
        """Test get_actions function with akro.Image observation space."""
        env = DummyDiscretePixelEnv()
        env = self._initialize_obs_env(env)
        policy = CategoricalCNNPolicy(env=env,
                                      kernel_sizes=kernel_sizes,
                                      hidden_channels=hidden_channels,
                                      strides=strides,
                                      hidden_sizes=hidden_sizes)
        env.reset()
        obs, _, _, _ = env.step(1)

        actions, _ = policy.get_actions([obs, obs, obs])
        for action in actions:
            assert env.action_space.contains(action)
        torch_obs = torch.Tensor(obs)
        actions, _ = policy.get_actions([torch_obs, torch_obs, torch_obs])
        for action in actions:
            assert env.action_space.contains(action)
    def test_fire_reset(self):
        env = DummyDiscretePixelEnv()
        env_wrap = FireReset(env)
        obs = env.reset()
        obs_wrap = env_wrap.reset()

        assert np.array_equal(obs, np.ones(env.observation_space.shape))
        assert np.array_equal(obs_wrap, np.full(env.observation_space.shape,
                                                2))

        env_wrap.step(2)
        obs_wrap = env_wrap.reset()  # env will call reset again, after fire
        assert np.array_equal(obs_wrap, np.ones(env.observation_space.shape))
Exemple #6
0
    def test_is_pickleable(self, hidden_channels, kernel_sizes, strides,
                           hidden_sizes):
        """Test if policy is pickable."""
        env = DummyDiscretePixelEnv()
        env = self._initialize_obs_env(env)
        policy = CategoricalCNNPolicy(env=env,
                                      kernel_sizes=kernel_sizes,
                                      hidden_channels=hidden_channels,
                                      strides=strides,
                                      hidden_sizes=hidden_sizes)
        env.reset()
        obs, _, _, _ = env.step(1)

        output_action_1, _ = policy.get_action(obs)

        p = cloudpickle.dumps(policy)
        policy_pickled = cloudpickle.loads(p)
        output_action_2, _ = policy_pickled.get_action(obs)

        assert env.action_space.contains(output_action_1)
        assert env.action_space.contains(output_action_2)
        assert output_action_1.shape == output_action_2.shape