Пример #1
0
  def testBasics(self):
    (worker,), _ = test.create_local_cluster(num_workers=1, num_ps=0)
    self.assertTrue(worker.target.startswith("grpc://"))
    tf_context = xrt.get_tf_context(worker.target[len("grpc://"):], "worker")
    backend = xrt.XrtBackend(tf_context, "XLA_CPU")

    a = np.arange(10)
    b = np.arange(10)

    c = BuildAddAndScaleComputation(
        xla_client.Shape.from_pyval(a), xla_client.Shape.from_pyval(b))

    executable = c.Compile(backend=backend)
    output = executable.ExecuteWithPythonValues((a, b))
    self.assertAllEqual(output, (a + b) * 3)
    def testTuples(self):
        (worker, ), _ = test.create_local_cluster(num_workers=1, num_ps=0)
        self.assertTrue(worker.target.startswith("grpc://"))
        tf_context = xrt.get_tf_context(worker.target[len("grpc://"):],
                                        "worker")
        backend = xrt.XrtBackend(tf_context, "XLA_CPU")

        a = np.random.randn(10)
        b = np.random.randn(15, 3)
        pieces = [
            xla_client.Buffer.from_pyval(a, backend=backend),
            xla_client.Buffer.from_pyval(b, backend=backend)
        ]
        t = xla_client.Buffer.make_tuple(pieces, backend=backend)
        a_out, b_out = t.destructure()
        self.assertAllEqual(a, a_out.to_py())
        self.assertAllEqual(b, b_out.to_py())