def compile_tf_module( module_class: Type[tf.Module], exported_names: Sequence[str] = () ) -> Callable[[Any], Any]: """Compiles module_class to each backend that we test. Args: module_class: the tf.Module subclass to compile. exported_names: optional iterable of strings representing which of module_class's functions to compile. If exported_names is empty all functions will be compiled. Returns: A 'Modules' namedtuple containing the reference module, target modules and artifacts directory. """ # Setup the directory for saving compilation artifacts and traces. artifacts_dir = _setup_artifacts_dir(module_class.__name__) # Get the backend information for this test. ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend, f"{FLAGS.reference_backend}_ref") tar_backend_infos = get_target_backends() compile_backend = lambda backend_info: backend_info.compile_from_class( module_class, exported_names, artifacts_dir) ref_module = compile_backend(ref_backend_info) tar_modules = [ compile_backend(backend_info) for backend_info in tar_backend_infos ] return Modules(ref_module, tar_modules, artifacts_dir)
def test_trace_serialize_and_load(self): def trace_function(module): module.increment() module.increment_by(np.array([81.], dtype=np.float32)) module.increment_by_max(np.array([81], dtype=np.float32), np.array([92], dtype=np.float32)) module.get_count() module = tf_utils.IreeCompiledModule.create_from_class( StatefulCountingModule, tf_utils.BackendInfo('iree_vmla')) trace = tf_test_utils.Trace(module, trace_function) trace_function(tf_test_utils.TracedModule(module, trace)) with tempfile.TemporaryDirectory() as artifacts_dir: trace_function_dir = tf_test_utils._get_trace_dir(artifacts_dir, trace) trace.serialize(trace_function_dir) self.assertTrue( os.path.exists(os.path.join(trace_function_dir, 'metadata.pkl'))) loaded_trace = tf_test_utils.Trace.load(trace_function_dir) # Check all calls match. self.assertTrue(tf_test_utils.Trace.compare_traces(trace, loaded_trace)) # Check all other metadata match. self.assertAllEqual(trace.__dict__.keys(), loaded_trace.__dict__.keys()) for key in trace.__dict__.keys(): if key != 'calls': self.assertEqual(trace.__dict__[key], loaded_trace.__dict__[key])
def test_nonmatching_inputs(self): def tf_function(module): module.increment_by(np.array([42.], dtype=np.float32)) def vmla_function(module): module.increment_by(np.array([22.], dtype=np.float32)) tf_module = tf_utils.TfCompiledModule.create_from_class( StatefulCountingModule, tf_utils.BackendInfo('tf')) tf_trace = tf_test_utils.Trace(tf_module, tf_function) tf_function(tf_test_utils.TracedModule(tf_module, tf_trace)) vmla_module = tf_utils.IreeCompiledModule.create_from_class( StatefulCountingModule, tf_utils.BackendInfo('iree_vmla')) vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function) vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace)) self.assertFalse(tf_test_utils.Trace.compare_traces(tf_trace, vmla_trace))
def test_unaltered_state(self, backend_name): backend_info = tf_utils.BackendInfo(backend_name) module = backend_info.compile_from_class(StatefulCountingModule) # Test that incrementing works properly. self.assertEqual([0.], module.get_count()) module.increment() self.assertEqual([1.], module.get_count()) module.reinitialize() # Test reinitialization. self.assertEqual([0.], module.get_count())
def test_random_initialization(self, backend_name): backend_info = tf_utils.BackendInfo(backend_name) # Test compilation is the same. module_1 = backend_info.compile_from_class(RandomInitModule) module_2 = backend_info.compile_from_class(RandomInitModule) self.assertAllEqual(module_1.get(), module_2.get()) # Test reinitialization is the same. old_value = module_1.get() module_1.reinitialize() self.assertAllEqual(old_value, module_1.get())
def test_nonmatching_methods(self): def tf_function(module): module.increment() module.increment() def vmla_function(module): module.increment() module.decrement() tf_module = tf_utils.TfCompiledModule(StatefulCountingModule, tf_utils.BackendInfo('tf')) tf_trace = tf_test_utils.Trace(tf_module, tf_function) tf_function(tf_test_utils.TracedModule(tf_module, tf_trace)) vmla_module = tf_utils.IreeCompiledModule( StatefulCountingModule, tf_utils.BackendInfo('iree_vmla')) vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function) vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace)) with self.assertRaises(ValueError): tf_test_utils.Trace.compare_traces(tf_trace, vmla_trace)
def setUp(self): # Ran before each unit test. super().setUp() # Create a CompiledModule for the reference backend and each target backend. ref_backend_info = tf_utils.BackendInfo( FLAGS.reference_backend, f"{FLAGS.reference_backend}_ref") self._ref_module = self._compile(ref_backend_info) tar_backend_infos = get_target_backends() self._tar_modules = [ self._compile(backend_info) for backend_info in tar_backend_infos ]
def test_unaltered_state(self, backend_name): backend_info = tf_utils.BackendInfo(backend_name) module = backend_info.compile(StatefulCountingModule) # Test that incrementing works properly. self.assertEqual([0.], module.get_count()) module.increment() self.assertEqual([1.], module.get_count()) reinitialized_module = module.create_reinitialized() # Test reinitialization. self.assertEqual([0.], reinitialized_module.get_count()) # Test independent state. self.assertEqual([1.], module.get_count())
def test_artifact_saving(self): backend_info = tf_utils.BackendInfo('iree_vmla') with tempfile.TemporaryDirectory() as artifacts_dir: tf_module = ConstantModule() iree_compiled_module, compiled_path = ( tf_utils._incrementally_compile_tf_module( tf_module, backend_info=backend_info, artifacts_dir=artifacts_dir)) artifacts_to_check = [ 'tf_input.mlir', 'iree_input.mlir', compiled_path, ] for artifact in artifacts_to_check: artifact_path = os.path.join(artifacts_dir, artifact) logging.info('Checking path: %s', artifact_path) self.assertTrue(os.path.exists(artifact_path))
def compile_tf_signature_def_saved_model(saved_model_dir: str, saved_model_tags: Set[str], module_name: str, exported_name: str, input_names: Sequence[str], output_names: Sequence[str]): """Compiles a SignatureDef SavedModel to each backend that we test. Args: saved_model_dir: Directory of the saved model. saved_model_tags: Optional set of tags to use when loading the model. module_name: A name for this compiled module. backend_info: BackendInfo with the details for compiling the saved model. exported_name: A str representing the signature on the saved model to compile. input_names: A sequence of kwargs to feed to the saved model. output_names: A sequence of named outputs to extract from the saved model. Returns: A 'Modules' namedtuple containing the reference module, target modules and artifacts directory. """ global _global_modules if _global_modules is not None: return _global_modules # Setup the directory for saving compilation artifacts and traces. artifacts_dir = _setup_artifacts_dir(module_name) # Get the backend information for this test. ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend, f"{FLAGS.reference_backend}_ref") tar_backend_infos = get_target_backends() compile_backend = ( lambda backend_info: backend_info.compile_signature_def_saved_model( saved_model_dir, saved_model_tags, module_name, exported_name, input_names, output_names, artifacts_dir)) ref_module = compile_backend(ref_backend_info) tar_modules = [ compile_backend(backend_info) for backend_info in tar_backend_infos ] _global_modules = Modules(ref_module, tar_modules, artifacts_dir) return _global_modules
def get_target_backends() -> Sequence[tf_utils.BackendInfo]: """Gets the BackendInfo instances to compare with the reference backend. By default all backends in BackendInfo will be used. Specific backends to run on can be specified using the `--target_backends` flag. Returns: Sequence of BackendInfo that should be used. """ if FLAGS.target_backends is not None: logging.info("Using backends from command line: %s", FLAGS.target_backends) backend_names, names = _parse_target_backends() backends = [ tf_utils.BackendInfo(backend, name) for backend, name in zip(backend_names, names) ] else: # If no backends are specified, use them all. backends = tf_utils.BackendInfo.get_all_backends() return backends
def setUpClass(cls): # Ran before any of the unit tests. super().setUpClass() if cls._module_class is None: raise AttributeError( "setUpClass was called but no module was specified. Specify a module " "to compile via the @tf_test_utils.compile_module decorator.") # Setup the directory for saving compilation artifacts and traces. cls._artifacts_dir = _setup_artifacts_dir(cls._module_class.__name__) # Create a CompiledModule for the reference backend and each target backend. ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend, f"{FLAGS.reference_backend}_ref") cls._ref_module = cls._compile(ref_backend_info) tar_backend_infos = get_target_backends() cls._tar_modules = [ cls._compile(backend_info) for backend_info in tar_backend_infos ]
def test_trace_inputs_and_outputs(self): def trace_function(module): # No inputs or outpus module.increment() # Only inputs module.increment_by(np.array([81.], dtype=np.float32)) # Only outputs module.get_count() module = tf_utils.TfCompiledModule(StatefulCountingModule, tf_utils.BackendInfo('tf')) trace = tf_test_utils.Trace(module, trace_function) trace_function(tf_test_utils.TracedModule(module, trace)) self.assertIsInstance(trace.calls[0].inputs, tuple) self.assertEmpty(trace.calls[0].inputs) self.assertIsInstance(trace.calls[0].outputs, tuple) self.assertEmpty(trace.calls[0].outputs) self.assertAllClose(trace.calls[1].inputs[0], [81.]) self.assertAllClose(trace.calls[2].outputs[0], [82.])
class UtilsTests(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters([ { 'testcase_name': 'single_backend', 'backend_infos': [tf_utils.BackendInfo('iree_vmla')], }, { 'testcase_name': 'multiple_backends', 'backend_infos': [ tf_utils.BackendInfo('iree_vmla'), tf_utils.BackendInfo('iree_llvmjit') ], }, ]) def test_artifact_saving(self, backend_infos): with tempfile.TemporaryDirectory() as artifacts_dir: tf_module = ConstantModule() iree_compiled_module, compiled_path = tf_utils.compile_tf_module( tf_module, backend_infos=backend_infos, artifacts_dir=artifacts_dir) artifacts_to_check = [ 'tf_input.mlir', 'iree_input.mlir', compiled_path, ] for artifact in artifacts_to_check: artifact_path = os.path.join(artifacts_dir, artifact) logging.info('Checking path: %s', artifact_path) self.assertTrue(os.path.exists(artifact_path)) @parameterized.named_parameters([ { 'testcase_name': 'tensorflow', 'backend_name': 'tf', }, { 'testcase_name': 'vmla', 'backend_name': 'iree_vmla', }, ]) def test_unaltered_state(self, backend_name): backend_info = tf_utils.BackendInfo(backend_name) module = backend_info.compile(StatefulCountingModule) # Test that incrementing works properly. self.assertEqual([0.], module.get_count()) module.increment() self.assertEqual([1.], module.get_count()) reinitialized_module = module.create_reinitialized() # Test reinitialization. self.assertEqual([0.], reinitialized_module.get_count()) # Test independent state. self.assertEqual([1.], module.get_count()) def test_to_mlir_type(self): self.assertEqual('i8', tf_utils.to_mlir_type(np.dtype('int8'))) self.assertEqual('i32', tf_utils.to_mlir_type(np.dtype('int32'))) self.assertEqual('f32', tf_utils.to_mlir_type(np.dtype('float32'))) self.assertEqual('f64', tf_utils.to_mlir_type(np.dtype('float64'))) def test_save_input_values(self): inputs = [np.array([1, 2], dtype=np.int32)] self.assertEqual('2xi32=1 2', tf_utils.save_input_values(inputs)) inputs = [np.array([1, 2], dtype=np.float32)] self.assertEqual('2xf32=1.0 2.0', tf_utils.save_input_values(inputs))