def test_nested_structure_type_signature_roundtrip(self): def traced_fn(x): return x[0][0] param_type = computation_types.to_type([(np.int32,)]) arg_fn = function_utils.create_argument_unpacking_fn(traced_fn, param_type) ctx_stack = context_stack_impl.context_stack comp_pb = jax_serialization.serialize_jax_computation( traced_fn, arg_fn, param_type, ctx_stack) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '(<<int32>> -> int32)')
def test_arg_ordering(self): param_type = computation_types.to_type( (computation_types.TensorType(np.int32, 10), computation_types.TensorType(np.int32))) def traced_fn(b, a): return jax.numpy.add(a, jax.numpy.sum(b)) arg_fn = function_utils.create_argument_unpacking_fn(traced_fn, param_type) ctx_stack = context_stack_impl.context_stack comp_pb = jax_serialization.serialize_jax_computation( traced_fn, arg_fn, param_type, ctx_stack) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '(<int32[10],int32> -> int32)')
def test_serialize_jax_with_2xint32_to_2xint32(self): ctx_stack = context_stack_impl.context_stack param_type = collections.OrderedDict([('foo', np.int32), ('bar', np.int32)]) arg_func = lambda x: ([x], {}) def traced_func(x): return collections.OrderedDict([('sum', x['foo'] + x['bar']), ('difference', x['bar'] - x['foo'])]) comp_pb = jax_serialization.serialize_jax_computation( traced_func, arg_func, param_type, ctx_stack) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual( str(type_spec), '(<foo=int32,bar=int32> -> <sum=int32,difference=int32>)') xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module) self.assertEqual( xla_comp.as_hlo_text(), # pylint: disable=line-too-long 'HloModule xla_computation_traced_func.8\n\n' 'ENTRY xla_computation_traced_func.8 {\n' ' constant.4 = pred[] constant(false)\n' ' parameter.1 = (s32[], s32[]) parameter(0)\n' ' get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0\n' ' get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1\n' ' add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)\n' ' subtract.6 = s32[] subtract(get-tuple-element.3, get-tuple-element.2)\n' ' ROOT tuple.7 = (s32[], s32[]) tuple(add.5, subtract.6)\n' '}\n\n') self.assertEqual(str(comp_pb.xla.result), str(comp_pb.xla.parameter)) self.assertEqual( str(comp_pb.xla.parameter), 'struct {\n' ' element {\n' ' tensor {\n' ' index: 0\n' ' }\n' ' }\n' ' element {\n' ' tensor {\n' ' index: 1\n' ' }\n' ' }\n' '}\n')
def test_serialize_jax_with_nested_struct_arg(self): def traced_fn(x, y): return x[0] + y param_type = computation_types.StructType([ (None, computation_types.StructType([(None, np.int32)])), (None, np.int32) ]) arg_fn = function_utils.create_argument_unpacking_fn( traced_fn, param_type) ctx_stack = context_stack_impl.context_stack comp_pb = jax_serialization.serialize_jax_computation( traced_fn, arg_fn, param_type, ctx_stack) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '(<<int32>,int32> -> int32)')
def test_serialize_jax_with_int32_to_int32(self): def traced_fn(x): return x + 10 param_type = computation_types.to_type(np.int32) arg_fn = function_utils.create_argument_unpacking_fn(traced_fn, param_type) ctx_stack = context_stack_impl.context_stack comp_pb = jax_serialization.serialize_jax_computation( traced_fn, arg_fn, param_type, ctx_stack) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '(int32 -> int32)') xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module) self.assertIn('ROOT tuple.6 = (s32[]) tuple(add.5)', xla_comp.as_hlo_text()) self.assertEqual(str(comp_pb.xla.result), str(comp_pb.xla.parameter)) self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' ' index: 0\n' '}\n')
def test_serialize_jax_with_int32_to_int32(self): ctx_stack = context_stack_impl.context_stack param_type = np.int32 arg_func = lambda x: ([x], {}) def traced_func(x): return x + 10 comp_pb = jax_serialization.serialize_jax_computation( traced_func, arg_func, param_type, ctx_stack) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '(int32 -> int32)') xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module) self.assertIn('ROOT tuple.6 = (s32[]) tuple(add.5)', xla_comp.as_hlo_text()) self.assertEqual(str(comp_pb.xla.result), str(comp_pb.xla.parameter)) self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' ' index: 0\n' '}\n')
def test_serialize_jax_with_two_args(self): ctx_stack = context_stack_impl.context_stack param_type = computation_types.StructType([('a', np.int32), ('b', np.int32)]) arg_func = lambda arg: ([], {'x': arg[0], 'y': arg[1]}) def traced_func(x, y): return x + y comp_pb = jax_serialization.serialize_jax_computation( traced_func, arg_func, param_type, ctx_stack) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '(<a=int32,b=int32> -> int32)') xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module) self.assertEqual( xla_comp.as_hlo_text(), # pylint: disable=line-too-long 'HloModule xla_computation_traced_func__3.7\n\n' 'ENTRY xla_computation_traced_func__3.7 {\n' ' constant.4 = pred[] constant(false)\n' ' parameter.1 = (s32[], s32[]) parameter(0)\n' ' get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0\n' ' get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1\n' ' add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)\n' ' ROOT tuple.6 = (s32[]) tuple(add.5)\n' '}\n\n') self.assertEqual( str(comp_pb.xla.parameter), 'struct {\n' ' element {\n' ' tensor {\n' ' index: 0\n' ' }\n' ' }\n' ' element {\n' ' tensor {\n' ' index: 1\n' ' }\n' ' }\n' '}\n') self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' ' index: 0\n' '}\n')
def test_serialize_jax_with_two_args(self): def traced_fn(x, y): return x + y param_type = computation_types.to_type( collections.OrderedDict([('x', np.int32), ('y', np.int32)])) arg_fn = function_utils.create_argument_unpacking_fn( traced_fn, param_type) ctx_stack = context_stack_impl.context_stack comp_pb = jax_serialization.serialize_jax_computation( traced_fn, arg_fn, param_type, ctx_stack) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '(<x=int32,y=int32> -> int32)') xla_comp = xla_serialization.unpack_xla_computation( comp_pb.xla.hlo_module) self.assertIn( # pylint: disable=line-too-long ' constant.4 = pred[] constant(false)\n' ' parameter.1 = (s32[], s32[]) parameter(0)\n' ' get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0\n' ' get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1\n' ' add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)\n' ' ROOT tuple.6 = (s32[]) tuple(add.5)\n', xla_comp.as_hlo_text()) self.assertEqual( str(comp_pb.xla.parameter), 'struct {\n' ' element {\n' ' tensor {\n' ' index: 0\n' ' }\n' ' }\n' ' element {\n' ' tensor {\n' ' index: 1\n' ' }\n' ' }\n' '}\n') self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' ' index: 0\n' '}\n')
def _jax_strategy_fn(fn_to_wrap, fn_name, parameter_type, unpack): """Serializes a Python function containing JAX code as a TFF computation. Args: fn_to_wrap: The Python function containing JAX code to be serialized as a computation containing XLA. fn_name: The name for the constructed computation (currently ignored). parameter_type: An instance of `computation_types.Type` that represents the TFF type of the computation parameter, or `None` if there's none. unpack: See `unpack` in `function_utils.create_argument_unpacking_fn`. Returns: An instance of `computation_impl.ComputationImpl` with the constructed computation. """ del fn_name # Unused. unpack_arguments_fn = function_utils.create_argument_unpacking_fn( fn_to_wrap, parameter_type, unpack=unpack) ctx_stack = context_stack_impl.context_stack comp_pb = jax_serialization.serialize_jax_computation( fn_to_wrap, unpack_arguments_fn, parameter_type, ctx_stack) return computation_impl.ComputationImpl(comp_pb, ctx_stack)