コード例 #1
0
    def test_fit(self, obs_dim):
        box_env = TfEnv(DummyBoxEnv(obs_dim=obs_dim))
        with mock.patch(('garage.tf.baselines.'
                         'continuous_mlp_baseline_with_model.'
                         'ContinuousMLPRegressorWithModel'),
                        new=SimpleMLPRegressor):
            cmb = ContinuousMLPBaselineWithModel(env_spec=box_env.spec)
        paths = [{
            'observations': [np.full(obs_dim, 1)],
            'returns': [1]
        }, {
            'observations': [np.full(obs_dim, 2)],
            'returns': [2]
        }]
        cmb.fit(paths)

        obs = {'observations': [np.full(obs_dim, 1), np.full(obs_dim, 2)]}
        prediction = cmb.predict(obs)
        assert np.array_equal(prediction, [1, 2])
コード例 #2
0
    def test_is_pickleable(self):
        box_env = TfEnv(DummyBoxEnv(obs_dim=(1, )))
        with mock.patch(('garage.tf.baselines.'
                         'continuous_mlp_baseline_with_model.'
                         'ContinuousMLPRegressorWithModel'),
                        new=SimpleMLPRegressor):
            cmb = ContinuousMLPBaselineWithModel(env_spec=box_env.spec)
        obs = {'observations': [np.full(1, 1), np.full(1, 1)]}

        with tf.variable_scope('ContinuousMLPBaselineWithModel', reuse=True):
            return_var = tf.get_variable('SimpleMLPModel/return_var')
        return_var.load(1.0)

        prediction = cmb.predict(obs)

        h = pickle.dumps(cmb)

        with tf.Session(graph=tf.Graph()):
            cmb_pickled = pickle.loads(h)
            prediction2 = cmb_pickled.predict(obs)

            assert np.array_equal(prediction, prediction2)