def test_computation_callable(self): tf_module = tf.Module() fn = lambda x: x + 1.0 sig = [tf.TensorSpec([], tf.float32)] tf_module.foo = tf.function(fn, input_signature=sig) with tempfile.TemporaryDirectory() as model_dir: save_options = tf.saved_model.SaveOptions(save_debug_info=True) tf.saved_model.save(tf_module, model_dir, options=save_options) iree_compiler_module = iree_compiler_tf.compile_saved_model( model_dir, import_only=True) my_computation_module = computation_module.ComputationModule( iree_compiler_module, 'foo', computation_types.FunctionType(tf.float32, tf.float32)) computation_callable = runtime.ComputationCallable( my_computation_module, backend_info.VULKAN_SPIRV) self.assertTrue(callable(computation_callable)) result = computation_callable(np.float32(5.0)) self.assertEqual(result, 6.0)
def test_module_class_with_add_one(self): tf_module = tf.Module() tf_module.foo = tf.function( lambda x: x + 1.0, input_signature=[tf.TensorSpec(shape=(), dtype=tf.float32)]) model_dir = '/tmp/foo' save_options = tf.saved_model.SaveOptions(save_debug_info=True) tf.saved_model.save(tf_module, model_dir, options=save_options) iree_compiler_module = iree_compiler_tf.compile_saved_model( model_dir, import_only=True, exported_names=['foo'], target_backends=iree_compiler_tf.DEFAULT_TESTING_BACKENDS) my_computation_module = computation_module.ComputationModule( iree_compiler_module, 'foo', computation_types.FunctionType(tf.float32, tf.float32)) self.assertIs(my_computation_module.compiler_module, iree_compiler_module) self.assertEqual(my_computation_module.function_name, 'foo') self.assertEqual( str(my_computation_module.type_signature), '(float32 -> float32)')
def _incrementally_compile_tf_signature_def_saved_model( saved_model_dir: str, saved_model_tags: Set[str], backend_info: "BackendInfo", exported_name: str, artifacts_dir: str): """Compile a SignatureDef SavedModel and optionally save compilation artifacts. The module blob this creates is not callable. See IreeCompiledModule for an API that returns a module that can be called without any further steps. Args: saved_model_dir: Directory of the saved model. saved_model_tags: Optional set of tags to use when loading the model. backend_info: BackendInfo with the details for compiling the saved model. exported_name: A str representing the signature on the saved model to compile. artifacts_dir: An optional string pointing to where compilation artifacts should be saved. No compilation artifacts will be saved if this is not provided. Returns: A compiled IREE module blob and the path to the compiled VM FlatBuffer if artifacts_dir is provided. """ output_kwargs = (_get_tf_import_output_kwargs( artifacts_dir, backend_info.backend_id) if artifacts_dir else {}) immediate_result = tf_compiler.compile_saved_model( saved_model_dir, import_type="SIGNATURE_DEF", target_backends=backend_info.compiler_targets, exported_names=[exported_name], saved_model_tags=saved_model_tags, **output_kwargs) output_file = output_kwargs.get("output_file") if output_file: with open(output_file, "rb") as f: immediate_result = f.read() return immediate_result, output_file
def import_tensorflow_computation(comp, name='fn'): """Creates a `computation_module.ComputationModule` from a TF computation. WARNING: This helper function is under construction, and most capabilities are not implemented at this stage: * The parameter and result of `comp` can only be a single tensor. Named tuples, sequences, or functional types are not currently supported. * Only tensorflow code can be imported. TODO(b/153499219): Add support for named tuples, sequences, and functions. Args: comp: An instance of a `pb.Computation` with TensorFlow code to import. name: An optional `str` name of the (single) function in the IREE module. Returns: An instance of `Module` with the imported function present. Raises: TypeError: If arguments are of the wrong types, e.g., in `comp` is not a TensorFlow computation. """ py_typecheck.check_type(comp, pb.Computation) type_spec = type_serialization.deserialize_type(comp.type) if not type_spec.is_function(): type_spec = computation_types.FunctionType(None, type_spec) # TODO(b/153499219): Replace this with a recursive check of the signature # after relaxing the type restrictions and introducing nested structures. py_typecheck.check_type(type_spec.result, computation_types.TensorType) if type_spec.parameter is not None: py_typecheck.check_type(type_spec.parameter, computation_types.TensorType) which_computation = comp.WhichOneof('computation') if which_computation != 'tensorflow': raise TypeError('Expected a TensorFlow computation, found {}.'.format( which_computation)) output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding( comp.tensorflow.result) if type_spec.parameter is not None: input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding( comp.tensorflow.parameter) else: input_tensor_names = [] graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def) init_op = comp.tensorflow.initialize_op return_elements = input_tensor_names + output_tensor_names if init_op: graph_def = tensorflow_utils.add_control_deps_for_init_op( graph_def, init_op) return_elements.append(init_op) with tf.Graph().as_default() as graph: # TODO(b/153499219): See if we can reintroduce uniquify_shared_names(). # Right now, it causes loader breakage, and unclear if still necessary. import_results = tf.import_graph_def( graph_def, input_map={}, return_elements=return_elements, name='') if init_op: initializer = import_results[-1] import_results.pop() else: initializer = None inputs = import_results[0:len(input_tensor_names)] outputs = import_results[len(input_tensor_names):] with graph.as_default(): # TODO(b/153499219): Find a way to reflect the nested parameter and result # structure here after relaxing the restrictions. if inputs: assert len(inputs) < 2 input_dict = { 'parameter': tf.compat.v1.saved_model.utils.build_tensor_info(inputs[0]) } else: input_dict = {} assert len(outputs) == 1 output_dict = { 'result': tf.compat.v1.saved_model.utils.build_tensor_info(outputs[0]) } sig_def = tf.compat.v1.saved_model.signature_def_utils.build_signature_def( inputs=input_dict, outputs=output_dict, method_name=name) with tempfile.TemporaryDirectory() as model_dir: builder = tf.compat.v1.saved_model.Builder(model_dir) with tf.compat.v1.Session(graph=graph) as sess: builder.add_meta_graph_and_variables( sess, ['unused'], signature_def_map={name: sig_def}, legacy_init_op=initializer, strip_default_attrs=True) builder.save() iree_module = iree_compiler_tf.compile_saved_model( model_dir, import_type='SIGNATURE_DEF', import_only=True, saved_model_tags=set(['unused']), exported_names=[name]) return computation_module.ComputationModule(iree_module, name, type_spec)