def GetDotDimensionsFromLists(dimension_numbers): (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers dot_dims_proto = xla_data_pb2.DotDimensionNumbers() dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) return dot_dims_proto
def dot_fn(lhs, rhs): dnums = xla_data_pb2.DotDimensionNumbers() dnums.lhs_contracting_dimensions.append(2) dnums.rhs_contracting_dimensions.append(1) dnums.lhs_batch_dimensions.append(0) dnums.rhs_batch_dimensions.append(0) return xla.dot_general( lhs, rhs, dimension_numbers=dnums, preferred_element_type=np.int32)
def testDotUnknownAndKnownContractingDimension(self): a = array_ops.placeholder(np.float32, shape=(3, 4)) b = array_ops.placeholder(np.float32, shape=(None, 2)) dim_nums = xla_data_pb2.DotDimensionNumbers() dim_nums.lhs_contracting_dimensions.append(1) dim_nums.rhs_contracting_dimensions.append(0) c = xla.dot_general(a, b, dim_nums) self.assertEqual(c.shape.as_list(), [3, 2])
def testDotDifferentContractingDimensionsSizes(self): a = array_ops.placeholder(np.float32, shape=(2, 2, 2, 2)) b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4)) dim_nums = xla_data_pb2.DotDimensionNumbers() dim_nums.lhs_contracting_dimensions.append(2) dim_nums.rhs_contracting_dimensions.append(3) with self.assertRaisesRegex( ValueError, 'Dimensions must be equal, but are 2 and 4'): xla.dot_general(a, b, dim_nums)
def testDotShapeInference(self): a = array_ops.placeholder(np.float32, shape=(1, 2, 3, 4)) b = array_ops.placeholder(np.float32, shape=(4, 3, 2, 1)) dim_nums = xla_data_pb2.DotDimensionNumbers() dim_nums.lhs_contracting_dimensions.append(1) dim_nums.rhs_contracting_dimensions.append(2) dim_nums.lhs_batch_dimensions.append(3) dim_nums.rhs_batch_dimensions.append(0) c = xla.dot_general(a, b, dim_nums) self.assertEqual(c.shape, tensor_shape.TensorShape([4, 1, 3, 3, 1]))
def testDotDifferentNumberOfBatchDimensions(self): a = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4)) b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4)) dim_nums = xla_data_pb2.DotDimensionNumbers() dim_nums.lhs_batch_dimensions.append(2) dim_nums.rhs_batch_dimensions.append(2) dim_nums.rhs_batch_dimensions.append(3) with self.assertRaisesRegex(ValueError, 'Must specify the same number of batch ' 'dimensions for lhs and rhs. Got: 1 and 2'): xla.dot_general(a, b, dim_nums)
def testDotDifferentBatchDimensionsSizes(self): a = array_ops.placeholder(np.float32, shape=(2, 2, 2, 2)) b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 2)) dim_nums = xla_data_pb2.DotDimensionNumbers() dim_nums.lhs_contracting_dimensions.append(2) dim_nums.rhs_contracting_dimensions.append(3) dim_nums.lhs_batch_dimensions.append(0) dim_nums.rhs_batch_dimensions.append(0) with self.assertRaisesRegex(ValueError, 'Batch dimension sizes do not match. ' 'Got: 2 and 4'): xla.dot_general(a, b, dim_nums)
def dot_fn(lhs, rhs): dnums = xla_data_pb2.DotDimensionNumbers() dnums.lhs_contracting_dimensions.append(2) dnums.rhs_contracting_dimensions.append(1) dnums.lhs_batch_dimensions.append(0) dnums.rhs_batch_dimensions.append(0) precision_config = None if precision: precision_config = xla_data_pb2.PrecisionConfig() precision_config.operand_precision.extend([precision, precision]) return xla.dot_general( lhs, rhs, dimension_numbers=dnums, precision_config=precision_config)