Ejemplo n.º 1
0
  def testHash(self):
    builder0 = xla_client.XlaBuilder("computation0")
    p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0)))
    p1 = ops.Parameter(builder0, 1,
                       xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
    ops.Mul(p0, p1)
    computation0 = builder0.build()

    builder1 = xla_client.XlaBuilder("computation1")
    p0 = ops.Parameter(builder1, 0, xla_client.shape_from_pyval(np.float32(0)))
    p1 = ops.Parameter(builder1, 1,
                       xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
    ops.Mul(p0, p1)
    computation1 = builder1.build()

    self.assertEqual(computation0.hash(), computation1.hash())
Ejemplo n.º 2
0
 def ExampleComputation(self):
   builder = xla_client.XlaBuilder("acomputation")
   p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0)))
   p1 = ops.Parameter(builder, 1,
                      xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
   x = ops.Mul(p0, p1)
   ops.Add(x, x)
   return builder.build()
Ejemplo n.º 3
0
 def testSetUpAlias(self):
     c = xla_client.XlaBuilder(self.id())
     p1 = ops.Parameter(
         c, 0,
         xla_client.shape_from_pyval(np.array(
             1.0, np.float32)).with_major_to_minor_layout_if_absent())
     p2 = ops.Parameter(
         c, 1,
         xla_client.shape_from_pyval(np.array(
             1.0, np.float32)).with_major_to_minor_layout_if_absent())
     out = ops.Add(p1, p2)
     c.setup_alias([], 0, [])
     c.build(out)
Ejemplo n.º 4
0
  def testXLA(self):
    """Tests that a basic saved model to XLA workflow grossly functions.

    This is largely here to verify that everything is linked in that needs to be
    and that there are not no-ops, etc.
    """
    # Generate a sample XLA computation.
    builder = xla_client.XlaBuilder("testbuilder")
    in_shape = np.array([4], dtype=np.float32)
    in_feed = ops.Parameter(builder, 0, xla_client.shape_from_pyval(in_shape))
    result = ops.Add(in_feed, ops.Constant(builder, np.float32(1.0)))
    xla_computation = builder.Build(result)

    # Load into XLA Module.
    module = compiler.xla_load_module_proto(xla_computation)

    # Validate imported ASM.
    xla_asm = module.to_asm()
    print("XLA ASM: ", xla_asm)
    self.assertRegex(xla_asm, "mhlo.add")