def __init__(self, constructor, backend_info, exported_names, artifacts_dir, _from_existing_args=None): """Default constructor – use `compile` or `from_existing` instead.""" super().__init__(constructor, backend_info, exported_names, artifacts_dir) if _from_existing_args is None: # Called from IreeCompiledModule.compile(...) self._module_blob = compile_tf_module( tf_module=constructor(), target_backends=backend_info.iree_compiler_targets, exported_names=exported_names, artifacts_dir=artifacts_dir) self._module = rt.VmModule.from_flatbuffer(self._module_blob) self._config = rt.Config(driver_name=backend_info.iree_driver) else: # Called from IreeCompiledModule.from_existing(module) self._module_blob, self._module, self._config = _from_existing_args # Holds all of the module's mutable state. self._context = rt.SystemContext(modules=[self._module], config=self._config)
def create_from_instance(cls, module_instance: tf.Module, backend_info: "BackendInfo", exported_names: Sequence[str] = (), artifacts_dir: str = None): """Compile a tf.Module instance to the target backend in backend_info. Args: module_instance: The tf.Module instance to compile. backend_info: BackendInfo with the details for compiling module to IREE. exported_names: Optional sequence representing the exported names to keep. artifacts_dir: An optional string pointing to where compilation artifacts should be saved. No compilation artifacts will be saved if this is not provided. """ module_blob, compiled_path = _incrementally_compile_tf_module( module=module_instance, backend_info=backend_info, exported_names=exported_names, artifacts_dir=artifacts_dir) vm_module = rt.VmModule.from_flatbuffer(module_blob) config = rt.Config(driver_name=backend_info.driver) compiled_paths = None if compiled_path is not None: # IREE bundles every compiled method into the same compiled module. compiled_paths = collections.defaultdict(lambda: compiled_path) module_name = type(module_instance).__name__ return cls(module_name, backend_info, compiled_paths, vm_module, config)
def __init__(self, backend, iree_module_blob, iree_module): self._backend = backend self._iree_module_blob = iree_module_blob self._iree_module = iree_module self._iree_module_name = self._iree_module.name self._system_config = rt.Config(driver_name=backend.iree_driver) self._context = rt.SystemContext(modules=[self._iree_module], config=self._system_config)
def _get_default_config_for_driver(driver_name): """Returns an IREE runtime config for the given driver. Enforces that there is always at most one config per driver. Args: driver_name: A string that represents the name of the driver. Returns: An instance of `iree_runtime.Config` for this driver. """ # TODO(b/153499219): Upstream this to IREE in some form (we won't block on # this here, though, since upstreaming well would mean yanking the constructor # for `iree_runtime.Config` and updating all existing call sites that use it). py_typecheck.check_type(driver_name, str) with _driver_name_to_config_lock: config = _driver_name_to_config_dict.get(driver_name) if config is None: config = iree_runtime.Config(driver_name=driver_name) _driver_name_to_config_dict[driver_name] = config return config
def __init__(self, module_class: Type[tf.Module], backend_info: "BackendInfo", exported_names: Sequence[str] = (), artifacts_dir: str = None, _create_reinitialized_dict: Dict[str, Any] = None): """Compile a tf.Module to the target backend in backend_info. Args: module_class: the tf.Module subclass to compile. backend_info: an element of BackendInfo corresponding to the IREE backend to compile to. exported_names: an optional iterable of strings representing which of the module_class's functions to compile. If exported_names is empty all functions will be compiled. artifacts_dir: an optional path to save compilation artifacts to. _create_reinitialized_dict: used internally. """ super().__init__(module_class, backend_info, exported_names, artifacts_dir) if _create_reinitialized_dict is None: set_random_seed() self._module_blob, self.compiled_path = compile_tf_module( tf_module=module_class(), backend_infos=[backend_info], exported_names=exported_names, artifacts_dir=artifacts_dir) self._module = rt.VmModule.from_flatbuffer(self._module_blob) self._config = rt.Config(driver_name=backend_info.driver) else: # Called from self.create_reinitialized() self._module_blob = _create_reinitialized_dict["_module_blob"] self._module = _create_reinitialized_dict["_module"] self._config = _create_reinitialized_dict["_config"] self.compiled_path = _create_reinitialized_dict["compiled_path"] # Holds all of the module's mutable state. self._context = rt.SystemContext(modules=[self._module], config=self._config)
def create_from_signature_def_saved_model(cls, saved_model_dir: str, saved_model_tags: Set[str], module_name: str, backend_info: "BackendInfo", exported_name: str, input_names: Sequence[str], output_names: Sequence[str], artifacts_dir: str = None): """Compile a SignatureDef SavedModel to the target backend in backend_info. Args: saved_model_dir: Directory of the saved model. saved_model_tags: Optional set of tags to use when loading the model. module_name: A name for this compiled module. backend_info: BackendInfo with the details for compiling the saved model. exported_name: A str representing the signature on the saved model to compile. input_names: A sequence of kwargs to feed to the saved model. output_names: A sequence of named outputs to extract from the saved model. artifacts_dir: An optional string pointing to where compilation artifacts should be saved. No compilation artifacts will be saved if this is not provided. """ del input_names # Unused. del output_names # Unused. module_blob, compiled_path = _incrementally_compile_tf_signature_def_saved_model( saved_model_dir, saved_model_tags, backend_info, exported_name, artifacts_dir) vm_module = rt.VmModule.from_flatbuffer(module_blob) config = rt.Config(driver_name=backend_info.driver) compiled_paths = None if compiled_path is not None: # IREE bundles every compiled method into the same compiled module :) compiled_paths = collections.defaultdict(lambda: compiled_path) return cls(module_name, backend_info, compiled_paths, vm_module, config)
def test_subsequent_driver(self): config = rt.Config("nothere1,vmla")
def test_non_existing_driver(self): with self.assertRaisesRegex(RuntimeError, "Could not create any requested driver"): config = rt.Config("nothere1,nothere2")
def test_subsequent_driver(self): config = rt.Config("nothere1,interpreter")
def __init__(self, function, driver: str, **options): self._function = function self._driver_config = rt.Config(driver) self._options = options self._memoized_signatures = {}