def test_add_numbers(self): builder = xla_client.XlaBuilder('comp') param = xla_client.ops.Parameter( builder, 0, xla_client.shape_from_pyval( tuple([np.array(0, dtype=np.int32)] * 2))) xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0), xla_client.ops.GetTupleElement(param, 1)) xla_comp = builder.build() comp_type = computation_types.FunctionType((np.int32, np.int32), np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [0, 1], comp_type) ex = executor.XlaExecutor() async def _compute_fn(): comp_val = await ex.create_value(comp_pb, comp_type) x_val = await ex.create_value(20, np.int32) y_val = await ex.create_value(30, np.int32) arg_val = await ex.create_struct([x_val, y_val]) call_val = await ex.create_call(comp_val, arg_val) return await call_val.compute() result = asyncio.get_event_loop().run_until_complete(_compute_fn()) self.assertEqual(result, 50)
def test_create_compute_int32(self): ex = executor.XlaExecutor() int_val = asyncio.run(ex.create_value(10, np.int32)) self.assertIsInstance(int_val, executor.XlaValue) self.assertEqual(str(int_val.type_signature), 'int32') self.assertIsInstance(int_val.internal_representation, np.int32) self.assertEqual(int_val.internal_representation, 10) result = asyncio.run(int_val.compute()) self.assertEqual(result, 10)
def test_create_compute_2xint32_struct(self): ex = executor.XlaExecutor() x_val = asyncio.run(ex.create_value(10, np.int32)) y_val = asyncio.run(ex.create_value(20, np.int32)) struct_val = asyncio.run(ex.create_struct([x_val, y_val])) self.assertIsInstance(struct_val, executor.XlaValue) self.assertEqual(str(struct_val.type_signature), '<int32,int32>') self.assertIsInstance(struct_val.internal_representation, structure.Struct) self.assertEqual(str(struct_val.internal_representation), '<10,20>') result = asyncio.run(struct_val.compute()) self.assertEqual(str(result), '<10,20>')
def test_selection(self): ex = executor.XlaExecutor() struct_val = asyncio.run( ex.create_value( collections.OrderedDict([('a', 10), ('b', 20)]), computation_types.StructType([('a', np.int32), ('b', np.int32)]))) self.assertIsInstance(struct_val, executor.XlaValue) self.assertEqual(str(struct_val.type_signature), '<a=int32,b=int32>') by_index_val = asyncio.run(ex.create_selection(struct_val, index=0)) self.assertEqual(by_index_val.internal_representation, 10) by_name_val = asyncio.run(ex.create_selection(struct_val, name='b')) self.assertEqual(by_name_val.internal_representation, 20)
def test_create_and_invoke_noarg_comp_returning_int32(self): builder = xla_client.XlaBuilder('comp') xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) xla_client.ops.Constant(builder, np.int32(10)) xla_comp = builder.build() comp_type = computation_types.FunctionType(None, np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [], comp_type) ex = executor.XlaExecutor() comp_val = asyncio.run(ex.create_value(comp_pb, comp_type)) self.assertIsInstance(comp_val, executor.XlaValue) self.assertEqual(str(comp_val.type_signature), str(comp_type)) self.assertTrue(callable(comp_val.internal_representation)) result = comp_val.internal_representation() self.assertEqual(result, 10) call_val = asyncio.run(ex.create_call(comp_val)) self.assertIsInstance(call_val, executor.XlaValue) self.assertEqual(str(call_val.type_signature), 'int32') result = asyncio.run(call_val.compute()) self.assertEqual(result, 10)