def testVariableBatchSize(self): q_tm1 = tf.placeholder(tf.float32, shape=[None, 3]) q_t = tf.placeholder(tf.float32, shape=[None, 3]) a_tm1 = tf.placeholder(tf.int32, shape=[None]) pcont_t = tf.placeholder(tf.float32, shape=[None]) a_t = tf.placeholder(tf.int32, shape=[None]) r_t = tf.placeholder(tf.float32, shape=[None]) sarsa = rl.sarsa(q_tm1, a_tm1, r_t, pcont_t, q_t, a_t) # Check static shapes. self.assertEqual(sarsa.loss.get_shape().as_list(), [None]) self.assertEqual(sarsa.extra.td_error.get_shape().as_list(), [None]) self.assertEqual(sarsa.extra.target.get_shape().as_list(), [None]) # Check runtime shapes. batch_size = 11 feed_dict = { q_tm1: np.random.random([batch_size, 3]), q_t: np.random.random([batch_size, 3]), a_tm1: np.random.randint(0, 3, [batch_size]), pcont_t: np.random.random([batch_size]), a_t: np.random.randint(0, 3, [batch_size]), r_t: np.random.random(batch_size) } with self.test_session() as sess: loss, td_error, target = sess.run( [sarsa.loss, sarsa.extra.td_error, sarsa.extra.target], feed_dict=feed_dict) self.assertEqual(loss.shape, (batch_size, )) self.assertEqual(td_error.shape, (batch_size, )) self.assertEqual(target.shape, (batch_size, ))
def testVariableBatchSize(self): q_tm1 = tf.placeholder(tf.float32, shape=[None, 3]) q_t = tf.placeholder(tf.float32, shape=[None, 3]) a_tm1 = tf.placeholder(tf.int32, shape=[None]) pcont_t = tf.placeholder(tf.float32, shape=[None]) a_t = tf.placeholder(tf.int32, shape=[None]) r_t = tf.placeholder(tf.float32, shape=[None]) sarsa = rl.sarsa(q_tm1, a_tm1, r_t, pcont_t, q_t, a_t) # Check static shapes. self.assertEqual(sarsa.loss.get_shape().as_list(), [None]) self.assertEqual(sarsa.extra.td_error.get_shape().as_list(), [None]) self.assertEqual(sarsa.extra.target.get_shape().as_list(), [None]) # Check runtime shapes. batch_size = 11 feed_dict = { q_tm1: np.random.random([batch_size, 3]), q_t: np.random.random([batch_size, 3]), a_tm1: np.random.randint(0, 3, [batch_size]), pcont_t: np.random.random([batch_size]), a_t: np.random.randint(0, 3, [batch_size]), r_t: np.random.random(batch_size) } with self.test_session() as sess: loss, td_error, target = sess.run( [sarsa.loss, sarsa.extra.td_error, sarsa.extra.target], feed_dict=feed_dict) self.assertEqual(loss.shape, (batch_size,)) self.assertEqual(td_error.shape, (batch_size,)) self.assertEqual(target.shape, (batch_size,))
def setUp(self): super(SarsaTest, self).setUp() self.q_tm1 = tf.constant([[1, 1, 0], [1, 1, 0]], dtype=tf.float32) self.q_t = tf.constant([[0, 1, 0], [3, 2, 0]], dtype=tf.float32) self.a_tm1 = tf.constant([0, 1], dtype=tf.int32) self.a_t = tf.constant([1, 0], dtype=tf.int32) self.pcont_t = tf.constant([0, 1], dtype=tf.float32) self.r_t = tf.constant([1, 1], dtype=tf.float32) self.sarsa = rl.sarsa(self.q_tm1, self.a_tm1, self.r_t, self.pcont_t, self.q_t, self.a_t)
def testCompatibilityCheck(self): a_t = tf.placeholder(tf.float32, [3]) with self.assertRaisesRegexp( ValueError, "Sarsa: Error in rank and/or compatibility check"): self.sarsa = rl.sarsa(self.q_tm1, self.a_tm1, self.r_t, self.pcont_t, self.q_t, a_t)