示例#1
0
 def test_create_xla_tff_computation_with_reordered_tensor_indexes(self):
     builder = xla_client.XlaBuilder('comp')
     tensor_shape_1 = xla_client.Shape.array_shape(
         xla_client.dtype_to_etype(np.int32), (10, 1))
     param_1 = xla_client.ops.Parameter(builder, 0, tensor_shape_1)
     tensor_shape_2 = xla_client.Shape.array_shape(
         xla_client.dtype_to_etype(np.int32), (1, 20))
     param_2 = xla_client.ops.Parameter(builder, 1, tensor_shape_2)
     xla_client.ops.Dot(param_1, param_2)
     xla_comp = builder.build()
     comp_pb_1 = xla_serialization.create_xla_tff_computation(
         xla_comp, [0, 1],
         computation_types.FunctionType(
             ((np.int32, (10, 1)), (np.int32, (1, 20))), (np.int32, (
                 10,
                 20,
             ))))
     self.assertIsInstance(comp_pb_1, pb.Computation)
     self.assertEqual(comp_pb_1.WhichOneof('computation'), 'xla')
     type_spec_1 = type_serialization.deserialize_type(comp_pb_1.type)
     self.assertEqual(str(type_spec_1),
                      '(<int32[10,1],int32[1,20]> -> int32[10,20])')
     comp_pb_2 = xla_serialization.create_xla_tff_computation(
         xla_comp, [1, 0],
         computation_types.FunctionType(
             ((np.int32, (1, 20)), (np.int32, (10, 1))), (np.int32, (
                 10,
                 20,
             ))))
     self.assertIsInstance(comp_pb_2, pb.Computation)
     self.assertEqual(comp_pb_2.WhichOneof('computation'), 'xla')
     type_spec_2 = type_serialization.deserialize_type(comp_pb_2.type)
     self.assertEqual(str(type_spec_2),
                      '(<int32[1,20],int32[10,1]> -> int32[10,20])')
示例#2
0
    def create_scalar_multiply_operator(
        self, operand_type: computation_types.Type,
        scalar_type: computation_types.TensorType
    ) -> local_computation_factory_base.ComputationProtoAndType:
        py_typecheck.check_type(operand_type, computation_types.Type)
        py_typecheck.check_type(scalar_type, computation_types.TensorType)
        if not type_analysis.is_structure_of_tensors(operand_type):
            raise ValueError(
                'Not a tensor or a structure of tensors: {}'.format(
                    str(operand_type)))

        operand_shapes = _xla_tensor_shape_list_from_from_tff_tensor_or_struct_type(
            operand_type)
        scalar_shape = _xla_tensor_shape_from_tff_tensor_type(scalar_type)
        num_operand_tensors = len(operand_shapes)
        builder = xla_client.XlaBuilder('comp')
        param = xla_client.ops.Parameter(
            builder, 0,
            xla_client.Shape.tuple_shape(operand_shapes + [scalar_shape]))
        scalar_ref = xla_client.ops.GetTupleElement(param, num_operand_tensors)
        result_tensors = []
        for idx in range(num_operand_tensors):
            result_tensors.append(
                xla_client.ops.Mul(xla_client.ops.GetTupleElement(param, idx),
                                   scalar_ref))
        xla_client.ops.Tuple(builder, result_tensors)
        xla_computation = builder.build()

        comp_type = computation_types.FunctionType(
            computation_types.StructType([(None, operand_type),
                                          (None, scalar_type)]), operand_type)
        comp_pb = xla_serialization.create_xla_tff_computation(
            xla_computation, list(range(num_operand_tensors + 1)), comp_type)
        return (comp_pb, comp_type)
