コード例 #1
0
    def testVariableBatchSize(self):
        q_tm1 = tf.placeholder(tf.float32, shape=[None, 3])
        q_t = tf.placeholder(tf.float32, shape=[None, 3])
        a_tm1 = tf.placeholder(tf.int32, shape=[None])
        pcont_t = tf.placeholder(tf.float32, shape=[None])
        probs_a_t = tf.placeholder(tf.float32, shape=[None, 3])
        r_t = tf.placeholder(tf.float32, shape=[None])
        sarse = rl.sarse(q_tm1, a_tm1, r_t, pcont_t, q_t, probs_a_t)

        # Check static shapes.
        self.assertEqual(sarse.loss.get_shape().as_list(), [None])
        self.assertEqual(sarse.extra.td_error.get_shape().as_list(), [None])
        self.assertEqual(sarse.extra.target.get_shape().as_list(), [None])

        # Check runtime shapes.
        batch_size = 11
        feed_dict = {
            q_tm1: np.random.random([batch_size, 3]),
            q_t: np.random.random([batch_size, 3]),
            a_tm1: np.random.randint(0, 3, [batch_size]),
            pcont_t: np.random.random([batch_size]),
            r_t: np.random.random(batch_size),
            probs_a_t: np.random.uniform(size=[batch_size, 3])
        }
        with self.test_session() as sess:
            loss, td_error, target = sess.run(
                [sarse.loss, sarse.extra.td_error, sarse.extra.target],
                feed_dict=feed_dict)

        self.assertEqual(loss.shape, (batch_size, ))
        self.assertEqual(td_error.shape, (batch_size, ))
        self.assertEqual(target.shape, (batch_size, ))
コード例 #2
0
  def testVariableBatchSize(self):
    q_tm1 = tf.placeholder(tf.float32, shape=[None, 3])
    q_t = tf.placeholder(tf.float32, shape=[None, 3])
    a_tm1 = tf.placeholder(tf.int32, shape=[None])
    pcont_t = tf.placeholder(tf.float32, shape=[None])
    probs_a_t = tf.placeholder(tf.float32, shape=[None, 3])
    r_t = tf.placeholder(tf.float32, shape=[None])
    sarse = rl.sarse(q_tm1, a_tm1, r_t, pcont_t, q_t, probs_a_t)

    # Check static shapes.
    self.assertEqual(sarse.loss.get_shape().as_list(), [None])
    self.assertEqual(sarse.extra.td_error.get_shape().as_list(), [None])
    self.assertEqual(sarse.extra.target.get_shape().as_list(), [None])

    # Check runtime shapes.
    batch_size = 11
    feed_dict = {
        q_tm1: np.random.random([batch_size, 3]),
        q_t: np.random.random([batch_size, 3]),
        a_tm1: np.random.randint(0, 3, [batch_size]),
        pcont_t: np.random.random([batch_size]),
        r_t: np.random.random(batch_size),
        probs_a_t: np.random.uniform(size=[batch_size, 3])
    }
    with self.test_session() as sess:
      loss, td_error, target = sess.run(
          [sarse.loss, sarse.extra.td_error, sarse.extra.target],
          feed_dict=feed_dict)

    self.assertEqual(loss.shape, (batch_size,))
    self.assertEqual(td_error.shape, (batch_size,))
    self.assertEqual(target.shape, (batch_size,))
コード例 #3
0
 def testIncorrectProbsTensor(self):
   probs_a_t = tf.constant([[0.2, 0.5, 0.3], [0.3, 0.5, 0.3]],
                           dtype=tf.float32)
   with self.test_session() as sess:
     with self.assertRaisesRegexp(
         tf.errors.InvalidArgumentError, "probs_a_t tensor does not sum to 1"):
       self.sarse = rl.sarse(self.q_tm1, self.a_tm1, self.r_t, self.pcont_t,
                             self.q_t, probs_a_t, debug=True)
       sess.run(self.sarse.extra.target)
コード例 #4
0
 def setUp(self):
     super(SarseTest, self).setUp()
     self.q_tm1 = tf.constant([[1, 1, 0.5], [1, 1, 3]], dtype=tf.float32)
     self.q_t = tf.constant([[1.5, 1, 2], [3, 2, 1]], dtype=tf.float32)
     self.a_tm1 = tf.constant([0, 1], dtype=tf.int32)
     self.probs_a_t = tf.constant([[0.2, 0.5, 0.3], [0.3, 0.4, 0.3]],
                                  dtype=tf.float32)
     self.pcont_t = tf.constant([1, 1], dtype=tf.float32)
     self.r_t = tf.constant([4, 1], dtype=tf.float32)
     self.sarse = rl.sarse(self.q_tm1, self.a_tm1, self.r_t, self.pcont_t,
                           self.q_t, self.probs_a_t)
コード例 #5
0
 def setUp(self):
   super(SarseTest, self).setUp()
   self.q_tm1 = tf.constant([[1, 1, 0.5], [1, 1, 3]], dtype=tf.float32)
   self.q_t = tf.constant([[1.5, 1, 2], [3, 2, 1]], dtype=tf.float32)
   self.a_tm1 = tf.constant([0, 1], dtype=tf.int32)
   self.probs_a_t = tf.constant([[0.2, 0.5, 0.3], [0.3, 0.4, 0.3]],
                                dtype=tf.float32)
   self.pcont_t = tf.constant([1, 1], dtype=tf.float32)
   self.r_t = tf.constant([4, 1], dtype=tf.float32)
   self.sarse = rl.sarse(self.q_tm1, self.a_tm1, self.r_t, self.pcont_t,
                         self.q_t, self.probs_a_t)
コード例 #6
0
 def testIncorrectProbsTensor(self):
     probs_a_t = tf.constant([[0.2, 0.5, 0.3], [0.3, 0.5, 0.3]],
                             dtype=tf.float32)
     with self.test_session() as sess:
         with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
                                      "probs_a_t tensor does not sum to 1"):
             self.sarse = rl.sarse(self.q_tm1,
                                   self.a_tm1,
                                   self.r_t,
                                   self.pcont_t,
                                   self.q_t,
                                   probs_a_t,
                                   debug=True)
             sess.run(self.sarse.extra.target)
コード例 #7
0
 def testCompatibilityCheck(self):
     probs_a_t = tf.placeholder(tf.float32, [None, 2])
     with self.assertRaisesRegexp(
             ValueError, "Sarse: Error in rank and/or compatibility check"):
         self.sarse = rl.sarse(self.q_tm1, self.a_tm1, self.r_t,
                               self.pcont_t, self.q_t, probs_a_t)
コード例 #8
0
 def testCompatibilityCheck(self):
   probs_a_t = tf.placeholder(tf.float32, [None, 2])
   with self.assertRaisesRegexp(
       ValueError, "Sarse: Error in rank and/or compatibility check"):
     self.sarse = rl.sarse(self.q_tm1, self.a_tm1, self.r_t, self.pcont_t,
                           self.q_t, probs_a_t)