Exemplo n.º 1
0
    def __init__(self,
                 constructor,
                 backend_info,
                 exported_names,
                 artifacts_dir,
                 _from_existing_args=None):
        """Default constructor – use `compile` or `from_existing` instead."""
        super().__init__(constructor, backend_info, exported_names,
                         artifacts_dir)

        if _from_existing_args is None:
            # Called from IreeCompiledModule.compile(...)
            self._module_blob = compile_tf_module(
                tf_module=constructor(),
                target_backends=backend_info.iree_compiler_targets,
                exported_names=exported_names,
                artifacts_dir=artifacts_dir)
            self._module = rt.VmModule.from_flatbuffer(self._module_blob)
            self._config = rt.Config(driver_name=backend_info.iree_driver)
        else:
            # Called from IreeCompiledModule.from_existing(module)
            self._module_blob, self._module, self._config = _from_existing_args

        # Holds all of the module's mutable state.
        self._context = rt.SystemContext(modules=[self._module],
                                         config=self._config)
Exemplo n.º 2
0
    def create_from_instance(cls,
                             module_instance: tf.Module,
                             backend_info: "BackendInfo",
                             exported_names: Sequence[str] = (),
                             artifacts_dir: str = None):
        """Compile a tf.Module instance to the target backend in backend_info.

    Args:
      module_instance: The tf.Module instance to compile.
      backend_info: BackendInfo with the details for compiling module to IREE.
      exported_names: Optional sequence representing the exported names to keep.
      artifacts_dir: An optional string pointing to where compilation artifacts
        should be saved. No compilation artifacts will be saved if this is not
        provided.
    """
        module_blob, compiled_path = _incrementally_compile_tf_module(
            module=module_instance,
            backend_info=backend_info,
            exported_names=exported_names,
            artifacts_dir=artifacts_dir)
        vm_module = rt.VmModule.from_flatbuffer(module_blob)
        config = rt.Config(driver_name=backend_info.driver)

        compiled_paths = None
        if compiled_path is not None:
            # IREE bundles every compiled method into the same compiled module.
            compiled_paths = collections.defaultdict(lambda: compiled_path)

        module_name = type(module_instance).__name__

        return cls(module_name, backend_info, compiled_paths, vm_module,
                   config)
Exemplo n.º 3
0
    def __init__(self, backend, iree_module_blob, iree_module):
        self._backend = backend
        self._iree_module_blob = iree_module_blob
        self._iree_module = iree_module
        self._iree_module_name = self._iree_module.name

        self._system_config = rt.Config(driver_name=backend.iree_driver)
        self._context = rt.SystemContext(modules=[self._iree_module],
                                         config=self._system_config)
Exemplo n.º 4
0
def _get_default_config_for_driver(driver_name):
    """Returns an IREE runtime config for the given driver.

  Enforces that there is always at most one config per driver.

  Args:
    driver_name: A string that represents the name of the driver.

  Returns:
    An instance of `iree_runtime.Config` for this driver.
  """
    # TODO(b/153499219): Upstream this to IREE in some form (we won't block on
    # this here, though, since upstreaming well would mean yanking the constructor
    # for `iree_runtime.Config` and updating all existing call sites that use it).
    py_typecheck.check_type(driver_name, str)
    with _driver_name_to_config_lock:
        config = _driver_name_to_config_dict.get(driver_name)
        if config is None:
            config = iree_runtime.Config(driver_name=driver_name)
            _driver_name_to_config_dict[driver_name] = config
        return config
Exemplo n.º 5
0
    def __init__(self,
                 module_class: Type[tf.Module],
                 backend_info: "BackendInfo",
                 exported_names: Sequence[str] = (),
                 artifacts_dir: str = None,
                 _create_reinitialized_dict: Dict[str, Any] = None):
        """Compile a tf.Module to the target backend in backend_info.

    Args:
      module_class: the tf.Module subclass to compile.
      backend_info: an element of BackendInfo corresponding to the IREE backend
        to compile to.
      exported_names: an optional iterable of strings representing which of the
        module_class's functions to compile. If exported_names is empty all
        functions will be compiled.
      artifacts_dir: an optional path to save compilation artifacts to.
      _create_reinitialized_dict: used internally.
    """
        super().__init__(module_class, backend_info, exported_names,
                         artifacts_dir)

        if _create_reinitialized_dict is None:
            set_random_seed()
            self._module_blob, self.compiled_path = compile_tf_module(
                tf_module=module_class(),
                backend_infos=[backend_info],
                exported_names=exported_names,
                artifacts_dir=artifacts_dir)
            self._module = rt.VmModule.from_flatbuffer(self._module_blob)
            self._config = rt.Config(driver_name=backend_info.driver)
        else:
            # Called from self.create_reinitialized()
            self._module_blob = _create_reinitialized_dict["_module_blob"]
            self._module = _create_reinitialized_dict["_module"]
            self._config = _create_reinitialized_dict["_config"]
            self.compiled_path = _create_reinitialized_dict["compiled_path"]

        # Holds all of the module's mutable state.
        self._context = rt.SystemContext(modules=[self._module],
                                         config=self._config)
Exemplo n.º 6
0
    def create_from_signature_def_saved_model(cls,
                                              saved_model_dir: str,
                                              saved_model_tags: Set[str],
                                              module_name: str,
                                              backend_info: "BackendInfo",
                                              exported_name: str,
                                              input_names: Sequence[str],
                                              output_names: Sequence[str],
                                              artifacts_dir: str = None):
        """Compile a SignatureDef SavedModel to the target backend in backend_info.

    Args:
      saved_model_dir: Directory of the saved model.
      saved_model_tags: Optional set of tags to use when loading the model.
      module_name: A name for this compiled module.
      backend_info: BackendInfo with the details for compiling the saved model.
      exported_name: A str representing the signature on the saved model to
        compile.
      input_names: A sequence of kwargs to feed to the saved model.
      output_names: A sequence of named outputs to extract from the saved model.
      artifacts_dir: An optional string pointing to where compilation artifacts
        should be saved. No compilation artifacts will be saved if this is not
        provided.
    """
        del input_names  # Unused.
        del output_names  # Unused.
        module_blob, compiled_path = _incrementally_compile_tf_signature_def_saved_model(
            saved_model_dir, saved_model_tags, backend_info, exported_name,
            artifacts_dir)
        vm_module = rt.VmModule.from_flatbuffer(module_blob)
        config = rt.Config(driver_name=backend_info.driver)

        compiled_paths = None
        if compiled_path is not None:
            # IREE bundles every compiled method into the same compiled module :)
            compiled_paths = collections.defaultdict(lambda: compiled_path)

        return cls(module_name, backend_info, compiled_paths, vm_module,
                   config)
Exemplo n.º 7
0
 def test_subsequent_driver(self):
     config = rt.Config("nothere1,vmla")
Exemplo n.º 8
0
 def test_non_existing_driver(self):
     with self.assertRaisesRegex(RuntimeError,
                                 "Could not create any requested driver"):
         config = rt.Config("nothere1,nothere2")
Exemplo n.º 9
0
 def test_subsequent_driver(self):
   config = rt.Config("nothere1,interpreter")
Exemplo n.º 10
0
 def __init__(self, function, driver: str, **options):
     self._function = function
     self._driver_config = rt.Config(driver)
     self._options = options
     self._memoized_signatures = {}