Esempio n. 1
0
def aot_compile_cpu_meta_graph_def(checkpoint_path,
                                   meta_graph_def,
                                   output_prefix,
                                   signature_def_key,
                                   cpp_class,
                                   target_triple,
                                   variables_to_feed=(),
                                   enable_multithreading=False):
    """Compile a `MetaGraphDef` to header+object files in `output_prefix`.

  Use XLA AOT (`tfcompile`) to convert the given meta graph and
  signature into a header + object files.  Also create an include makefile
  that helps identify the appropriate necessary include and library paths
  to incorporate these files into your C++ program.

  The graph is always optimized with grappler, and optionally (by default)
  variables are frozen as constants, before compilation happens.

  If the `freeze_graph` is `True`, all variables are embedded as constants
  into the graph and binary objects.  If it is `False`, then the variable
  values become inputs and outputs of the compiled class and the C++
  caller must set these values manually.

  Args:
    checkpoint_path: Python string.  Path to checkpoints/variables.
    meta_graph_def: Instance of `MetaGraphDef`.
    output_prefix: Python string.  Path prefix for outputs.
    signature_def_key: String, the signature_def to use in the SavedModel.
    cpp_class: String, Name of output C++ class.
    target_triple: String, LLVM target triple.
    variables_to_feed: A list of strings, the variables that will be fed by the
      user; these won't be frozen.  If `None`, then we will extract all the
      variables in the graph and mark them as to-feed.  The default behavior is
      an empty tuple: all variables must be frozen.
    enable_multithreading: Not implemented.  Enable multithreading in the
      compiled computation.

  Raises:
    RuntimeError: If tensorflow was not built with XLA.
    ImportError: If tensorflow was built with XLA but there was another
      issue importing the tfcompile python wrapper.
    ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
      missing or has empty outputs.
    NotImplementedError: If `enable_multithreading is True`.
  """
    if _pywrap_tfcompile_import_error:
        raise _pywrap_tfcompile_import_error

    if enable_multithreading:
        raise NotImplementedError(
            'Multithreading is not currently supported because it requires '
            'additional dependencies in the AOT runtime.')
    else:
        # TODO(ebrevdo): Pipe DebugOptions through tfcompile::Main and pywrap
        # so that we can set these directly instead of relying on env vars.
        xla_flags = os.environ.get('XLA_FLAGS')
        if not xla_flags:
            xla_flags = '--xla_cpu_multi_thread_eigen=false'
        else:
            xla_flags += ',--xla_cpu_multi_thread_eigen=false'
        os.environ['XLA_FLAGS'] = xla_flags

    signature_def_map = meta_graph_def.signature_def
    if signature_def_key not in signature_def_map:
        raise ValueError(
            'Unable to find signature_def key \'{}\' in signature def map.  '
            'Available keys: {}'.format(signature_def_key,
                                        list(signature_def_map.keys())))
    signature_def = signature_def_map[signature_def_key]
    if not signature_def.outputs:
        raise ValueError(
            'Signature key {} must have outputs, but saw none:\n{}'.format(
                signature_def_key, str(signature_def)))

    temp_dir = test.get_temp_dir()
    file_io.recursive_create_dir(temp_dir)
    if logging.get_verbosity() >= logging.INFO:
        original_graph_def_location = os.path.join(temp_dir,
                                                   'original_graph.pb')
        with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer:
            graph_writer.write(meta_graph_def.graph_def.SerializeToString())

    # This updates graph_def in place.
    _replace_input_placeholders_with_default_values(meta_graph_def.graph_def,
                                                    signature_def)

    graph_def = _optimize_graph(meta_graph_def, signature_def)

    all_variables = _get_variable_nodes_from_graph_def(graph_def)
    if variables_to_feed is None:
        variable_nodes_to_feed = list(all_variables.values())
    else:
        not_in_graph = set(variables_to_feed).difference(list(all_variables))
        if not_in_graph:
            raise ValueError(
                'Asked to feed variables that were not found in graph: {}.  '
                'Variables contained in the graph: {}'.format(
                    not_in_graph, list(all_variables)))
        variable_nodes_to_feed = [
            all_variables[name] for name in variables_to_feed
        ]

    if logging.get_verbosity() >= logging.INFO:
        prefrozen_graph_def_location = os.path.join(temp_dir,
                                                    'prefrozen_graph.pb')
        with file_io.FileIO(prefrozen_graph_def_location,
                            'wb') as graph_writer:
            graph_writer.write(graph_def.SerializeToString())

    # Load the Variables so that we can freeze the graph.
    with session.Session(graph=ops_lib.Graph()) as sess:
        restorer = saver_lib.import_meta_graph(meta_graph_def,
                                               clear_devices=True)
        restorer.restore(sess, checkpoint_path)
        graph_def.CopyFrom(
            graph_util.convert_variables_to_constants(
                sess,
                graph_def,
                output_node_names=[
                    _parse_tensor_name(n.name)[0]
                    for n in signature_def.outputs.values()
                ],
                variable_names_blacklist=[
                    n.name for n, _ in variable_nodes_to_feed
                ],
            ))

    signature_def = _prune_removed_feed_nodes(signature_def, graph_def)

    frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb')
    config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt')
    logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
    with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
        graph_writer.write(graph_def.SerializeToString())
    config = _signature_to_tf2xla_config(
        signature_def, variable_nodes_to_feed=variable_nodes_to_feed)
    logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
    with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
        config_writer.write(str(config))

    output_dir = os.path.dirname(output_prefix)
    file_io.recursive_create_dir(output_dir)

    entry_digest = hashlib.md5()
    entry_digest.update(str(config).encode())
    entry_digest.update(str(graph_def).encode())
    entry_digest = entry_digest.hexdigest()

    logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir))

    makefile_inc_location = '{}_makefile.inc'.format(output_prefix)
    with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer:
        makefile_writer.write(_xla_makefile_string(output_prefix))

    output_prefix = _shlex_quote(output_prefix)

    _pywrap_tfcompile.Compile(
        graph=frozen_graph_def_location,
        config=config_pbtxt_location,
        cpp_class=cpp_class,
        target_triple=target_triple,
        entry_point='entry_{}'.format(entry_digest),
        out_function_object='{}.o'.format(output_prefix),
        out_header='{}.h'.format(output_prefix),
        out_metadata_object='{}_metadata.o'.format(output_prefix),
        gen_name_to_index=True,
        # ProgramShape isn't uniquefied by entry_point.
        gen_program_shape=False)
