def testBasics(self): (worker,), _ = test.create_local_cluster(num_workers=1, num_ps=0) self.assertTrue(worker.target.startswith("grpc://")) tf_context = xrt.get_tf_context(worker.target[len("grpc://"):], "worker") backend = xrt.XrtBackend(tf_context, "XLA_CPU") a = np.arange(10) b = np.arange(10) c = BuildAddAndScaleComputation( xla_client.Shape.from_pyval(a), xla_client.Shape.from_pyval(b)) executable = c.Compile(backend=backend) output = executable.ExecuteWithPythonValues((a, b)) self.assertAllEqual(output, (a + b) * 3)
def testTuples(self): (worker, ), _ = test.create_local_cluster(num_workers=1, num_ps=0) self.assertTrue(worker.target.startswith("grpc://")) tf_context = xrt.get_tf_context(worker.target[len("grpc://"):], "worker") backend = xrt.XrtBackend(tf_context, "XLA_CPU") a = np.random.randn(10) b = np.random.randn(15, 3) pieces = [ xla_client.Buffer.from_pyval(a, backend=backend), xla_client.Buffer.from_pyval(b, backend=backend) ] t = xla_client.Buffer.make_tuple(pieces, backend=backend) a_out, b_out = t.destructure() self.assertAllEqual(a, a_out.to_py()) self.assertAllEqual(b, b_out.to_py())