Exemple #1
0
 def ExampleComputation(self):
   builder = xla_client.XlaBuilder("acomputation")
   p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0)))
   p1 = ops.Parameter(builder, 1,
                      xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
   x = ops.Mul(p0, p1)
   ops.Add(x, x)
   return builder.build()
Exemple #2
0
 def testSetUpAlias(self):
     c = xla_client.XlaBuilder(self.id())
     p1 = ops.Parameter(
         c, 0,
         xla_client.shape_from_pyval(np.array(
             1.0, np.float32)).with_major_to_minor_layout_if_absent())
     p2 = ops.Parameter(
         c, 1,
         xla_client.shape_from_pyval(np.array(
             1.0, np.float32)).with_major_to_minor_layout_if_absent())
     out = ops.Add(p1, p2)
     c.setup_alias([], 0, [])
     c.build(out)
Exemple #3
0
  def testBasics(self):
    (worker,), _ = test.create_local_cluster(num_workers=1, num_ps=0)
    self.assertTrue(worker.target.startswith("grpc://"))
    tf_context = xrt.get_tf_context(worker.target[len("grpc://"):], "worker")
    backend = xrt.XrtBackend(tf_context, "XLA_CPU")

    a = np.arange(10)
    b = np.arange(10)

    c = BuildAddAndScaleComputation(
        xla_client.shape_from_pyval(a), xla_client.shape_from_pyval(b))

    executable = c.Compile(backend=backend)
    output = xla_client.execute_with_python_values(
        executable, (a, b), backend=backend)
    self.assertAllEqual(output, (a + b) * 3)
Exemple #4
0
  def testHash(self):
    builder0 = xla_client.XlaBuilder("computation0")
    p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0)))
    p1 = ops.Parameter(builder0, 1,
                       xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
    ops.Mul(p0, p1)
    computation0 = builder0.build()

    builder1 = xla_client.XlaBuilder("computation1")
    p0 = ops.Parameter(builder1, 0, xla_client.shape_from_pyval(np.float32(0)))
    p1 = ops.Parameter(builder1, 1,
                       xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
    ops.Mul(p0, p1)
    computation1 = builder1.build()

    self.assertEqual(computation0.hash(), computation1.hash())
Exemple #5
0
    def testBasics(self):
        (worker, ), _ = test.create_local_cluster(num_workers=1, num_ps=0)
        self.assertTrue(worker.target.startswith("grpc://"))
        tf_context = xrt.get_tf_context(worker.target[len("grpc://"):],
                                        "worker")
        backend = xrt.XrtBackend(tf_context, "XLA_CPU")

        a = np.arange(10)
        b = np.arange(10)

        c = BuildAddAndScaleComputation(xla_client.shape_from_pyval(a),
                                        xla_client.shape_from_pyval(b))

        executable = c.Compile(backend=backend)
        output = executable.ExecuteWithPythonValues((a, b))
        self.assertAllEqual(output, (a + b) * 3)
Exemple #6
0
  def testXLA(self):
    """Tests that a basic saved model to XLA workflow grossly functions.

    This is largely here to verify that everything is linked in that needs to be
    and that there are not no-ops, etc.
    """
    # Generate a sample XLA computation.
    builder = xla_client.XlaBuilder("testbuilder")
    in_shape = np.array([4], dtype=np.float32)
    in_feed = ops.Parameter(builder, 0, xla_client.shape_from_pyval(in_shape))
    result = ops.Add(in_feed, ops.Constant(builder, np.float32(1.0)))
    xla_computation = builder.Build(result)

    # Load into XLA Module.
    module = compiler.xla_load_module_proto(xla_computation)

    # Validate imported ASM.
    xla_asm = module.to_asm()
    print("XLA ASM: ", xla_asm)
    self.assertRegex(xla_asm, "mhlo.add")