Beispiel #1
0
    def test_add_numbers(self):
        builder = xla_client.XlaBuilder('comp')
        param = xla_client.ops.Parameter(
            builder, 0,
            xla_client.shape_from_pyval(
                tuple([np.array(0, dtype=np.int32)] * 2)))
        xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0),
                           xla_client.ops.GetTupleElement(param, 1))
        xla_comp = builder.build()
        comp_type = computation_types.FunctionType((np.int32, np.int32),
                                                   np.int32)
        comp_pb = xla_serialization.create_xla_tff_computation(
            xla_comp, [0, 1], comp_type)
        ex = executor.XlaExecutor()

        async def _compute_fn():
            comp_val = await ex.create_value(comp_pb, comp_type)
            x_val = await ex.create_value(20, np.int32)
            y_val = await ex.create_value(30, np.int32)
            arg_val = await ex.create_struct([x_val, y_val])
            call_val = await ex.create_call(comp_val, arg_val)
            return await call_val.compute()

        result = asyncio.get_event_loop().run_until_complete(_compute_fn())
        self.assertEqual(result, 50)
Beispiel #2
0
 def test_create_compute_int32(self):
     ex = executor.XlaExecutor()
     int_val = asyncio.run(ex.create_value(10, np.int32))
     self.assertIsInstance(int_val, executor.XlaValue)
     self.assertEqual(str(int_val.type_signature), 'int32')
     self.assertIsInstance(int_val.internal_representation, np.int32)
     self.assertEqual(int_val.internal_representation, 10)
     result = asyncio.run(int_val.compute())
     self.assertEqual(result, 10)
Beispiel #3
0
 def test_create_compute_2xint32_struct(self):
     ex = executor.XlaExecutor()
     x_val = asyncio.run(ex.create_value(10, np.int32))
     y_val = asyncio.run(ex.create_value(20, np.int32))
     struct_val = asyncio.run(ex.create_struct([x_val, y_val]))
     self.assertIsInstance(struct_val, executor.XlaValue)
     self.assertEqual(str(struct_val.type_signature), '<int32,int32>')
     self.assertIsInstance(struct_val.internal_representation,
                           structure.Struct)
     self.assertEqual(str(struct_val.internal_representation), '<10,20>')
     result = asyncio.run(struct_val.compute())
     self.assertEqual(str(result), '<10,20>')
Beispiel #4
0
 def test_selection(self):
     ex = executor.XlaExecutor()
     struct_val = asyncio.run(
         ex.create_value(
             collections.OrderedDict([('a', 10), ('b', 20)]),
             computation_types.StructType([('a', np.int32),
                                           ('b', np.int32)])))
     self.assertIsInstance(struct_val, executor.XlaValue)
     self.assertEqual(str(struct_val.type_signature), '<a=int32,b=int32>')
     by_index_val = asyncio.run(ex.create_selection(struct_val, index=0))
     self.assertEqual(by_index_val.internal_representation, 10)
     by_name_val = asyncio.run(ex.create_selection(struct_val, name='b'))
     self.assertEqual(by_name_val.internal_representation, 20)
Beispiel #5
0
 def test_create_and_invoke_noarg_comp_returning_int32(self):
     builder = xla_client.XlaBuilder('comp')
     xla_client.ops.Parameter(builder, 0,
                              xla_client.shape_from_pyval(tuple()))
     xla_client.ops.Constant(builder, np.int32(10))
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType(None, np.int32)
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [], comp_type)
     ex = executor.XlaExecutor()
     comp_val = asyncio.run(ex.create_value(comp_pb, comp_type))
     self.assertIsInstance(comp_val, executor.XlaValue)
     self.assertEqual(str(comp_val.type_signature), str(comp_type))
     self.assertTrue(callable(comp_val.internal_representation))
     result = comp_val.internal_representation()
     self.assertEqual(result, 10)
     call_val = asyncio.run(ex.create_call(comp_val))
     self.assertIsInstance(call_val, executor.XlaValue)
     self.assertEqual(str(call_val.type_signature), 'int32')
     result = asyncio.run(call_val.compute())
     self.assertEqual(result, 10)