コード例 #1
0
    def test_serialize_jax_with_int32_to_int32(self):
        self.skipTest('HLO pattern matching broken by '
                      'https://github.com/google/jax/pull/10232')

        def traced_fn(x):
            return x + 10

        param_type = computation_types.to_type(np.int32)
        arg_fn = function_utils.create_argument_unpacking_fn(
            traced_fn, param_type)
        ctx_stack = context_stack_impl.context_stack
        comp_pb = jax_serialization.serialize_jax_computation(
            traced_fn, arg_fn, param_type, ctx_stack)
        self.assertIsInstance(comp_pb, pb.Computation)
        self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
        type_spec = type_serialization.deserialize_type(comp_pb.type)
        self.assertEqual(str(type_spec), '(int32 -> int32)')
        xla_comp = xla_serialization.unpack_xla_computation(
            comp_pb.xla.hlo_module)
        self.assertIn('ROOT tuple.6 = (s32[]) tuple(add.5)',
                      xla_comp.as_hlo_text())
        self.assertEqual(str(comp_pb.xla.result), str(comp_pb.xla.parameter))
        self.assertEqual(str(comp_pb.xla.result), 'tensor {\n'
                         '  index: 0\n'
                         '}\n')
コード例 #2
0
 def test_wrap_as_zero_or_one_arg_callable(self, fn, parameter_type, unpack,
                                           arg, expected_result):
     parameter_type = computation_types.to_type(parameter_type)
     unpack_arguments = function_utils.create_argument_unpacking_fn(
         fn, parameter_type, unpack)
     args, kwargs = unpack_arguments(arg)
     actual_result = fn(*args, **kwargs)
     self.assertEqual(actual_result, expected_result)
コード例 #3
0
def _federated_computation_serializer(fn, parameter_name, parameter_type):
  unpack_arguments = function_utils.create_argument_unpacking_fn(
      fn, parameter_type)
  fn_gen = federated_computation_utils.federated_computation_serializer(
      parameter_name, parameter_type, context_stack_impl.context_stack)
  args, kwargs = unpack_arguments(next(fn_gen))
  result = fn(*args, **kwargs)
  return fn_gen.send(result)
コード例 #4
0
    def test_nested_structure_type_signature_roundtrip(self):
        def traced_fn(x):
            return x[0][0]

        param_type = computation_types.to_type([(np.int32, )])
        arg_fn = function_utils.create_argument_unpacking_fn(
            traced_fn, param_type)
        ctx_stack = context_stack_impl.context_stack
        comp_pb = jax_serialization.serialize_jax_computation(
            traced_fn, arg_fn, param_type, ctx_stack)
        self.assertIsInstance(comp_pb, pb.Computation)
        self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
        type_spec = type_serialization.deserialize_type(comp_pb.type)
        self.assertEqual(str(type_spec), '(<<int32>> -> int32)')
コード例 #5
0
    def test_tracing_with_float64_input(self):
        self.skipTest('b/237566862')

        param_type = computation_types.TensorType(np.int64)
        identity_fn = lambda x: x
        arg_fn = function_utils.create_argument_unpacking_fn(
            identity_fn, param_type)
        ctx_stack = context_stack_impl.context_stack
        comp_pb = jax_serialization.serialize_jax_computation(
            identity_fn, arg_fn, param_type, ctx_stack)
        self.assertIsInstance(comp_pb, pb.Computation)
        self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
        type_spec = type_serialization.deserialize_type(comp_pb.type)
        self.assertEqual(str(type_spec), '(float64 -> float64)')
