def testVariableBatchSize(self): q_tm1 = tf.placeholder(tf.float32, shape=[None, 3]) a_tm1 = tf.placeholder(tf.int32, shape=[None]) pcont_t = tf.placeholder(tf.float32, shape=[None]) r_t = tf.placeholder(tf.float32, shape=[None]) v_t = tf.placeholder(tf.float32, shape=[None]) loss_op, extra_ops = rl.qv_learning(q_tm1, a_tm1, r_t, pcont_t, v_t) # Check static shapes. self.assertEqual(loss_op.get_shape().as_list(), [None]) self.assertEqual(extra_ops.td_error.get_shape().as_list(), [None]) self.assertEqual(extra_ops.target.get_shape().as_list(), [None]) # Check runtime shapes. batch_size = 11 feed_dict = { q_tm1: np.random.random([batch_size, 3]), a_tm1: np.random.randint(0, 3, [batch_size]), pcont_t: np.random.random([batch_size]), r_t: np.random.random(batch_size), v_t: np.random.random(batch_size), } with self.test_session() as sess: loss, td_error, target = sess.run( [loss_op, extra_ops.td_error, extra_ops.target], feed_dict=feed_dict) self.assertEqual(loss.shape, (batch_size, )) self.assertEqual(td_error.shape, (batch_size, )) self.assertEqual(target.shape, (batch_size, ))
def setUp(self): super(QVTest, self).setUp() self.q_tm1 = tf.constant([[1, 1, 0], [1, 1, 0]], dtype=tf.float32) self.a_tm1 = tf.constant([0, 1], dtype=tf.int32) self.pcont_t = tf.constant([0, 1], dtype=tf.float32) self.r_t = tf.constant([1, 1], dtype=tf.float32) self.v_t = tf.constant([1, 3], dtype=tf.float32) self.loss_op, self.extra_ops = rl.qv_learning(self.q_tm1, self.a_tm1, self.r_t, self.pcont_t, self.v_t)
def testRankCheck(self): q_tm1 = tf.placeholder(tf.float32, [None]) with self.assertRaisesRegexp( ValueError, "QVLearning: Error in rank and/or compatibility check"): rl.qv_learning(q_tm1, self.a_tm1, self.r_t, self.pcont_t, self.v_t)