Exemple #1
0
def _run_test(test_dict):
    """Runs an individual test dict."""
    tf_module_builder_lambda = test_dict["tf_module_builder"]
    tf_module = tf_module_builder_lambda()
    ctx = compiler.Context()
    with tempfile.TemporaryDirectory() as sm_path:
        options = tf.saved_model.SaveOptions(save_debug_info=True)
        tf.saved_model.save(tf_module, sm_path, options=options)
        input_module = compiler.binding.tf_interop.load_saved_model(
            ctx, sm_path)

    passes = test_dict.get("passes")
    expect_pass_failure = test_dict.get("expect_pass_failure")
    if passes:
        try:
            input_module.run_pass_pipeline(passes)
        except:  # pylint: disable=bare-except
            if not expect_pass_failure:
                print(
                    "UNEXPECTED PASS FAILURE (INTERMEDIATE ASM FOLLOWS ON STDERR):",
                    file=sys.stderr)
                print(input_module.to_asm(), file=sys.stderr)
            raise

    # Print the input module ASM.
    if test_dict.get("print_input_module"):
        print(input_module.to_asm())
Exemple #2
0
 def testParseAndCompileToFlatbuffer(self):
     ctx = compiler.Context()
     input_module = ctx.parse_asm(SIMPLE_MUL_ASM)
     binary = input_module.compile()
     b = binary.bytes
     print("Flatbuffer size =", len(b))
     self.assertTrue(binary.bytes)
Exemple #3
0
 def compile_from_path(sm_path):
     compiler_context = compiler.Context()
     compiler_module = compiler.tf_load_saved_model(
         sm_path,
         exported_names=exported_names,
         compiler_context=compiler_context)
     return compiler_module.compile(target_backends=target_backends)
Exemple #4
0
 def testParseAndCompileToMlirText(self):
     ctx = compiler.Context()
     input_module = ctx.parse_asm(SIMPLE_MUL_ASM)
     options = compiler.CompileOptions()
     options.output_format = compiler.OutputFormat.MLIR_TEXT
     blob = input_module.compile(options=options)
     text = blob.text
     self.assertTrue(text)
Exemple #5
0
def create_add_scalar_module():
  ctx = compiler.Context()
  input_module = ctx.parse_asm("""
    func @add_scalar(%arg0: i32, %arg1: i32) -> i32 attributes { iree.module.export } {
      %0 = addi %arg0, %arg1 : i32
      return %0 : i32
    }
    """)
  binary = input_module.compile()
  m = rt.VmModule.from_flatbuffer(binary)
  return m
Exemple #6
0
def create_simple_static_mul_module():
  ctx = compiler.Context()
  input_module = ctx.parse_asm("""
    func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
          attributes { iree.module.export } {
        %0 = "mhlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
        return %0 : tensor<4xf32>
    }
    """)
  binary = input_module.compile()
  m = rt.VmModule.from_flatbuffer(binary)
  return m
Exemple #7
0
def create_simple_dynamic_abs_module():
  ctx = compiler.Context()
  # TODO(laurenzo): Compile for more backends as dynamic shapes come online.
  target_backends = ["vmla"]
  input_module = ctx.parse_asm("""
    func @simple_mul(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
          attributes { iree.module.export } {
        %0 = "mhlo.abs"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
        return %0 : tensor<?x?xf32>
    }
    """)
  binary = input_module.compile(target_backends=target_backends)
  m = rt.VmModule.from_flatbuffer(binary)
  return m
Exemple #8
0
 def testParseError(self):
     ctx = compiler.Context()
     with self.assertRaisesRegex(ValueError,
                                 "custom op 'FOOBAR' is unknown"):
         ctx.parse_asm("""FOOBAR: I SHOULD NOT PARSE""")