コード例 #6
0
    def test_serialize_jax_with_2xint32_to_2xint32(self):
        self.skipTest('HLO pattern matching broken by '
                      'https://github.com/google/jax/pull/10232')

        def traced_fn(x):
            return collections.OrderedDict([('sum', x['foo'] + x['bar']),
                                            ('difference', x['bar'] - x['foo'])
                                            ])

        param_type = computation_types.to_type(
            collections.OrderedDict([('foo', np.int32), ('bar', np.int32)]))
        arg_fn = function_utils.create_argument_unpacking_fn(
            traced_fn, param_type)
        ctx_stack = context_stack_impl.context_stack
        comp_pb = jax_serialization.serialize_jax_computation(
            traced_fn, arg_fn, param_type, ctx_stack)
        self.assertIsInstance(comp_pb, pb.Computation)
        self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
        type_spec = type_serialization.deserialize_type(comp_pb.type)
        self.assertEqual(
            str(type_spec),
            '(<foo=int32,bar=int32> -> <sum=int32,difference=int32>)')
        xla_comp = xla_serialization.unpack_xla_computation(
            comp_pb.xla.hlo_module)
        self.assertIn(
            # pylint: disable=line-too-long
            '  constant.4 = pred[] constant(false)\n'
            '  parameter.1 = (s32[], s32[]) parameter(0)\n'
            '  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0\n'
            '  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1\n'
            '  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)\n'
            '  subtract.6 = s32[] subtract(get-tuple-element.3, get-tuple-element.2)\n'
            '  ROOT tuple.7 = (s32[], s32[]) tuple(add.5, subtract.6)\n',
            xla_comp.as_hlo_text())
        self.assertEqual(str(comp_pb.xla.result), str(comp_pb.xla.parameter))
        self.assertEqual(
            str(comp_pb.xla.parameter), 'struct {\n'
            '  element {\n'
            '    tensor {\n'
            '      index: 0\n'
            '    }\n'
            '  }\n'
            '  element {\n'
            '    tensor {\n'
            '      index: 1\n'
            '    }\n'
            '  }\n'
            '}\n')
コード例 #7
0
  def test_arg_ordering(self):
    param_type = computation_types.to_type(
        (computation_types.TensorType(np.int32, 10),
         computation_types.TensorType(np.int32)))

    def traced_fn(b, a):
      return jax.numpy.add(a, jax.numpy.sum(b))

    arg_fn = function_utils.create_argument_unpacking_fn(traced_fn, param_type)
    ctx_stack = context_stack_impl.context_stack
    comp_pb = jax_serialization.serialize_jax_computation(
        traced_fn, arg_fn, param_type, ctx_stack)
    self.assertIsInstance(comp_pb, pb.Computation)
    self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
    type_spec = type_serialization.deserialize_type(comp_pb.type)
    self.assertEqual(str(type_spec), '(<int32[10],int32> -> int32)')
コード例 #8
0
    def test_serialize_jax_with_nested_struct_arg(self):
        def traced_fn(x, y):
            return x[0] + y

        param_type = computation_types.StructType([
            (None, computation_types.StructType([(None, np.int32)])),
            (None, np.int32)
        ])
        arg_fn = function_utils.create_argument_unpacking_fn(
            traced_fn, param_type)
        ctx_stack = context_stack_impl.context_stack
        comp_pb = jax_serialization.serialize_jax_computation(
            traced_fn, arg_fn, param_type, ctx_stack)
        self.assertIsInstance(comp_pb, pb.Computation)
        self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
        type_spec = type_serialization.deserialize_type(comp_pb.type)
        self.assertEqual(str(type_spec), '(<<int32>,int32> -> int32)')
コード例 #9
0
  def test_serialize_jax_with_noarg_to_int32(self):

    def traced_fn():
      return 10

    param_type = None
    arg_fn = function_utils.create_argument_unpacking_fn(traced_fn, param_type)
    ctx_stack = context_stack_impl.context_stack
    comp_pb = jax_serialization.serialize_jax_computation(
        traced_fn, arg_fn, param_type, ctx_stack)
    self.assertIsInstance(comp_pb, pb.Computation)
    self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
    type_spec = type_serialization.deserialize_type(comp_pb.type)
    self.assertEqual(str(type_spec), '( -> int32)')
    xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module)
    self.assertIn('ROOT tuple.4 = (s32[]) tuple(constant.3)',
                  xla_comp.as_hlo_text())
    self.assertEqual(str(comp_pb.xla.parameter), '')
    self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' '  index: 0\n' '}\n')
