Пример #1
0
    def test_is_pickleable(self, obs_dim, action_dim, mock_rand):
        mock_rand.return_value = 0
        env = TfEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('garage.tf.policies.'
                         'categorical_mlp_policy_with_model.MLPModel'),
                        new=SimpleMLPModel):
            policy = CategoricalMLPPolicyWithModel(env_spec=env.spec)

        env.reset()
        obs, _, _, _ = env.step(1)

        expected_prob = np.full(action_dim, 0.5)

        p = pickle.dumps(policy)

        with tf.Session(graph=tf.Graph()):
            policy_pickled = pickle.loads(p)
            action, prob = policy_pickled.get_action(obs)
            assert env.action_space.contains(action)
            assert action == 0
            assert np.array_equal(prob['prob'], expected_prob)

            prob1 = policy.dist_info([obs.flatten()])
            prob2 = policy_pickled.dist_info([obs.flatten()])
            assert np.array_equal(prob1['prob'], prob2['prob'])
            assert np.array_equal(prob2['prob'][0], expected_prob)
Пример #2
0
    def test_dist_info(self, obs_dim, action_dim):
        env = TfEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('garage.tf.policies.'
                         'categorical_mlp_policy_with_model.MLPModel'),
                        new=SimpleMLPModel):
            policy = CategoricalMLPPolicyWithModel(env_spec=env.spec)

        env.reset()
        obs, _, _, _ = env.step(1)

        expected_prob = np.full(action_dim, 0.5)

        policy_probs = policy.dist_info([obs.flatten()])
        assert np.array_equal(policy_probs['prob'][0], expected_prob)