def test_is_pickleable(self, obs_dim, action_dim, mock_rand): mock_rand.return_value = 0 env = TfEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim)) with mock.patch(('garage.tf.policies.' 'categorical_mlp_policy_with_model.MLPModel'), new=SimpleMLPModel): policy = CategoricalMLPPolicyWithModel(env_spec=env.spec) env.reset() obs, _, _, _ = env.step(1) expected_prob = np.full(action_dim, 0.5) p = pickle.dumps(policy) with tf.Session(graph=tf.Graph()): policy_pickled = pickle.loads(p) action, prob = policy_pickled.get_action(obs) assert env.action_space.contains(action) assert action == 0 assert np.array_equal(prob['prob'], expected_prob) prob1 = policy.dist_info([obs.flatten()]) prob2 = policy_pickled.dist_info([obs.flatten()]) assert np.array_equal(prob1['prob'], prob2['prob']) assert np.array_equal(prob2['prob'][0], expected_prob)
def test_dist_info(self, obs_dim, action_dim): env = TfEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim)) with mock.patch(('garage.tf.policies.' 'categorical_mlp_policy_with_model.MLPModel'), new=SimpleMLPModel): policy = CategoricalMLPPolicyWithModel(env_spec=env.spec) env.reset() obs, _, _, _ = env.step(1) expected_prob = np.full(action_dim, 0.5) policy_probs = policy.dist_info([obs.flatten()]) assert np.array_equal(policy_probs['prob'][0], expected_prob)