def test_create_xla_tff_computation_with_reordered_tensor_indexes(self): builder = xla_client.XlaBuilder('comp') tensor_shape_1 = xla_client.Shape.array_shape( xla_client.dtype_to_etype(np.int32), (10, 1)) param_1 = xla_client.ops.Parameter(builder, 0, tensor_shape_1) tensor_shape_2 = xla_client.Shape.array_shape( xla_client.dtype_to_etype(np.int32), (1, 20)) param_2 = xla_client.ops.Parameter(builder, 1, tensor_shape_2) xla_client.ops.Dot(param_1, param_2) xla_comp = builder.build() comp_pb_1 = xla_serialization.create_xla_tff_computation( xla_comp, [0, 1], computation_types.FunctionType( ((np.int32, (10, 1)), (np.int32, (1, 20))), (np.int32, ( 10, 20, )))) self.assertIsInstance(comp_pb_1, pb.Computation) self.assertEqual(comp_pb_1.WhichOneof('computation'), 'xla') type_spec_1 = type_serialization.deserialize_type(comp_pb_1.type) self.assertEqual(str(type_spec_1), '(<int32[10,1],int32[1,20]> -> int32[10,20])') comp_pb_2 = xla_serialization.create_xla_tff_computation( xla_comp, [1, 0], computation_types.FunctionType( ((np.int32, (1, 20)), (np.int32, (10, 1))), (np.int32, ( 10, 20, )))) self.assertIsInstance(comp_pb_2, pb.Computation) self.assertEqual(comp_pb_2.WhichOneof('computation'), 'xla') type_spec_2 = type_serialization.deserialize_type(comp_pb_2.type) self.assertEqual(str(type_spec_2), '(<int32[1,20],int32[10,1]> -> int32[10,20])')
def test_xla_shapes_and_binding_to_tff_type_raises_unused_element(self): tensor_shape_1 = xla_client.Shape.array_shape( xla_client.dtype_to_etype(np.int32), (10,)) tensor_shape_2 = xla_client.Shape.array_shape( xla_client.dtype_to_etype(np.float32), (20,)) xla_shapes = [tensor_shape_1, tensor_shape_2] binding = pb.Xla.Binding(tensor=pb.Xla.TensorBinding(index=1)) with self.assertRaises(ValueError): xla_serialization.xla_shapes_and_binding_to_tff_type(xla_shapes, binding)
def test_flatten_xla_tuple_shape(self): tensor_shape_1 = xla_client.Shape.array_shape( xla_client.dtype_to_etype(np.int32), (10,)) tensor_shape_2 = xla_client.Shape.array_shape( xla_client.dtype_to_etype(np.float32), (20,)) tuple_shape = xla_client.Shape.tuple_shape([tensor_shape_1, tensor_shape_2]) flattened = xla_serialization.flatten_xla_shape(tuple_shape) self.assertIsInstance(flattened, list) self.assertListEqual(flattened, [tensor_shape_1, tensor_shape_2])
def test_xla_shapes_and_binding_to_tff_type_with_tuple(self): tensor_shape_1 = xla_client.Shape.array_shape( xla_client.dtype_to_etype(np.int32), (10, )) tensor_shape_2 = xla_client.Shape.array_shape( xla_client.dtype_to_etype(np.float32), (20, )) xla_shapes = [tensor_shape_1, tensor_shape_2] binding = pb.Xla.Binding(struct=pb.Xla.StructBinding(element=[ pb.Xla.Binding(tensor=pb.Xla.TensorBinding(index=1)), pb.Xla.Binding(tensor=pb.Xla.TensorBinding(index=0)) ])) tff_type = xla_serialization.xla_shapes_and_binding_to_tff_type( xla_shapes, binding) self.assertEqual(str(tff_type), '<float32[20],int32[10]>')
def test_xla_shapes_and_binding_to_tff_type_raises_unused_tensor(self): tensor_shape = xla_client.Shape.array_shape( xla_client.dtype_to_etype(np.int32), (10,)) xla_shapes = [tensor_shape] binding = None with self.assertRaises(ValueError): xla_serialization.xla_shapes_and_binding_to_tff_type(xla_shapes, binding)
def test_xla_shapes_and_binding_to_tff_type_with_tensor(self): tensor_shape = xla_client.Shape.array_shape( xla_client.dtype_to_etype(np.int32), (10, )) xla_shapes = [tensor_shape] binding = pb.Xla.Binding(tensor=pb.Xla.TensorBinding(index=0)) tff_type = xla_serialization.xla_shapes_and_binding_to_tff_type( xla_shapes, binding) self.assertEqual(str(tff_type), 'int32[10]')
def _make_test_xla_comp_int32x10_to_int32x10(): builder = xla_client.XlaBuilder('comp') tensor_shape = xla_client.Shape.array_shape( xla_client.dtype_to_etype(np.int32), (10,)) param = xla_client.ops.Parameter(builder, 0, tensor_shape) constant = xla_client.ops.Constant(builder, np.zeros((10,), dtype=np.int32)) xla_client.ops.Add(param, constant) return builder.build()
def test_xla_computation_and_bindings_to_tff_type_raises_unused_element(self): builder = xla_client.XlaBuilder('comp') tensor_shape = xla_client.Shape.array_shape( xla_client.dtype_to_etype(np.int32), (10,)) tuple_shape = xla_client.Shape.tuple_shape([tensor_shape, tensor_shape]) param = xla_client.ops.Parameter(builder, 0, tuple_shape) constant = xla_client.ops.Constant(builder, np.zeros((10,), dtype=np.int32)) xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0), constant) xla_computation = builder.build() parameter_binding = pb.Xla.Binding( struct=pb.Xla.StructBinding( element=[pb.Xla.Binding(tensor=pb.Xla.TensorBinding(index=0))])) result_binding = pb.Xla.Binding(tensor=pb.Xla.TensorBinding(index=0)) with self.assertRaises(ValueError): xla_serialization.xla_computation_and_bindings_to_tff_type( xla_computation, parameter_binding, result_binding)
def test_xla_computation_and_bindings_to_tff_type_int32_tuple_to_int32(self): builder = xla_client.XlaBuilder('comp') tensor_shape = xla_client.Shape.array_shape( xla_client.dtype_to_etype(np.int32), (10,)) tuple_shape = xla_client.Shape.tuple_shape([tensor_shape]) param = xla_client.ops.Parameter(builder, 0, tuple_shape) constant = xla_client.ops.Constant(builder, np.zeros((10,), dtype=np.int32)) xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0), constant) xla_computation = builder.build() parameter_binding = pb.Xla.Binding( struct=pb.Xla.StructBinding( element=[pb.Xla.Binding(tensor=pb.Xla.TensorBinding(index=0))])) result_binding = pb.Xla.Binding(tensor=pb.Xla.TensorBinding(index=0)) tff_type = xla_serialization.xla_computation_and_bindings_to_tff_type( xla_computation, parameter_binding, result_binding) self.assertEqual(str(tff_type), '(<int32[10]> -> int32[10])')
def _xla_tensor_shape_from_tff_tensor_type(tensor_type): py_typecheck.check_type(tensor_type, computation_types.TensorType) return xla_client.Shape.array_shape( xla_client.dtype_to_etype(tensor_type.dtype.as_numpy_dtype), tensor_type.shape.dims)
def test_flatten_xla_tensor_shape(self): tensor_shape = xla_client.Shape.array_shape( xla_client.dtype_to_etype(np.int32), (10, )) flattened = xla_serialization.flatten_xla_shape(tensor_shape) self.assertIsInstance(flattened, list) self.assertListEqual(flattened, [tensor_shape])