Ejemplo n.º 1
0
    def test_add_numbers(self):
        builder = xla_client.XlaBuilder('comp')
        param = xla_client.ops.Parameter(
            builder, 0,
            xla_client.shape_from_pyval(
                tuple([np.array(0, dtype=np.int32)] * 2)))
        xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0),
                           xla_client.ops.GetTupleElement(param, 1))
        xla_comp = builder.build()
        comp_type = computation_types.FunctionType((np.int32, np.int32),
                                                   np.int32)
        comp_pb = xla_serialization.create_xla_tff_computation(
            xla_comp, [0, 1], comp_type)
        ex = executor.XlaExecutor()

        async def _compute_fn():
            comp_val = await ex.create_value(comp_pb, comp_type)
            x_val = await ex.create_value(20, np.int32)
            y_val = await ex.create_value(30, np.int32)
            arg_val = await ex.create_struct([x_val, y_val])
            call_val = await ex.create_call(comp_val, arg_val)
            return await call_val.compute()

        result = asyncio.get_event_loop().run_until_complete(_compute_fn())
        self.assertEqual(result, 50)
Ejemplo n.º 2
0
 def test_set_local_execution_context(self):
   builder = xla_client.XlaBuilder('comp')
   xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple()))
   xla_client.ops.Constant(builder, np.int32(10))
   xla_comp = builder.build()
   comp_type = computation_types.FunctionType(None, np.int32)
   comp_pb = xla_serialization.create_xla_tff_computation(
       xla_comp, [], comp_type)
   ctx_stack = context_stack_impl.context_stack
   comp = computation_impl.ComputationImpl(comp_pb, ctx_stack)
   execution_contexts.set_local_execution_context()
   self.assertEqual(comp(), 10)
Ejemplo n.º 3
0
 def test_computation_callable_return_one_number(self):
   builder = xla_client.XlaBuilder('comp')
   xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple()))
   xla_client.ops.Constant(builder, np.int32(10))
   xla_comp = builder.build()
   comp_type = computation_types.FunctionType(None, np.int32)
   comp_pb = xla_serialization.create_xla_tff_computation(
       xla_comp, [], comp_type)
   backend = jax.lib.xla_bridge.get_backend()
   comp_callable = runtime.ComputationCallable(comp_pb, comp_type, backend)
   self.assertIsInstance(comp_callable, runtime.ComputationCallable)
   self.assertEqual(str(comp_callable.type_signature), '( -> int32)')
   result = comp_callable()
   self.assertEqual(result, 10)
Ejemplo n.º 4
0
 def test_to_representation_for_type_with_noarg_to_int32_comp(self):
     builder = xla_client.XlaBuilder('comp')
     xla_client.ops.Parameter(builder, 0,
                              xla_client.shape_from_pyval(tuple()))
     xla_client.ops.Constant(builder, np.int32(10))
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType(None, np.int32)
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [], comp_type)
     rep = executor.to_representation_for_type(comp_pb, comp_type,
                                               self._backend)
     self.assertTrue(callable(rep))
     result = rep()
     self.assertEqual(result, 10)
Ejemplo n.º 5
0
 def test_to_representation_for_type_with_2xint32_to_int32_comp(self):
     builder = xla_client.XlaBuilder('comp')
     param = xla_client.ops.Parameter(
         builder, 0,
         xla_client.shape_from_pyval(
             tuple([np.array(0, dtype=np.int32)] * 2)))
     xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0),
                        xla_client.ops.GetTupleElement(param, 1))
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType((np.int32, np.int32),
                                                np.int32)
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [0, 1], comp_type)
     rep = executor.to_representation_for_type(comp_pb, comp_type,
                                               self._backend)
     self.assertTrue(callable(rep))
     result = rep(
         structure.Struct([(None, np.int32(20)), (None, np.int32(30))]))
     self.assertEqual(result, 50)
Ejemplo n.º 6
0
 def test_to_representation_for_type_with_noarg_to_2xint32_comp(self):
     builder = xla_client.XlaBuilder('comp')
     xla_client.ops.Parameter(builder, 0,
                              xla_client.shape_from_pyval(tuple()))
     xla_client.ops.Tuple(builder, [
         xla_client.ops.Constant(builder, np.int32(10)),
         xla_client.ops.Constant(builder, np.int32(20))
     ])
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType(
         None,
         computation_types.StructType([('a', np.int32), ('b', np.int32)]))
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [0, 1], comp_type)
     rep = executor.to_representation_for_type(comp_pb, comp_type,
                                               self._backend)
     self.assertTrue(callable(rep))
     result = rep()
     self.assertEqual(str(result), '<a=10,b=20>')
