def save_and_load_tf_module(tf_module): with tempfile.TemporaryDirectory() as sm_path: options = tf.saved_model.SaveOptions(save_debug_info=True) tf.saved_model.save(tf_module, sm_path, options=options) ctx = pyiree.CompilerContext() input_module = pyiree.tf_load_saved_model(ctx, sm_path) return input_module
def _run_test(test_dict): """Runs an individual test dict.""" tf_module_builder_lambda = test_dict["tf_module_builder"] tf_module = tf_module_builder_lambda() ctx = pyiree.CompilerContext() with tempfile.TemporaryDirectory() as sm_path: options = tf.saved_model.SaveOptions(save_debug_info=True) tf.saved_model.save(tf_module, sm_path, options=options) input_module = pyiree.tf_load_saved_model(ctx, sm_path) passes = test_dict.get("passes") expect_pass_failure = test_dict.get("expect_pass_failure") if passes: try: input_module.run_pass_pipeline(passes) except: # pylint: disable=bare-except if not expect_pass_failure: print( "UNEXPECTED PASS FAILURE (INTERMEDIATE ASM FOLLOWS ON STDERR):", file=sys.stderr) print(input_module.to_asm(), file=sys.stderr) raise # Print the input module ASM. if test_dict.get("print_input_module"): print(input_module.to_asm())
def testParseAndCompileToSequencer(self): ctx = pyiree.CompilerContext() input_module = ctx.parse_asm(""" func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> attributes { iree.module.export } { %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } """) binary = input_module.compile_to_sequencer_blob() self.assertTrue(binary)
def create_simple_mul_module(): ctx = pyiree.CompilerContext() input_module = ctx.parse_asm(""" func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> attributes { iree.module.export } { %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } """) binary = input_module.compile_to_sequencer_blob() m = pyiree.binding.vm.create_module_from_blob(binary) return m
def create_simple_mul_module(): ctx = pyiree.CompilerContext() input_module = ctx.parse_asm(""" module @arithmetic { func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> attributes { iree.module.export } { %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } } """) binary = input_module.compile() m = pyiree.binding.vm.VmModule.from_flatbuffer(binary) return m
def testParseError(self): ctx = pyiree.CompilerContext() with self.assertRaisesRegex(ValueError, "custom op 'FOOBAR' is unknown"): ctx.parse_asm("""FOOBAR: I SHOULD NOT PARSE""")