def testVariableBatchSize(self): q_tm1 = tf.placeholder(tf.float32, shape=[None, 3]) q_t = tf.placeholder(tf.float32, shape=[None, 3]) a_tm1 = tf.placeholder(tf.int32, shape=[None]) pcont_t = tf.placeholder(tf.float32, shape=[None]) probs_a_t = tf.placeholder(tf.float32, shape=[None, 3]) r_t = tf.placeholder(tf.float32, shape=[None]) sarse = rl.sarse(q_tm1, a_tm1, r_t, pcont_t, q_t, probs_a_t) # Check static shapes. self.assertEqual(sarse.loss.get_shape().as_list(), [None]) self.assertEqual(sarse.extra.td_error.get_shape().as_list(), [None]) self.assertEqual(sarse.extra.target.get_shape().as_list(), [None]) # Check runtime shapes. batch_size = 11 feed_dict = { q_tm1: np.random.random([batch_size, 3]), q_t: np.random.random([batch_size, 3]), a_tm1: np.random.randint(0, 3, [batch_size]), pcont_t: np.random.random([batch_size]), r_t: np.random.random(batch_size), probs_a_t: np.random.uniform(size=[batch_size, 3]) } with self.test_session() as sess: loss, td_error, target = sess.run( [sarse.loss, sarse.extra.td_error, sarse.extra.target], feed_dict=feed_dict) self.assertEqual(loss.shape, (batch_size, )) self.assertEqual(td_error.shape, (batch_size, )) self.assertEqual(target.shape, (batch_size, ))
def testVariableBatchSize(self): q_tm1 = tf.placeholder(tf.float32, shape=[None, 3]) q_t = tf.placeholder(tf.float32, shape=[None, 3]) a_tm1 = tf.placeholder(tf.int32, shape=[None]) pcont_t = tf.placeholder(tf.float32, shape=[None]) probs_a_t = tf.placeholder(tf.float32, shape=[None, 3]) r_t = tf.placeholder(tf.float32, shape=[None]) sarse = rl.sarse(q_tm1, a_tm1, r_t, pcont_t, q_t, probs_a_t) # Check static shapes. self.assertEqual(sarse.loss.get_shape().as_list(), [None]) self.assertEqual(sarse.extra.td_error.get_shape().as_list(), [None]) self.assertEqual(sarse.extra.target.get_shape().as_list(), [None]) # Check runtime shapes. batch_size = 11 feed_dict = { q_tm1: np.random.random([batch_size, 3]), q_t: np.random.random([batch_size, 3]), a_tm1: np.random.randint(0, 3, [batch_size]), pcont_t: np.random.random([batch_size]), r_t: np.random.random(batch_size), probs_a_t: np.random.uniform(size=[batch_size, 3]) } with self.test_session() as sess: loss, td_error, target = sess.run( [sarse.loss, sarse.extra.td_error, sarse.extra.target], feed_dict=feed_dict) self.assertEqual(loss.shape, (batch_size,)) self.assertEqual(td_error.shape, (batch_size,)) self.assertEqual(target.shape, (batch_size,))
def testIncorrectProbsTensor(self): probs_a_t = tf.constant([[0.2, 0.5, 0.3], [0.3, 0.5, 0.3]], dtype=tf.float32) with self.test_session() as sess: with self.assertRaisesRegexp( tf.errors.InvalidArgumentError, "probs_a_t tensor does not sum to 1"): self.sarse = rl.sarse(self.q_tm1, self.a_tm1, self.r_t, self.pcont_t, self.q_t, probs_a_t, debug=True) sess.run(self.sarse.extra.target)
def setUp(self): super(SarseTest, self).setUp() self.q_tm1 = tf.constant([[1, 1, 0.5], [1, 1, 3]], dtype=tf.float32) self.q_t = tf.constant([[1.5, 1, 2], [3, 2, 1]], dtype=tf.float32) self.a_tm1 = tf.constant([0, 1], dtype=tf.int32) self.probs_a_t = tf.constant([[0.2, 0.5, 0.3], [0.3, 0.4, 0.3]], dtype=tf.float32) self.pcont_t = tf.constant([1, 1], dtype=tf.float32) self.r_t = tf.constant([4, 1], dtype=tf.float32) self.sarse = rl.sarse(self.q_tm1, self.a_tm1, self.r_t, self.pcont_t, self.q_t, self.probs_a_t)
def testIncorrectProbsTensor(self): probs_a_t = tf.constant([[0.2, 0.5, 0.3], [0.3, 0.5, 0.3]], dtype=tf.float32) with self.test_session() as sess: with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, "probs_a_t tensor does not sum to 1"): self.sarse = rl.sarse(self.q_tm1, self.a_tm1, self.r_t, self.pcont_t, self.q_t, probs_a_t, debug=True) sess.run(self.sarse.extra.target)
def testCompatibilityCheck(self): probs_a_t = tf.placeholder(tf.float32, [None, 2]) with self.assertRaisesRegexp( ValueError, "Sarse: Error in rank and/or compatibility check"): self.sarse = rl.sarse(self.q_tm1, self.a_tm1, self.r_t, self.pcont_t, self.q_t, probs_a_t)