Пример #1
0
    def testVariableBatchSize(self):
        q_tm1 = tf.placeholder(tf.float32, shape=[None, 3])
        q_t = tf.placeholder(tf.float32, shape=[None, 3])
        a_tm1 = tf.placeholder(tf.int32, shape=[None])
        pcont_t = tf.placeholder(tf.float32, shape=[None])
        a_t = tf.placeholder(tf.int32, shape=[None])
        r_t = tf.placeholder(tf.float32, shape=[None])
        sarsa = rl.sarsa(q_tm1, a_tm1, r_t, pcont_t, q_t, a_t)

        # Check static shapes.
        self.assertEqual(sarsa.loss.get_shape().as_list(), [None])
        self.assertEqual(sarsa.extra.td_error.get_shape().as_list(), [None])
        self.assertEqual(sarsa.extra.target.get_shape().as_list(), [None])

        # Check runtime shapes.
        batch_size = 11
        feed_dict = {
            q_tm1: np.random.random([batch_size, 3]),
            q_t: np.random.random([batch_size, 3]),
            a_tm1: np.random.randint(0, 3, [batch_size]),
            pcont_t: np.random.random([batch_size]),
            a_t: np.random.randint(0, 3, [batch_size]),
            r_t: np.random.random(batch_size)
        }
        with self.test_session() as sess:
            loss, td_error, target = sess.run(
                [sarsa.loss, sarsa.extra.td_error, sarsa.extra.target],
                feed_dict=feed_dict)

        self.assertEqual(loss.shape, (batch_size, ))
        self.assertEqual(td_error.shape, (batch_size, ))
        self.assertEqual(target.shape, (batch_size, ))
Пример #2
0
  def testVariableBatchSize(self):
    q_tm1 = tf.placeholder(tf.float32, shape=[None, 3])
    q_t = tf.placeholder(tf.float32, shape=[None, 3])
    a_tm1 = tf.placeholder(tf.int32, shape=[None])
    pcont_t = tf.placeholder(tf.float32, shape=[None])
    a_t = tf.placeholder(tf.int32, shape=[None])
    r_t = tf.placeholder(tf.float32, shape=[None])
    sarsa = rl.sarsa(q_tm1, a_tm1, r_t, pcont_t, q_t, a_t)

    # Check static shapes.
    self.assertEqual(sarsa.loss.get_shape().as_list(), [None])
    self.assertEqual(sarsa.extra.td_error.get_shape().as_list(), [None])
    self.assertEqual(sarsa.extra.target.get_shape().as_list(), [None])

    # Check runtime shapes.
    batch_size = 11
    feed_dict = {
        q_tm1: np.random.random([batch_size, 3]),
        q_t: np.random.random([batch_size, 3]),
        a_tm1: np.random.randint(0, 3, [batch_size]),
        pcont_t: np.random.random([batch_size]),
        a_t: np.random.randint(0, 3, [batch_size]),
        r_t: np.random.random(batch_size)
    }
    with self.test_session() as sess:
      loss, td_error, target = sess.run(
          [sarsa.loss, sarsa.extra.td_error, sarsa.extra.target],
          feed_dict=feed_dict)

    self.assertEqual(loss.shape, (batch_size,))
    self.assertEqual(td_error.shape, (batch_size,))
    self.assertEqual(target.shape, (batch_size,))
Пример #3
0
 def setUp(self):
     super(SarsaTest, self).setUp()
     self.q_tm1 = tf.constant([[1, 1, 0], [1, 1, 0]], dtype=tf.float32)
     self.q_t = tf.constant([[0, 1, 0], [3, 2, 0]], dtype=tf.float32)
     self.a_tm1 = tf.constant([0, 1], dtype=tf.int32)
     self.a_t = tf.constant([1, 0], dtype=tf.int32)
     self.pcont_t = tf.constant([0, 1], dtype=tf.float32)
     self.r_t = tf.constant([1, 1], dtype=tf.float32)
     self.sarsa = rl.sarsa(self.q_tm1, self.a_tm1, self.r_t, self.pcont_t,
                           self.q_t, self.a_t)
Пример #4
0
 def setUp(self):
   super(SarsaTest, self).setUp()
   self.q_tm1 = tf.constant([[1, 1, 0], [1, 1, 0]], dtype=tf.float32)
   self.q_t = tf.constant([[0, 1, 0], [3, 2, 0]], dtype=tf.float32)
   self.a_tm1 = tf.constant([0, 1], dtype=tf.int32)
   self.a_t = tf.constant([1, 0], dtype=tf.int32)
   self.pcont_t = tf.constant([0, 1], dtype=tf.float32)
   self.r_t = tf.constant([1, 1], dtype=tf.float32)
   self.sarsa = rl.sarsa(self.q_tm1, self.a_tm1, self.r_t, self.pcont_t,
                         self.q_t, self.a_t)
Пример #5
0
 def testCompatibilityCheck(self):
     a_t = tf.placeholder(tf.float32, [3])
     with self.assertRaisesRegexp(
             ValueError, "Sarsa: Error in rank and/or compatibility check"):
         self.sarsa = rl.sarsa(self.q_tm1, self.a_tm1, self.r_t,
                               self.pcont_t, self.q_t, a_t)
Пример #6
0
 def testCompatibilityCheck(self):
   a_t = tf.placeholder(tf.float32, [3])
   with self.assertRaisesRegexp(
       ValueError, "Sarsa: Error in rank and/or compatibility check"):
     self.sarsa = rl.sarsa(self.q_tm1, self.a_tm1, self.r_t, self.pcont_t,
                           self.q_t, a_t)