Beispiel #1
0
    def test_get_action(self, mock_normal, obs_dim, action_dim):
        mock_normal.return_value = 0.5
        env = TfEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('metarl.tf.policies.'
                         'gaussian_mlp_policy.GaussianMLPModel'),
                        new=SimpleGaussianMLPModel):
            policy = GaussianMLPPolicy(env_spec=env.spec)

        env.reset()
        obs, _, _, _ = env.step(1)

        action, prob = policy.get_action(obs)

        expected_action = np.full(action_dim, 0.75)
        expected_mean = np.full(action_dim, 0.5)
        expected_log_std = np.full(action_dim, np.log(0.5))

        assert env.action_space.contains(action)
        assert np.array_equal(action, expected_action)
        assert np.array_equal(prob['mean'], expected_mean)
        assert np.array_equal(prob['log_std'], expected_log_std)

        actions, probs = policy.get_actions([obs, obs, obs])
        for action, mean, log_std in zip(actions, probs['mean'],
                                         probs['log_std']):
            assert env.action_space.contains(action)
            assert np.array_equal(action, expected_action)
            assert np.array_equal(prob['mean'], expected_mean)
            assert np.array_equal(prob['log_std'], expected_log_std)
Beispiel #2
0
    def test_get_action(self, obs_dim, action_dim):
        env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        obs_var = tf.compat.v1.placeholder(
            tf.float32,
            shape=[None, None, env.observation_space.flat_dim],
            name='obs')
        policy = GaussianMLPPolicy(env_spec=env.spec)

        policy.build(obs_var)
        env.reset()
        obs, _, _, _ = env.step(1)

        action, _ = policy.get_action(obs.flatten())
        assert env.action_space.contains(action)
        actions, _ = policy.get_actions(
            [obs.flatten(), obs.flatten(),
             obs.flatten()])
        for action in actions:
            assert env.action_space.contains(action)