def test_trace_serialize_and_load(self): def trace_function(module): module.increment() module.increment_by(np.array([81.], dtype=np.float32)) module.increment_by_max(np.array([81], dtype=np.float32), np.array([92], dtype=np.float32)) module.get_count() module = module_utils.IreeCompiledModule.create_from_class( StatefulCountingModule, module_utils.BackendInfo('iree_vmla')) trace = trace_utils.Trace(module, trace_function) trace_function(trace_utils.TracedModule(module, trace)) with tempfile.TemporaryDirectory() as artifacts_dir: trace_function_dir = trace_utils.get_trace_dir( artifacts_dir, trace) trace.serialize(trace_function_dir) self.assertTrue( os.path.exists(os.path.join(trace_function_dir, 'metadata.pkl'))) loaded_trace = trace_utils.Trace.load(trace_function_dir) # Check all calls match. self.assertTrue(trace_utils.compare_traces(trace, loaded_trace)) # Check all other metadata match. self.assertAllEqual(trace.__dict__.keys(), loaded_trace.__dict__.keys()) for key in trace.__dict__.keys(): if key != 'calls': self.assertEqual(trace.__dict__[key], loaded_trace.__dict__[key])
def test_nonmatching_inputs(self): def tf_function(module): module.increment_by(np.array([42.], dtype=np.float32)) def vmla_function(module): module.increment_by(np.array([22.], dtype=np.float32)) tf_module = module_utils.TfCompiledModule.create_from_class( StatefulCountingModule, module_utils.BackendInfo('tf')) tf_trace = trace_utils.Trace(tf_module, tf_function) tf_function(trace_utils.TracedModule(tf_module, tf_trace)) vmla_module = module_utils.IreeCompiledModule.create_from_class( StatefulCountingModule, module_utils.BackendInfo('iree_vmla')) vmla_trace = trace_utils.Trace(vmla_module, vmla_function) vmla_function(trace_utils.TracedModule(vmla_module, vmla_trace)) same, error_messages = trace_utils.compare_traces(tf_trace, vmla_trace) self.assertFalse(same)
def test_random_initialization(self, backend_name): backend_info = module_utils.BackendInfo(backend_name) # Test compilation is the same. module_1 = backend_info.compile_from_class(RandomInitModule) module_2 = backend_info.compile_from_class(RandomInitModule) self.assertAllEqual(module_1.get(), module_2.get()) # Test reinitialization is the same. old_value = module_1.get() module_1.reinitialize() self.assertAllEqual(old_value, module_1.get())
def test_unaltered_state(self, backend_name): backend_info = module_utils.BackendInfo(backend_name) module = backend_info.compile_from_class(StatefulCountingModule) # Test that incrementing works properly. self.assertEqual([0.], module.get_count()) module.increment() self.assertEqual([1.], module.get_count()) module.reinitialize() # Test reinitialization. self.assertEqual([0.], module.get_count())
def test_nonmatching_methods(self): def tf_function(module): module.increment() module.increment() def vmla_function(module): module.increment() module.decrement() tf_module = module_utils.TfCompiledModule.create_from_class( StatefulCountingModule, module_utils.BackendInfo('tf')) tf_trace = trace_utils.Trace(tf_module, tf_function) tf_function(trace_utils.TracedModule(tf_module, tf_trace)) vmla_module = module_utils.IreeCompiledModule.create_from_class( StatefulCountingModule, module_utils.BackendInfo('iree_vmla')) vmla_trace = trace_utils.Trace(vmla_module, vmla_function) vmla_function(trace_utils.TracedModule(vmla_module, vmla_trace)) with self.assertRaises(ValueError): trace_utils.compare_traces(tf_trace, vmla_trace)
def test_artifact_saving(self): backend_info = module_utils.BackendInfo('iree_vmla') with tempfile.TemporaryDirectory() as artifacts_dir: tf_module = ConstantModule() iree_module_utils, compiled_path = ( module_utils._incrementally_compile_tf_module( tf_module, backend_info=backend_info, artifacts_dir=artifacts_dir)) artifacts_to_check = [ 'tf_input.mlir', 'iree_input.mlir', compiled_path, ] for artifact in artifacts_to_check: artifact_path = os.path.join(artifacts_dir, artifact) logging.info('Checking path: %s', artifact_path) self.assertTrue(os.path.exists(artifact_path))
def compile_tf_signature_def_saved_model( saved_model_dir: str, saved_model_tags: Set[str], module_name: str, exported_name: str, input_names: Sequence[str], output_names: Sequence[str]) -> Modules: """Compiles a SignatureDef SavedModel to each backend that we test. Args: saved_model_dir: Directory of the saved model. saved_model_tags: Optional set of tags to use when loading the model. module_name: A name for this compiled module. backend_info: BackendInfo with the details for compiling the saved model. exported_name: A str representing the signature on the saved model to compile. input_names: A sequence of kwargs to feed to the saved model. output_names: A sequence of named outputs to extract from the saved model. Returns: A 'Modules' namedtuple containing the reference module, target modules and artifacts directory. """ global _global_modules if _global_modules is not None: return _global_modules # Setup the directory for saving compilation artifacts and traces. artifacts_dir = _setup_artifacts_dir(module_name) # Get the backend information for this test. ref_backend_info = module_utils.BackendInfo(FLAGS.reference_backend, f"{FLAGS.reference_backend}_ref") tar_backend_infos = get_target_backends() compile_backend = ( lambda backend_info: backend_info.compile_signature_def_saved_model( saved_model_dir, saved_model_tags, module_name, exported_name, input_names, output_names, artifacts_dir)) ref_module = compile_backend(ref_backend_info) tar_modules = [ compile_backend(backend_info) for backend_info in tar_backend_infos ] _global_modules = Modules(ref_module, tar_modules, artifacts_dir) return _global_modules
def get_target_backends() -> Sequence[module_utils.BackendInfo]: """Gets the BackendInfo instances to compare with the reference backend. By default all backends in BackendInfo will be used. Specific backends to run on can be specified using the `--target_backends` flag. Returns: Sequence of BackendInfo that should be used. """ if FLAGS.target_backends is not None: logging.info("Using backends from command line: %s", FLAGS.target_backends) backend_names, backend_ids = _parse_target_backends() backends = [ module_utils.BackendInfo(backend_name, backend_id) for backend_name, backend_id in zip(backend_names, backend_ids) ] else: # If no backends are specified, use them all. backends = module_utils.BackendInfo.get_all_backends() return backends
def compile_tf_module(module_class: Type[tf.Module], exported_names: Sequence[str] = (), relative_artifacts_dir: str = None) -> Modules: """Compiles module_class to each backend that we test. Args: module_class: the tf.Module subclass to compile. exported_names: optional iterable of strings representing which of module_class's functions to compile. If exported_names is empty all functions will be compiled. relative_artifacts_dir: optional string specifying where to save compilation artifacts within the artifacts_dir. If it is not specified then module_class.__name__ will be used. Returns: A 'Modules' namedtuple containing the reference module, target modules and artifacts directory. """ global _global_modules if _global_modules is not None: return _global_modules # Setup the directory for saving compilation artifacts and traces. if relative_artifacts_dir is None: relative_artifacts_dir = module_class.__name__ artifacts_dir = _setup_artifacts_dir(relative_artifacts_dir) # Get the backend information for this test. ref_backend_info = module_utils.BackendInfo(FLAGS.reference_backend, f"{FLAGS.reference_backend}_ref") tar_backend_infos = get_target_backends() compile_backend = lambda backend_info: backend_info.compile_from_class( module_class, exported_names, artifacts_dir) ref_module = compile_backend(ref_backend_info) tar_modules = [ compile_backend(backend_info) for backend_info in tar_backend_infos ] _global_modules = Modules(ref_module, tar_modules, artifacts_dir) return _global_modules
def test_trace_inputs_and_outputs(self): def trace_function(module): # No inputs or outputs module.increment() # Only inputs module.increment_by(np.array([81.], dtype=np.float32)) # Only outputs module.get_count() module = module_utils.TfCompiledModule.create_from_class( StatefulCountingModule, module_utils.BackendInfo('tf')) trace = trace_utils.Trace(module, trace_function) trace_function(trace_utils.TracedModule(module, trace)) self.assertIsInstance(trace.calls[0].inputs, tuple) self.assertEmpty(trace.calls[0].inputs) self.assertIsInstance(trace.calls[0].outputs, tuple) self.assertEmpty(trace.calls[0].outputs) self.assertAllClose(trace.calls[1].inputs[0], [81.]) self.assertAllClose(trace.calls[2].outputs[0], [82.])