Exemple #1
0
def BuildAddAndScaleComputation(shape1, shape2):
  """Builds the computation (a + b) * 3."""
  b = xla_client.ComputationBuilder("add-and-scale")
  x = b.ParameterWithShape(shape1)
  y = b.ParameterWithShape(shape2)
  dtype = shape1.numpy_dtype().type
  b.Mul(b.Add(x, y), b.Constant(dtype(3)))
  return b.Build()
    def testXLA(self):
        """Tests that a basic saved model to XLA workflow grossly functions.

    This is largely here to verify that everything is linked in that needs to be
    and that there are not no-ops, etc.
    """
        # Generate a sample XLA computation.
        builder = xla_client.ComputationBuilder("testbuilder")
        in_shape = np.array([4], dtype=np.float32)
        in_feed = builder.ParameterWithShape(
            xla_client.shape_from_pyval(in_shape))
        result = builder.Add(in_feed, builder.ConstantF32Scalar(1.0))
        xla_computation = builder.Build(result)

        # Load into XLA Module.
        module = compiler.xla_load_module_proto(xla_computation)

        # Validate imported ASM.
        xla_asm = module.to_asm()
        print("XLA ASM: ", xla_asm)
        self.assertRegex(xla_asm, "xla_hlo.add")
 def _NewComputation(self, name=None):
     if name is None:
         name = self.id()
     return xla_client.ComputationBuilder(name)