def setUp(self): super(QVMAXTest, self).setUp() self.v_tm1 = tf.constant([1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=tf.float32) self.pcont_t = tf.constant([0, 0.5, 1, 0, 0.5, 1, 0, 0.5, 1], dtype=tf.float32) self.r_t = tf.constant([-1, -1, -1, -1, -1, -1, -1, -1, -1], dtype=tf.float32) self.q_t = tf.constant([[0, -1], [-2, 0], [0, -3], [1, 0], [1, 1], [0, 1], [1, 2], [2, -2], [2, 2]], dtype=tf.float32) self.loss_op, self.extra_ops = value_ops.qv_max( self.v_tm1, self.r_t, self.pcont_t, self.q_t)
def testCompatibilityCheck(self): pcont_t = tf.placeholder(tf.float32, [8]) with self.assertRaisesRegexp( ValueError, 'QVMAX: Error in rank and/or compatibility check'): value_ops.qv_max(self.v_tm1, self.r_t, pcont_t, self.q_t)