def save_and_load_tf_module(tf_module):
    with tempfile.TemporaryDirectory() as sm_path:
        options = tf.saved_model.SaveOptions(save_debug_info=True)
        tf.saved_model.save(tf_module, sm_path, options=options)
        ctx = pyiree.CompilerContext()
        input_module = pyiree.tf_load_saved_model(ctx, sm_path)
    return input_module
Exemplo n.º 2
0
def _run_test(test_dict):
    """Runs an individual test dict."""
    tf_module_builder_lambda = test_dict["tf_module_builder"]
    tf_module = tf_module_builder_lambda()
    ctx = pyiree.CompilerContext()
    with tempfile.TemporaryDirectory() as sm_path:
        options = tf.saved_model.SaveOptions(save_debug_info=True)
        tf.saved_model.save(tf_module, sm_path, options=options)
        input_module = pyiree.tf_load_saved_model(ctx, sm_path)

    passes = test_dict.get("passes")
    expect_pass_failure = test_dict.get("expect_pass_failure")
    if passes:
        try:
            input_module.run_pass_pipeline(passes)
        except:  # pylint: disable=bare-except
            if not expect_pass_failure:
                print(
                    "UNEXPECTED PASS FAILURE (INTERMEDIATE ASM FOLLOWS ON STDERR):",
                    file=sys.stderr)
                print(input_module.to_asm(), file=sys.stderr)
            raise

    # Print the input module ASM.
    if test_dict.get("print_input_module"):
        print(input_module.to_asm())
Exemplo n.º 3
0
 def testParseAndCompileToSequencer(self):
     ctx = pyiree.CompilerContext()
     input_module = ctx.parse_asm("""
   func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
         attributes { iree.module.export } {
       %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
       return %0 : tensor<4xf32>
   }
   """)
     binary = input_module.compile_to_sequencer_blob()
     self.assertTrue(binary)
Exemplo n.º 4
0
def create_simple_mul_module():
    ctx = pyiree.CompilerContext()
    input_module = ctx.parse_asm("""
    func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
          attributes { iree.module.export } {
        %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
        return %0 : tensor<4xf32>
    }
    """)
    binary = input_module.compile_to_sequencer_blob()
    m = pyiree.binding.vm.create_module_from_blob(binary)
    return m
Exemplo n.º 5
0
def create_simple_mul_module():
    ctx = pyiree.CompilerContext()
    input_module = ctx.parse_asm("""
  module @arithmetic {
    func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
          attributes { iree.module.export } {
        %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
        return %0 : tensor<4xf32>
    }
  }
  """)
    binary = input_module.compile()
    m = pyiree.binding.vm.VmModule.from_flatbuffer(binary)
    return m
Exemplo n.º 6
0
 def testParseError(self):
     ctx = pyiree.CompilerContext()
     with self.assertRaisesRegex(ValueError,
                                 "custom op 'FOOBAR' is unknown"):
         ctx.parse_asm("""FOOBAR: I SHOULD NOT PARSE""")