def testImport(self): @def_function.function def identity(i): return i concrete_function = identity.get_concrete_function( tensor_spec.TensorSpec(None, dtypes.float32)) mlir_module = mlir.convert_function(concrete_function) self.assertRegex(mlir_module, r'func @.*identity.*\(')
def testImportWithControlRet(self): @def_function.function def logging(): logging_ops.print_v2('some message') concrete_function = logging.get_concrete_function() mlir_module = mlir.convert_function(concrete_function, pass_pipeline='') self.assertRegex(mlir_module, r'tf\.PrintV2') self.assertRegex(mlir_module, r'tf_executor.fetch.*: !tf_executor.control')
def testImport(self): @def_function.function def sqr(i): return i * i concrete_function = sqr.get_concrete_function( tensor_spec.TensorSpec(None, dtypes.float32)) mlir_module = mlir.convert_function(concrete_function, show_debug_info=True) self.assertRegex(mlir_module, r'func @.*sqr.*\(') self.assertRegex(mlir_module, r'loc\(')
def testImportWithCall(self): @def_function.function def callee(i): return i @def_function.function def caller(i): return callee(i) concrete_function = caller.get_concrete_function( tensor_spec.TensorSpec(None, dtypes.float32)) mlir_module = mlir.convert_function(concrete_function) self.assertRegex(mlir_module, r'func @.*caller.*\(') self.assertRegex(mlir_module, r'func private @.*callee.*\(')