Ejemplo n.º 1
0
    def test_trace_serialize_and_load(self):
        def trace_function(module):
            module.increment()
            module.increment_by(np.array([81.], dtype=np.float32))
            module.increment_by_max(np.array([81], dtype=np.float32),
                                    np.array([92], dtype=np.float32))
            module.get_count()

        module = module_utils.IreeCompiledModule.create_from_class(
            StatefulCountingModule, module_utils.BackendInfo('iree_vmla'))
        trace = trace_utils.Trace(module, trace_function)
        trace_function(trace_utils.TracedModule(module, trace))

        with tempfile.TemporaryDirectory() as artifacts_dir:
            trace_function_dir = trace_utils.get_trace_dir(
                artifacts_dir, trace)
            trace.serialize(trace_function_dir)
            self.assertTrue(
                os.path.exists(os.path.join(trace_function_dir,
                                            'metadata.pkl')))
            loaded_trace = trace_utils.Trace.load(trace_function_dir)

            # Check all calls match.
            self.assertTrue(trace_utils.compare_traces(trace, loaded_trace))

            # Check all other metadata match.
            self.assertAllEqual(trace.__dict__.keys(),
                                loaded_trace.__dict__.keys())
            for key in trace.__dict__.keys():
                if key != 'calls':
                    self.assertEqual(trace.__dict__[key],
                                     loaded_trace.__dict__[key])
Ejemplo n.º 2
0
    def test_nonmatching_inputs(self):
        def tf_function(module):
            module.increment_by(np.array([42.], dtype=np.float32))

        def vmla_function(module):
            module.increment_by(np.array([22.], dtype=np.float32))

        tf_module = module_utils.TfCompiledModule.create_from_class(
            StatefulCountingModule, module_utils.BackendInfo('tf'))
        tf_trace = trace_utils.Trace(tf_module, tf_function)
        tf_function(trace_utils.TracedModule(tf_module, tf_trace))

        vmla_module = module_utils.IreeCompiledModule.create_from_class(
            StatefulCountingModule, module_utils.BackendInfo('iree_vmla'))
        vmla_trace = trace_utils.Trace(vmla_module, vmla_function)
        vmla_function(trace_utils.TracedModule(vmla_module, vmla_trace))

        same, error_messages = trace_utils.compare_traces(tf_trace, vmla_trace)
        self.assertFalse(same)
Ejemplo n.º 3
0
    def test_nonmatching_methods(self):
        def tf_function(module):
            module.increment()
            module.increment()

        def vmla_function(module):
            module.increment()
            module.decrement()

        tf_module = module_utils.TfCompiledModule.create_from_class(
            StatefulCountingModule, module_utils.BackendInfo('tf'))
        tf_trace = trace_utils.Trace(tf_module, tf_function)
        tf_function(trace_utils.TracedModule(tf_module, tf_trace))

        vmla_module = module_utils.IreeCompiledModule.create_from_class(
            StatefulCountingModule, module_utils.BackendInfo('iree_vmla'))
        vmla_trace = trace_utils.Trace(vmla_module, vmla_function)
        vmla_function(trace_utils.TracedModule(vmla_module, vmla_trace))

        with self.assertRaises(ValueError):
            trace_utils.compare_traces(tf_trace, vmla_trace)
Ejemplo n.º 4
0
    def test_random_initialization(self, backend_name):
        backend_info = module_utils.BackendInfo(backend_name)

        # Test compilation is the same.
        module_1 = backend_info.compile_from_class(RandomInitModule)
        module_2 = backend_info.compile_from_class(RandomInitModule)
        self.assertAllEqual(module_1.get(), module_2.get())

        # Test reinitialization is the same.
        old_value = module_1.get()
        module_1.reinitialize()
        self.assertAllEqual(old_value, module_1.get())
Ejemplo n.º 5
0
    def test_unaltered_state(self, backend_name):
        backend_info = module_utils.BackendInfo(backend_name)
        module = backend_info.compile_from_class(StatefulCountingModule)

        # Test that incrementing works properly.
        self.assertEqual([0.], module.get_count())
        module.increment()
        self.assertEqual([1.], module.get_count())

        module.reinitialize()
        # Test reinitialization.
        self.assertEqual([0.], module.get_count())
Ejemplo n.º 6
0
    def test_artifact_saving(self):
        backend_info = module_utils.BackendInfo('iree_vmvx')
        with tempfile.TemporaryDirectory() as artifacts_dir:
            tf_module = ConstantModule()
            iree_module_utils, compiled_path = (
                module_utils._incrementally_compile_tf_module(
                    tf_module,
                    backend_info=backend_info,
                    artifacts_dir=artifacts_dir))

            artifacts_to_check = [
                'tf_input.mlir',
                'iree_input.mlir',
                compiled_path,
            ]
            for artifact in artifacts_to_check:
                artifact_path = os.path.join(artifacts_dir, artifact)
                logging.info('Checking path: %s', artifact_path)
                self.assertTrue(os.path.exists(artifact_path))
