示例#1
0
    def test_baseline(self):
        """Test the baseline initialization."""
        box_env = TfEnv(DummyBoxEnv())
        deterministic_mlp_baseline = ContinuousMLPBaseline(env_spec=box_env)
        gaussian_mlp_baseline = GaussianMLPBaseline(env_spec=box_env)

        self.sess.run(tf.compat.v1.global_variables_initializer())
        deterministic_mlp_baseline.get_param_values()
        gaussian_mlp_baseline.get_param_values()

        box_env.close()
    def test_param_values(self, obs_dim):
        box_env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim))
        with mock.patch(('metarl.tf.baselines.'
                         'gaussian_mlp_baseline.'
                         'GaussianMLPRegressor'),
                        new=SimpleGaussianMLPRegressor):
            gmb = GaussianMLPBaseline(env_spec=box_env.spec)
            new_gmb = GaussianMLPBaseline(env_spec=box_env.spec,
                                          name='GaussianMLPBaseline2')

        # Manual change the parameter of GaussianMLPBaseline
        with tf.compat.v1.variable_scope('GaussianMLPBaseline', reuse=True):
            return_var = tf.compat.v1.get_variable(
                'SimpleGaussianMLPModel/return_var')
        return_var.load(1.0)

        old_param_values = gmb.get_param_values()
        new_param_values = new_gmb.get_param_values()
        assert not np.array_equal(old_param_values, new_param_values)
        new_gmb.set_param_values(old_param_values)
        new_param_values = new_gmb.get_param_values()
        assert np.array_equal(old_param_values, new_param_values)