def _logging_show_info(): try: verbosity = logging.get_verbosity() logging.set_verbosity(logging.INFO) yield finally: logging.set_verbosity(verbosity)
def aot_compile_cpu_meta_graph_def(checkpoint_path, meta_graph_def, output_prefix, signature_def_key, cpp_class, target_triple, variables_to_feed=(), enable_multithreading=False): """Compile a `MetaGraphDef` to header+object files in `output_prefix`. Use XLA AOT (`tfcompile`) to convert the given meta graph and signature into a header + object files. Also create an include makefile that helps identify the appropriate necessary include and library paths to incorporate these files into your C++ program. The graph is always optimized with grappler, and optionally (by default) variables are frozen as constants, before compilation happens. If the `freeze_graph` is `True`, all variables are embedded as constants into the graph and binary objects. If it is `False`, then the variable values become inputs and outputs of the compiled class and the C++ caller must set these values manually. Args: checkpoint_path: Python string. Path to checkpoints/variables. meta_graph_def: Instance of `MetaGraphDef`. output_prefix: Python string. Path prefix for outputs. signature_def_key: String, the signature_def to use in the SavedModel. cpp_class: String, Name of output C++ class. target_triple: String, LLVM target triple. variables_to_feed: A list of strings, the variables that will be fed by the user; these won't be frozen. If `None`, then we will extract all the variables in the graph and mark them as to-feed. The default behavior is an empty tuple: all variables must be frozen. enable_multithreading: Not implemented. Enable multithreading in the compiled computation. Raises: RuntimeError: If tensorflow was not built with XLA. ImportError: If tensorflow was built with XLA but there was another issue importing the tfcompile python wrapper. ValueError: If `meta_graph_def.signature_def[signature_def_key]` is missing or has empty outputs. NotImplementedError: If `enable_multithreading is True`. """ if _pywrap_tfcompile_import_error: raise _pywrap_tfcompile_import_error if enable_multithreading: raise NotImplementedError( 'Multithreading is not currently supported because it requires ' 'additional dependencies in the AOT runtime.') else: # TODO(ebrevdo): Pipe DebugOptions through tfcompile::Main and pywrap # so that we can set these directly instead of relying on env vars. xla_flags = os.environ.get('XLA_FLAGS') if not xla_flags: xla_flags = '--xla_cpu_multi_thread_eigen=false' else: xla_flags += ',--xla_cpu_multi_thread_eigen=false' os.environ['XLA_FLAGS'] = xla_flags signature_def_map = meta_graph_def.signature_def if signature_def_key not in signature_def_map: raise ValueError( 'Unable to find signature_def key \'{}\' in signature def map. ' 'Available keys: {}'.format(signature_def_key, list(signature_def_map.keys()))) signature_def = signature_def_map[signature_def_key] if not signature_def.outputs: raise ValueError( 'Signature key {} must have outputs, but saw none:\n{}'.format( signature_def_key, str(signature_def))) temp_dir = test.get_temp_dir() file_io.recursive_create_dir(temp_dir) if logging.get_verbosity() >= logging.INFO: original_graph_def_location = os.path.join(temp_dir, 'original_graph.pb') with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer: graph_writer.write(meta_graph_def.graph_def.SerializeToString()) # This updates graph_def in place. _replace_input_placeholders_with_default_values(meta_graph_def.graph_def, signature_def) graph_def = _optimize_graph(meta_graph_def, signature_def) all_variables = _get_variable_nodes_from_graph_def(graph_def) if variables_to_feed is None: variable_nodes_to_feed = list(all_variables.values()) else: not_in_graph = set(variables_to_feed).difference(list(all_variables)) if not_in_graph: raise ValueError( 'Asked to feed variables that were not found in graph: {}. ' 'Variables contained in the graph: {}'.format( not_in_graph, list(all_variables))) variable_nodes_to_feed = [ all_variables[name] for name in variables_to_feed ] if logging.get_verbosity() >= logging.INFO: prefrozen_graph_def_location = os.path.join(temp_dir, 'prefrozen_graph.pb') with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer: graph_writer.write(graph_def.SerializeToString()) # Load the Variables so that we can freeze the graph. with session.Session(graph=ops_lib.Graph()) as sess: restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True) restorer.restore(sess, checkpoint_path) graph_def.CopyFrom( graph_util.convert_variables_to_constants( sess, graph_def, output_node_names=[ _parse_tensor_name(n.name)[0] for n in signature_def.outputs.values() ], variable_names_blacklist=[ n.name for n, _ in variable_nodes_to_feed ], )) signature_def = _prune_removed_feed_nodes(signature_def, graph_def) frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb') config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt') logging.info('Writing graph def to: {}'.format(frozen_graph_def_location)) with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer: graph_writer.write(graph_def.SerializeToString()) config = _signature_to_tf2xla_config( signature_def, variable_nodes_to_feed=variable_nodes_to_feed) logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location)) with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer: config_writer.write(str(config)) output_dir = os.path.dirname(output_prefix) file_io.recursive_create_dir(output_dir) entry_digest = hashlib.md5() entry_digest.update(str(config).encode()) entry_digest.update(str(graph_def).encode()) entry_digest = entry_digest.hexdigest() logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir)) makefile_inc_location = '{}_makefile.inc'.format(output_prefix) with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer: makefile_writer.write(_xla_makefile_string(output_prefix)) output_prefix = _shlex_quote(output_prefix) _pywrap_tfcompile.Compile( graph=frozen_graph_def_location, config=config_pbtxt_location, cpp_class=cpp_class, target_triple=target_triple, entry_point='entry_{}'.format(entry_digest), out_function_object='{}.o'.format(output_prefix), out_header='{}.h'.format(output_prefix), out_metadata_object='{}_metadata.o'.format(output_prefix), gen_name_to_index=True, # ProgramShape isn't uniquefied by entry_point. gen_program_shape=False)
def entity_to_graph(o, program_ctx, arg_values, arg_types): """Compile a Python entity into equivalent TensorFlow. The function will also recursively compile all the entities that `o` references, updating `dependency_cache`. This function is reentrant, and relies on dependency_cache to avoid generating duplicate code. Args: o: A Python entity. program_ctx: A ProgramContext object. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. Returns: A tuple (ast, new_name, namespace): * ast: An AST representing an entity with interface equivalent to `o`, but which when executed it creates TF a graph. * new_name: The symbol name under which the new entity can be found. * namespace: A dict mapping all symbols visible to the converted entity, keyed by their symbol name. Raises: ValueError: if the entity type is not supported. """ logging.vlog(logging.DEBUG, 'Converting %s', o) if tf_inspect.isclass(o): node, name, ns = class_to_graph(o, program_ctx) elif tf_inspect.isfunction(o): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) elif tf_inspect.ismethod(o): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) # TODO(mdan,yashkatariya): Remove when object conversion is implemented. elif hasattr(o, '__class__'): raise NotImplementedError( 'Object conversion is not yet supported. If you are ' 'trying to convert code that uses an existing object, ' 'try including the creation of that object in the ' 'conversion. For example, instead of converting the method ' 'of a class, try converting the entire class instead. ' 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' 'contrib/autograph/README.md#using-the-functional-api ' 'for more information.') else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) # TODO(mdan): This is temporary. it should be created using a converter. # TODO(mdan): The attribute should be added with a helper, not directly. # The helper can ensure there are no collisions. template = ''' entity.autograph_info__ = {} ''' node.extend(templates.replace(template, entity=name)) program_ctx.add_to_cache(o, node) if logging.get_verbosity() <= logging.DEBUG: logging.vlog(logging.DEBUG, 'Compiled output of %s:\n\n%s\n', o, compiler.ast_to_source(node)) if program_ctx.options.recursive: while True: candidate = None for obj in program_ctx.name_map.keys(): if obj not in program_ctx.dependency_cache: candidate = obj break if candidate is None: break if (hasattr(candidate, 'im_class') and getattr( candidate, 'im_class') not in program_ctx.partial_types): # Class members are converted with their objects, unless they're # only converted partially. continue entity_to_graph(candidate, program_ctx, {}, {}) return node, name, ns
def freeze_model(checkpoint_path: str, meta_graph_def: meta_graph_pb2.MetaGraphDef, output_prefix: str, signature_def_key: str, variables_to_feed: List[str]) -> Tuple[str, str]: """Freeze a `MetaGraphDef` in preparation for tfcompile`. The graph is always optimized with grappler, and optionally (by default) variables are frozen as constants, before compilation happens. Args: checkpoint_path: Python string. Path to checkpoints/variables. meta_graph_def: Instance of `MetaGraphDef`. output_prefix: Python string. Path prefix for outputs. signature_def_key: String, the signature_def to use in the SavedModel. variables_to_feed: A list of strings, the variables that will be fed by the user; these won't be frozen. If `None`, then we will extract all the variables in the graph and mark them as to-feed. The default behavior is an empty tuple: all variables must be frozen. Returns: a pair containing the path to the frozen model and the path to the config. Raises: RuntimeError: If tensorflow was not built with XLA. ImportError: If tensorflow was built with XLA but there was another issue importing the tfcompile python wrapper. ValueError: If `meta_graph_def.signature_def[signature_def_key]` is missing or has empty outputs. """ if _pywrap_tfcompile_import_error: raise _pywrap_tfcompile_import_error # pylint: disable=raising-bad-type signature_def_map = meta_graph_def.signature_def if signature_def_key not in signature_def_map: raise ValueError( f"Unable to find signature_def_key '{signature_def_key}' in signature " 'def map of `meta_graph_def`. Available keys: ' f'{list(signature_def_map.keys())}') signature_def = signature_def_map[signature_def_key] if not signature_def.outputs: raise ValueError( f'Signature key {signature_def_key} must have outputs, but saw none:\n' f'{str(signature_def)}') file_io.recursive_create_dir(output_prefix) if logging.get_verbosity() >= logging.INFO: original_graph_def_location = os.path.join(output_prefix, 'original_graph.pb') with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer: graph_writer.write(meta_graph_def.graph_def.SerializeToString()) # This updates graph_def in place. _replace_input_placeholders_with_default_values(meta_graph_def.graph_def, signature_def) graph_def = _optimize_graph(meta_graph_def, signature_def) all_variables = _get_variable_nodes_from_graph_def(graph_def) if variables_to_feed is None: variable_nodes_to_feed = list(all_variables.values()) else: not_in_graph = set(variables_to_feed).difference(list(all_variables)) if not_in_graph: raise ValueError( 'Asked to feed variables that were not found in graph: ' f'{not_in_graph}. Variables contained in the graph: ' f'{list(all_variables)}') variable_nodes_to_feed = [ all_variables[name] for name in variables_to_feed ] if logging.get_verbosity() >= logging.INFO: prefrozen_graph_def_location = os.path.join(output_prefix, 'prefrozen_graph.pb') with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer: graph_writer.write(graph_def.SerializeToString()) # Load the Variables so that we can freeze the graph. with session.Session(graph=ops_lib.Graph()) as sess: restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True) if restorer is not None: restorer.restore(sess, checkpoint_path) graph_def.CopyFrom( graph_util.convert_variables_to_constants( sess, graph_def, output_node_names=[ _parse_tensor_name(n.name)[0] for n in signature_def.outputs.values() ], variable_names_blacklist=[ n.name for n, _ in variable_nodes_to_feed ], )) signature_def = _prune_removed_feed_nodes(signature_def, graph_def) frozen_graph_def_location = os.path.join(output_prefix, 'frozen_graph.pb') config_pbtxt_location = os.path.join(output_prefix, 'config.pbtxt') logging.info('Writing graph def to: {}'.format(frozen_graph_def_location)) with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer: graph_writer.write(graph_def.SerializeToString()) config = _signature_to_tf2xla_config( signature_def, variable_nodes_to_feed=variable_nodes_to_feed) logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location)) with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer: config_writer.write(str(config)) return frozen_graph_def_location, config_pbtxt_location
def entity_to_graph(o, program_ctx, arg_values, arg_types): """Compile a Python entity into equivalent TensorFlow. The function will also recursively compile all the entities that `o` references, updating `dependency_cache`. This function is reentrant, and relies on dependency_cache to avoid generating duplicate code. Args: o: A Python entity. program_ctx: A ProgramContext object. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. Returns: A tuple (ast, new_name, namespace): * ast: An AST representing an entity with interface equivalent to `o`, but which when executed it creates TF a graph. * new_name: The symbol name under which the new entity can be found. * namespace: A dict mapping all symbols visible to the converted entity, keyed by their symbol name. Raises: ValueError: if the entity type is not supported. """ logging.vlog(logging.DEBUG, 'Converting %s', o) if tf_inspect.isclass(o): node, name, ns = class_to_graph(o, program_ctx) elif tf_inspect.isfunction(o): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) elif tf_inspect.ismethod(o): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) # TODO(mdan,yashkatariya): Remove when object conversion is implemented. elif hasattr(o, '__class__'): raise NotImplementedError( 'Object conversion is not yet supported. If you are ' 'trying to convert code that uses an existing object, ' 'try including the creation of that object in the ' 'conversion. For example, instead of converting the method ' 'of a class, try converting the entire class instead. ' 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' 'contrib/autograph/README.md#using-the-functional-api ' 'for more information.') else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) # TODO(mdan): This is temporary. it should be created using a converter. # TODO(mdan): The attribute should be added with a helper, not directly. # The helper can ensure there are no collisions. template = ''' entity.autograph_info__ = {} ''' node.extend(templates.replace(template, entity=name)) program_ctx.add_to_cache(o, node) if logging.get_verbosity() <= logging.DEBUG: logging.vlog(logging.DEBUG, 'Compiled output of %s:\n\n%s\n', o, compiler.ast_to_source(node)) if program_ctx.options.recursive: while True: candidate = None for obj in program_ctx.name_map.keys(): if obj not in program_ctx.dependency_cache: candidate = obj break if candidate is None: break if (hasattr(candidate, 'im_class') and getattr(candidate, 'im_class') not in program_ctx.partial_types): # Class members are converted with their objects, unless they're # only converted partially. continue entity_to_graph(candidate, program_ctx, {}, {}) return node, name, ns