コード例 #1
0
ファイル: tf_test_utils.py プロジェクト: BernhardRiemann/iree
def compile_tf_module(
    module_class: Type[tf.Module], exported_names: Sequence[str] = ()
) -> Callable[[Any], Any]:
    """Compiles module_class to each backend that we test.

  Args:
    module_class: the tf.Module subclass to compile.
    exported_names: optional iterable of strings representing which of
      module_class's functions to compile. If exported_names is empty all
      functions will be compiled.

  Returns:
    A 'Modules' namedtuple containing the reference module, target modules and
    artifacts directory.
  """

    # Setup the directory for saving compilation artifacts and traces.
    artifacts_dir = _setup_artifacts_dir(module_class.__name__)

    # Get the backend information for this test.
    ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
                                            f"{FLAGS.reference_backend}_ref")
    tar_backend_infos = get_target_backends()

    compile_backend = lambda backend_info: backend_info.compile_from_class(
        module_class, exported_names, artifacts_dir)

    ref_module = compile_backend(ref_backend_info)
    tar_modules = [
        compile_backend(backend_info) for backend_info in tar_backend_infos
    ]
    return Modules(ref_module, tar_modules, artifacts_dir)
コード例 #2
0
  def test_trace_serialize_and_load(self):

    def trace_function(module):
      module.increment()
      module.increment_by(np.array([81.], dtype=np.float32))
      module.increment_by_max(np.array([81], dtype=np.float32),
                              np.array([92], dtype=np.float32))
      module.get_count()

    module = tf_utils.IreeCompiledModule.create_from_class(
        StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
    trace = tf_test_utils.Trace(module, trace_function)
    trace_function(tf_test_utils.TracedModule(module, trace))

    with tempfile.TemporaryDirectory() as artifacts_dir:
      trace_function_dir = tf_test_utils._get_trace_dir(artifacts_dir, trace)
      trace.serialize(trace_function_dir)
      self.assertTrue(
          os.path.exists(os.path.join(trace_function_dir, 'metadata.pkl')))
      loaded_trace = tf_test_utils.Trace.load(trace_function_dir)

      # Check all calls match.
      self.assertTrue(tf_test_utils.Trace.compare_traces(trace, loaded_trace))

      # Check all other metadata match.
      self.assertAllEqual(trace.__dict__.keys(), loaded_trace.__dict__.keys())
      for key in trace.__dict__.keys():
        if key != 'calls':
          self.assertEqual(trace.__dict__[key], loaded_trace.__dict__[key])
コード例 #3
0
  def test_nonmatching_inputs(self):

    def tf_function(module):
      module.increment_by(np.array([42.], dtype=np.float32))

    def vmla_function(module):
      module.increment_by(np.array([22.], dtype=np.float32))

    tf_module = tf_utils.TfCompiledModule.create_from_class(
        StatefulCountingModule, tf_utils.BackendInfo('tf'))
    tf_trace = tf_test_utils.Trace(tf_module, tf_function)
    tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))

    vmla_module = tf_utils.IreeCompiledModule.create_from_class(
        StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
    vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
    vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))

    self.assertFalse(tf_test_utils.Trace.compare_traces(tf_trace, vmla_trace))
コード例 #4
0
    def test_unaltered_state(self, backend_name):
        backend_info = tf_utils.BackendInfo(backend_name)
        module = backend_info.compile_from_class(StatefulCountingModule)

        # Test that incrementing works properly.
        self.assertEqual([0.], module.get_count())
        module.increment()
        self.assertEqual([1.], module.get_count())

        module.reinitialize()
        # Test reinitialization.
        self.assertEqual([0.], module.get_count())
コード例 #5
0
    def test_random_initialization(self, backend_name):
        backend_info = tf_utils.BackendInfo(backend_name)

        # Test compilation is the same.
        module_1 = backend_info.compile_from_class(RandomInitModule)
        module_2 = backend_info.compile_from_class(RandomInitModule)
        self.assertAllEqual(module_1.get(), module_2.get())

        # Test reinitialization is the same.
        old_value = module_1.get()
        module_1.reinitialize()
        self.assertAllEqual(old_value, module_1.get())
コード例 #6
0
    def test_nonmatching_methods(self):
        def tf_function(module):
            module.increment()
            module.increment()

        def vmla_function(module):
            module.increment()
            module.decrement()

        tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
                                              tf_utils.BackendInfo('tf'))
        tf_trace = tf_test_utils.Trace(tf_module, tf_function)
        tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))

        vmla_module = tf_utils.IreeCompiledModule(
            StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
        vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
        vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))

        with self.assertRaises(ValueError):
            tf_test_utils.Trace.compare_traces(tf_trace, vmla_trace)
