def test_create_xla_tff_computation_with_reordered_tensor_indexes(self): builder = xla_client.XlaBuilder('comp') tensor_shape_1 = xla_client.Shape.array_shape( xla_client.dtype_to_etype(np.int32), (10, 1)) param_1 = xla_client.ops.Parameter(builder, 0, tensor_shape_1) tensor_shape_2 = xla_client.Shape.array_shape( xla_client.dtype_to_etype(np.int32), (1, 20)) param_2 = xla_client.ops.Parameter(builder, 1, tensor_shape_2) xla_client.ops.Dot(param_1, param_2) xla_comp = builder.build() comp_pb_1 = xla_serialization.create_xla_tff_computation( xla_comp, [0, 1], computation_types.FunctionType( ((np.int32, (10, 1)), (np.int32, (1, 20))), (np.int32, ( 10, 20, )))) self.assertIsInstance(comp_pb_1, pb.Computation) self.assertEqual(comp_pb_1.WhichOneof('computation'), 'xla') type_spec_1 = type_serialization.deserialize_type(comp_pb_1.type) self.assertEqual(str(type_spec_1), '(<int32[10,1],int32[1,20]> -> int32[10,20])') comp_pb_2 = xla_serialization.create_xla_tff_computation( xla_comp, [1, 0], computation_types.FunctionType( ((np.int32, (1, 20)), (np.int32, (10, 1))), (np.int32, ( 10, 20, )))) self.assertIsInstance(comp_pb_2, pb.Computation) self.assertEqual(comp_pb_2.WhichOneof('computation'), 'xla') type_spec_2 = type_serialization.deserialize_type(comp_pb_2.type) self.assertEqual(str(type_spec_2), '(<int32[1,20],int32[10,1]> -> int32[10,20])')
def create_scalar_multiply_operator( self, operand_type: computation_types.Type, scalar_type: computation_types.TensorType ) -> local_computation_factory_base.ComputationProtoAndType: py_typecheck.check_type(operand_type, computation_types.Type) py_typecheck.check_type(scalar_type, computation_types.TensorType) if not type_analysis.is_structure_of_tensors(operand_type): raise ValueError( 'Not a tensor or a structure of tensors: {}'.format( str(operand_type))) operand_shapes = _xla_tensor_shape_list_from_from_tff_tensor_or_struct_type( operand_type) scalar_shape = _xla_tensor_shape_from_tff_tensor_type(scalar_type) num_operand_tensors = len(operand_shapes) builder = xla_client.XlaBuilder('comp') param = xla_client.ops.Parameter( builder, 0, xla_client.Shape.tuple_shape(operand_shapes + [scalar_shape])) scalar_ref = xla_client.ops.GetTupleElement(param, num_operand_tensors) result_tensors = [] for idx in range(num_operand_tensors): result_tensors.append( xla_client.ops.Mul(xla_client.ops.GetTupleElement(param, idx), scalar_ref)) xla_client.ops.Tuple(builder, result_tensors) xla_computation = builder.build() comp_type = computation_types.FunctionType( computation_types.StructType([(None, operand_type), (None, scalar_type)]), operand_type) comp_pb = xla_serialization.create_xla_tff_computation( xla_computation, list(range(num_operand_tensors + 1)), comp_type) return (comp_pb, comp_type)
def test_add_numbers(self): builder = xla_client.XlaBuilder('comp') param = xla_client.ops.Parameter( builder, 0, xla_client.shape_from_pyval( tuple([np.array(0, dtype=np.int32)] * 2))) xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0), xla_client.ops.GetTupleElement(param, 1)) xla_comp = builder.build() comp_type = computation_types.FunctionType((np.int32, np.int32), np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [0, 1], comp_type) ex = executor.XlaExecutor() async def _compute_fn(): comp_val = await ex.create_value(comp_pb, comp_type) x_val = await ex.create_value(20, np.int32) y_val = await ex.create_value(30, np.int32) arg_val = await ex.create_struct([x_val, y_val]) call_val = await ex.create_call(comp_val, arg_val) return await call_val.compute() result = asyncio.get_event_loop().run_until_complete(_compute_fn()) self.assertEqual(result, 50)
def test_create_xla_tff_computation_int32x10_to_int32x10(self): xla_comp = _make_test_xla_comp_int32x10_to_int32x10() comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [0], computation_types.FunctionType((np.int32, (10,)), (np.int32, (10,)))) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '(int32[10] -> int32[10])')
def test_create_xla_tff_computation_noarg(self): xla_comp = _make_test_xla_comp_noarg_to_int32() comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [], computation_types.FunctionType(None, np.int32)) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '( -> int32)') xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module) self.assertIn('ROOT constant.1 = s32[] constant(10)', xla_comp.as_hlo_text()) self.assertEqual(str(comp_pb.xla.parameter), '') self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' ' index: 0\n' '}\n')
def test_set_local_execution_context_and_run_simple_xla_computation(self): builder = xla_client.XlaBuilder('comp') xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) xla_client.ops.Constant(builder, np.int32(10)) xla_comp = builder.build() comp_type = computation_types.FunctionType(None, np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [], comp_type) ctx_stack = context_stack_impl.context_stack comp = computation_impl.ComputationImpl(comp_pb, ctx_stack) execution_contexts.set_local_execution_context() self.assertEqual(comp(), 10)
def test_to_representation_for_type_with_noarg_to_int32_comp(self): builder = xla_client.XlaBuilder('comp') xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) xla_client.ops.Constant(builder, np.int32(10)) xla_comp = builder.build() comp_type = computation_types.FunctionType(None, np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [], comp_type) rep = executor.to_representation_for_type(comp_pb, comp_type, self._backend) self.assertTrue(callable(rep)) result = rep() self.assertEqual(result, 10)
def test_computation_callable_return_one_number(self): builder = xla_client.XlaBuilder('comp') xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) xla_client.ops.Constant(builder, np.int32(10)) xla_comp = builder.build() comp_type = computation_types.FunctionType(None, np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [], comp_type) backend = xla_client.get_local_backend(None) comp_callable = runtime.ComputationCallable(comp_pb, comp_type, backend) self.assertIsInstance(comp_callable, runtime.ComputationCallable) self.assertEqual(str(comp_callable.type_signature), '( -> int32)') result = comp_callable() self.assertEqual(result, 10)
def test_to_representation_for_type_with_2xint32_to_int32_comp(self): builder = xla_client.XlaBuilder('comp') param = xla_client.ops.Parameter( builder, 0, xla_client.shape_from_pyval( tuple([np.array(0, dtype=np.int32)] * 2))) xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0), xla_client.ops.GetTupleElement(param, 1)) xla_comp = builder.build() comp_type = computation_types.FunctionType((np.int32, np.int32), np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [0, 1], comp_type) rep = executor.to_representation_for_type(comp_pb, comp_type, self._backend) self.assertTrue(callable(rep)) result = rep( structure.Struct([(None, np.int32(20)), (None, np.int32(30))])) self.assertEqual(result, 50)
def test_to_representation_for_type_with_noarg_to_2xint32_comp(self): builder = xla_client.XlaBuilder('comp') xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) xla_client.ops.Tuple(builder, [ xla_client.ops.Constant(builder, np.int32(10)), xla_client.ops.Constant(builder, np.int32(20)) ]) xla_comp = builder.build() comp_type = computation_types.FunctionType( None, computation_types.StructType([('a', np.int32), ('b', np.int32)])) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [0, 1], comp_type) rep = executor.to_representation_for_type(comp_pb, comp_type, self._backend) self.assertTrue(callable(rep)) result = rep() self.assertEqual(str(result), '<a=10,b=20>')
def _create_xla_binary_op_computation(type_spec, xla_binary_op_constructor): """Helper for constructing computations that implement binary operators. The constructed computation is of type `(<T,T> -> T)`, where `T` is the type of the operand (`type_spec`). Args: type_spec: The type of a single operand. xla_binary_op_constructor: A two-argument callable that constructs a binary xla op from tensor parameters (such as `xla_client.ops.Add` or similar). Returns: An instance of `local_computation_factory_base.ComputationProtoAndType`. Raises: ValueError: if the arguments are invalid. """ py_typecheck.check_type(type_spec, computation_types.Type) if not type_analysis.is_structure_of_tensors(type_spec): raise ValueError('Not a tensor or a structure of tensors: {}'.format( str(type_spec))) tensor_shapes = _xla_tensor_shape_list_from_from_tff_tensor_or_struct_type( type_spec) num_tensors = len(tensor_shapes) builder = xla_client.XlaBuilder('comp') param = xla_client.ops.Parameter( builder, 0, xla_client.Shape.tuple_shape(tensor_shapes * 2)) result_tensors = [] for idx in range(num_tensors): result_tensors.append( xla_binary_op_constructor( xla_client.ops.GetTupleElement(param, idx), xla_client.ops.GetTupleElement(param, idx + num_tensors))) xla_client.ops.Tuple(builder, result_tensors) xla_computation = builder.build() comp_type = computation_types.FunctionType( computation_types.StructType([(None, type_spec)] * 2), type_spec) comp_pb = xla_serialization.create_xla_tff_computation( xla_computation, list(range(2 * num_tensors)), comp_type) return (comp_pb, comp_type)
def create_constant_from_scalar( self, value, type_spec: computation_types.Type ) -> local_computation_factory_base.ComputationProtoAndType: py_typecheck.check_type(type_spec, computation_types.Type) if not type_analysis.is_structure_of_tensors(type_spec): raise ValueError( 'Not a tensor or a structure of tensors: {}'.format( str(type_spec))) builder = xla_client.XlaBuilder('comp') # We maintain the convention that arguments are supplied as a tuple for the # sake of consistency and uniformity (see comments in `computation.proto`). # Since there are no arguments here, we create an empty tuple. xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) def _constant_from_tensor(tensor_type): py_typecheck.check_type(tensor_type, computation_types.TensorType) numpy_value = np.full(shape=tensor_type.shape.dims, fill_value=value, dtype=tensor_type.dtype.as_numpy_dtype) return xla_client.ops.Constant(builder, numpy_value) if isinstance(type_spec, computation_types.TensorType): tensors = [_constant_from_tensor(type_spec)] else: tensors = [ _constant_from_tensor(x) for x in structure.flatten(type_spec) ] # Likewise, results are always returned as a single tuple with results. # This is always a flat tuple; the nested TFF structure is defined by the # binding. xla_client.ops.Tuple(builder, tensors) xla_computation = builder.build() comp_type = computation_types.FunctionType(None, type_spec) comp_pb = xla_serialization.create_xla_tff_computation( xla_computation, [], comp_type) return (comp_pb, comp_type)
def test_computation_callable_add_two_numbers(self): builder = xla_client.XlaBuilder('comp') param = xla_client.ops.Parameter( builder, 0, xla_client.shape_from_pyval( tuple([np.array(0, dtype=np.int32)] * 2))) xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0), xla_client.ops.GetTupleElement(param, 1)) xla_comp = builder.build() comp_type = computation_types.FunctionType((np.int32, np.int32), np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [0, 1], comp_type) backend = xla_client.get_local_backend(None) comp_callable = runtime.ComputationCallable(comp_pb, comp_type, backend) self.assertIsInstance(comp_callable, runtime.ComputationCallable) self.assertEqual(str(comp_callable.type_signature), '(<int32,int32> -> int32)') result = comp_callable( structure.Struct([(None, np.int32(2)), (None, np.int32(3))])) self.assertEqual(result, 5)
def test_create_and_invoke_noarg_comp_returning_int32(self): builder = xla_client.XlaBuilder('comp') xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) xla_client.ops.Constant(builder, np.int32(10)) xla_comp = builder.build() comp_type = computation_types.FunctionType(None, np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [], comp_type) ex = executor.XlaExecutor() comp_val = asyncio.get_event_loop().run_until_complete( ex.create_value(comp_pb, comp_type)) self.assertIsInstance(comp_val, executor.XlaValue) self.assertEqual(str(comp_val.type_signature), str(comp_type)) self.assertTrue(callable(comp_val.internal_representation)) result = comp_val.internal_representation() self.assertEqual(result, 10) call_val = asyncio.get_event_loop().run_until_complete( ex.create_call(comp_val)) self.assertIsInstance(call_val, executor.XlaValue) self.assertEqual(str(call_val.type_signature), 'int32') result = asyncio.get_event_loop().run_until_complete( call_val.compute()) self.assertEqual(result, 10)
def serialize_jax_computation(traced_fn, arg_fn, parameter_type, context_stack): """Serializes a Python function containing JAX code as a TFF computation. Args: traced_fn: The Python function containing JAX code to be traced by JAX and serialized as a TFF computation containing XLA code. arg_fn: An unpacking function that takes a TFF argument, and returns a combo of (args, kwargs) to invoke `traced_fn` with (e.g., as the one constructed by `function_utils.create_argument_unpacking_fn`). parameter_type: An instance of `computation_types.Type` that represents the TFF type of the computation parameter, or `None` if the function does not take any parameters. context_stack: The context stack to use during serialization. Returns: An instance of `pb.Computation` with the constructed computation. Raises: TypeError: if the arguments are of the wrong types. """ py_typecheck.check_callable(traced_fn) py_typecheck.check_callable(arg_fn) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if parameter_type is not None: parameter_type = computation_types.to_type(parameter_type) packed_arg = _tff_type_to_xla_serializer_arg(parameter_type) else: packed_arg = None args, kwargs = arg_fn(packed_arg) # While the fake parameters are fed via args/kwargs during serialization, # it is possible for them to get reorderd in the actual generate XLA code. # We use here the same flatenning function as that one, which is used by # the JAX serializer to determine the orderding and allow it to be captured # in the parameter binding. We do not need to do anything special for the # results, since the results, if multiple, are always returned as a tuple. flattened_obj, _ = jax.tree_util.tree_flatten((args, kwargs)) tensor_indexes = list(np.argsort([x.tensor_index for x in flattened_obj])) def _adjust_arg(x): if isinstance(x, structure.Struct): return type_conversions.type_to_py_container(x, x.type_signature) else: return x args = [_adjust_arg(x) for x in args] kwargs = {k: _adjust_arg(v) for k, v in kwargs.items()} context = jax_computation_context.JaxComputationContext() with context_stack.install(context): tracer_callable = jax.xla_computation(traced_fn, tuple_args=True, return_shape=True) compiled_xla, returned_shape = tracer_callable(*args, **kwargs) if isinstance(returned_shape, jax.ShapeDtypeStruct): returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor( returned_shape) else: returned_type_spec = computation_types.to_type( structure.map_structure( _jax_shape_dtype_struct_to_tff_tensor, structure.from_container(returned_shape, recursive=True))) computation_type = computation_types.FunctionType(parameter_type, returned_type_spec) return xla_serialization.create_xla_tff_computation( compiled_xla, tensor_indexes, computation_type)
def test_create_xla_tff_computation_raises_result_type_mismatch(self): xla_comp = _make_test_xla_comp_int32x10_to_int32x10() with self.assertRaises(ValueError): xla_serialization.create_xla_tff_computation( xla_comp, [0], computation_types.FunctionType((np.int32, (10, )), np.int32))
def test_create_xla_tff_computation_raises_missing_arg_in_type_spec(self): xla_comp = _make_test_xla_comp_int32x10_to_int32x10() with self.assertRaises(ValueError): xla_serialization.create_xla_tff_computation( xla_comp, [], computation_types.FunctionType(None, np.int32))
def test_create_xla_tff_computation_raises_missing_arg_in_xla(self): xla_comp = _make_test_xla_comp_noarg_to_int32() with self.assertRaises(ValueError): xla_serialization.create_xla_tff_computation( xla_comp, [0], computation_types.FunctionType(np.int32, np.int32))