Ejemplo n.º 7
0
 def test_computation_callable_add_two_numbers(self):
   builder = xla_client.XlaBuilder('comp')
   param = xla_client.ops.Parameter(
       builder, 0,
       xla_client.shape_from_pyval(tuple([np.array(0, dtype=np.int32)] * 2)))
   xla_client.ops.Add(
       xla_client.ops.GetTupleElement(param, 0),
       xla_client.ops.GetTupleElement(param, 1))
   xla_comp = builder.build()
   comp_type = computation_types.FunctionType((np.int32, np.int32), np.int32)
   comp_pb = xla_serialization.create_xla_tff_computation(
       xla_comp, [0, 1], comp_type)
   backend = jax.lib.xla_bridge.get_backend()
   comp_callable = runtime.ComputationCallable(comp_pb, comp_type, backend)
   self.assertIsInstance(comp_callable, runtime.ComputationCallable)
   self.assertEqual(
       str(comp_callable.type_signature), '(<int32,int32> -> int32)')
   result = comp_callable(
       structure.Struct([(None, np.int32(2)), (None, np.int32(3))]))
   self.assertEqual(result, 5)
Ejemplo n.º 8
0
    def create_constant_from_scalar(
        self, value, type_spec: computation_types.Type
    ) -> local_computation_factory_base.ComputationProtoAndType:
        py_typecheck.check_type(type_spec, computation_types.Type)
        if not type_analysis.is_structure_of_tensors(type_spec):
            raise ValueError(
                'Not a tensor or a structure of tensors: {}'.format(
                    str(type_spec)))

        builder = xla_client.XlaBuilder('comp')

        # We maintain the convention that arguments are supplied as a tuple for the
        # sake of consistency and uniformity (see comments in `computation.proto`).
        # Since there are no arguments here, we create an empty tuple.
        xla_client.ops.Parameter(builder, 0,
                                 xla_client.shape_from_pyval(tuple()))

        def _constant_from_tensor(tensor_type):
            py_typecheck.check_type(tensor_type, computation_types.TensorType)
            numpy_value = np.full(shape=tensor_type.shape.dims,
                                  fill_value=value,
                                  dtype=tensor_type.dtype.as_numpy_dtype)
            return xla_client.ops.Constant(builder, numpy_value)

        if isinstance(type_spec, computation_types.TensorType):
            tensors = [_constant_from_tensor(type_spec)]
        else:
            tensors = [
                _constant_from_tensor(x) for x in structure.flatten(type_spec)
            ]

        # Likewise, results are always returned as a single tuple with results.
        # This is always a flat tuple; the nested TFF structure is defined by the
        # binding.
        xla_client.ops.Tuple(builder, tensors)
        xla_computation = builder.build()

        comp_type = computation_types.FunctionType(None, type_spec)
        comp_pb = xla_serialization.create_xla_tff_computation(
            xla_computation, [], comp_type)
        return (comp_pb, comp_type)
Ejemplo n.º 9
0
 def test_create_and_invoke_noarg_comp_returning_int32(self):
     builder = xla_client.XlaBuilder('comp')
     xla_client.ops.Parameter(builder, 0,
                              xla_client.shape_from_pyval(tuple()))
     xla_client.ops.Constant(builder, np.int32(10))
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType(None, np.int32)
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [], comp_type)
     ex = executor.XlaExecutor()
     comp_val = asyncio.get_event_loop().run_until_complete(
         ex.create_value(comp_pb, comp_type))
     self.assertIsInstance(comp_val, executor.XlaValue)
     self.assertEqual(str(comp_val.type_signature), str(comp_type))
     self.assertTrue(callable(comp_val.internal_representation))
     result = comp_val.internal_representation()
     self.assertEqual(result, 10)
     call_val = asyncio.get_event_loop().run_until_complete(
         ex.create_call(comp_val))
     self.assertIsInstance(call_val, executor.XlaValue)
     self.assertEqual(str(call_val.type_signature), 'int32')
     result = asyncio.get_event_loop().run_until_complete(
         call_val.compute())
     self.assertEqual(result, 10)