コード例 #10
0
    def test_serialize_jax_with_two_args(self):
        def traced_fn(x, y):
            return x + y

        param_type = computation_types.to_type(
            collections.OrderedDict([('x', np.int32), ('y', np.int32)]))
        arg_fn = function_utils.create_argument_unpacking_fn(
            traced_fn, param_type)
        ctx_stack = context_stack_impl.context_stack
        comp_pb = jax_serialization.serialize_jax_computation(
            traced_fn, arg_fn, param_type, ctx_stack)
        self.assertIsInstance(comp_pb, pb.Computation)
        self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
        type_spec = type_serialization.deserialize_type(comp_pb.type)
        self.assertEqual(str(type_spec), '(<x=int32,y=int32> -> int32)')
        xla_comp = xla_serialization.unpack_xla_computation(
            comp_pb.xla.hlo_module)
        self.assertIn(
            # pylint: disable=line-too-long
            '  constant.4 = pred[] constant(false)\n'
            '  parameter.1 = (s32[], s32[]) parameter(0)\n'
            '  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0\n'
            '  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1\n'
            '  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)\n'
            '  ROOT tuple.6 = (s32[]) tuple(add.5)\n',
            xla_comp.as_hlo_text())
        self.assertEqual(
            str(comp_pb.xla.parameter), 'struct {\n'
            '  element {\n'
            '    tensor {\n'
            '      index: 0\n'
            '    }\n'
            '  }\n'
            '  element {\n'
            '    tensor {\n'
            '      index: 1\n'
            '    }\n'
            '  }\n'
            '}\n')
        self.assertEqual(str(comp_pb.xla.result), 'tensor {\n'
                         '  index: 0\n'
                         '}\n')
コード例 #11
0
def _jax_strategy_fn(fn_to_wrap, fn_name, parameter_type, unpack):
    """Serializes a Python function containing JAX code as a TFF computation.

  Args:
    fn_to_wrap: The Python function containing JAX code to be serialized as a
      computation containing XLA.
    fn_name: The name for the constructed computation (currently ignored).
    parameter_type: An instance of `computation_types.Type` that represents the
      TFF type of the computation parameter, or `None` if there's none.
    unpack: See `unpack` in `function_utils.create_argument_unpacking_fn`.

  Returns:
    An instance of `computation_impl.ComputationImpl` with the constructed
    computation.
  """
    del fn_name  # Unused.
    unpack_arguments_fn = function_utils.create_argument_unpacking_fn(
        fn_to_wrap, parameter_type, unpack=unpack)
    ctx_stack = context_stack_impl.context_stack
    comp_pb = jax_serialization.serialize_jax_computation(
        fn_to_wrap, unpack_arguments_fn, parameter_type, ctx_stack)
    return computation_impl.ComputationImpl(comp_pb, ctx_stack)
コード例 #12
0
 def __call__(self, fn_to_wrap, fn_name, parameter_type, unpack):
     unpack_arguments_fn = function_utils.create_argument_unpacking_fn(
         fn_to_wrap, parameter_type, unpack=unpack)
     wrapped_fn_generator = _wrap_concrete(fn_name, self._wrapper_fn,
                                           parameter_type)
     packed_args = next(wrapped_fn_generator)
     try:
         args, kwargs = unpack_arguments_fn(packed_args)
         result = fn_to_wrap(*args, **kwargs)
         if result is None:
             raise ComputationReturnedNoneError(fn_to_wrap)
     except Exception:
         # Give nested generators an opportunity to clean up, then
         # re-raise the original error without extra context.
         # We don't want to simply pass the error into the generators,
         # as that would result in the whole generator stack being added
         # to the error message.
         try:
             wrapped_fn_generator.throw(_TracingError())
         except _TracingError:
             pass
         raise
     return wrapped_fn_generator.send(result)