Beispiel #1
0
 def test_create_xla_tff_computation(self):
     xla_comp = _make_test_xla_comp()
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, computation_types.FunctionType(None, np.int32))
     self.assertIsInstance(comp_pb, pb.Computation)
     self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
     type_spec = type_serialization.deserialize_type(comp_pb.type)
     self.assertEqual(str(type_spec), '( -> int32)')
     xla_comp = xla_serialization.unpack_xla_computation(
         comp_pb.xla.hlo_module)
     self.assertIn('ROOT constant.1 = s32[] constant(10)',
                   xla_comp.as_hlo_text())
     self.assertEqual(str(comp_pb.xla.parameter), '')
     self.assertEqual(str(comp_pb.xla.result), 'tensor {\n'
                      '  index: 0\n'
                      '}\n')
  def test_serialize_jax_with_2xint32_to_2xint32(self):

    ctx_stack = context_stack_impl.context_stack
    param_type = collections.OrderedDict([('foo', np.int32), ('bar', np.int32)])
    arg_func = lambda x: ([x], {})

    def traced_func(x):
      return collections.OrderedDict([('sum', x['foo'] + x['bar']),
                                      ('difference', x['bar'] - x['foo'])])

    comp_pb = jax_serialization.serialize_jax_computation(
        traced_func, arg_func, param_type, ctx_stack)
    self.assertIsInstance(comp_pb, pb.Computation)
    self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
    type_spec = type_serialization.deserialize_type(comp_pb.type)
    self.assertEqual(
        str(type_spec),
        '(<foo=int32,bar=int32> -> <sum=int32,difference=int32>)')
    xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module)
    self.assertEqual(
        xla_comp.as_hlo_text(),
        # pylint: disable=line-too-long
        'HloModule xla_computation_traced_func.8\n\n'
        'ENTRY xla_computation_traced_func.8 {\n'
        '  constant.4 = pred[] constant(false)\n'
        '  parameter.1 = (s32[], s32[]) parameter(0)\n'
        '  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0\n'
        '  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1\n'
        '  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)\n'
        '  subtract.6 = s32[] subtract(get-tuple-element.3, get-tuple-element.2)\n'
        '  ROOT tuple.7 = (s32[], s32[]) tuple(add.5, subtract.6)\n'
        '}\n\n')
    self.assertEqual(str(comp_pb.xla.result), str(comp_pb.xla.parameter))
    self.assertEqual(
        str(comp_pb.xla.parameter), 'struct {\n'
        '  element {\n'
        '    tensor {\n'
        '      index: 0\n'
        '    }\n'
        '  }\n'
        '  element {\n'
        '    tensor {\n'
        '      index: 1\n'
        '    }\n'
        '  }\n'
        '}\n')
  def test_serialize_jax_with_int32_to_int32(self):

    def traced_fn(x):
      return x + 10

    param_type = computation_types.to_type(np.int32)
    arg_fn = function_utils.create_argument_unpacking_fn(traced_fn, param_type)
    ctx_stack = context_stack_impl.context_stack
    comp_pb = jax_serialization.serialize_jax_computation(
        traced_fn, arg_fn, param_type, ctx_stack)
    self.assertIsInstance(comp_pb, pb.Computation)
    self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
    type_spec = type_serialization.deserialize_type(comp_pb.type)
    self.assertEqual(str(type_spec), '(int32 -> int32)')
    xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module)
    self.assertIn('ROOT tuple.6 = (s32[]) tuple(add.5)', xla_comp.as_hlo_text())
    self.assertEqual(str(comp_pb.xla.result), str(comp_pb.xla.parameter))
    self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' '  index: 0\n' '}\n')
  def test_serialize_jax_with_int32_to_int32(self):

    ctx_stack = context_stack_impl.context_stack
    param_type = np.int32
    arg_func = lambda x: ([x], {})

    def traced_func(x):
      return x + 10

    comp_pb = jax_serialization.serialize_jax_computation(
        traced_func, arg_func, param_type, ctx_stack)
    self.assertIsInstance(comp_pb, pb.Computation)
    self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
    type_spec = type_serialization.deserialize_type(comp_pb.type)
    self.assertEqual(str(type_spec), '(int32 -> int32)')
    xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module)
    self.assertIn('ROOT tuple.6 = (s32[]) tuple(add.5)', xla_comp.as_hlo_text())
    self.assertEqual(str(comp_pb.xla.result), str(comp_pb.xla.parameter))
    self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' '  index: 0\n' '}\n')
  def test_serialize_jax_with_two_args(self):

    ctx_stack = context_stack_impl.context_stack
    param_type = computation_types.StructType([('a', np.int32),
                                               ('b', np.int32)])
    arg_func = lambda arg: ([], {'x': arg[0], 'y': arg[1]})

    def traced_func(x, y):
      return x + y

    comp_pb = jax_serialization.serialize_jax_computation(
        traced_func, arg_func, param_type, ctx_stack)
    self.assertIsInstance(comp_pb, pb.Computation)
    self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
    type_spec = type_serialization.deserialize_type(comp_pb.type)
    self.assertEqual(str(type_spec), '(<a=int32,b=int32> -> int32)')
    xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module)
    self.assertEqual(
        xla_comp.as_hlo_text(),
        # pylint: disable=line-too-long
        'HloModule xla_computation_traced_func__3.7\n\n'
        'ENTRY xla_computation_traced_func__3.7 {\n'
        '  constant.4 = pred[] constant(false)\n'
        '  parameter.1 = (s32[], s32[]) parameter(0)\n'
        '  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0\n'
        '  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1\n'
        '  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)\n'
        '  ROOT tuple.6 = (s32[]) tuple(add.5)\n'
        '}\n\n')
    self.assertEqual(
        str(comp_pb.xla.parameter), 'struct {\n'
        '  element {\n'
        '    tensor {\n'
        '      index: 0\n'
        '    }\n'
        '  }\n'
        '  element {\n'
        '    tensor {\n'
        '      index: 1\n'
        '    }\n'
        '  }\n'
        '}\n')
    self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' '  index: 0\n' '}\n')
