def test_get_action(self, hidden_channels, kernel_sizes, strides,
                     hidden_sizes):
     """Test get_action function."""
     env = GymEnv(DummyDiscretePixelEnv(), is_image=True)
     policy = CategoricalCNNPolicy(env_spec=env.spec,
                                   image_format='NHWC',
                                   kernel_sizes=kernel_sizes,
                                   hidden_channels=hidden_channels,
                                   strides=strides,
                                   hidden_sizes=hidden_sizes)
     env.reset()
     obs = env.step(1).observation
     action, _ = policy.get_action(obs)
     assert env.action_space.contains(action)
Esempio n. 2
0
 def test_get_action(self, hidden_channels, kernel_sizes, strides,
                     hidden_sizes):
     """Test get_action function."""
     env = DummyDiscretePixelEnv()
     env = self._initialize_obs_env(env)
     policy = CategoricalCNNPolicy(env=env,
                                   kernel_sizes=kernel_sizes,
                                   hidden_channels=hidden_channels,
                                   strides=strides,
                                   hidden_sizes=hidden_sizes)
     env.reset()
     obs, _, _, _ = env.step(1)
     action, _ = policy.get_action(obs)
     assert env.action_space.contains(action)
Esempio n. 3
0
    def test_get_action_img_obs(self, hidden_channels, kernel_sizes, strides,
                                hidden_sizes):
        """Test get_action function with akro.Image observation space."""
        env = GarageEnv(DummyDiscretePixelEnv(), is_image=True)
        env = self._initialize_obs_env(env)
        policy = CategoricalCNNPolicy(env=env,
                                      kernel_sizes=kernel_sizes,
                                      hidden_channels=hidden_channels,
                                      strides=strides,
                                      hidden_sizes=hidden_sizes)
        env.reset()
        obs, _, _, _ = env.step(1)

        action, _ = policy.get_action(obs)
        assert env.action_space.contains(action)
 def test_obs_unflattened(self, hidden_channels, kernel_sizes, strides,
                          hidden_sizes):
     """Test if a flattened image obs is passed to get_action
        then it is unflattened.
     """
     env = GymEnv(DummyDiscretePixelEnv(), is_image=True)
     env.reset()
     policy = CategoricalCNNPolicy(env_spec=env.spec,
                                   image_format='NHWC',
                                   kernel_sizes=kernel_sizes,
                                   hidden_channels=hidden_channels,
                                   strides=strides,
                                   hidden_sizes=hidden_sizes)
     obs = env.observation_space.sample()
     action, _ = policy.get_action(env.observation_space.flatten(obs))
     env.step(action)
Esempio n. 5
0
    def test_is_pickleable(self, hidden_channels, kernel_sizes, strides,
                           hidden_sizes):
        """Test if policy is pickable."""
        env = GarageEnv(DummyDiscretePixelEnv(), is_image=True)
        env = self._initialize_obs_env(env)
        policy = CategoricalCNNPolicy(env=env,
                                      kernel_sizes=kernel_sizes,
                                      hidden_channels=hidden_channels,
                                      strides=strides,
                                      hidden_sizes=hidden_sizes)
        env.reset()
        obs, _, _, _ = env.step(1)

        output_action_1, _ = policy.get_action(obs)

        p = cloudpickle.dumps(policy)
        policy_pickled = cloudpickle.loads(p)
        output_action_2, _ = policy_pickled.get_action(obs)

        assert env.action_space.contains(output_action_1)
        assert env.action_space.contains(output_action_2)
        assert output_action_1.shape == output_action_2.shape
    def test_is_pickleable(self, hidden_channels, kernel_sizes, strides,
                           hidden_sizes):
        """Test if policy is pickable."""
        env = GymEnv(DummyDiscretePixelEnv(), is_image=True)
        policy = CategoricalCNNPolicy(env_spec=env.spec,
                                      image_format='NHWC',
                                      kernel_sizes=kernel_sizes,
                                      hidden_channels=hidden_channels,
                                      strides=strides,
                                      hidden_sizes=hidden_sizes)
        env.reset()
        obs = env.step(1).observation

        output_action_1, _ = policy.get_action(obs)

        p = cloudpickle.dumps(policy)
        policy_pickled = cloudpickle.loads(p)
        output_action_2, _ = policy_pickled.get_action(obs)

        assert env.action_space.contains(output_action_1)
        assert env.action_space.contains(output_action_2)
        assert output_action_1.shape == output_action_2.shape