Beispiel #1
0
def _run_test(test_dict):
    """Runs an individual test dict."""
    tf_module_builder_lambda = test_dict["tf_module_builder"]
    tf_module = tf_module_builder_lambda()
    ctx = compiler.Context()
    with tempfile.TemporaryDirectory() as sm_path:
        options = tf.saved_model.SaveOptions(save_debug_info=True)
        tf.saved_model.save(tf_module, sm_path, options=options)
        input_module = compiler.tf_load_saved_model(sm_path,
                                                    ctx,
                                                    pass_pipeline=())

    passes = test_dict.get("passes")
    expect_pass_failure = test_dict.get("expect_pass_failure")
    if passes:
        try:
            input_module.run_pass_pipeline(passes)
        except:  # pylint: disable=bare-except
            if not expect_pass_failure:
                print(
                    "UNEXPECTED PASS FAILURE (INTERMEDIATE ASM FOLLOWS ON STDERR):",
                    file=sys.stderr)
                print(input_module.to_asm(), file=sys.stderr)
            raise

    # Print the input module ASM.
    if test_dict.get("print_input_module"):
        print(input_module.to_asm())
Beispiel #2
0
 def compile_from_path(sm_path):
     compiler_context = compiler.Context()
     compiler_module = compiler.tf_load_saved_model(
         sm_path,
         exported_names=exported_names,
         compiler_context=compiler_context)
     return compiler_module.compile(target_backends=target_backends)
Beispiel #3
0
    def _compile_from_path(sm_path: str) -> compiler.binding.OpaqueBlob:
        """Helper function for compile_tf_module."""
        if artifacts_dir is not None:
            # Set up a crash reproducer for debugging.
            compiler.Context.default_crash_reproducer_path = os.path.join(
                artifacts_dir, f"reproducer__{backends_string}.mlir")
        try:
            # We break up the compilation here so we can save intermediary artifacts.
            compiler_context = compiler.Context()

            # Convert the tf_module into raw TF input MLIR.
            compiler_module = compiler.tf_load_saved_model(
                sm_path,
                exported_names=exported_names,
                compiler_context=compiler_context,
                pass_pipeline=())

            if artifacts_dir is not None:
                tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
                logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
                with open(tf_mlir_path, "w") as f:
                    f.write(compiler_module.to_asm())

            # Now run the passes manually that tf_load_saved_model would usually do.
            compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)

            if artifacts_dir is not None:
                iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
                logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
                with open(iree_mlir_path, "w") as f:
                    f.write(compiler_module.to_asm())

            target_backends = []
            for backend_info in backend_infos:
                target_backends.extend(backend_info.compiler_targets)
            compiled_module = compiler_module.compile(
                target_backends=target_backends)

            if artifacts_dir is not None:
                compiled_path = _get_backends_path("compiled", backend_infos,
                                                   artifacts_dir)
                compiled_path = f"{compiled_path}.vmfb"
                logging.info("Saving compiled IREE module to: %s",
                             compiled_path)
                with open(compiled_path, "wb") as f:
                    f.write(compiled_module)

            return compiled_module
        except Exception:  # pylint: disable=broad-except
            if artifacts_dir is not None:
                # Disable the crash reproducer (to avoid inadvertently overwriting it).
                compiler.Context.default_crash_reproducer_path = None
            raise
Beispiel #4
0
    def _compile_from_path(sm_path):
        """Helper function for compile_tf_module."""
        # We break up the compilation here so we can save intermediary artifacts.
        compiler_context = compiler.Context()

        if artifacts_dir is not None:
            normalized_backends = []
            for backend in target_backends:
                # Remove unusual characters and ensure names don't end or start in "_".
                backend = re.sub("[^0-9a-zA-Z_]+", "_", backend)
                normalized_backends.append(backend.strip("_"))
            backends_string = "__".join(normalized_backends)

        # Convert the tf_module into raw TF input MLIR.
        compiler_module = compiler.tf_load_saved_model(
            sm_path,
            exported_names=exported_names,
            compiler_context=compiler_context,
            pass_pipeline=())

        if artifacts_dir is not None:
            tf_mlir_path = os.path.join(artifacts_dir,
                                        f"tf_input__{backends_string}.mlir")
            logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
            with open(tf_mlir_path, "w") as f:
                f.write(compiler_module.to_asm())

        # Now run the passes manually that tf_load_saved_model would usually do.
        compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)

        if artifacts_dir is not None:
            iree_mlir_path = os.path.join(
                artifacts_dir, f"iree_input__{backends_string}.mlir")
            logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
            with open(iree_mlir_path, "w") as f:
                f.write(compiler_module.to_asm())

        compiled_module = compiler_module.compile(
            target_backends=target_backends)
        if artifacts_dir is not None:
            compiled_path = os.path.join(artifacts_dir,
                                         f"compiled__{backends_string}.vmfb")
            logging.info("Saving compiled IREE module to: %s", compiled_path)
            with open(compiled_path, "wb") as f:
                f.write(compiled_module)

        return compiled_module
Beispiel #5
0
  def testLoadSavedModelToXlaPipeline(self):
    """Tests that a basic saved model to XLA workflow grossly functions.

    This is largely here to verify that everything is linked in that needs to be
    and that there are not no-ops, etc.
    """
    with tempfile.TemporaryDirectory() as temp_dir:
      sm_dir = os.path.join(temp_dir, "simple.sm")
      print("Saving to:", sm_dir)
      my_module = StatelessModule()
      options = tf.saved_model.SaveOptions(save_debug_info=True)
      tf.saved_model.save(my_module, sm_dir, options=options)

      # Load it up.
      input_module = compiler.tf_load_saved_model(sm_dir)
      xla_asm = input_module.to_asm()
      print("XLA ASM:", xla_asm)
      self.assertRegex(xla_asm, "xla_hlo.tanh")
Beispiel #6
0
    def compile_from_path(sm_path):
        compiler_context = compiler.Context()
        # Break up the compilation so we can save debug artifacts.
        compiler_module = compiler.tf_load_saved_model(
            sm_path,
            exported_names=exported_names,
            compiler_context=compiler_context,
            pass_pipeline=())

        # Save the input MLIR module.
        flattened_target_backends = re.sub("[^0-9a-zA-Z]+", "_",
                                           "__".join(target_backends))
        if global_debug_dir:
            mlir_path = os.path.join(global_debug_dir,
                                     "raw_%s.mlir" % flattened_target_backends)
            logging.info("Saving raw TF input MLIR to: %s", mlir_path)
            with open(mlir_path, "w") as f:
                f.write(compiler_module.to_asm())

        # Now run the passes manually that tf_load_saved_model would usually do.
        compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)

        if global_debug_dir:
            mlir_path = os.path.join(
                global_debug_dir, "input_%s.mlir" % flattened_target_backends)
            logging.info("Saving IREE input MLIR to: %s", mlir_path)
            with open(mlir_path, "w") as f:
                f.write(compiler_module.to_asm())

        compiled_module = compiler_module.compile(
            target_backends=target_backends)
        if global_debug_dir:
            compiled_path = os.path.join(
                global_debug_dir,
                "compiled_%s.vmfb" % flattened_target_backends)
            logging.info("Saving compiled IREE module to: %s", compiled_path)
            with open(compiled_path, "wb") as f:
                f.write(compiled_module)

        return compiled_module