示例#3
0
    def test_add_numbers(self):
        builder = xla_client.XlaBuilder('comp')
        param = xla_client.ops.Parameter(
            builder, 0,
            xla_client.shape_from_pyval(
                tuple([np.array(0, dtype=np.int32)] * 2)))
        xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0),
                           xla_client.ops.GetTupleElement(param, 1))
        xla_comp = builder.build()
        comp_type = computation_types.FunctionType((np.int32, np.int32),
                                                   np.int32)
        comp_pb = xla_serialization.create_xla_tff_computation(
            xla_comp, [0, 1], comp_type)
        ex = executor.XlaExecutor()

        async def _compute_fn():
            comp_val = await ex.create_value(comp_pb, comp_type)
            x_val = await ex.create_value(20, np.int32)
            y_val = await ex.create_value(30, np.int32)
            arg_val = await ex.create_struct([x_val, y_val])
            call_val = await ex.create_call(comp_val, arg_val)
            return await call_val.compute()

        result = asyncio.get_event_loop().run_until_complete(_compute_fn())
        self.assertEqual(result, 50)
 def test_create_xla_tff_computation_int32x10_to_int32x10(self):
   xla_comp = _make_test_xla_comp_int32x10_to_int32x10()
   comp_pb = xla_serialization.create_xla_tff_computation(
       xla_comp, [0],
       computation_types.FunctionType((np.int32, (10,)), (np.int32, (10,))))
   self.assertIsInstance(comp_pb, pb.Computation)
   self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
   type_spec = type_serialization.deserialize_type(comp_pb.type)
   self.assertEqual(str(type_spec), '(int32[10] -> int32[10])')
 def test_create_xla_tff_computation_noarg(self):
   xla_comp = _make_test_xla_comp_noarg_to_int32()
   comp_pb = xla_serialization.create_xla_tff_computation(
       xla_comp, [], computation_types.FunctionType(None, np.int32))
   self.assertIsInstance(comp_pb, pb.Computation)
   self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
   type_spec = type_serialization.deserialize_type(comp_pb.type)
   self.assertEqual(str(type_spec), '( -> int32)')
   xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module)
   self.assertIn('ROOT constant.1 = s32[] constant(10)',
                 xla_comp.as_hlo_text())
   self.assertEqual(str(comp_pb.xla.parameter), '')
   self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' '  index: 0\n' '}\n')
 def test_set_local_execution_context_and_run_simple_xla_computation(self):
     builder = xla_client.XlaBuilder('comp')
     xla_client.ops.Parameter(builder, 0,
                              xla_client.shape_from_pyval(tuple()))
     xla_client.ops.Constant(builder, np.int32(10))
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType(None, np.int32)
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [], comp_type)
     ctx_stack = context_stack_impl.context_stack
     comp = computation_impl.ComputationImpl(comp_pb, ctx_stack)
     execution_contexts.set_local_execution_context()
     self.assertEqual(comp(), 10)
示例#7
0
 def test_to_representation_for_type_with_noarg_to_int32_comp(self):
     builder = xla_client.XlaBuilder('comp')
     xla_client.ops.Parameter(builder, 0,
                              xla_client.shape_from_pyval(tuple()))
     xla_client.ops.Constant(builder, np.int32(10))
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType(None, np.int32)
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [], comp_type)
     rep = executor.to_representation_for_type(comp_pb, comp_type,
                                               self._backend)
     self.assertTrue(callable(rep))
     result = rep()
     self.assertEqual(result, 10)
示例#8
0
 def test_computation_callable_return_one_number(self):
     builder = xla_client.XlaBuilder('comp')
     xla_client.ops.Parameter(builder, 0,
                              xla_client.shape_from_pyval(tuple()))
     xla_client.ops.Constant(builder, np.int32(10))
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType(None, np.int32)
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [], comp_type)
     backend = xla_client.get_local_backend(None)
     comp_callable = runtime.ComputationCallable(comp_pb, comp_type,
                                                 backend)
     self.assertIsInstance(comp_callable, runtime.ComputationCallable)
     self.assertEqual(str(comp_callable.type_signature), '( -> int32)')
     result = comp_callable()
     self.assertEqual(result, 10)
示例#9
0
 def test_to_representation_for_type_with_2xint32_to_int32_comp(self):
     builder = xla_client.XlaBuilder('comp')
     param = xla_client.ops.Parameter(
         builder, 0,
         xla_client.shape_from_pyval(
             tuple([np.array(0, dtype=np.int32)] * 2)))
     xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0),
                        xla_client.ops.GetTupleElement(param, 1))
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType((np.int32, np.int32),
                                                np.int32)
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [0, 1], comp_type)
     rep = executor.to_representation_for_type(comp_pb, comp_type,
                                               self._backend)
     self.assertTrue(callable(rep))
     result = rep(
         structure.Struct([(None, np.int32(20)), (None, np.int32(30))]))
     self.assertEqual(result, 50)
示例#10
0
 def test_to_representation_for_type_with_noarg_to_2xint32_comp(self):
     builder = xla_client.XlaBuilder('comp')
     xla_client.ops.Parameter(builder, 0,
                              xla_client.shape_from_pyval(tuple()))
     xla_client.ops.Tuple(builder, [
         xla_client.ops.Constant(builder, np.int32(10)),
         xla_client.ops.Constant(builder, np.int32(20))
     ])
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType(
         None,
         computation_types.StructType([('a', np.int32), ('b', np.int32)]))
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [0, 1], comp_type)
     rep = executor.to_representation_for_type(comp_pb, comp_type,
                                               self._backend)
     self.assertTrue(callable(rep))
     result = rep()
     self.assertEqual(str(result), '<a=10,b=20>')
