def __init__(self, constructor, backend_info, exported_names, artifacts_dir, _from_existing_args=None): """Default constructor – use `compile` or `from_existing` instead.""" super().__init__(constructor, backend_info, exported_names, artifacts_dir) if _from_existing_args is None: # Called from IreeCompiledModule.compile(...) self._module_blob = compile_tf_module( tf_module=constructor(), target_backends=backend_info.iree_compiler_targets, exported_names=exported_names, artifacts_dir=artifacts_dir) self._module = rt.VmModule.from_flatbuffer(self._module_blob) self._config = rt.Config(driver_name=backend_info.iree_driver) else: # Called from IreeCompiledModule.from_existing(module) self._module_blob, self._module, self._config = _from_existing_args # Holds all of the module's mutable state. self._context = rt.SystemContext(modules=[self._module], config=self._config)
def __init__(self, backend, iree_module_blob, iree_module): self._backend = backend self._iree_module_blob = iree_module_blob self._iree_module = iree_module self._iree_module_name = self._iree_module.name self._system_config = rt.Config(driver_name=backend.iree_driver) self._context = rt.SystemContext(modules=[self._iree_module], config=self._system_config)
def test_static_invoke(self): ctx = rt.SystemContext() self.assertTrue(ctx.is_dynamic) ctx.add_module(create_simple_mul_module()) self.assertEqual(ctx.modules.arithmetic.name, "arithmetic") f = ctx.modules.arithmetic["simple_mul"] arg0 = np.array([1., 2., 3., 4.], dtype=np.float32) arg1 = np.array([4., 5., 6., 7.], dtype=np.float32) results = f(arg0, arg1) np.testing.assert_allclose(results, [4., 10., 18., 28.])
def test_custom_dynamic(self): ctx = rt.SystemContext() self.assertTrue(ctx.is_dynamic) ctx.add_module(create_simple_mul_module()) self.assertEqual(ctx.modules.arithmetic.name, "arithmetic") f = ctx.modules.arithmetic["simple_mul"] f_repr = repr(f) print(f_repr) self.assertRegex( f_repr, re.escape( "(Buffer<float32[4]>, Buffer<float32[4]>) -> (Buffer<float32[4]>)"))
def test_serialize_values(self): ctx = rt.SystemContext() self.assertTrue(ctx.is_dynamic) ctx.add_module(create_simple_mul_module()) self.assertEqual(ctx.modules.arithmetic.name, "arithmetic") f = ctx.modules.arithmetic["simple_mul"] arg0 = np.array([1., 2., 3., 4.], dtype=np.float32) arg1 = np.array([4., 5., 6., 7.], dtype=np.float32) results = f(arg0, arg1) inputs, outputs = f.get_serialized_values() self.assertEqual(inputs, ("4xf32=1 2 3 4", "4xf32=4 5 6 7")) self.assertEqual(outputs, ("4xf32=4 10 18 28", ))
def __call__(self, *args, **kwargs): """Invokes this callable with the given set of arguments. Args: *args: Positional arguments. **kwargs: Keyword arguments. Returns: The result of the call. """ # Context creation can be expected to be on the order of milliseconds or # less, so constructing one per call should be cheap enough, and can make # things simpler while we look for ways to support true local variables # in IREE and eliminate any kind of global state. context = iree_runtime.SystemContext(config=self._config) context.add_module(self._vm_module) callable_fn = getattr(context.modules.module, self._function_name) return callable_fn(*args, **kwargs)
def __init__(self, module_class: Type[tf.Module], backend_info: "BackendInfo", exported_names: Sequence[str] = (), artifacts_dir: str = None, _create_reinitialized_dict: Dict[str, Any] = None): """Compile a tf.Module to the target backend in backend_info. Args: module_class: the tf.Module subclass to compile. backend_info: an element of BackendInfo corresponding to the IREE backend to compile to. exported_names: an optional iterable of strings representing which of the module_class's functions to compile. If exported_names is empty all functions will be compiled. artifacts_dir: an optional path to save compilation artifacts to. _create_reinitialized_dict: used internally. """ super().__init__(module_class, backend_info, exported_names, artifacts_dir) if _create_reinitialized_dict is None: set_random_seed() self._module_blob, self.compiled_path = compile_tf_module( tf_module=module_class(), backend_infos=[backend_info], exported_names=exported_names, artifacts_dir=artifacts_dir) self._module = rt.VmModule.from_flatbuffer(self._module_blob) self._config = rt.Config(driver_name=backend_info.driver) else: # Called from self.create_reinitialized() self._module_blob = _create_reinitialized_dict["_module_blob"] self._module = _create_reinitialized_dict["_module"] self._config = _create_reinitialized_dict["_config"] self.compiled_path = _create_reinitialized_dict["compiled_path"] # Holds all of the module's mutable state. self._context = rt.SystemContext(modules=[self._module], config=self._config)
def reinitialize(self): """Reinitializes all stateful variables.""" # set_random_seed is not needed here because the model_class.__init__ is not # called. self._context = rt.SystemContext(modules=[self._vm_module], config=self._config)
def test_duplicate_module(self): ctx = rt.SystemContext() self.assertTrue(ctx.is_dynamic) ctx.add_module(create_simple_mul_module()) with self.assertRaisesRegex(ValueError, "arithmetic"): ctx.add_module(create_simple_mul_module())
def test_empty_static(self): ctx = rt.SystemContext(modules=()) self.assertFalse(ctx.is_dynamic) self.assertIn("hal", ctx.modules) self.assertEqual(ctx.modules.hal.name, "hal")
def test_empty_dynamic(self): ctx = rt.SystemContext() self.assertTrue(ctx.is_dynamic) self.assertIn("hal", ctx.modules) self.assertEqual(ctx.modules.hal.name, "hal")