Esempio n. 1
0
def GetDotDimensionsFromLists(dimension_numbers):
  (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
  dot_dims_proto = xla_data_pb2.DotDimensionNumbers()
  dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract)
  dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract)
  dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch)
  dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch)
  return dot_dims_proto
Esempio n. 2
0
 def dot_fn(lhs, rhs):
   dnums = xla_data_pb2.DotDimensionNumbers()
   dnums.lhs_contracting_dimensions.append(2)
   dnums.rhs_contracting_dimensions.append(1)
   dnums.lhs_batch_dimensions.append(0)
   dnums.rhs_batch_dimensions.append(0)
   return xla.dot_general(
       lhs, rhs, dimension_numbers=dnums, preferred_element_type=np.int32)
Esempio n. 3
0
    def testDotUnknownAndKnownContractingDimension(self):
        a = array_ops.placeholder(np.float32, shape=(3, 4))
        b = array_ops.placeholder(np.float32, shape=(None, 2))

        dim_nums = xla_data_pb2.DotDimensionNumbers()
        dim_nums.lhs_contracting_dimensions.append(1)
        dim_nums.rhs_contracting_dimensions.append(0)

        c = xla.dot_general(a, b, dim_nums)
        self.assertEqual(c.shape.as_list(), [3, 2])
Esempio n. 4
0
    def testDotDifferentContractingDimensionsSizes(self):
        a = array_ops.placeholder(np.float32, shape=(2, 2, 2, 2))
        b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))

        dim_nums = xla_data_pb2.DotDimensionNumbers()
        dim_nums.lhs_contracting_dimensions.append(2)
        dim_nums.rhs_contracting_dimensions.append(3)

        with self.assertRaisesRegex(
                ValueError, 'Dimensions must be equal, but are 2 and 4'):
            xla.dot_general(a, b, dim_nums)
Esempio n. 5
0
  def testDotShapeInference(self):
    a = array_ops.placeholder(np.float32, shape=(1, 2, 3, 4))
    b = array_ops.placeholder(np.float32, shape=(4, 3, 2, 1))

    dim_nums = xla_data_pb2.DotDimensionNumbers()
    dim_nums.lhs_contracting_dimensions.append(1)
    dim_nums.rhs_contracting_dimensions.append(2)
    dim_nums.lhs_batch_dimensions.append(3)
    dim_nums.rhs_batch_dimensions.append(0)

    c = xla.dot_general(a, b, dim_nums)
    self.assertEqual(c.shape, tensor_shape.TensorShape([4, 1, 3, 3, 1]))
Esempio n. 6
0
  def testDotDifferentNumberOfBatchDimensions(self):
    a = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
    b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))

    dim_nums = xla_data_pb2.DotDimensionNumbers()
    dim_nums.lhs_batch_dimensions.append(2)
    dim_nums.rhs_batch_dimensions.append(2)
    dim_nums.rhs_batch_dimensions.append(3)

    with self.assertRaisesRegex(ValueError,
                                'Must specify the same number of batch '
                                'dimensions for lhs and rhs. Got: 1 and 2'):
      xla.dot_general(a, b, dim_nums)
Esempio n. 7
0
  def testDotDifferentBatchDimensionsSizes(self):
    a = array_ops.placeholder(np.float32, shape=(2, 2, 2, 2))
    b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 2))

    dim_nums = xla_data_pb2.DotDimensionNumbers()
    dim_nums.lhs_contracting_dimensions.append(2)
    dim_nums.rhs_contracting_dimensions.append(3)
    dim_nums.lhs_batch_dimensions.append(0)
    dim_nums.rhs_batch_dimensions.append(0)

    with self.assertRaisesRegex(ValueError,
                                'Batch dimension sizes do not match. '
                                'Got: 2 and 4'):
      xla.dot_general(a, b, dim_nums)
Esempio n. 8
0
 def dot_fn(lhs, rhs):
   dnums = xla_data_pb2.DotDimensionNumbers()
   dnums.lhs_contracting_dimensions.append(2)
   dnums.rhs_contracting_dimensions.append(1)
   dnums.lhs_batch_dimensions.append(0)
   dnums.rhs_batch_dimensions.append(0)
   precision_config = None
   if precision:
     precision_config = xla_data_pb2.PrecisionConfig()
     precision_config.operand_precision.extend([precision, precision])
   return xla.dot_general(
       lhs,
       rhs,
       dimension_numbers=dnums,
       precision_config=precision_config)