def testHash(self): builder0 = xla_client.XlaBuilder("computation0") p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0))) p1 = ops.Parameter(builder0, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) ops.Mul(p0, p1) computation0 = builder0.build() builder1 = xla_client.XlaBuilder("computation1") p0 = ops.Parameter(builder1, 0, xla_client.shape_from_pyval(np.float32(0))) p1 = ops.Parameter(builder1, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) ops.Mul(p0, p1) computation1 = builder1.build() self.assertEqual(computation0.hash(), computation1.hash())
def ExampleComputation(self): builder = xla_client.XlaBuilder("acomputation") p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0))) p1 = ops.Parameter(builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) x = ops.Mul(p0, p1) ops.Add(x, x) return builder.build()
def testSetUpAlias(self): c = xla_client.XlaBuilder(self.id()) p1 = ops.Parameter( c, 0, xla_client.shape_from_pyval(np.array( 1.0, np.float32)).with_major_to_minor_layout_if_absent()) p2 = ops.Parameter( c, 1, xla_client.shape_from_pyval(np.array( 1.0, np.float32)).with_major_to_minor_layout_if_absent()) out = ops.Add(p1, p2) c.setup_alias([], 0, []) c.build(out)
def testXLA(self): """Tests that a basic saved model to XLA workflow grossly functions. This is largely here to verify that everything is linked in that needs to be and that there are not no-ops, etc. """ # Generate a sample XLA computation. builder = xla_client.XlaBuilder("testbuilder") in_shape = np.array([4], dtype=np.float32) in_feed = ops.Parameter(builder, 0, xla_client.shape_from_pyval(in_shape)) result = ops.Add(in_feed, ops.Constant(builder, np.float32(1.0))) xla_computation = builder.Build(result) # Load into XLA Module. module = compiler.xla_load_module_proto(xla_computation) # Validate imported ASM. xla_asm = module.to_asm() print("XLA ASM: ", xla_asm) self.assertRegex(xla_asm, "mhlo.add")