def test_serialize_jax_with_int32_to_int32(self): self.skipTest('HLO pattern matching broken by ' 'https://github.com/google/jax/pull/10232') def traced_fn(x): return x + 10 param_type = computation_types.to_type(np.int32) arg_fn = function_utils.create_argument_unpacking_fn( traced_fn, param_type) ctx_stack = context_stack_impl.context_stack comp_pb = jax_serialization.serialize_jax_computation( traced_fn, arg_fn, param_type, ctx_stack) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '(int32 -> int32)') xla_comp = xla_serialization.unpack_xla_computation( comp_pb.xla.hlo_module) self.assertIn('ROOT tuple.6 = (s32[]) tuple(add.5)', xla_comp.as_hlo_text()) self.assertEqual(str(comp_pb.xla.result), str(comp_pb.xla.parameter)) self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' ' index: 0\n' '}\n')
def test_wrap_as_zero_or_one_arg_callable(self, fn, parameter_type, unpack, arg, expected_result): parameter_type = computation_types.to_type(parameter_type) unpack_arguments = function_utils.create_argument_unpacking_fn( fn, parameter_type, unpack) args, kwargs = unpack_arguments(arg) actual_result = fn(*args, **kwargs) self.assertEqual(actual_result, expected_result)
def _federated_computation_serializer(fn, parameter_name, parameter_type): unpack_arguments = function_utils.create_argument_unpacking_fn( fn, parameter_type) fn_gen = federated_computation_utils.federated_computation_serializer( parameter_name, parameter_type, context_stack_impl.context_stack) args, kwargs = unpack_arguments(next(fn_gen)) result = fn(*args, **kwargs) return fn_gen.send(result)
def test_nested_structure_type_signature_roundtrip(self): def traced_fn(x): return x[0][0] param_type = computation_types.to_type([(np.int32, )]) arg_fn = function_utils.create_argument_unpacking_fn( traced_fn, param_type) ctx_stack = context_stack_impl.context_stack comp_pb = jax_serialization.serialize_jax_computation( traced_fn, arg_fn, param_type, ctx_stack) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '(<<int32>> -> int32)')
def test_tracing_with_float64_input(self): self.skipTest('b/237566862') param_type = computation_types.TensorType(np.int64) identity_fn = lambda x: x arg_fn = function_utils.create_argument_unpacking_fn( identity_fn, param_type) ctx_stack = context_stack_impl.context_stack comp_pb = jax_serialization.serialize_jax_computation( identity_fn, arg_fn, param_type, ctx_stack) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '(float64 -> float64)')
def test_serialize_jax_with_2xint32_to_2xint32(self): self.skipTest('HLO pattern matching broken by ' 'https://github.com/google/jax/pull/10232') def traced_fn(x): return collections.OrderedDict([('sum', x['foo'] + x['bar']), ('difference', x['bar'] - x['foo']) ]) param_type = computation_types.to_type( collections.OrderedDict([('foo', np.int32), ('bar', np.int32)])) arg_fn = function_utils.create_argument_unpacking_fn( traced_fn, param_type) ctx_stack = context_stack_impl.context_stack comp_pb = jax_serialization.serialize_jax_computation( traced_fn, arg_fn, param_type, ctx_stack) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual( str(type_spec), '(<foo=int32,bar=int32> -> <sum=int32,difference=int32>)') xla_comp = xla_serialization.unpack_xla_computation( comp_pb.xla.hlo_module) self.assertIn( # pylint: disable=line-too-long ' constant.4 = pred[] constant(false)\n' ' parameter.1 = (s32[], s32[]) parameter(0)\n' ' get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0\n' ' get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1\n' ' add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)\n' ' subtract.6 = s32[] subtract(get-tuple-element.3, get-tuple-element.2)\n' ' ROOT tuple.7 = (s32[], s32[]) tuple(add.5, subtract.6)\n', xla_comp.as_hlo_text()) self.assertEqual(str(comp_pb.xla.result), str(comp_pb.xla.parameter)) self.assertEqual( str(comp_pb.xla.parameter), 'struct {\n' ' element {\n' ' tensor {\n' ' index: 0\n' ' }\n' ' }\n' ' element {\n' ' tensor {\n' ' index: 1\n' ' }\n' ' }\n' '}\n')
def test_arg_ordering(self): param_type = computation_types.to_type( (computation_types.TensorType(np.int32, 10), computation_types.TensorType(np.int32))) def traced_fn(b, a): return jax.numpy.add(a, jax.numpy.sum(b)) arg_fn = function_utils.create_argument_unpacking_fn(traced_fn, param_type) ctx_stack = context_stack_impl.context_stack comp_pb = jax_serialization.serialize_jax_computation( traced_fn, arg_fn, param_type, ctx_stack) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '(<int32[10],int32> -> int32)')
def test_serialize_jax_with_nested_struct_arg(self): def traced_fn(x, y): return x[0] + y param_type = computation_types.StructType([ (None, computation_types.StructType([(None, np.int32)])), (None, np.int32) ]) arg_fn = function_utils.create_argument_unpacking_fn( traced_fn, param_type) ctx_stack = context_stack_impl.context_stack comp_pb = jax_serialization.serialize_jax_computation( traced_fn, arg_fn, param_type, ctx_stack) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '(<<int32>,int32> -> int32)')
def test_serialize_jax_with_noarg_to_int32(self): def traced_fn(): return 10 param_type = None arg_fn = function_utils.create_argument_unpacking_fn(traced_fn, param_type) ctx_stack = context_stack_impl.context_stack comp_pb = jax_serialization.serialize_jax_computation( traced_fn, arg_fn, param_type, ctx_stack) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '( -> int32)') xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module) self.assertIn('ROOT tuple.4 = (s32[]) tuple(constant.3)', xla_comp.as_hlo_text()) self.assertEqual(str(comp_pb.xla.parameter), '') self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' ' index: 0\n' '}\n')
def test_serialize_jax_with_two_args(self): def traced_fn(x, y): return x + y param_type = computation_types.to_type( collections.OrderedDict([('x', np.int32), ('y', np.int32)])) arg_fn = function_utils.create_argument_unpacking_fn( traced_fn, param_type) ctx_stack = context_stack_impl.context_stack comp_pb = jax_serialization.serialize_jax_computation( traced_fn, arg_fn, param_type, ctx_stack) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '(<x=int32,y=int32> -> int32)') xla_comp = xla_serialization.unpack_xla_computation( comp_pb.xla.hlo_module) self.assertIn( # pylint: disable=line-too-long ' constant.4 = pred[] constant(false)\n' ' parameter.1 = (s32[], s32[]) parameter(0)\n' ' get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0\n' ' get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1\n' ' add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)\n' ' ROOT tuple.6 = (s32[]) tuple(add.5)\n', xla_comp.as_hlo_text()) self.assertEqual( str(comp_pb.xla.parameter), 'struct {\n' ' element {\n' ' tensor {\n' ' index: 0\n' ' }\n' ' }\n' ' element {\n' ' tensor {\n' ' index: 1\n' ' }\n' ' }\n' '}\n') self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' ' index: 0\n' '}\n')
def _jax_strategy_fn(fn_to_wrap, fn_name, parameter_type, unpack): """Serializes a Python function containing JAX code as a TFF computation. Args: fn_to_wrap: The Python function containing JAX code to be serialized as a computation containing XLA. fn_name: The name for the constructed computation (currently ignored). parameter_type: An instance of `computation_types.Type` that represents the TFF type of the computation parameter, or `None` if there's none. unpack: See `unpack` in `function_utils.create_argument_unpacking_fn`. Returns: An instance of `computation_impl.ComputationImpl` with the constructed computation. """ del fn_name # Unused. unpack_arguments_fn = function_utils.create_argument_unpacking_fn( fn_to_wrap, parameter_type, unpack=unpack) ctx_stack = context_stack_impl.context_stack comp_pb = jax_serialization.serialize_jax_computation( fn_to_wrap, unpack_arguments_fn, parameter_type, ctx_stack) return computation_impl.ComputationImpl(comp_pb, ctx_stack)
def __call__(self, fn_to_wrap, fn_name, parameter_type, unpack): unpack_arguments_fn = function_utils.create_argument_unpacking_fn( fn_to_wrap, parameter_type, unpack=unpack) wrapped_fn_generator = _wrap_concrete(fn_name, self._wrapper_fn, parameter_type) packed_args = next(wrapped_fn_generator) try: args, kwargs = unpack_arguments_fn(packed_args) result = fn_to_wrap(*args, **kwargs) if result is None: raise ComputationReturnedNoneError(fn_to_wrap) except Exception: # Give nested generators an opportunity to clean up, then # re-raise the original error without extra context. # We don't want to simply pass the error into the generators, # as that would result in the whole generator stack being added # to the error message. try: wrapped_fn_generator.throw(_TracingError()) except _TracingError: pass raise return wrapped_fn_generator.send(result)