Esempio n. 1
0
    def __init__(self,
                 constructor,
                 backend_info,
                 exported_names,
                 artifacts_dir,
                 _from_existing_args=None):
        """Default constructor – use `compile` or `from_existing` instead."""
        super().__init__(constructor, backend_info, exported_names,
                         artifacts_dir)

        if _from_existing_args is None:
            # Called from IreeCompiledModule.compile(...)
            self._module_blob = compile_tf_module(
                tf_module=constructor(),
                target_backends=backend_info.iree_compiler_targets,
                exported_names=exported_names,
                artifacts_dir=artifacts_dir)
            self._module = rt.VmModule.from_flatbuffer(self._module_blob)
            self._config = rt.Config(driver_name=backend_info.iree_driver)
        else:
            # Called from IreeCompiledModule.from_existing(module)
            self._module_blob, self._module, self._config = _from_existing_args

        # Holds all of the module's mutable state.
        self._context = rt.SystemContext(modules=[self._module],
                                         config=self._config)
Esempio n. 2
0
    def __init__(self, backend, iree_module_blob, iree_module):
        self._backend = backend
        self._iree_module_blob = iree_module_blob
        self._iree_module = iree_module
        self._iree_module_name = self._iree_module.name

        self._system_config = rt.Config(driver_name=backend.iree_driver)
        self._context = rt.SystemContext(modules=[self._iree_module],
                                         config=self._system_config)
Esempio n. 3
0
 def test_static_invoke(self):
     ctx = rt.SystemContext()
     self.assertTrue(ctx.is_dynamic)
     ctx.add_module(create_simple_mul_module())
     self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
     f = ctx.modules.arithmetic["simple_mul"]
     arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
     arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
     results = f(arg0, arg1)
     np.testing.assert_allclose(results, [4., 10., 18., 28.])
Esempio n. 4
0
 def test_custom_dynamic(self):
   ctx = rt.SystemContext()
   self.assertTrue(ctx.is_dynamic)
   ctx.add_module(create_simple_mul_module())
   self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
   f = ctx.modules.arithmetic["simple_mul"]
   f_repr = repr(f)
   print(f_repr)
   self.assertRegex(
       f_repr,
       re.escape(
           "(Buffer<float32[4]>, Buffer<float32[4]>) -> (Buffer<float32[4]>)"))
Esempio n. 5
0
 def test_serialize_values(self):
     ctx = rt.SystemContext()
     self.assertTrue(ctx.is_dynamic)
     ctx.add_module(create_simple_mul_module())
     self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
     f = ctx.modules.arithmetic["simple_mul"]
     arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
     arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
     results = f(arg0, arg1)
     inputs, outputs = f.get_serialized_values()
     self.assertEqual(inputs, ("4xf32=1 2 3 4", "4xf32=4 5 6 7"))
     self.assertEqual(outputs, ("4xf32=4 10 18 28", ))
Esempio n. 6
0
    def __call__(self, *args, **kwargs):
        """Invokes this callable with the given set of arguments.

    Args:
      *args: Positional arguments.
      **kwargs: Keyword arguments.

    Returns:
      The result of the call.
    """
        # Context creation can be expected to be on the order of milliseconds or
        # less, so constructing one per call should be cheap enough, and can make
        # things simpler while we look for ways to support true local variables
        # in IREE and eliminate any kind of global state.
        context = iree_runtime.SystemContext(config=self._config)
        context.add_module(self._vm_module)
        callable_fn = getattr(context.modules.module, self._function_name)
        return callable_fn(*args, **kwargs)
Esempio n. 7
0
    def __init__(self,
                 module_class: Type[tf.Module],
                 backend_info: "BackendInfo",
                 exported_names: Sequence[str] = (),
                 artifacts_dir: str = None,
                 _create_reinitialized_dict: Dict[str, Any] = None):
        """Compile a tf.Module to the target backend in backend_info.

    Args:
      module_class: the tf.Module subclass to compile.
      backend_info: an element of BackendInfo corresponding to the IREE backend
        to compile to.
      exported_names: an optional iterable of strings representing which of the
        module_class's functions to compile. If exported_names is empty all
        functions will be compiled.
      artifacts_dir: an optional path to save compilation artifacts to.
      _create_reinitialized_dict: used internally.
    """
        super().__init__(module_class, backend_info, exported_names,
                         artifacts_dir)

        if _create_reinitialized_dict is None:
            set_random_seed()
            self._module_blob, self.compiled_path = compile_tf_module(
                tf_module=module_class(),
                backend_infos=[backend_info],
                exported_names=exported_names,
                artifacts_dir=artifacts_dir)
            self._module = rt.VmModule.from_flatbuffer(self._module_blob)
            self._config = rt.Config(driver_name=backend_info.driver)
        else:
            # Called from self.create_reinitialized()
            self._module_blob = _create_reinitialized_dict["_module_blob"]
            self._module = _create_reinitialized_dict["_module"]
            self._config = _create_reinitialized_dict["_config"]
            self.compiled_path = _create_reinitialized_dict["compiled_path"]

        # Holds all of the module's mutable state.
        self._context = rt.SystemContext(modules=[self._module],
                                         config=self._config)
Esempio n. 8
0
 def reinitialize(self):
     """Reinitializes all stateful variables."""
     # set_random_seed is not needed here because the model_class.__init__ is not
     # called.
     self._context = rt.SystemContext(modules=[self._vm_module],
                                      config=self._config)
Esempio n. 9
0
 def test_duplicate_module(self):
     ctx = rt.SystemContext()
     self.assertTrue(ctx.is_dynamic)
     ctx.add_module(create_simple_mul_module())
     with self.assertRaisesRegex(ValueError, "arithmetic"):
         ctx.add_module(create_simple_mul_module())
Esempio n. 10
0
 def test_empty_static(self):
     ctx = rt.SystemContext(modules=())
     self.assertFalse(ctx.is_dynamic)
     self.assertIn("hal", ctx.modules)
     self.assertEqual(ctx.modules.hal.name, "hal")
Esempio n. 11
0
 def test_empty_dynamic(self):
     ctx = rt.SystemContext()
     self.assertTrue(ctx.is_dynamic)
     self.assertIn("hal", ctx.modules)
     self.assertEqual(ctx.modules.hal.name, "hal")