def test_get_qval_sym(self, obs_dim, action_dim): env = TfEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim)) with mock.patch(('garage.tf.q_functions.' 'discrete_mlp_q_function.MLPModel'), new=SimpleMLPModel): qf = DiscreteMLPQFunction(env_spec=env.spec) env.reset() obs, _, _, _ = env.step(1) output1 = self.sess.run(qf.q_vals, feed_dict={qf.input: [obs]}) input_var = tf.placeholder(tf.float32, shape=(None, ) + obs_dim) q_vals = qf.get_qval_sym(input_var, 'another') output2 = self.sess.run(q_vals, feed_dict={input_var: [obs]}) expected_output = np.full(action_dim, 0.5) assert np.array_equal(output1, output2) assert np.array_equal(output2[0], expected_output)
class TestDiscreteMLPQFunction(TfGraphTestCase): def setUp(self): super().setUp() self.data = np.ones((2, 1)) self.env = TfEnv(DummyDiscreteEnv()) self.qf = DiscreteMLPQFunction(self.env.spec) def test_discrete_mlp_q_function(self): output1 = self.sess.run( self.qf.model.networks['default'].outputs, feed_dict={self.qf.model.networks['default'].input: self.data}) assert output1.shape == (2, self.env.action_space.n) def test_discrete_mlp_q_function_is_rebuilt_output_same(self): output1 = self.sess.run( self.qf.model.networks['default'].outputs, feed_dict={self.qf.model.networks['default'].input: self.data}) input_var = tf.placeholder(tf.float32, shape=(None, 1)) q_vals = self.qf.get_qval_sym(input_var, "another") output2 = self.sess.run(q_vals, feed_dict={input_var: self.data}) assert np.array_equal(output1, output2) def test_discrete_mlp_q_function_is_pickleable(self): output1 = self.sess.run( self.qf.model.networks['default'].outputs, feed_dict={self.qf.model.networks['default'].input: self.data}) h_data = pickle.dumps(self.qf) with tf.Session(graph=tf.Graph()) as sess: qf_pickled = pickle.loads(h_data) input_var = tf.placeholder(tf.float32, shape=(None, 1)) q_vals = qf_pickled.get_qval_sym(input_var, "another") output2 = sess.run(q_vals, feed_dict={input_var: self.data}) assert np.array_equal(output1, output2)