コード例 #7
0
ファイル: tf_test_utils.py プロジェクト: navdeepkk/iree
    def setUp(self):
        # Ran before each unit test.
        super().setUp()
        # Create a CompiledModule for the reference backend and each target backend.
        ref_backend_info = tf_utils.BackendInfo(
            FLAGS.reference_backend, f"{FLAGS.reference_backend}_ref")
        self._ref_module = self._compile(ref_backend_info)

        tar_backend_infos = get_target_backends()
        self._tar_modules = [
            self._compile(backend_info) for backend_info in tar_backend_infos
        ]
コード例 #8
0
    def test_unaltered_state(self, backend_name):
        backend_info = tf_utils.BackendInfo(backend_name)
        module = backend_info.compile(StatefulCountingModule)

        # Test that incrementing works properly.
        self.assertEqual([0.], module.get_count())
        module.increment()
        self.assertEqual([1.], module.get_count())

        reinitialized_module = module.create_reinitialized()
        # Test reinitialization.
        self.assertEqual([0.], reinitialized_module.get_count())
        # Test independent state.
        self.assertEqual([1.], module.get_count())
コード例 #9
0
  def test_artifact_saving(self):
    backend_info = tf_utils.BackendInfo('iree_vmla')
    with tempfile.TemporaryDirectory() as artifacts_dir:
      tf_module = ConstantModule()
      iree_compiled_module, compiled_path = (
          tf_utils._incrementally_compile_tf_module(
              tf_module, backend_info=backend_info,
              artifacts_dir=artifacts_dir))

      artifacts_to_check = [
          'tf_input.mlir',
          'iree_input.mlir',
          compiled_path,
      ]
      for artifact in artifacts_to_check:
        artifact_path = os.path.join(artifacts_dir, artifact)
        logging.info('Checking path: %s', artifact_path)
        self.assertTrue(os.path.exists(artifact_path))
コード例 #10
0
def compile_tf_signature_def_saved_model(saved_model_dir: str,
                                         saved_model_tags: Set[str],
                                         module_name: str, exported_name: str,
                                         input_names: Sequence[str],
                                         output_names: Sequence[str]):
    """Compiles a SignatureDef SavedModel to each backend that we test.

  Args:
    saved_model_dir: Directory of the saved model.
    saved_model_tags: Optional set of tags to use when loading the model.
    module_name: A name for this compiled module.
    backend_info: BackendInfo with the details for compiling the saved model.
    exported_name: A str representing the signature on the saved model to
      compile.
    input_names: A sequence of kwargs to feed to the saved model.
    output_names: A sequence of named outputs to extract from the saved model.

  Returns:
    A 'Modules' namedtuple containing the reference module, target modules and
    artifacts directory.
  """
    global _global_modules
    if _global_modules is not None:
        return _global_modules

    # Setup the directory for saving compilation artifacts and traces.
    artifacts_dir = _setup_artifacts_dir(module_name)

    # Get the backend information for this test.
    ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
                                            f"{FLAGS.reference_backend}_ref")
    tar_backend_infos = get_target_backends()

    compile_backend = (
        lambda backend_info: backend_info.compile_signature_def_saved_model(
            saved_model_dir, saved_model_tags, module_name, exported_name,
            input_names, output_names, artifacts_dir))

    ref_module = compile_backend(ref_backend_info)
    tar_modules = [
        compile_backend(backend_info) for backend_info in tar_backend_infos
    ]
    _global_modules = Modules(ref_module, tar_modules, artifacts_dir)
    return _global_modules
コード例 #11
0
ファイル: tf_test_utils.py プロジェクト: sailfish009/iree
def get_target_backends() -> Sequence[tf_utils.BackendInfo]:
  """Gets the BackendInfo instances to compare with the reference backend.

  By default all backends in BackendInfo will be used. Specific backends to
  run on can be specified using the `--target_backends` flag.

  Returns:
    Sequence of BackendInfo that should be used.
  """
  if FLAGS.target_backends is not None:
    logging.info("Using backends from command line: %s", FLAGS.target_backends)
    backend_names, names = _parse_target_backends()
    backends = [
        tf_utils.BackendInfo(backend, name)
        for backend, name in zip(backend_names, names)
    ]
  else:
    # If no backends are specified, use them all.
    backends = tf_utils.BackendInfo.get_all_backends()
  return backends
