コード例 #1
0
ファイル: value_ops_test.py プロジェクト: zhuanglineu/trfl
 def setUp(self):
     super(QVMAXTest, self).setUp()
     self.v_tm1 = tf.constant([1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=tf.float32)
     self.pcont_t = tf.constant([0, 0.5, 1, 0, 0.5, 1, 0, 0.5, 1],
                                dtype=tf.float32)
     self.r_t = tf.constant([-1, -1, -1, -1, -1, -1, -1, -1, -1],
                            dtype=tf.float32)
     self.q_t = tf.constant([[0, -1], [-2, 0], [0, -3], [1, 0], [1, 1],
                             [0, 1], [1, 2], [2, -2], [2, 2]],
                            dtype=tf.float32)
     self.loss_op, self.extra_ops = value_ops.qv_max(
         self.v_tm1, self.r_t, self.pcont_t, self.q_t)
コード例 #2
0
ファイル: value_ops_test.py プロジェクト: zhuanglineu/trfl
 def testCompatibilityCheck(self):
     pcont_t = tf.placeholder(tf.float32, [8])
     with self.assertRaisesRegexp(
             ValueError, 'QVMAX: Error in rank and/or compatibility check'):
         value_ops.qv_max(self.v_tm1, self.r_t, pcont_t, self.q_t)