Beispiel #1
0
    def test_dist_info_sym_include_action(self, obs_dim, action_dim,
                                          hidden_dim):
        env = TfEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))

        obs_ph = tf.compat.v1.placeholder(
            tf.float32, shape=(None, None, env.observation_space.flat_dim))

        with mock.patch(('metarl.tf.policies.'
                         'gaussian_gru_policy.GaussianGRUModel'),
                        new=SimpleGaussianGRUModel):
            policy = GaussianGRUPolicy(env_spec=env.spec,
                                       state_include_action=True)

            policy.reset()
            obs = env.reset()
            dist_sym = policy.dist_info_sym(
                obs_var=obs_ph,
                state_info_vars={'prev_action': np.zeros((2, 1) + action_dim)},
                name='p2_sym')
        dist = self.sess.run(
            dist_sym, feed_dict={obs_ph: [[obs.flatten()], [obs.flatten()]]})

        assert np.array_equal(dist['mean'], np.full((2, 1) + action_dim, 0.5))
        assert np.array_equal(dist['log_std'], np.full((2, 1) + action_dim,
                                                       0.5))
Beispiel #2
0
    def test_get_action(self, mock_normal, obs_dim, action_dim, hidden_dim):
        mock_normal.return_value = 0.5
        env = TfEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('metarl.tf.policies.'
                         'gaussian_gru_policy.GaussianGRUModel'),
                        new=SimpleGaussianGRUModel):
            policy = GaussianGRUPolicy(env_spec=env.spec,
                                       state_include_action=False)

        policy.reset()
        obs = env.reset()

        expected_action = np.full(action_dim, 0.5 * np.exp(0.5) + 0.5)
        action, agent_info = policy.get_action(obs)
        assert env.action_space.contains(action)
        assert np.allclose(action, expected_action)

        expected_mean = np.full(action_dim, 0.5)
        assert np.array_equal(agent_info['mean'], expected_mean)
        expected_log_std = np.full(action_dim, 0.5)
        assert np.array_equal(agent_info['log_std'], expected_log_std)

        actions, agent_infos = policy.get_actions([obs])
        for action, mean, log_std in zip(actions, agent_infos['mean'],
                                         agent_infos['log_std']):
            assert env.action_space.contains(action)
            assert np.allclose(action, expected_action)
            assert np.array_equal(mean, expected_mean)
            assert np.array_equal(log_std, expected_log_std)
    def test_get_action_state_include_action(self, obs_dim, action_dim,
                                             hidden_dim):
        env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        obs_var = tf.compat.v1.placeholder(
            tf.float32,
            shape=[
                None, None,
                env.observation_space.flat_dim + np.prod(action_dim)
            ],
            name='obs')
        policy = GaussianGRUPolicy(env_spec=env.spec,
                                   hidden_dim=hidden_dim,
                                   state_include_action=True)

        policy.build(obs_var)
        policy.reset()
        obs = env.reset()

        action, _ = policy.get_action(obs.flatten())
        assert env.action_space.contains(action)

        policy.reset()

        actions, _ = policy.get_actions([obs.flatten()])
        for action in actions:
            assert env.action_space.contains(action)
Beispiel #4
0
    def test_gaussian_gru_policy(self):
        gaussian_gru_policy = GaussianGRUPolicy(env_spec=self.env,
                                                hidden_dim=1)
        self.sess.run(tf.compat.v1.global_variables_initializer())

        gaussian_gru_policy.reset()

        obs = self.env.observation_space.high
        assert gaussian_gru_policy.get_action(obs)
Beispiel #5
0
    def test_dist_info_sym_wrong_input(self):
        env = TfEnv(DummyBoxEnv(obs_dim=(1, ), action_dim=(1, )))

        obs_ph = tf.compat.v1.placeholder(
            tf.float32, shape=(None, None, env.observation_space.flat_dim))

        with mock.patch(('metarl.tf.policies.'
                         'gaussian_gru_policy.GaussianGRUModel'),
                        new=SimpleGaussianGRUModel):
            policy = GaussianGRUPolicy(env_spec=env.spec,
                                       state_include_action=True)

            policy.reset()
            obs = env.reset()

            policy.dist_info_sym(
                obs_var=obs_ph,
                state_info_vars={'prev_action': np.zeros((3, 1, 1))},
                name='p2_sym')
        # observation batch size = 2 but prev_action batch size = 3
        with pytest.raises(tf.errors.InvalidArgumentError):
            self.sess.run(
                policy.model.networks['p2_sym'].input,
                feed_dict={obs_ph: [[obs.flatten()], [obs.flatten()]]})