Beispiel #1
0
    def test_trace_serialize_and_load(self):
        def trace_function(module):
            module.increment()
            module.increment_by(np.array([81.], dtype=np.float32))
            module.increment_by_max(np.array([81], dtype=np.float32),
                                    np.array([92], dtype=np.float32))
            module.get_count()

        module = module_utils.IreeCompiledModule.create_from_class(
            StatefulCountingModule, module_utils.BackendInfo('iree_vmla'))
        trace = trace_utils.Trace(module, trace_function)
        trace_function(trace_utils.TracedModule(module, trace))

        with tempfile.TemporaryDirectory() as artifacts_dir:
            trace_function_dir = trace_utils.get_trace_dir(
                artifacts_dir, trace)
            trace.serialize(trace_function_dir)
            self.assertTrue(
                os.path.exists(os.path.join(trace_function_dir,
                                            'metadata.pkl')))
            loaded_trace = trace_utils.Trace.load(trace_function_dir)

            # Check all calls match.
            self.assertTrue(trace_utils.compare_traces(trace, loaded_trace))

            # Check all other metadata match.
            self.assertAllEqual(trace.__dict__.keys(),
                                loaded_trace.__dict__.keys())
            for key in trace.__dict__.keys():
                if key != 'calls':
                    self.assertEqual(trace.__dict__[key],
                                     loaded_trace.__dict__[key])
Beispiel #2
0
    def test_nonmatching_inputs(self):
        def tf_function(module):
            module.increment_by(np.array([42.], dtype=np.float32))

        def vmla_function(module):
            module.increment_by(np.array([22.], dtype=np.float32))

        tf_module = module_utils.TfCompiledModule.create_from_class(
            StatefulCountingModule, module_utils.BackendInfo('tf'))
        tf_trace = trace_utils.Trace(tf_module, tf_function)
        tf_function(trace_utils.TracedModule(tf_module, tf_trace))

        vmla_module = module_utils.IreeCompiledModule.create_from_class(
            StatefulCountingModule, module_utils.BackendInfo('iree_vmla'))
        vmla_trace = trace_utils.Trace(vmla_module, vmla_function)
        vmla_function(trace_utils.TracedModule(vmla_module, vmla_trace))

        same, error_messages = trace_utils.compare_traces(tf_trace, vmla_trace)
        self.assertFalse(same)
Beispiel #3
0
    def test_nonmatching_methods(self):
        def tf_function(module):
            module.increment()
            module.increment()

        def vmla_function(module):
            module.increment()
            module.decrement()

        tf_module = module_utils.TfCompiledModule.create_from_class(
            StatefulCountingModule, module_utils.BackendInfo('tf'))
        tf_trace = trace_utils.Trace(tf_module, tf_function)
        tf_function(trace_utils.TracedModule(tf_module, tf_trace))

        vmla_module = module_utils.IreeCompiledModule.create_from_class(
            StatefulCountingModule, module_utils.BackendInfo('iree_vmla'))
        vmla_trace = trace_utils.Trace(vmla_module, vmla_function)
        vmla_function(trace_utils.TracedModule(vmla_module, vmla_trace))

        with self.assertRaises(ValueError):
            trace_utils.compare_traces(tf_trace, vmla_trace)
Beispiel #4
0
    def test_trace_inputs_and_outputs(self):
        def trace_function(module):
            # No inputs or outputs
            module.increment()
            # Only inputs
            module.increment_by(np.array([81.], dtype=np.float32))
            # Only outputs
            module.get_count()

        module = module_utils.TfCompiledModule.create_from_class(
            StatefulCountingModule, module_utils.BackendInfo('tf'))
        trace = trace_utils.Trace(module, trace_function)
        trace_function(trace_utils.TracedModule(module, trace))

        self.assertIsInstance(trace.calls[0].inputs, tuple)
        self.assertEmpty(trace.calls[0].inputs)
        self.assertIsInstance(trace.calls[0].outputs, tuple)
        self.assertEmpty(trace.calls[0].outputs)

        self.assertAllClose(trace.calls[1].inputs[0], [81.])
        self.assertAllClose(trace.calls[2].outputs[0], [82.])
Beispiel #5
0
  def compare_backends(self,
                       trace_function: Callable[[trace_utils.TracedModule],
                                                None],
                       modules: Modules) -> None:
    """Run the reference and target backends on trace_function and compare them.

    Random seeds for tensorflow, numpy and python are set before each invocation
    of trace_function.

    Args:
      trace_function: a function accepting a TracedModule as its argument.
    """
    # Create Traces for each backend.
    ref_trace = trace_utils.Trace(modules.ref_module, trace_function)
    tar_traces = [
        trace_utils.Trace(module, trace_function)
        for module in modules.tar_modules
    ]

    # Run the traces through trace_function with their associated modules.
    tf_utils.set_random_seed()
    trace_function(trace_utils.TracedModule(modules.ref_module, ref_trace))
    if FLAGS.log_all_traces:
      logging.info(ref_trace)
    for module, trace in zip(modules.tar_modules, tar_traces):
      tf_utils.set_random_seed()
      trace_function(trace_utils.TracedModule(module, trace))
      if FLAGS.log_all_traces:
        logging.info(trace)

    # Compare each target trace of trace_function with the reference trace.
    failed_backend_indices = []
    error_messages = []
    for i, tar_trace in enumerate(tar_traces):
      logging.info("Comparing the reference backend '%s' with '%s'",
                   ref_trace.backend_id, tar_trace.backend_id)
      traces_match, errors = trace_utils.compare_traces(ref_trace, tar_trace)
      if not traces_match:
        failed_backend_indices.append(i)
        error_messages.extend(errors)

    # Save the results to disk before validating.
    ref_trace_dir = trace_utils.get_trace_dir(modules.artifacts_dir, ref_trace)
    ref_trace.save_plaintext(ref_trace_dir, FLAGS.summarize)
    ref_trace.serialize(ref_trace_dir)
    for tar_trace in tar_traces:
      tar_trace_dir = trace_utils.get_trace_dir(modules.artifacts_dir,
                                                tar_trace)
      tar_trace.save_plaintext(tar_trace_dir, FLAGS.summarize)
      tar_trace.serialize(tar_trace_dir)

    # Validate results.
    if failed_backend_indices:
      # Extract info for logging.
      failed_backends = [
          tar_traces[i].backend_id for i in failed_backend_indices
      ]
      error_list = ''.join([f'\n  - {message}' for message in error_messages])
      self.fail(
          "Comparison between the reference backend and the following targets "
          f"failed: {failed_backends}. Errors: {error_list}\n"
          "See the logs above for more details about the non-matching calls.")