Beispiel #6
0
  def __init__(self, comp_pb: pb.Computation,
               type_spec: computation_types.FunctionType,
               backend: xla_client.Client):
    """Creates this callable for a given computation, type, and backend.

    Args:
      comp_pb: An instance of `pb.Computation`.
      type_spec: An instance of `computation_types.FunctionType`.
      backend: An instance of `xla_client.Client`.
    """
    py_typecheck.check_type(comp_pb, pb.Computation)
    py_typecheck.check_type(type_spec, computation_types.FunctionType)
    py_typecheck.check_type(backend, xla_client.Client)
    xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module)
    compile_options = xla_client.CompileOptions()
    compile_options.parameter_is_tupled_arguments = True
    self._executable = backend.compile(xla_comp, compile_options)
    self._type_signature = type_spec
    self._backend = backend
Beispiel #7
0
    def test_serialize_jax_with_two_args(self):
        def traced_fn(x, y):
            return x + y

        param_type = computation_types.to_type(
            collections.OrderedDict([('x', np.int32), ('y', np.int32)]))
        arg_fn = function_utils.create_argument_unpacking_fn(
            traced_fn, param_type)
        ctx_stack = context_stack_impl.context_stack
        comp_pb = jax_serialization.serialize_jax_computation(
            traced_fn, arg_fn, param_type, ctx_stack)
        self.assertIsInstance(comp_pb, pb.Computation)
        self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
        type_spec = type_serialization.deserialize_type(comp_pb.type)
        self.assertEqual(str(type_spec), '(<x=int32,y=int32> -> int32)')
        xla_comp = xla_serialization.unpack_xla_computation(
            comp_pb.xla.hlo_module)
        self.assertIn(
            # pylint: disable=line-too-long
            '  constant.4 = pred[] constant(false)\n'
            '  parameter.1 = (s32[], s32[]) parameter(0)\n'
            '  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0\n'
            '  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1\n'
            '  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)\n'
            '  ROOT tuple.6 = (s32[]) tuple(add.5)\n',
            xla_comp.as_hlo_text())
        self.assertEqual(
            str(comp_pb.xla.parameter), 'struct {\n'
            '  element {\n'
            '    tensor {\n'
            '      index: 0\n'
            '    }\n'
            '  }\n'
            '  element {\n'
            '    tensor {\n'
            '      index: 1\n'
            '    }\n'
            '  }\n'
            '}\n')
        self.assertEqual(str(comp_pb.xla.result), 'tensor {\n'
                         '  index: 0\n'
                         '}\n')
Beispiel #8
0
 def test_pack_unpack_xla_computation_roundtrip(self):
     xla_comp = _make_test_xla_comp()
     any_pb = xla_serialization.pack_xla_computation(xla_comp)
     new_comp = xla_serialization.unpack_xla_computation(any_pb)
     self.assertEqual(new_comp.as_hlo_text(), xla_comp.as_hlo_text())