def _run_test(test_dict): """Runs an individual test dict.""" tf_module_builder_lambda = test_dict["tf_module_builder"] tf_module = tf_module_builder_lambda() ctx = compiler.Context() with tempfile.TemporaryDirectory() as sm_path: options = tf.saved_model.SaveOptions(save_debug_info=True) tf.saved_model.save(tf_module, sm_path, options=options) input_module = compiler.binding.tf_interop.load_saved_model( ctx, sm_path) passes = test_dict.get("passes") expect_pass_failure = test_dict.get("expect_pass_failure") if passes: try: input_module.run_pass_pipeline(passes) except: # pylint: disable=bare-except if not expect_pass_failure: print( "UNEXPECTED PASS FAILURE (INTERMEDIATE ASM FOLLOWS ON STDERR):", file=sys.stderr) print(input_module.to_asm(), file=sys.stderr) raise # Print the input module ASM. if test_dict.get("print_input_module"): print(input_module.to_asm())
def testParseAndCompileToFlatbuffer(self): ctx = compiler.Context() input_module = ctx.parse_asm(SIMPLE_MUL_ASM) binary = input_module.compile() b = binary.bytes print("Flatbuffer size =", len(b)) self.assertTrue(binary.bytes)
def compile_from_path(sm_path): compiler_context = compiler.Context() compiler_module = compiler.tf_load_saved_model( sm_path, exported_names=exported_names, compiler_context=compiler_context) return compiler_module.compile(target_backends=target_backends)
def testParseAndCompileToMlirText(self): ctx = compiler.Context() input_module = ctx.parse_asm(SIMPLE_MUL_ASM) options = compiler.CompileOptions() options.output_format = compiler.OutputFormat.MLIR_TEXT blob = input_module.compile(options=options) text = blob.text self.assertTrue(text)
def create_add_scalar_module(): ctx = compiler.Context() input_module = ctx.parse_asm(""" func @add_scalar(%arg0: i32, %arg1: i32) -> i32 attributes { iree.module.export } { %0 = addi %arg0, %arg1 : i32 return %0 : i32 } """) binary = input_module.compile() m = rt.VmModule.from_flatbuffer(binary) return m
def create_simple_static_mul_module(): ctx = compiler.Context() input_module = ctx.parse_asm(""" func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> attributes { iree.module.export } { %0 = "mhlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } """) binary = input_module.compile() m = rt.VmModule.from_flatbuffer(binary) return m
def create_simple_dynamic_abs_module(): ctx = compiler.Context() # TODO(laurenzo): Compile for more backends as dynamic shapes come online. target_backends = ["vmla"] input_module = ctx.parse_asm(""" func @simple_mul(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> attributes { iree.module.export } { %0 = "mhlo.abs"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32> return %0 : tensor<?x?xf32> } """) binary = input_module.compile(target_backends=target_backends) m = rt.VmModule.from_flatbuffer(binary) return m
def testParseError(self): ctx = compiler.Context() with self.assertRaisesRegex(ValueError, "custom op 'FOOBAR' is unknown"): ctx.parse_asm("""FOOBAR: I SHOULD NOT PARSE""")