def aot_compile_cpu_meta_graph_def(checkpoint_path, meta_graph_def, output_prefix, signature_def_key, cpp_class, target_triple, variables_to_feed=(), enable_multithreading=False): """Compile a `MetaGraphDef` to header+object files in `output_prefix`. Use XLA AOT (`tfcompile`) to convert the given meta graph and signature into a header + object files. Also create an include makefile that helps identify the appropriate necessary include and library paths to incorporate these files into your C++ program. The graph is always optimized with grappler, and optionally (by default) variables are frozen as constants, before compilation happens. If the `freeze_graph` is `True`, all variables are embedded as constants into the graph and binary objects. If it is `False`, then the variable values become inputs and outputs of the compiled class and the C++ caller must set these values manually. Args: checkpoint_path: Python string. Path to checkpoints/variables. meta_graph_def: Instance of `MetaGraphDef`. output_prefix: Python string. Path prefix for outputs. signature_def_key: String, the signature_def to use in the SavedModel. cpp_class: String, Name of output C++ class. target_triple: String, LLVM target triple. variables_to_feed: A list of strings, the variables that will be fed by the user; these won't be frozen. If `None`, then we will extract all the variables in the graph and mark them as to-feed. The default behavior is an empty tuple: all variables must be frozen. enable_multithreading: Not implemented. Enable multithreading in the compiled computation. Raises: RuntimeError: If tensorflow was not built with XLA. ImportError: If tensorflow was built with XLA but there was another issue importing the tfcompile python wrapper. ValueError: If `meta_graph_def.signature_def[signature_def_key]` is missing or has empty outputs. NotImplementedError: If `enable_multithreading is True`. """ if _pywrap_tfcompile_import_error: raise _pywrap_tfcompile_import_error if enable_multithreading: raise NotImplementedError( 'Multithreading is not currently supported because it requires ' 'additional dependencies in the AOT runtime.') else: # TODO(ebrevdo): Pipe DebugOptions through tfcompile::Main and pywrap # so that we can set these directly instead of relying on env vars. xla_flags = os.environ.get('XLA_FLAGS') if not xla_flags: xla_flags = '--xla_cpu_multi_thread_eigen=false' else: xla_flags += ',--xla_cpu_multi_thread_eigen=false' os.environ['XLA_FLAGS'] = xla_flags signature_def_map = meta_graph_def.signature_def if signature_def_key not in signature_def_map: raise ValueError( 'Unable to find signature_def key \'{}\' in signature def map. ' 'Available keys: {}'.format(signature_def_key, list(signature_def_map.keys()))) signature_def = signature_def_map[signature_def_key] if not signature_def.outputs: raise ValueError( 'Signature key {} must have outputs, but saw none:\n{}'.format( signature_def_key, str(signature_def))) temp_dir = test.get_temp_dir() file_io.recursive_create_dir(temp_dir) if logging.get_verbosity() >= logging.INFO: original_graph_def_location = os.path.join(temp_dir, 'original_graph.pb') with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer: graph_writer.write(meta_graph_def.graph_def.SerializeToString()) # This updates graph_def in place. _replace_input_placeholders_with_default_values(meta_graph_def.graph_def, signature_def) graph_def = _optimize_graph(meta_graph_def, signature_def) all_variables = _get_variable_nodes_from_graph_def(graph_def) if variables_to_feed is None: variable_nodes_to_feed = list(all_variables.values()) else: not_in_graph = set(variables_to_feed).difference(list(all_variables)) if not_in_graph: raise ValueError( 'Asked to feed variables that were not found in graph: {}. ' 'Variables contained in the graph: {}'.format( not_in_graph, list(all_variables))) variable_nodes_to_feed = [ all_variables[name] for name in variables_to_feed ] if logging.get_verbosity() >= logging.INFO: prefrozen_graph_def_location = os.path.join(temp_dir, 'prefrozen_graph.pb') with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer: graph_writer.write(graph_def.SerializeToString()) # Load the Variables so that we can freeze the graph. with session.Session(graph=ops_lib.Graph()) as sess: restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True) restorer.restore(sess, checkpoint_path) graph_def.CopyFrom( graph_util.convert_variables_to_constants( sess, graph_def, output_node_names=[ _parse_tensor_name(n.name)[0] for n in signature_def.outputs.values() ], variable_names_blacklist=[ n.name for n, _ in variable_nodes_to_feed ], )) signature_def = _prune_removed_feed_nodes(signature_def, graph_def) frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb') config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt') logging.info('Writing graph def to: {}'.format(frozen_graph_def_location)) with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer: graph_writer.write(graph_def.SerializeToString()) config = _signature_to_tf2xla_config( signature_def, variable_nodes_to_feed=variable_nodes_to_feed) logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location)) with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer: config_writer.write(str(config)) output_dir = os.path.dirname(output_prefix) file_io.recursive_create_dir(output_dir) entry_digest = hashlib.md5() entry_digest.update(str(config).encode()) entry_digest.update(str(graph_def).encode()) entry_digest = entry_digest.hexdigest() logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir)) makefile_inc_location = '{}_makefile.inc'.format(output_prefix) with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer: makefile_writer.write(_xla_makefile_string(output_prefix)) output_prefix = _shlex_quote(output_prefix) _pywrap_tfcompile.Compile( graph=frozen_graph_def_location, config=config_pbtxt_location, cpp_class=cpp_class, target_triple=target_triple, entry_point='entry_{}'.format(entry_digest), out_function_object='{}.o'.format(output_prefix), out_header='{}.h'.format(output_prefix), out_metadata_object='{}_metadata.o'.format(output_prefix), gen_name_to_index=True, # ProgramShape isn't uniquefied by entry_point. gen_program_shape=False)
def aot_compile_cpu_meta_graph_def(checkpoint_path, meta_graph_def, output_prefix, signature_def_key, cpp_class, target_triple, target_cpu, variables_to_feed=(), multithreading=False): """Compile a `MetaGraphDef` to header+object files in `output_prefix`. Use XLA AOT (`tfcompile`) to convert the given meta graph and signature into a header + object files. Also create an include makefile that helps identify the appropriate necessary include and library paths to incorporate these files into your C++ program. Freezing a graph entails restoring the checkpoint and replacing any inputs and variables with constants. If values are feed, those are used, else inputs are replaced with default all-zero constants. Finally, the graph is pruned and then optimized with grappler. If the `freeze_graph` is `True`, all variables are embedded as constants into the graph and binary objects. If it is `False`, then the variable values become inputs and outputs of the compiled class and the C++ caller must set these values manually. Args: checkpoint_path: Python string. Path to checkpoints/variables. meta_graph_def: Instance of `MetaGraphDef`. output_prefix: Python string. Path prefix for outputs. signature_def_key: String, the signature_def to use in the SavedModel. cpp_class: String, Name of output C++ class. target_triple: String, LLVM target triple. target_cpu: String, LLVM target cpu name. variables_to_feed: A list of strings, the variables that will be fed by the user; these won't be frozen. If `None`, then we will extract all the variables in the graph and mark them as to-feed. The default behavior is an empty tuple: all variables must be frozen. multithreading: Whether to enable multithreading in the compiled computation. Note that if using this option, the resulting object files may have external dependencies on multithreading libraries like nsync. Raises: RuntimeError: If tensorflow was not built with XLA. ImportError: If tensorflow was built with XLA but there was another issue importing the tfcompile python wrapper. ValueError: If `meta_graph_def.signature_def[signature_def_key]` is missing or has empty outputs. """ if _pywrap_tfcompile_import_error: raise _pywrap_tfcompile_import_error # pylint: disable=raising-bad-type else: # TODO(ebrevdo): Pipe DebugOptions through tfcompile::Main and pywrap # so that we can set these directly instead of relying on env vars. xla_flags = os.environ.get('XLA_FLAGS') if not xla_flags: xla_flags = '--xla_cpu_multi_thread_eigen={}'.format( 'true' if multithreading else 'false') else: xla_flags += ' --xla_cpu_multi_thread_eigen={}'.format( 'true' if multithreading else 'false') os.environ['XLA_FLAGS'] = xla_flags temp_dir = test.get_temp_dir() file_io.recursive_create_dir(temp_dir) frozen_graph_def_location, config_pbtxt_location = freeze_model( checkpoint_path=checkpoint_path, meta_graph_def=meta_graph_def, output_prefix=temp_dir, signature_def_key=signature_def_key, variables_to_feed=variables_to_feed) output_dir = os.path.dirname(output_prefix) file_io.recursive_create_dir(output_dir) entry_point = re.sub('[^0-9a-zA-Z]+', '_', '__xla_' + output_prefix + '__' + cpp_class) logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir)) makefile_inc_location = '{}_makefile.inc'.format(output_prefix) with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer: makefile_writer.write(_xla_makefile_string(output_prefix)) output_prefix = _shlex_quote(output_prefix) _pywrap_tfcompile.Compile( graph=frozen_graph_def_location, config=config_pbtxt_location, cpp_class=cpp_class, target_triple=target_triple, target_cpu=target_cpu, entry_point=entry_point, out_function_object='{}.o'.format(output_prefix), out_header='{}.h'.format(output_prefix), out_metadata_object='{}_metadata.o'.format(output_prefix), gen_name_to_index=True, # ProgramShape isn't uniquefied by entry_point. gen_program_shape=False)