示例#1
0
def _logging_show_info():
    try:
        verbosity = logging.get_verbosity()
        logging.set_verbosity(logging.INFO)
        yield
    finally:
        logging.set_verbosity(verbosity)
示例#2
0
def aot_compile_cpu_meta_graph_def(checkpoint_path,
                                   meta_graph_def,
                                   output_prefix,
                                   signature_def_key,
                                   cpp_class,
                                   target_triple,
                                   variables_to_feed=(),
                                   enable_multithreading=False):
    """Compile a `MetaGraphDef` to header+object files in `output_prefix`.

  Use XLA AOT (`tfcompile`) to convert the given meta graph and
  signature into a header + object files.  Also create an include makefile
  that helps identify the appropriate necessary include and library paths
  to incorporate these files into your C++ program.

  The graph is always optimized with grappler, and optionally (by default)
  variables are frozen as constants, before compilation happens.

  If the `freeze_graph` is `True`, all variables are embedded as constants
  into the graph and binary objects.  If it is `False`, then the variable
  values become inputs and outputs of the compiled class and the C++
  caller must set these values manually.

  Args:
    checkpoint_path: Python string.  Path to checkpoints/variables.
    meta_graph_def: Instance of `MetaGraphDef`.
    output_prefix: Python string.  Path prefix for outputs.
    signature_def_key: String, the signature_def to use in the SavedModel.
    cpp_class: String, Name of output C++ class.
    target_triple: String, LLVM target triple.
    variables_to_feed: A list of strings, the variables that will be fed by the
      user; these won't be frozen.  If `None`, then we will extract all the
      variables in the graph and mark them as to-feed.  The default behavior is
      an empty tuple: all variables must be frozen.
    enable_multithreading: Not implemented.  Enable multithreading in the
      compiled computation.

  Raises:
    RuntimeError: If tensorflow was not built with XLA.
    ImportError: If tensorflow was built with XLA but there was another
      issue importing the tfcompile python wrapper.
    ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
      missing or has empty outputs.
    NotImplementedError: If `enable_multithreading is True`.
  """
    if _pywrap_tfcompile_import_error:
        raise _pywrap_tfcompile_import_error

    if enable_multithreading:
        raise NotImplementedError(
            'Multithreading is not currently supported because it requires '
            'additional dependencies in the AOT runtime.')
    else:
        # TODO(ebrevdo): Pipe DebugOptions through tfcompile::Main and pywrap
        # so that we can set these directly instead of relying on env vars.
        xla_flags = os.environ.get('XLA_FLAGS')
        if not xla_flags:
            xla_flags = '--xla_cpu_multi_thread_eigen=false'
        else:
            xla_flags += ',--xla_cpu_multi_thread_eigen=false'
        os.environ['XLA_FLAGS'] = xla_flags

    signature_def_map = meta_graph_def.signature_def
    if signature_def_key not in signature_def_map:
        raise ValueError(
            'Unable to find signature_def key \'{}\' in signature def map.  '
            'Available keys: {}'.format(signature_def_key,
                                        list(signature_def_map.keys())))
    signature_def = signature_def_map[signature_def_key]
    if not signature_def.outputs:
        raise ValueError(
            'Signature key {} must have outputs, but saw none:\n{}'.format(
                signature_def_key, str(signature_def)))

    temp_dir = test.get_temp_dir()
    file_io.recursive_create_dir(temp_dir)
    if logging.get_verbosity() >= logging.INFO:
        original_graph_def_location = os.path.join(temp_dir,
                                                   'original_graph.pb')
        with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer:
            graph_writer.write(meta_graph_def.graph_def.SerializeToString())

    # This updates graph_def in place.
    _replace_input_placeholders_with_default_values(meta_graph_def.graph_def,
                                                    signature_def)

    graph_def = _optimize_graph(meta_graph_def, signature_def)

    all_variables = _get_variable_nodes_from_graph_def(graph_def)
    if variables_to_feed is None:
        variable_nodes_to_feed = list(all_variables.values())
    else:
        not_in_graph = set(variables_to_feed).difference(list(all_variables))
        if not_in_graph:
            raise ValueError(
                'Asked to feed variables that were not found in graph: {}.  '
                'Variables contained in the graph: {}'.format(
                    not_in_graph, list(all_variables)))
        variable_nodes_to_feed = [
            all_variables[name] for name in variables_to_feed
        ]

    if logging.get_verbosity() >= logging.INFO:
        prefrozen_graph_def_location = os.path.join(temp_dir,
                                                    'prefrozen_graph.pb')
        with file_io.FileIO(prefrozen_graph_def_location,
                            'wb') as graph_writer:
            graph_writer.write(graph_def.SerializeToString())

    # Load the Variables so that we can freeze the graph.
    with session.Session(graph=ops_lib.Graph()) as sess:
        restorer = saver_lib.import_meta_graph(meta_graph_def,
                                               clear_devices=True)
        restorer.restore(sess, checkpoint_path)
        graph_def.CopyFrom(
            graph_util.convert_variables_to_constants(
                sess,
                graph_def,
                output_node_names=[
                    _parse_tensor_name(n.name)[0]
                    for n in signature_def.outputs.values()
                ],
                variable_names_blacklist=[
                    n.name for n, _ in variable_nodes_to_feed
                ],
            ))

    signature_def = _prune_removed_feed_nodes(signature_def, graph_def)

    frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb')
    config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt')
    logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
    with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
        graph_writer.write(graph_def.SerializeToString())
    config = _signature_to_tf2xla_config(
        signature_def, variable_nodes_to_feed=variable_nodes_to_feed)
    logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
    with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
        config_writer.write(str(config))

    output_dir = os.path.dirname(output_prefix)
    file_io.recursive_create_dir(output_dir)

    entry_digest = hashlib.md5()
    entry_digest.update(str(config).encode())
    entry_digest.update(str(graph_def).encode())
    entry_digest = entry_digest.hexdigest()

    logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir))

    makefile_inc_location = '{}_makefile.inc'.format(output_prefix)
    with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer:
        makefile_writer.write(_xla_makefile_string(output_prefix))

    output_prefix = _shlex_quote(output_prefix)

    _pywrap_tfcompile.Compile(
        graph=frozen_graph_def_location,
        config=config_pbtxt_location,
        cpp_class=cpp_class,
        target_triple=target_triple,
        entry_point='entry_{}'.format(entry_digest),
        out_function_object='{}.o'.format(output_prefix),
        out_header='{}.h'.format(output_prefix),
        out_metadata_object='{}_metadata.o'.format(output_prefix),
        gen_name_to_index=True,
        # ProgramShape isn't uniquefied by entry_point.
        gen_program_shape=False)
