class TestQfDerivedPolicy(TfGraphTestCase): def setup_method(self): super().setup_method() self.env = GarageEnv(DummyDiscreteEnv()) self.qf = SimpleQFunction(self.env.spec) self.policy = DiscreteQfDerivedPolicy(env_spec=self.env.spec, qf=self.qf) self.sess.run(tf.compat.v1.global_variables_initializer()) self.env.reset() def test_discrete_qf_derived_policy(self): obs, _, _, _ = self.env.step(1) action, _ = self.policy.get_action(obs) assert self.env.action_space.contains(action) actions, _ = self.policy.get_actions([obs]) for action in actions: assert self.env.action_space.contains(action) def test_is_pickleable(self): with tf.compat.v1.variable_scope('SimpleQFunction/SimpleMLPModel', reuse=True): return_var = tf.compat.v1.get_variable('return_var') # assign it to all one return_var.load(tf.ones_like(return_var).eval()) obs, _, _, _ = self.env.step(1) action1, _ = self.policy.get_action(obs) p = pickle.dumps(self.policy) with tf.compat.v1.Session(graph=tf.Graph()): policy_pickled = pickle.loads(p) action2, _ = policy_pickled.get_action(obs) assert action1 == action2
class TestQfDerivedPolicy(TfGraphTestCase): def setup_method(self): super().setup_method() self.env = GymEnv(DummyDiscreteEnv()) self.qf = SimpleQFunction(self.env.spec) self.policy = DiscreteQfDerivedPolicy(env_spec=self.env.spec, qf=self.qf) self.sess.run(tf.compat.v1.global_variables_initializer()) self.env.reset() def test_discrete_qf_derived_policy(self): obs = self.env.step(1).observation action, _ = self.policy.get_action(obs) assert self.env.action_space.contains(action) actions, _ = self.policy.get_actions([obs]) for action in actions: assert self.env.action_space.contains(action) def test_get_param(self): with tf.compat.v1.variable_scope('SimpleQFunction', reuse=True): return_var = tf.compat.v1.get_variable('return_var') assert self.policy.get_param_values() == return_var.eval() def test_is_pickleable(self): with tf.compat.v1.variable_scope('SimpleQFunction', reuse=True): return_var = tf.compat.v1.get_variable('return_var') # assign it to all one return_var.load(tf.ones_like(return_var).eval()) obs = self.env.step(1).observation action1, _ = self.policy.get_action(obs) p = pickle.dumps(self.policy) with tf.compat.v1.Session(graph=tf.Graph()): policy_pickled = pickle.loads(p) action2, _ = policy_pickled.get_action(obs) assert action1 == action2 def test_does_not_support_dict_obs_space(self): """Test that policy raises error if passed a dict obs space.""" env = GymEnv(DummyDictEnv(act_space_type='discrete')) with pytest.raises(ValueError): qf = SimpleQFunction(env.spec, name='does_not_support_dict_obs_space') DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) def test_invalid_action_spaces(self): """Test that policy raises error if passed a dict obs space.""" env = GymEnv(DummyDictEnv(act_space_type='box')) with pytest.raises(ValueError): qf = SimpleQFunction(env.spec) DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf)
class TestQfDerivedPolicy(TfGraphTestCase): def setUp(self): super().setUp() self.env = TfEnv(DummyDiscreteEnv()) self.qf = SimpleQFunction(self.env.spec) self.policy = DiscreteQfDerivedPolicy( env_spec=self.env.spec, qf=self.qf) self.sess.run(tf.global_variables_initializer()) self.env.reset() def test_discrete_qf_derived_policy(self): obs, _, _, _ = self.env.step(1) action = self.policy.get_action(obs) assert self.env.action_space.contains(action) actions = self.policy.get_actions([obs]) for action in actions: assert self.env.action_space.contains(action)