def BuildAddAndScaleComputation(shape1, shape2): """Builds the computation (a + b) * 3.""" b = xla_client.ComputationBuilder("add-and-scale") x = b.ParameterWithShape(shape1) y = b.ParameterWithShape(shape2) dtype = shape1.numpy_dtype().type b.Mul(b.Add(x, y), b.Constant(dtype(3))) return b.Build()
def testXLA(self): """Tests that a basic saved model to XLA workflow grossly functions. This is largely here to verify that everything is linked in that needs to be and that there are not no-ops, etc. """ # Generate a sample XLA computation. builder = xla_client.ComputationBuilder("testbuilder") in_shape = np.array([4], dtype=np.float32) in_feed = builder.ParameterWithShape( xla_client.shape_from_pyval(in_shape)) result = builder.Add(in_feed, builder.ConstantF32Scalar(1.0)) xla_computation = builder.Build(result) # Load into XLA Module. module = compiler.xla_load_module_proto(xla_computation) # Validate imported ASM. xla_asm = module.to_asm() print("XLA ASM: ", xla_asm) self.assertRegex(xla_asm, "xla_hlo.add")
def _NewComputation(self, name=None): if name is None: name = self.id() return xla_client.ComputationBuilder(name)