Exemplo n.º 1
0
    def test_get_action(self, mock_rand, obs_dim, action_dim, filter_dims,
                        filter_sizes, strides, padding, hidden_sizes):
        mock_rand.return_value = 0
        env = TfEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('garage.tf.policies.'
                         'categorical_cnn_policy.MLPModel'),
                        new=SimpleMLPModel):
            with mock.patch(('garage.tf.policies.'
                             'categorical_cnn_policy.CNNModel'),
                            new=SimpleCNNModel):
                policy = CategoricalCNNPolicy(env_spec=env.spec,
                                              conv_filters=filter_dims,
                                              conv_filter_sizes=filter_sizes,
                                              conv_strides=strides,
                                              conv_pad=padding,
                                              hidden_sizes=hidden_sizes)

        env.reset()
        obs, _, _, _ = env.step(1)

        action, prob = policy.get_action(obs)
        expected_prob = np.full(action_dim, 0.5)

        assert env.action_space.contains(action)
        assert action == 0
        assert np.array_equal(prob['prob'], expected_prob)

        actions, probs = policy.get_actions([obs, obs, obs])
        for action, prob in zip(actions, probs['prob']):
            assert env.action_space.contains(action)
            assert action == 0
            assert np.array_equal(prob, expected_prob)
    def test_get_action(self, filters, strides, padding, hidden_sizes):
        env = GymEnv(DummyDiscretePixelEnv())
        policy = CategoricalCNNPolicy(env_spec=env.spec,
                                      filters=filters,
                                      strides=strides,
                                      padding=padding,
                                      hidden_sizes=hidden_sizes)

        env.reset()
        obs = env.step(1).observation

        action, _ = policy.get_action(obs)
        assert env.action_space.contains(action)

        actions, _ = policy.get_actions([obs, obs, obs])
        for action in actions:
            assert env.action_space.contains(action)
class TestCategoricalCNNPolicyImageObs(TfGraphTestCase):
    def setup_method(self):
        super().setup_method()
        self.env = GymEnv(DummyDiscretePixelEnv(), is_image=True)
        self.sess.run(tf.compat.v1.global_variables_initializer())
        self.env.reset()

    @pytest.mark.parametrize('filters, strides, padding, hidden_sizes', [
        (((3, (32, 32)), ), (1, ), 'VALID', (4, )),
    ])
    def test_obs_unflattened(self, filters, strides, padding, hidden_sizes):
        self.policy = CategoricalCNNPolicy(env_spec=self.env.spec,
                                           filters=filters,
                                           strides=strides,
                                           padding=padding,
                                           hidden_sizes=hidden_sizes)
        obs = self.env.observation_space.sample()
        action, _ = self.policy.get_action(
            self.env.observation_space.flatten(obs))
        self.env.step(action)
    def test_get_action(self, filters, strides, padding, hidden_sizes):
        env = GarageEnv(DummyDiscretePixelEnv())
        policy = CategoricalCNNPolicy(env_spec=env.spec,
                                      filters=filters,
                                      strides=strides,
                                      padding=padding,
                                      hidden_sizes=hidden_sizes)
        obs_var = tf.compat.v1.placeholder(tf.float32,
                                           shape=(None, None) +
                                           env.observation_space.shape,
                                           name='obs')
        policy.build(obs_var)

        env.reset()
        obs, _, _, _ = env.step(1)

        action, _ = policy.get_action(obs)
        assert env.action_space.contains(action)

        actions, _ = policy.get_actions([obs, obs, obs])
        for action in actions:
            assert env.action_space.contains(action)