Пример #1
0
    def testVariableBatchSize(self):
        q_tm1 = tf.placeholder(tf.float32, shape=[None, 3])
        a_tm1 = tf.placeholder(tf.int32, shape=[None])
        pcont_t = tf.placeholder(tf.float32, shape=[None])
        r_t = tf.placeholder(tf.float32, shape=[None])
        v_t = tf.placeholder(tf.float32, shape=[None])
        loss_op, extra_ops = rl.qv_learning(q_tm1, a_tm1, r_t, pcont_t, v_t)

        # Check static shapes.
        self.assertEqual(loss_op.get_shape().as_list(), [None])
        self.assertEqual(extra_ops.td_error.get_shape().as_list(), [None])
        self.assertEqual(extra_ops.target.get_shape().as_list(), [None])

        # Check runtime shapes.
        batch_size = 11
        feed_dict = {
            q_tm1: np.random.random([batch_size, 3]),
            a_tm1: np.random.randint(0, 3, [batch_size]),
            pcont_t: np.random.random([batch_size]),
            r_t: np.random.random(batch_size),
            v_t: np.random.random(batch_size),
        }
        with self.test_session() as sess:
            loss, td_error, target = sess.run(
                [loss_op, extra_ops.td_error, extra_ops.target],
                feed_dict=feed_dict)

        self.assertEqual(loss.shape, (batch_size, ))
        self.assertEqual(td_error.shape, (batch_size, ))
        self.assertEqual(target.shape, (batch_size, ))
Пример #2
0
 def setUp(self):
     super(QVTest, self).setUp()
     self.q_tm1 = tf.constant([[1, 1, 0], [1, 1, 0]], dtype=tf.float32)
     self.a_tm1 = tf.constant([0, 1], dtype=tf.int32)
     self.pcont_t = tf.constant([0, 1], dtype=tf.float32)
     self.r_t = tf.constant([1, 1], dtype=tf.float32)
     self.v_t = tf.constant([1, 3], dtype=tf.float32)
     self.loss_op, self.extra_ops = rl.qv_learning(self.q_tm1, self.a_tm1,
                                                   self.r_t, self.pcont_t,
                                                   self.v_t)
Пример #3
0
 def testRankCheck(self):
     q_tm1 = tf.placeholder(tf.float32, [None])
     with self.assertRaisesRegexp(
             ValueError,
             "QVLearning: Error in rank and/or compatibility check"):
         rl.qv_learning(q_tm1, self.a_tm1, self.r_t, self.pcont_t, self.v_t)