示例#11
0
def _create_xla_binary_op_computation(type_spec, xla_binary_op_constructor):
    """Helper for constructing computations that implement binary operators.

  The constructed computation is of type `(<T,T> -> T)`, where `T` is the type
  of the operand (`type_spec`).

  Args:
    type_spec: The type of a single operand.
    xla_binary_op_constructor: A two-argument callable that constructs a binary
      xla op from tensor parameters (such as `xla_client.ops.Add` or similar).

  Returns:
    An instance of `local_computation_factory_base.ComputationProtoAndType`.

  Raises:
    ValueError: if the arguments are invalid.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if not type_analysis.is_structure_of_tensors(type_spec):
        raise ValueError('Not a tensor or a structure of tensors: {}'.format(
            str(type_spec)))

    tensor_shapes = _xla_tensor_shape_list_from_from_tff_tensor_or_struct_type(
        type_spec)
    num_tensors = len(tensor_shapes)
    builder = xla_client.XlaBuilder('comp')
    param = xla_client.ops.Parameter(
        builder, 0, xla_client.Shape.tuple_shape(tensor_shapes * 2))
    result_tensors = []
    for idx in range(num_tensors):
        result_tensors.append(
            xla_binary_op_constructor(
                xla_client.ops.GetTupleElement(param, idx),
                xla_client.ops.GetTupleElement(param, idx + num_tensors)))
    xla_client.ops.Tuple(builder, result_tensors)
    xla_computation = builder.build()

    comp_type = computation_types.FunctionType(
        computation_types.StructType([(None, type_spec)] * 2), type_spec)
    comp_pb = xla_serialization.create_xla_tff_computation(
        xla_computation, list(range(2 * num_tensors)), comp_type)
    return (comp_pb, comp_type)
示例#12
0
    def create_constant_from_scalar(
        self, value, type_spec: computation_types.Type
    ) -> local_computation_factory_base.ComputationProtoAndType:
        py_typecheck.check_type(type_spec, computation_types.Type)
        if not type_analysis.is_structure_of_tensors(type_spec):
            raise ValueError(
                'Not a tensor or a structure of tensors: {}'.format(
                    str(type_spec)))

        builder = xla_client.XlaBuilder('comp')

        # We maintain the convention that arguments are supplied as a tuple for the
        # sake of consistency and uniformity (see comments in `computation.proto`).
        # Since there are no arguments here, we create an empty tuple.
        xla_client.ops.Parameter(builder, 0,
                                 xla_client.shape_from_pyval(tuple()))

        def _constant_from_tensor(tensor_type):
            py_typecheck.check_type(tensor_type, computation_types.TensorType)
            numpy_value = np.full(shape=tensor_type.shape.dims,
                                  fill_value=value,
                                  dtype=tensor_type.dtype.as_numpy_dtype)
            return xla_client.ops.Constant(builder, numpy_value)

        if isinstance(type_spec, computation_types.TensorType):
            tensors = [_constant_from_tensor(type_spec)]
        else:
            tensors = [
                _constant_from_tensor(x) for x in structure.flatten(type_spec)
            ]

        # Likewise, results are always returned as a single tuple with results.
        # This is always a flat tuple; the nested TFF structure is defined by the
        # binding.
        xla_client.ops.Tuple(builder, tensors)
        xla_computation = builder.build()

        comp_type = computation_types.FunctionType(None, type_spec)
        comp_pb = xla_serialization.create_xla_tff_computation(
            xla_computation, [], comp_type)
        return (comp_pb, comp_type)
示例#13
0
 def test_computation_callable_add_two_numbers(self):
     builder = xla_client.XlaBuilder('comp')
     param = xla_client.ops.Parameter(
         builder, 0,
         xla_client.shape_from_pyval(
             tuple([np.array(0, dtype=np.int32)] * 2)))
     xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0),
                        xla_client.ops.GetTupleElement(param, 1))
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType((np.int32, np.int32),
                                                np.int32)
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [0, 1], comp_type)
     backend = xla_client.get_local_backend(None)
     comp_callable = runtime.ComputationCallable(comp_pb, comp_type,
                                                 backend)
     self.assertIsInstance(comp_callable, runtime.ComputationCallable)
     self.assertEqual(str(comp_callable.type_signature),
                      '(<int32,int32> -> int32)')
     result = comp_callable(
         structure.Struct([(None, np.int32(2)), (None, np.int32(3))]))
     self.assertEqual(result, 5)
示例#14
0
 def test_create_and_invoke_noarg_comp_returning_int32(self):
     builder = xla_client.XlaBuilder('comp')
     xla_client.ops.Parameter(builder, 0,
                              xla_client.shape_from_pyval(tuple()))
     xla_client.ops.Constant(builder, np.int32(10))
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType(None, np.int32)
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [], comp_type)
     ex = executor.XlaExecutor()
     comp_val = asyncio.get_event_loop().run_until_complete(
         ex.create_value(comp_pb, comp_type))
     self.assertIsInstance(comp_val, executor.XlaValue)
     self.assertEqual(str(comp_val.type_signature), str(comp_type))
     self.assertTrue(callable(comp_val.internal_representation))
     result = comp_val.internal_representation()
     self.assertEqual(result, 10)
     call_val = asyncio.get_event_loop().run_until_complete(
         ex.create_call(comp_val))
     self.assertIsInstance(call_val, executor.XlaValue)
     self.assertEqual(str(call_val.type_signature), 'int32')
     result = asyncio.get_event_loop().run_until_complete(
         call_val.compute())
     self.assertEqual(result, 10)
示例#15
0
def serialize_jax_computation(traced_fn, arg_fn, parameter_type,
                              context_stack):
    """Serializes a Python function containing JAX code as a TFF computation.

  Args:
    traced_fn: The Python function containing JAX code to be traced by JAX and
      serialized as a TFF computation containing XLA code.
    arg_fn: An unpacking function that takes a TFF argument, and returns a combo
      of (args, kwargs) to invoke `traced_fn` with (e.g., as the one constructed
      by `function_utils.create_argument_unpacking_fn`).
    parameter_type: An instance of `computation_types.Type` that represents the
      TFF type of the computation parameter, or `None` if the function does not
      take any parameters.
    context_stack: The context stack to use during serialization.

  Returns:
    An instance of `pb.Computation` with the constructed computation.

  Raises:
    TypeError: if the arguments are of the wrong types.
  """
    py_typecheck.check_callable(traced_fn)
    py_typecheck.check_callable(arg_fn)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)

    if parameter_type is not None:
        parameter_type = computation_types.to_type(parameter_type)
        packed_arg = _tff_type_to_xla_serializer_arg(parameter_type)
    else:
        packed_arg = None

    args, kwargs = arg_fn(packed_arg)

    # While the fake parameters are fed via args/kwargs during serialization,
    # it is possible for them to get reorderd in the actual generate XLA code.
    # We use here the same flatenning function as that one, which is used by
    # the JAX serializer to determine the orderding and allow it to be captured
    # in the parameter binding. We do not need to do anything special for the
    # results, since the results, if multiple, are always returned as a tuple.
    flattened_obj, _ = jax.tree_util.tree_flatten((args, kwargs))
    tensor_indexes = list(np.argsort([x.tensor_index for x in flattened_obj]))

    def _adjust_arg(x):
        if isinstance(x, structure.Struct):
            return type_conversions.type_to_py_container(x, x.type_signature)
        else:
            return x

    args = [_adjust_arg(x) for x in args]
    kwargs = {k: _adjust_arg(v) for k, v in kwargs.items()}

    context = jax_computation_context.JaxComputationContext()
    with context_stack.install(context):
        tracer_callable = jax.xla_computation(traced_fn,
                                              tuple_args=True,
                                              return_shape=True)
        compiled_xla, returned_shape = tracer_callable(*args, **kwargs)

    if isinstance(returned_shape, jax.ShapeDtypeStruct):
        returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor(
            returned_shape)
    else:
        returned_type_spec = computation_types.to_type(
            structure.map_structure(
                _jax_shape_dtype_struct_to_tff_tensor,
                structure.from_container(returned_shape, recursive=True)))

    computation_type = computation_types.FunctionType(parameter_type,
                                                      returned_type_spec)
    return xla_serialization.create_xla_tff_computation(
        compiled_xla, tensor_indexes, computation_type)
示例#16
0
 def test_create_xla_tff_computation_raises_result_type_mismatch(self):
     xla_comp = _make_test_xla_comp_int32x10_to_int32x10()
     with self.assertRaises(ValueError):
         xla_serialization.create_xla_tff_computation(
             xla_comp, [0],
             computation_types.FunctionType((np.int32, (10, )), np.int32))
示例#17
0
 def test_create_xla_tff_computation_raises_missing_arg_in_type_spec(self):
     xla_comp = _make_test_xla_comp_int32x10_to_int32x10()
     with self.assertRaises(ValueError):
         xla_serialization.create_xla_tff_computation(
             xla_comp, [], computation_types.FunctionType(None, np.int32))
示例#18
0
 def test_create_xla_tff_computation_raises_missing_arg_in_xla(self):
     xla_comp = _make_test_xla_comp_noarg_to_int32()
     with self.assertRaises(ValueError):
         xla_serialization.create_xla_tff_computation(
             xla_comp, [0],
             computation_types.FunctionType(np.int32, np.int32))