def ExampleComputation(self): builder = xla_client.XlaBuilder("acomputation") p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0))) p1 = ops.Parameter(builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) x = ops.Mul(p0, p1) ops.Add(x, x) return builder.build()
def testSetUpAlias(self): c = xla_client.XlaBuilder(self.id()) p1 = ops.Parameter( c, 0, xla_client.shape_from_pyval(np.array( 1.0, np.float32)).with_major_to_minor_layout_if_absent()) p2 = ops.Parameter( c, 1, xla_client.shape_from_pyval(np.array( 1.0, np.float32)).with_major_to_minor_layout_if_absent()) out = ops.Add(p1, p2) c.setup_alias([], 0, []) c.build(out)
def testBasics(self): (worker,), _ = test.create_local_cluster(num_workers=1, num_ps=0) self.assertTrue(worker.target.startswith("grpc://")) tf_context = xrt.get_tf_context(worker.target[len("grpc://"):], "worker") backend = xrt.XrtBackend(tf_context, "XLA_CPU") a = np.arange(10) b = np.arange(10) c = BuildAddAndScaleComputation( xla_client.shape_from_pyval(a), xla_client.shape_from_pyval(b)) executable = c.Compile(backend=backend) output = xla_client.execute_with_python_values( executable, (a, b), backend=backend) self.assertAllEqual(output, (a + b) * 3)
def testHash(self): builder0 = xla_client.XlaBuilder("computation0") p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0))) p1 = ops.Parameter(builder0, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) ops.Mul(p0, p1) computation0 = builder0.build() builder1 = xla_client.XlaBuilder("computation1") p0 = ops.Parameter(builder1, 0, xla_client.shape_from_pyval(np.float32(0))) p1 = ops.Parameter(builder1, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) ops.Mul(p0, p1) computation1 = builder1.build() self.assertEqual(computation0.hash(), computation1.hash())
def testBasics(self): (worker, ), _ = test.create_local_cluster(num_workers=1, num_ps=0) self.assertTrue(worker.target.startswith("grpc://")) tf_context = xrt.get_tf_context(worker.target[len("grpc://"):], "worker") backend = xrt.XrtBackend(tf_context, "XLA_CPU") a = np.arange(10) b = np.arange(10) c = BuildAddAndScaleComputation(xla_client.shape_from_pyval(a), xla_client.shape_from_pyval(b)) executable = c.Compile(backend=backend) output = executable.ExecuteWithPythonValues((a, b)) self.assertAllEqual(output, (a + b) * 3)
def testXLA(self): """Tests that a basic saved model to XLA workflow grossly functions. This is largely here to verify that everything is linked in that needs to be and that there are not no-ops, etc. """ # Generate a sample XLA computation. builder = xla_client.XlaBuilder("testbuilder") in_shape = np.array([4], dtype=np.float32) in_feed = ops.Parameter(builder, 0, xla_client.shape_from_pyval(in_shape)) result = ops.Add(in_feed, ops.Constant(builder, np.float32(1.0))) xla_computation = builder.Build(result) # Load into XLA Module. module = compiler.xla_load_module_proto(xla_computation) # Validate imported ASM. xla_asm = module.to_asm() print("XLA ASM: ", xla_asm) self.assertRegex(xla_asm, "mhlo.add")