def aot_compile_cpu_meta_graph_def(checkpoint_path,
                                   meta_graph_def,
                                   output_prefix,
                                   signature_def_key,
                                   cpp_class,
                                   target_triple,
                                   target_cpu,
                                   variables_to_feed=(),
                                   multithreading=False):
    """Compile a `MetaGraphDef` to header+object files in `output_prefix`.

  Use XLA AOT (`tfcompile`) to convert the given meta graph and
  signature into a header + object files.  Also create an include makefile
  that helps identify the appropriate necessary include and library paths
  to incorporate these files into your C++ program.

  Freezing a graph entails restoring the checkpoint and replacing any inputs and
  variables with constants. If values are feed, those are used, else inputs are
  replaced with default all-zero constants. Finally, the graph is pruned and
  then optimized with grappler.

  If the `freeze_graph` is `True`, all variables are embedded as constants
  into the graph and binary objects.  If it is `False`, then the variable
  values become inputs and outputs of the compiled class and the C++
  caller must set these values manually.

  Args:
    checkpoint_path: Python string.  Path to checkpoints/variables.
    meta_graph_def: Instance of `MetaGraphDef`.
    output_prefix: Python string.  Path prefix for outputs.
    signature_def_key: String, the signature_def to use in the SavedModel.
    cpp_class: String, Name of output C++ class.
    target_triple: String, LLVM target triple.
    target_cpu: String, LLVM target cpu name.
    variables_to_feed: A list of strings, the variables that will be fed by the
      user; these won't be frozen.  If `None`, then we will extract all the
      variables in the graph and mark them as to-feed.  The default behavior is
      an empty tuple: all variables must be frozen.
    multithreading: Whether to enable multithreading in the compiled
      computation.  Note that if using this option, the resulting object files
      may have external dependencies on multithreading libraries like nsync.

  Raises:
    RuntimeError: If tensorflow was not built with XLA.
    ImportError: If tensorflow was built with XLA but there was another
      issue importing the tfcompile python wrapper.
    ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
      missing or has empty outputs.
  """
    if _pywrap_tfcompile_import_error:
        raise _pywrap_tfcompile_import_error  # pylint: disable=raising-bad-type

    else:
        # TODO(ebrevdo): Pipe DebugOptions through tfcompile::Main and pywrap
        # so that we can set these directly instead of relying on env vars.
        xla_flags = os.environ.get('XLA_FLAGS')
        if not xla_flags:
            xla_flags = '--xla_cpu_multi_thread_eigen={}'.format(
                'true' if multithreading else 'false')
        else:
            xla_flags += ' --xla_cpu_multi_thread_eigen={}'.format(
                'true' if multithreading else 'false')
        os.environ['XLA_FLAGS'] = xla_flags

    temp_dir = test.get_temp_dir()
    file_io.recursive_create_dir(temp_dir)
    frozen_graph_def_location, config_pbtxt_location = freeze_model(
        checkpoint_path=checkpoint_path,
        meta_graph_def=meta_graph_def,
        output_prefix=temp_dir,
        signature_def_key=signature_def_key,
        variables_to_feed=variables_to_feed)
    output_dir = os.path.dirname(output_prefix)
    file_io.recursive_create_dir(output_dir)

    entry_point = re.sub('[^0-9a-zA-Z]+', '_',
                         '__xla_' + output_prefix + '__' + cpp_class)

    logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir))

    makefile_inc_location = '{}_makefile.inc'.format(output_prefix)
    with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer:
        makefile_writer.write(_xla_makefile_string(output_prefix))

    output_prefix = _shlex_quote(output_prefix)

    _pywrap_tfcompile.Compile(
        graph=frozen_graph_def_location,
        config=config_pbtxt_location,
        cpp_class=cpp_class,
        target_triple=target_triple,
        target_cpu=target_cpu,
        entry_point=entry_point,
        out_function_object='{}.o'.format(output_prefix),
        out_header='{}.h'.format(output_prefix),
        out_metadata_object='{}_metadata.o'.format(output_prefix),
        gen_name_to_index=True,
        # ProgramShape isn't uniquefied by entry_point.
        gen_program_shape=False)