Ejemplo n.º 1
0
 def test_create_xla_tff_computation_with_reordered_tensor_indexes(self):
     builder = xla_client.XlaBuilder('comp')
     tensor_shape_1 = xla_client.Shape.array_shape(
         xla_client.dtype_to_etype(np.int32), (10, 1))
     param_1 = xla_client.ops.Parameter(builder, 0, tensor_shape_1)
     tensor_shape_2 = xla_client.Shape.array_shape(
         xla_client.dtype_to_etype(np.int32), (1, 20))
     param_2 = xla_client.ops.Parameter(builder, 1, tensor_shape_2)
     xla_client.ops.Dot(param_1, param_2)
     xla_comp = builder.build()
     comp_pb_1 = xla_serialization.create_xla_tff_computation(
         xla_comp, [0, 1],
         computation_types.FunctionType(
             ((np.int32, (10, 1)), (np.int32, (1, 20))), (np.int32, (
                 10,
                 20,
             ))))
     self.assertIsInstance(comp_pb_1, pb.Computation)
     self.assertEqual(comp_pb_1.WhichOneof('computation'), 'xla')
     type_spec_1 = type_serialization.deserialize_type(comp_pb_1.type)
     self.assertEqual(str(type_spec_1),
                      '(<int32[10,1],int32[1,20]> -> int32[10,20])')
     comp_pb_2 = xla_serialization.create_xla_tff_computation(
         xla_comp, [1, 0],
         computation_types.FunctionType(
             ((np.int32, (1, 20)), (np.int32, (10, 1))), (np.int32, (
                 10,
                 20,
             ))))
     self.assertIsInstance(comp_pb_2, pb.Computation)
     self.assertEqual(comp_pb_2.WhichOneof('computation'), 'xla')
     type_spec_2 = type_serialization.deserialize_type(comp_pb_2.type)
     self.assertEqual(str(type_spec_2),
                      '(<int32[1,20],int32[10,1]> -> int32[10,20])')
Ejemplo n.º 2
0
 def test_xla_shapes_and_binding_to_tff_type_raises_unused_element(self):
   tensor_shape_1 = xla_client.Shape.array_shape(
       xla_client.dtype_to_etype(np.int32), (10,))
   tensor_shape_2 = xla_client.Shape.array_shape(
       xla_client.dtype_to_etype(np.float32), (20,))
   xla_shapes = [tensor_shape_1, tensor_shape_2]
   binding = pb.Xla.Binding(tensor=pb.Xla.TensorBinding(index=1))
   with self.assertRaises(ValueError):
     xla_serialization.xla_shapes_and_binding_to_tff_type(xla_shapes, binding)
Ejemplo n.º 3
0
 def test_flatten_xla_tuple_shape(self):
   tensor_shape_1 = xla_client.Shape.array_shape(
       xla_client.dtype_to_etype(np.int32), (10,))
   tensor_shape_2 = xla_client.Shape.array_shape(
       xla_client.dtype_to_etype(np.float32), (20,))
   tuple_shape = xla_client.Shape.tuple_shape([tensor_shape_1, tensor_shape_2])
   flattened = xla_serialization.flatten_xla_shape(tuple_shape)
   self.assertIsInstance(flattened, list)
   self.assertListEqual(flattened, [tensor_shape_1, tensor_shape_2])
Ejemplo n.º 4
0
 def test_xla_shapes_and_binding_to_tff_type_with_tuple(self):
     tensor_shape_1 = xla_client.Shape.array_shape(
         xla_client.dtype_to_etype(np.int32), (10, ))
     tensor_shape_2 = xla_client.Shape.array_shape(
         xla_client.dtype_to_etype(np.float32), (20, ))
     xla_shapes = [tensor_shape_1, tensor_shape_2]
     binding = pb.Xla.Binding(struct=pb.Xla.StructBinding(element=[
         pb.Xla.Binding(tensor=pb.Xla.TensorBinding(index=1)),
         pb.Xla.Binding(tensor=pb.Xla.TensorBinding(index=0))
     ]))
     tff_type = xla_serialization.xla_shapes_and_binding_to_tff_type(
         xla_shapes, binding)
     self.assertEqual(str(tff_type), '<float32[20],int32[10]>')
