Ejemplo n.º 1
0
    def test_param_values(self):
        gcb = GaussianCNNBaseline(env_spec=test_env_spec,
                                  filters=((3, (3, 3)), (6, (3, 3))),
                                  strides=(1, 1),
                                  padding='SAME',
                                  hidden_sizes=(32, ),
                                  adaptive_std=False,
                                  use_trust_region=False)
        new_gcb = GaussianCNNBaseline(env_spec=test_env_spec,
                                      filters=((3, (3, 3)), (6, (3, 3))),
                                      strides=(1, 1),
                                      padding='SAME',
                                      hidden_sizes=(32, ),
                                      adaptive_std=False,
                                      use_trust_region=False,
                                      name='GaussianCNNBaseline2')

        # Manual change the parameter of GaussianCNNBaseline
        with tf.compat.v1.variable_scope('GaussianCNNBaseline', reuse=True):
            bias_var = tf.compat.v1.get_variable(
                'dist_params/mean_network/hidden_0/bias')
        bias_var.load(tf.ones_like(bias_var).eval())

        old_param_values = gcb.get_param_values()
        new_param_values = new_gcb.get_param_values()
        assert not np.array_equal(old_param_values, new_param_values)
        new_gcb.set_param_values(old_param_values)
        new_param_values = new_gcb.get_param_values()
        assert np.array_equal(old_param_values, new_param_values)
    def test_param_values(self, obs_dim):
        box_env = GarageEnv(DummyBoxEnv(obs_dim=obs_dim))
        with mock.patch(('garage.tf.baselines.'
                         'gaussian_cnn_baseline.'
                         'GaussianCNNRegressor'),
                        new=SimpleGaussianCNNRegressor):
            gcb = GaussianCNNBaseline(env_spec=box_env.spec)
            new_gcb = GaussianCNNBaseline(env_spec=box_env.spec,
                                          name='GaussianCNNBaseline2')

        # Manual change the parameter of GaussianCNNBaseline
        with tf.compat.v1.variable_scope('GaussianCNNBaseline', reuse=True):
            return_var = tf.compat.v1.get_variable(
                'SimpleGaussianCNNModel/return_var')
        return_var.load(1.0)

        old_param_values = gcb.get_param_values()
        new_param_values = new_gcb.get_param_values()
        assert not np.array_equal(old_param_values, new_param_values)
        new_gcb.set_param_values(old_param_values)
        new_param_values = new_gcb.get_param_values()
        assert np.array_equal(old_param_values, new_param_values)