示例#3
0
def entity_to_graph(o, program_ctx, arg_values, arg_types):
    """Compile a Python entity into equivalent TensorFlow.

  The function will also recursively compile all the entities that `o`
  references, updating `dependency_cache`.

  This function is reentrant, and relies on dependency_cache to avoid
  generating duplicate code.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
    logging.vlog(logging.DEBUG, 'Converting %s', o)

    if tf_inspect.isclass(o):
        node, name, ns = class_to_graph(o, program_ctx)
    elif tf_inspect.isfunction(o):
        node, name, ns = function_to_graph(o, program_ctx, arg_values,
                                           arg_types)
    elif tf_inspect.ismethod(o):
        node, name, ns = function_to_graph(o, program_ctx, arg_values,
                                           arg_types)
    # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
    elif hasattr(o, '__class__'):
        raise NotImplementedError(
            'Object conversion is not yet supported. If you are '
            'trying to convert code that uses an existing object, '
            'try including the creation of that object in the '
            'conversion. For example, instead of converting the method '
            'of a class, try converting the entire class instead. '
            'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
            'contrib/autograph/README.md#using-the-functional-api '
            'for more information.')
    else:
        raise ValueError(
            'Entity "%s" has unsupported type "%s". Only functions and classes are '
            'supported for now.' % (o, type(o)))

    # TODO(mdan): This is temporary. it should be created using a converter.
    # TODO(mdan): The attribute should be added with a helper, not directly.
    # The helper can ensure there are no collisions.
    template = '''
      entity.autograph_info__ = {}
  '''
    node.extend(templates.replace(template, entity=name))

    program_ctx.add_to_cache(o, node)

    if logging.get_verbosity() <= logging.DEBUG:
        logging.vlog(logging.DEBUG, 'Compiled output of %s:\n\n%s\n', o,
                     compiler.ast_to_source(node))

    if program_ctx.options.recursive:
        while True:
            candidate = None
            for obj in program_ctx.name_map.keys():
                if obj not in program_ctx.dependency_cache:
                    candidate = obj
                    break
            if candidate is None:
                break
            if (hasattr(candidate, 'im_class') and getattr(
                    candidate, 'im_class') not in program_ctx.partial_types):
                # Class members are converted with their objects, unless they're
                # only converted partially.
                continue
            entity_to_graph(candidate, program_ctx, {}, {})

    return node, name, ns
def freeze_model(checkpoint_path: str,
                 meta_graph_def: meta_graph_pb2.MetaGraphDef,
                 output_prefix: str, signature_def_key: str,
                 variables_to_feed: List[str]) -> Tuple[str, str]:
    """Freeze a `MetaGraphDef` in preparation for tfcompile`.

  The graph is always optimized with grappler, and optionally (by default)
  variables are frozen as constants, before compilation happens.

  Args:
    checkpoint_path: Python string.  Path to checkpoints/variables.
    meta_graph_def: Instance of `MetaGraphDef`.
    output_prefix: Python string.  Path prefix for outputs.
    signature_def_key: String, the signature_def to use in the SavedModel.
    variables_to_feed: A list of strings, the variables that will be fed by the
      user; these won't be frozen.  If `None`, then we will extract all the
      variables in the graph and mark them as to-feed.  The default behavior is
      an empty tuple: all variables must be frozen.
  Returns:
    a pair containing the path to the frozen model and the path to the config.
  Raises:
    RuntimeError: If tensorflow was not built with XLA.
    ImportError: If tensorflow was built with XLA but there was another
      issue importing the tfcompile python wrapper.
    ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
      missing or has empty outputs.
  """
    if _pywrap_tfcompile_import_error:
        raise _pywrap_tfcompile_import_error  # pylint: disable=raising-bad-type

    signature_def_map = meta_graph_def.signature_def
    if signature_def_key not in signature_def_map:
        raise ValueError(
            f"Unable to find signature_def_key '{signature_def_key}' in signature "
            'def map of `meta_graph_def`. Available keys: '
            f'{list(signature_def_map.keys())}')
    signature_def = signature_def_map[signature_def_key]
    if not signature_def.outputs:
        raise ValueError(
            f'Signature key {signature_def_key} must have outputs, but saw none:\n'
            f'{str(signature_def)}')

    file_io.recursive_create_dir(output_prefix)
    if logging.get_verbosity() >= logging.INFO:
        original_graph_def_location = os.path.join(output_prefix,
                                                   'original_graph.pb')
        with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer:
            graph_writer.write(meta_graph_def.graph_def.SerializeToString())

    # This updates graph_def in place.
    _replace_input_placeholders_with_default_values(meta_graph_def.graph_def,
                                                    signature_def)

    graph_def = _optimize_graph(meta_graph_def, signature_def)

    all_variables = _get_variable_nodes_from_graph_def(graph_def)
    if variables_to_feed is None:
        variable_nodes_to_feed = list(all_variables.values())
    else:
        not_in_graph = set(variables_to_feed).difference(list(all_variables))
        if not_in_graph:
            raise ValueError(
                'Asked to feed variables that were not found in graph: '
                f'{not_in_graph}. Variables contained in the graph: '
                f'{list(all_variables)}')
        variable_nodes_to_feed = [
            all_variables[name] for name in variables_to_feed
        ]

    if logging.get_verbosity() >= logging.INFO:
        prefrozen_graph_def_location = os.path.join(output_prefix,
                                                    'prefrozen_graph.pb')
        with file_io.FileIO(prefrozen_graph_def_location,
                            'wb') as graph_writer:
            graph_writer.write(graph_def.SerializeToString())

    # Load the Variables so that we can freeze the graph.
    with session.Session(graph=ops_lib.Graph()) as sess:
        restorer = saver_lib.import_meta_graph(meta_graph_def,
                                               clear_devices=True)
        if restorer is not None:
            restorer.restore(sess, checkpoint_path)
        graph_def.CopyFrom(
            graph_util.convert_variables_to_constants(
                sess,
                graph_def,
                output_node_names=[
                    _parse_tensor_name(n.name)[0]
                    for n in signature_def.outputs.values()
                ],
                variable_names_blacklist=[
                    n.name for n, _ in variable_nodes_to_feed
                ],
            ))

    signature_def = _prune_removed_feed_nodes(signature_def, graph_def)

    frozen_graph_def_location = os.path.join(output_prefix, 'frozen_graph.pb')
    config_pbtxt_location = os.path.join(output_prefix, 'config.pbtxt')
    logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
    with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
        graph_writer.write(graph_def.SerializeToString())
    config = _signature_to_tf2xla_config(
        signature_def, variable_nodes_to_feed=variable_nodes_to_feed)
    logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
    with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
        config_writer.write(str(config))
    return frozen_graph_def_location, config_pbtxt_location
示例#5
0
def entity_to_graph(o, program_ctx, arg_values, arg_types):
  """Compile a Python entity into equivalent TensorFlow.

  The function will also recursively compile all the entities that `o`
  references, updating `dependency_cache`.

  This function is reentrant, and relies on dependency_cache to avoid
  generating duplicate code.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
  logging.vlog(logging.DEBUG, 'Converting %s', o)

  if tf_inspect.isclass(o):
    node, name, ns = class_to_graph(o, program_ctx)
  elif tf_inspect.isfunction(o):
    node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
  elif tf_inspect.ismethod(o):
    node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
  # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
  elif hasattr(o, '__class__'):
    raise NotImplementedError(
        'Object conversion is not yet supported. If you are '
        'trying to convert code that uses an existing object, '
        'try including the creation of that object in the '
        'conversion. For example, instead of converting the method '
        'of a class, try converting the entire class instead. '
        'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
        'contrib/autograph/README.md#using-the-functional-api '
        'for more information.')
  else:
    raise ValueError(
        'Entity "%s" has unsupported type "%s". Only functions and classes are '
        'supported for now.' % (o, type(o)))

  # TODO(mdan): This is temporary. it should be created using a converter.
  # TODO(mdan): The attribute should be added with a helper, not directly.
  # The helper can ensure there are no collisions.
  template = '''
      entity.autograph_info__ = {}
  '''
  node.extend(templates.replace(template, entity=name))

  program_ctx.add_to_cache(o, node)

  if logging.get_verbosity() <= logging.DEBUG:
    logging.vlog(logging.DEBUG, 'Compiled output of %s:\n\n%s\n', o,
                 compiler.ast_to_source(node))

  if program_ctx.options.recursive:
    while True:
      candidate = None
      for obj in program_ctx.name_map.keys():
        if obj not in program_ctx.dependency_cache:
          candidate = obj
          break
      if candidate is None:
        break
      if (hasattr(candidate, 'im_class') and
          getattr(candidate, 'im_class') not in program_ctx.partial_types):
        # Class members are converted with their objects, unless they're
        # only converted partially.
        continue
      entity_to_graph(candidate, program_ctx, {}, {})

  return node, name, ns