Ejemplo n.º 5
0
 def test_xla_shapes_and_binding_to_tff_type_raises_unused_tensor(self):
   tensor_shape = xla_client.Shape.array_shape(
       xla_client.dtype_to_etype(np.int32), (10,))
   xla_shapes = [tensor_shape]
   binding = None
   with self.assertRaises(ValueError):
     xla_serialization.xla_shapes_and_binding_to_tff_type(xla_shapes, binding)
Ejemplo n.º 6
0
 def test_xla_shapes_and_binding_to_tff_type_with_tensor(self):
     tensor_shape = xla_client.Shape.array_shape(
         xla_client.dtype_to_etype(np.int32), (10, ))
     xla_shapes = [tensor_shape]
     binding = pb.Xla.Binding(tensor=pb.Xla.TensorBinding(index=0))
     tff_type = xla_serialization.xla_shapes_and_binding_to_tff_type(
         xla_shapes, binding)
     self.assertEqual(str(tff_type), 'int32[10]')
Ejemplo n.º 7
0
def _make_test_xla_comp_int32x10_to_int32x10():
  builder = xla_client.XlaBuilder('comp')
  tensor_shape = xla_client.Shape.array_shape(
      xla_client.dtype_to_etype(np.int32), (10,))
  param = xla_client.ops.Parameter(builder, 0, tensor_shape)
  constant = xla_client.ops.Constant(builder, np.zeros((10,), dtype=np.int32))
  xla_client.ops.Add(param, constant)
  return builder.build()
Ejemplo n.º 8
0
 def test_xla_computation_and_bindings_to_tff_type_raises_unused_element(self):
   builder = xla_client.XlaBuilder('comp')
   tensor_shape = xla_client.Shape.array_shape(
       xla_client.dtype_to_etype(np.int32), (10,))
   tuple_shape = xla_client.Shape.tuple_shape([tensor_shape, tensor_shape])
   param = xla_client.ops.Parameter(builder, 0, tuple_shape)
   constant = xla_client.ops.Constant(builder, np.zeros((10,), dtype=np.int32))
   xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0), constant)
   xla_computation = builder.build()
   parameter_binding = pb.Xla.Binding(
       struct=pb.Xla.StructBinding(
           element=[pb.Xla.Binding(tensor=pb.Xla.TensorBinding(index=0))]))
   result_binding = pb.Xla.Binding(tensor=pb.Xla.TensorBinding(index=0))
   with self.assertRaises(ValueError):
     xla_serialization.xla_computation_and_bindings_to_tff_type(
         xla_computation, parameter_binding, result_binding)
Ejemplo n.º 9
0
 def test_xla_computation_and_bindings_to_tff_type_int32_tuple_to_int32(self):
   builder = xla_client.XlaBuilder('comp')
   tensor_shape = xla_client.Shape.array_shape(
       xla_client.dtype_to_etype(np.int32), (10,))
   tuple_shape = xla_client.Shape.tuple_shape([tensor_shape])
   param = xla_client.ops.Parameter(builder, 0, tuple_shape)
   constant = xla_client.ops.Constant(builder, np.zeros((10,), dtype=np.int32))
   xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0), constant)
   xla_computation = builder.build()
   parameter_binding = pb.Xla.Binding(
       struct=pb.Xla.StructBinding(
           element=[pb.Xla.Binding(tensor=pb.Xla.TensorBinding(index=0))]))
   result_binding = pb.Xla.Binding(tensor=pb.Xla.TensorBinding(index=0))
   tff_type = xla_serialization.xla_computation_and_bindings_to_tff_type(
       xla_computation, parameter_binding, result_binding)
   self.assertEqual(str(tff_type), '(<int32[10]> -> int32[10])')
Ejemplo n.º 10
0
def _xla_tensor_shape_from_tff_tensor_type(tensor_type):
    py_typecheck.check_type(tensor_type, computation_types.TensorType)
    return xla_client.Shape.array_shape(
        xla_client.dtype_to_etype(tensor_type.dtype.as_numpy_dtype),
        tensor_type.shape.dims)
Ejemplo n.º 11
0
 def test_flatten_xla_tensor_shape(self):
     tensor_shape = xla_client.Shape.array_shape(
         xla_client.dtype_to_etype(np.int32), (10, ))
     flattened = xla_serialization.flatten_xla_shape(tensor_shape)
     self.assertIsInstance(flattened, list)
     self.assertListEqual(flattened, [tensor_shape])