示例#1
0
    def testImport(self):
        @def_function.function
        def identity(i):
            return i

        concrete_function = identity.get_concrete_function(
            tensor_spec.TensorSpec(None, dtypes.float32))
        mlir_module = mlir.convert_function(concrete_function)
        self.assertRegex(mlir_module, r'func @.*identity.*\(')
示例#2
0
  def testImportWithControlRet(self):

    @def_function.function
    def logging():
      logging_ops.print_v2('some message')

    concrete_function = logging.get_concrete_function()
    mlir_module = mlir.convert_function(concrete_function, pass_pipeline='')
    self.assertRegex(mlir_module, r'tf\.PrintV2')
    self.assertRegex(mlir_module, r'tf_executor.fetch.*: !tf_executor.control')
示例#3
0
    def testImport(self):
        @def_function.function
        def sqr(i):
            return i * i

        concrete_function = sqr.get_concrete_function(
            tensor_spec.TensorSpec(None, dtypes.float32))
        mlir_module = mlir.convert_function(concrete_function,
                                            show_debug_info=True)
        self.assertRegex(mlir_module, r'func @.*sqr.*\(')
        self.assertRegex(mlir_module, r'loc\(')
示例#4
0
    def testImportWithCall(self):
        @def_function.function
        def callee(i):
            return i

        @def_function.function
        def caller(i):
            return callee(i)

        concrete_function = caller.get_concrete_function(
            tensor_spec.TensorSpec(None, dtypes.float32))
        mlir_module = mlir.convert_function(concrete_function)
        self.assertRegex(mlir_module, r'func @.*caller.*\(')
        self.assertRegex(mlir_module, r'func private @.*callee.*\(')