def test_nested_structure_type_signature_roundtrip(self):

    def traced_fn(x):
      return x[0][0]

    param_type = computation_types.to_type([(np.int32,)])
    arg_fn = function_utils.create_argument_unpacking_fn(traced_fn, param_type)
    ctx_stack = context_stack_impl.context_stack
    comp_pb = jax_serialization.serialize_jax_computation(
        traced_fn, arg_fn, param_type, ctx_stack)
    self.assertIsInstance(comp_pb, pb.Computation)
    self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
    type_spec = type_serialization.deserialize_type(comp_pb.type)
    self.assertEqual(str(type_spec), '(<<int32>> -> int32)')
  def test_arg_ordering(self):
    param_type = computation_types.to_type(
        (computation_types.TensorType(np.int32, 10),
         computation_types.TensorType(np.int32)))

    def traced_fn(b, a):
      return jax.numpy.add(a, jax.numpy.sum(b))

    arg_fn = function_utils.create_argument_unpacking_fn(traced_fn, param_type)
    ctx_stack = context_stack_impl.context_stack
    comp_pb = jax_serialization.serialize_jax_computation(
        traced_fn, arg_fn, param_type, ctx_stack)
    self.assertIsInstance(comp_pb, pb.Computation)
    self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
    type_spec = type_serialization.deserialize_type(comp_pb.type)
    self.assertEqual(str(type_spec), '(<int32[10],int32> -> int32)')
  def test_serialize_jax_with_2xint32_to_2xint32(self):

    ctx_stack = context_stack_impl.context_stack
    param_type = collections.OrderedDict([('foo', np.int32), ('bar', np.int32)])
    arg_func = lambda x: ([x], {})

    def traced_func(x):
      return collections.OrderedDict([('sum', x['foo'] + x['bar']),
                                      ('difference', x['bar'] - x['foo'])])

    comp_pb = jax_serialization.serialize_jax_computation(
        traced_func, arg_func, param_type, ctx_stack)
    self.assertIsInstance(comp_pb, pb.Computation)
    self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
    type_spec = type_serialization.deserialize_type(comp_pb.type)
    self.assertEqual(
        str(type_spec),
        '(<foo=int32,bar=int32> -> <sum=int32,difference=int32>)')
    xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module)
    self.assertEqual(
        xla_comp.as_hlo_text(),
        # pylint: disable=line-too-long
        'HloModule xla_computation_traced_func.8\n\n'
        'ENTRY xla_computation_traced_func.8 {\n'
        '  constant.4 = pred[] constant(false)\n'
        '  parameter.1 = (s32[], s32[]) parameter(0)\n'
        '  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0\n'
        '  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1\n'
        '  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)\n'
        '  subtract.6 = s32[] subtract(get-tuple-element.3, get-tuple-element.2)\n'
        '  ROOT tuple.7 = (s32[], s32[]) tuple(add.5, subtract.6)\n'
        '}\n\n')
    self.assertEqual(str(comp_pb.xla.result), str(comp_pb.xla.parameter))
    self.assertEqual(
        str(comp_pb.xla.parameter), 'struct {\n'
        '  element {\n'
        '    tensor {\n'
        '      index: 0\n'
        '    }\n'
        '  }\n'
        '  element {\n'
        '    tensor {\n'
        '      index: 1\n'
        '    }\n'
        '  }\n'
        '}\n')
    def test_serialize_jax_with_nested_struct_arg(self):
        def traced_fn(x, y):
            return x[0] + y

        param_type = computation_types.StructType([
            (None, computation_types.StructType([(None, np.int32)])),
            (None, np.int32)
        ])
        arg_fn = function_utils.create_argument_unpacking_fn(
            traced_fn, param_type)
        ctx_stack = context_stack_impl.context_stack
        comp_pb = jax_serialization.serialize_jax_computation(
            traced_fn, arg_fn, param_type, ctx_stack)
        self.assertIsInstance(comp_pb, pb.Computation)
        self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
        type_spec = type_serialization.deserialize_type(comp_pb.type)
        self.assertEqual(str(type_spec), '(<<int32>,int32> -> int32)')
  def test_serialize_jax_with_int32_to_int32(self):

    def traced_fn(x):
      return x + 10

    param_type = computation_types.to_type(np.int32)
    arg_fn = function_utils.create_argument_unpacking_fn(traced_fn, param_type)
    ctx_stack = context_stack_impl.context_stack
    comp_pb = jax_serialization.serialize_jax_computation(
        traced_fn, arg_fn, param_type, ctx_stack)
    self.assertIsInstance(comp_pb, pb.Computation)
    self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
    type_spec = type_serialization.deserialize_type(comp_pb.type)
    self.assertEqual(str(type_spec), '(int32 -> int32)')
    xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module)
    self.assertIn('ROOT tuple.6 = (s32[]) tuple(add.5)', xla_comp.as_hlo_text())
    self.assertEqual(str(comp_pb.xla.result), str(comp_pb.xla.parameter))
    self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' '  index: 0\n' '}\n')
  def test_serialize_jax_with_int32_to_int32(self):

    ctx_stack = context_stack_impl.context_stack
    param_type = np.int32
    arg_func = lambda x: ([x], {})

    def traced_func(x):
      return x + 10

    comp_pb = jax_serialization.serialize_jax_computation(
        traced_func, arg_func, param_type, ctx_stack)
    self.assertIsInstance(comp_pb, pb.Computation)
    self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
    type_spec = type_serialization.deserialize_type(comp_pb.type)
    self.assertEqual(str(type_spec), '(int32 -> int32)')
    xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module)
    self.assertIn('ROOT tuple.6 = (s32[]) tuple(add.5)', xla_comp.as_hlo_text())
    self.assertEqual(str(comp_pb.xla.result), str(comp_pb.xla.parameter))
    self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' '  index: 0\n' '}\n')
  def test_serialize_jax_with_two_args(self):

    ctx_stack = context_stack_impl.context_stack
    param_type = computation_types.StructType([('a', np.int32),
                                               ('b', np.int32)])
    arg_func = lambda arg: ([], {'x': arg[0], 'y': arg[1]})

    def traced_func(x, y):
      return x + y

    comp_pb = jax_serialization.serialize_jax_computation(
        traced_func, arg_func, param_type, ctx_stack)
    self.assertIsInstance(comp_pb, pb.Computation)
    self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
    type_spec = type_serialization.deserialize_type(comp_pb.type)
    self.assertEqual(str(type_spec), '(<a=int32,b=int32> -> int32)')
    xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module)
    self.assertEqual(
        xla_comp.as_hlo_text(),
        # pylint: disable=line-too-long
        'HloModule xla_computation_traced_func__3.7\n\n'
        'ENTRY xla_computation_traced_func__3.7 {\n'
        '  constant.4 = pred[] constant(false)\n'
        '  parameter.1 = (s32[], s32[]) parameter(0)\n'
        '  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0\n'
        '  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1\n'
        '  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)\n'
        '  ROOT tuple.6 = (s32[]) tuple(add.5)\n'
        '}\n\n')
    self.assertEqual(
        str(comp_pb.xla.parameter), 'struct {\n'
        '  element {\n'
        '    tensor {\n'
        '      index: 0\n'
        '    }\n'
        '  }\n'
        '  element {\n'
        '    tensor {\n'
        '      index: 1\n'
        '    }\n'
        '  }\n'
        '}\n')
    self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' '  index: 0\n' '}\n')
    def test_serialize_jax_with_two_args(self):
        def traced_fn(x, y):
            return x + y

        param_type = computation_types.to_type(
            collections.OrderedDict([('x', np.int32), ('y', np.int32)]))
        arg_fn = function_utils.create_argument_unpacking_fn(
            traced_fn, param_type)
        ctx_stack = context_stack_impl.context_stack
        comp_pb = jax_serialization.serialize_jax_computation(
            traced_fn, arg_fn, param_type, ctx_stack)
        self.assertIsInstance(comp_pb, pb.Computation)
        self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
        type_spec = type_serialization.deserialize_type(comp_pb.type)
        self.assertEqual(str(type_spec), '(<x=int32,y=int32> -> int32)')
        xla_comp = xla_serialization.unpack_xla_computation(
            comp_pb.xla.hlo_module)
        self.assertIn(
            # pylint: disable=line-too-long
            '  constant.4 = pred[] constant(false)\n'
            '  parameter.1 = (s32[], s32[]) parameter(0)\n'
            '  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0\n'
            '  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1\n'
            '  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)\n'
            '  ROOT tuple.6 = (s32[]) tuple(add.5)\n',
            xla_comp.as_hlo_text())
        self.assertEqual(
            str(comp_pb.xla.parameter), 'struct {\n'
            '  element {\n'
            '    tensor {\n'
            '      index: 0\n'
            '    }\n'
            '  }\n'
            '  element {\n'
            '    tensor {\n'
            '      index: 1\n'
            '    }\n'
            '  }\n'
            '}\n')
        self.assertEqual(str(comp_pb.xla.result), 'tensor {\n'
                         '  index: 0\n'
                         '}\n')
def _jax_strategy_fn(fn_to_wrap, fn_name, parameter_type, unpack):
    """Serializes a Python function containing JAX code as a TFF computation.

  Args:
    fn_to_wrap: The Python function containing JAX code to be serialized as a
      computation containing XLA.
    fn_name: The name for the constructed computation (currently ignored).
    parameter_type: An instance of `computation_types.Type` that represents the
      TFF type of the computation parameter, or `None` if there's none.
    unpack: See `unpack` in `function_utils.create_argument_unpacking_fn`.

  Returns:
    An instance of `computation_impl.ComputationImpl` with the constructed
    computation.
  """
    del fn_name  # Unused.
    unpack_arguments_fn = function_utils.create_argument_unpacking_fn(
        fn_to_wrap, parameter_type, unpack=unpack)
    ctx_stack = context_stack_impl.context_stack
    comp_pb = jax_serialization.serialize_jax_computation(
        fn_to_wrap, unpack_arguments_fn, parameter_type, ctx_stack)
    return computation_impl.ComputationImpl(comp_pb, ctx_stack)