def test_trace_serialize_and_load(self): def trace_function(module): module.increment() module.increment_by(np.array([81.], dtype=np.float32)) module.increment_by_max(np.array([81], dtype=np.float32), np.array([92], dtype=np.float32)) module.get_count() module = module_utils.IreeCompiledModule.create_from_class( StatefulCountingModule, module_utils.BackendInfo('iree_vmla')) trace = trace_utils.Trace(module, trace_function) trace_function(trace_utils.TracedModule(module, trace)) with tempfile.TemporaryDirectory() as artifacts_dir: trace_function_dir = trace_utils.get_trace_dir( artifacts_dir, trace) trace.serialize(trace_function_dir) self.assertTrue( os.path.exists(os.path.join(trace_function_dir, 'metadata.pkl'))) loaded_trace = trace_utils.Trace.load(trace_function_dir) # Check all calls match. self.assertTrue(trace_utils.compare_traces(trace, loaded_trace)) # Check all other metadata match. self.assertAllEqual(trace.__dict__.keys(), loaded_trace.__dict__.keys()) for key in trace.__dict__.keys(): if key != 'calls': self.assertEqual(trace.__dict__[key], loaded_trace.__dict__[key])
def test_nonmatching_methods(self): def tf_function(module): module.increment() module.increment() def vmla_function(module): module.increment() module.decrement() tf_module = module_utils.TfCompiledModule.create_from_class( StatefulCountingModule, module_utils.BackendInfo('tf')) tf_trace = trace_utils.Trace(tf_module, tf_function) tf_function(trace_utils.TracedModule(tf_module, tf_trace)) vmla_module = module_utils.IreeCompiledModule.create_from_class( StatefulCountingModule, module_utils.BackendInfo('iree_vmla')) vmla_trace = trace_utils.Trace(vmla_module, vmla_function) vmla_function(trace_utils.TracedModule(vmla_module, vmla_trace)) with self.assertRaises(ValueError): trace_utils.compare_traces(tf_trace, vmla_trace)
def test_nonmatching_inputs(self): def tf_function(module): module.increment_by(np.array([42.], dtype=np.float32)) def vmla_function(module): module.increment_by(np.array([22.], dtype=np.float32)) tf_module = module_utils.TfCompiledModule.create_from_class( StatefulCountingModule, module_utils.BackendInfo('tf')) tf_trace = trace_utils.Trace(tf_module, tf_function) tf_function(trace_utils.TracedModule(tf_module, tf_trace)) vmla_module = module_utils.IreeCompiledModule.create_from_class( StatefulCountingModule, module_utils.BackendInfo('iree_vmla')) vmla_trace = trace_utils.Trace(vmla_module, vmla_function) vmla_function(trace_utils.TracedModule(vmla_module, vmla_trace)) same, error_messages = trace_utils.compare_traces(tf_trace, vmla_trace) self.assertFalse(same)
def compare_backends(self, trace_function: Callable[[trace_utils.TracedModule], None], modules: Modules) -> None: """Run the reference and target backends on trace_function and compare them. Random seeds for tensorflow, numpy and python are set before each invocation of trace_function. Args: trace_function: a function accepting a TracedModule as its argument. """ # Create Traces for each backend. ref_trace = trace_utils.Trace(modules.ref_module, trace_function) tar_traces = [ trace_utils.Trace(module, trace_function) for module in modules.tar_modules ] # Run the traces through trace_function with their associated modules. tf_utils.set_random_seed() trace_function(trace_utils.TracedModule(modules.ref_module, ref_trace)) if FLAGS.log_all_traces: logging.info(ref_trace) for module, trace in zip(modules.tar_modules, tar_traces): tf_utils.set_random_seed() trace_function(trace_utils.TracedModule(module, trace)) if FLAGS.log_all_traces: logging.info(trace) # Compare each target trace of trace_function with the reference trace. failed_backend_indices = [] error_messages = [] for i, tar_trace in enumerate(tar_traces): logging.info("Comparing the reference backend '%s' with '%s'", ref_trace.backend_id, tar_trace.backend_id) traces_match, errors = trace_utils.compare_traces(ref_trace, tar_trace) if not traces_match: failed_backend_indices.append(i) error_messages.extend(errors) # Save the results to disk before validating. ref_trace_dir = trace_utils.get_trace_dir(modules.artifacts_dir, ref_trace) ref_trace.save_plaintext(ref_trace_dir, FLAGS.summarize) ref_trace.serialize(ref_trace_dir) for tar_trace in tar_traces: tar_trace_dir = trace_utils.get_trace_dir(modules.artifacts_dir, tar_trace) tar_trace.save_plaintext(tar_trace_dir, FLAGS.summarize) tar_trace.serialize(tar_trace_dir) # Validate results. if failed_backend_indices: # Extract info for logging. failed_backends = [ tar_traces[i].backend_id for i in failed_backend_indices ] error_list = ''.join([f'\n - {message}' for message in error_messages]) self.fail( "Comparison between the reference backend and the following targets " f"failed: {failed_backends}. Errors: {error_list}\n" "See the logs above for more details about the non-matching calls.")