def testBatchMatMulV3OutputType(self):
   # TODO(shivaniagrawal): uint8 is not supported for mixed matmul type in XLA.
   for (a_dtype, b_dtype) in [(np.int8, np.int8), (np.uint8, np.uint8)]:
     a = np.array([[1, 2], [3, 4]], dtype=a_dtype)
     b = np.array([[1, 2], [3, 4]], dtype=b_dtype)
     c = math_ops.batch_mat_mul_v3(a, b, adj_y=True, Tout=np.int32)
     self.assertAllEqual((2, 2), c.shape)
     self.assertAllEqual([[5, 11], [11, 25]], c)
 def testBatchMatMulV3MixedPrec(self):
   # TODO(shivaniagrawal): uint8 is not supported for mixed matmul type in XLA.
   np_bf16 = dtypes.bfloat16.as_numpy_dtype
   a = np.array([[1, 2], [3, 4]], dtype=np.int8)
   b = np.array([[1, 2], [3, 4]], dtype=np_bf16)
   c = math_ops.batch_mat_mul_v3(a, b, adj_y=True, Tout=np_bf16)
   self.assertAllEqual((2, 2), c.shape)
   self.assertAllEqual([[5, 11], [11, 25]], c)
Esempio n. 3
0
 def testBatchMatMulV3OutputType(self):
     a = np.array([[1, 2], [3, 4]], dtype=np.int8)
     b = np.array([[1, 2], [3, 4]], dtype=np.int8)
     c = math_ops.batch_mat_mul_v3(a, b, adj_y=True, Tout=np.int32)
     self.assertAllEqual((2, 2), c.shape)
     self.assertAllEqual([[5, 11], [11, 25]], c)