Ejemplo n.º 7
0
def compile_tf_signature_def_saved_model(
    saved_model_dir: str, saved_model_tags: Set[str], module_name: str,
    exported_name: str, input_names: Sequence[str],
    output_names: Sequence[str]) -> Modules:
  """Compiles a SignatureDef SavedModel to each backend that we test.

  Args:
    saved_model_dir: Directory of the saved model.
    saved_model_tags: Optional set of tags to use when loading the model.
    module_name: A name for this compiled module.
    backend_info: BackendInfo with the details for compiling the saved model.
    exported_name: A str representing the signature on the saved model to
      compile.
    input_names: A sequence of kwargs to feed to the saved model.
    output_names: A sequence of named outputs to extract from the saved model.

  Returns:
    A 'Modules' namedtuple containing the reference module, target modules and
    artifacts directory.
  """
  global _global_modules
  if _global_modules is not None:
    return _global_modules

  # Setup the directory for saving compilation artifacts and traces.
  artifacts_dir = _setup_artifacts_dir(module_name)

  # Get the backend information for this test.
  ref_backend_info = module_utils.BackendInfo(FLAGS.reference_backend,
                                              f"{FLAGS.reference_backend}_ref")
  tar_backend_infos = get_target_backends()

  compile_backend = (
      lambda backend_info: backend_info.compile_signature_def_saved_model(
          saved_model_dir, saved_model_tags, module_name, exported_name,
          input_names, output_names, artifacts_dir))

  ref_module = compile_backend(ref_backend_info)
  tar_modules = [
      compile_backend(backend_info) for backend_info in tar_backend_infos
  ]
  _global_modules = Modules(ref_module, tar_modules, artifacts_dir)
  return _global_modules
Ejemplo n.º 8
0
def get_target_backends() -> Sequence[module_utils.BackendInfo]:
  """Gets the BackendInfo instances to compare with the reference backend.

  By default all backends in BackendInfo will be used. Specific backends to
  run on can be specified using the `--target_backends` flag.

  Returns:
    Sequence of BackendInfo that should be used.
  """
  if FLAGS.target_backends is not None:
    logging.info("Using backends from command line: %s", FLAGS.target_backends)
    backend_names, backend_ids = _parse_target_backends()
    backends = [
        module_utils.BackendInfo(backend_name, backend_id)
        for backend_name, backend_id in zip(backend_names, backend_ids)
    ]
  else:
    # If no backends are specified, use them all.
    backends = module_utils.BackendInfo.get_all_backends()
  return backends
Ejemplo n.º 9
0
def compile_tf_module(module_class: Type[tf.Module],
                      exported_names: Sequence[str] = (),
                      relative_artifacts_dir: str = None) -> Modules:
  """Compiles module_class to each backend that we test.

  Args:
    module_class: the tf.Module subclass to compile.
    exported_names: optional iterable of strings representing which of
      module_class's functions to compile. If exported_names is empty all
      functions will be compiled.
    relative_artifacts_dir: optional string specifying where to save compilation
      artifacts within the artifacts_dir. If it is not specified then
      module_class.__name__ will be used.

  Returns:
    A 'Modules' namedtuple containing the reference module, target modules and
    artifacts directory.
  """
  global _global_modules
  if _global_modules is not None:
    return _global_modules

  # Setup the directory for saving compilation artifacts and traces.
  if relative_artifacts_dir is None:
    relative_artifacts_dir = module_class.__name__
  artifacts_dir = _setup_artifacts_dir(relative_artifacts_dir)

  # Get the backend information for this test.
  ref_backend_info = module_utils.BackendInfo(FLAGS.reference_backend,
                                              f"{FLAGS.reference_backend}_ref")
  tar_backend_infos = get_target_backends()

  compile_backend = lambda backend_info: backend_info.compile_from_class(
      module_class, exported_names, artifacts_dir)

  ref_module = compile_backend(ref_backend_info)
  tar_modules = [
      compile_backend(backend_info) for backend_info in tar_backend_infos
  ]
  _global_modules = Modules(ref_module, tar_modules, artifacts_dir)
  return _global_modules
Ejemplo n.º 10
0
    def test_trace_inputs_and_outputs(self):
        def trace_function(module):
            # No inputs or outputs
            module.increment()
            # Only inputs
            module.increment_by(np.array([81.], dtype=np.float32))
            # Only outputs
            module.get_count()

        module = module_utils.TfCompiledModule.create_from_class(
            StatefulCountingModule, module_utils.BackendInfo('tf'))
        trace = trace_utils.Trace(module, trace_function)
        trace_function(trace_utils.TracedModule(module, trace))

        self.assertIsInstance(trace.calls[0].inputs, tuple)
        self.assertEmpty(trace.calls[0].inputs)
        self.assertIsInstance(trace.calls[0].outputs, tuple)
        self.assertEmpty(trace.calls[0].outputs)

        self.assertAllClose(trace.calls[1].inputs[0], [81.])
        self.assertAllClose(trace.calls[2].outputs[0], [82.])