コード例 #12
0
ファイル: tf_test_utils.py プロジェクト: zhangys-lucky/iree
  def setUpClass(cls):
    # Ran before any of the unit tests.
    super().setUpClass()
    if cls._module_class is None:
      raise AttributeError(
          "setUpClass was called but no module was specified. Specify a module "
          "to compile via the @tf_test_utils.compile_module decorator.")

    # Setup the directory for saving compilation artifacts and traces.
    cls._artifacts_dir = _setup_artifacts_dir(cls._module_class.__name__)

    # Create a CompiledModule for the reference backend and each target backend.
    ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
                                            f"{FLAGS.reference_backend}_ref")
    cls._ref_module = cls._compile(ref_backend_info)

    tar_backend_infos = get_target_backends()
    cls._tar_modules = [
        cls._compile(backend_info) for backend_info in tar_backend_infos
    ]
コード例 #13
0
    def test_trace_inputs_and_outputs(self):
        def trace_function(module):
            # No inputs or outpus
            module.increment()
            # Only inputs
            module.increment_by(np.array([81.], dtype=np.float32))
            # Only outputs
            module.get_count()

        module = tf_utils.TfCompiledModule(StatefulCountingModule,
                                           tf_utils.BackendInfo('tf'))
        trace = tf_test_utils.Trace(module, trace_function)
        trace_function(tf_test_utils.TracedModule(module, trace))

        self.assertIsInstance(trace.calls[0].inputs, tuple)
        self.assertEmpty(trace.calls[0].inputs)
        self.assertIsInstance(trace.calls[0].outputs, tuple)
        self.assertEmpty(trace.calls[0].outputs)

        self.assertAllClose(trace.calls[1].inputs[0], [81.])
        self.assertAllClose(trace.calls[2].outputs[0], [82.])
コード例 #14
0
class UtilsTests(tf.test.TestCase, parameterized.TestCase):
    @parameterized.named_parameters([
        {
            'testcase_name': 'single_backend',
            'backend_infos': [tf_utils.BackendInfo('iree_vmla')],
        },
        {
            'testcase_name':
            'multiple_backends',
            'backend_infos': [
                tf_utils.BackendInfo('iree_vmla'),
                tf_utils.BackendInfo('iree_llvmjit')
            ],
        },
    ])
    def test_artifact_saving(self, backend_infos):
        with tempfile.TemporaryDirectory() as artifacts_dir:
            tf_module = ConstantModule()
            iree_compiled_module, compiled_path = tf_utils.compile_tf_module(
                tf_module,
                backend_infos=backend_infos,
                artifacts_dir=artifacts_dir)

            artifacts_to_check = [
                'tf_input.mlir',
                'iree_input.mlir',
                compiled_path,
            ]
            for artifact in artifacts_to_check:
                artifact_path = os.path.join(artifacts_dir, artifact)
                logging.info('Checking path: %s', artifact_path)
                self.assertTrue(os.path.exists(artifact_path))

    @parameterized.named_parameters([
        {
            'testcase_name': 'tensorflow',
            'backend_name': 'tf',
        },
        {
            'testcase_name': 'vmla',
            'backend_name': 'iree_vmla',
        },
    ])
    def test_unaltered_state(self, backend_name):
        backend_info = tf_utils.BackendInfo(backend_name)
        module = backend_info.compile(StatefulCountingModule)

        # Test that incrementing works properly.
        self.assertEqual([0.], module.get_count())
        module.increment()
        self.assertEqual([1.], module.get_count())

        reinitialized_module = module.create_reinitialized()
        # Test reinitialization.
        self.assertEqual([0.], reinitialized_module.get_count())
        # Test independent state.
        self.assertEqual([1.], module.get_count())

    def test_to_mlir_type(self):
        self.assertEqual('i8', tf_utils.to_mlir_type(np.dtype('int8')))
        self.assertEqual('i32', tf_utils.to_mlir_type(np.dtype('int32')))
        self.assertEqual('f32', tf_utils.to_mlir_type(np.dtype('float32')))
        self.assertEqual('f64', tf_utils.to_mlir_type(np.dtype('float64')))

    def test_save_input_values(self):
        inputs = [np.array([1, 2], dtype=np.int32)]
        self.assertEqual('2xi32=1 2', tf_utils.save_input_values(inputs))
        inputs = [np.array([1, 2], dtype=np.float32)]
        self.assertEqual('2xf32=1.0 2.0', tf_utils.save_input_values(inputs))