Exemplo n.º 1
0
def force_text(s, encoding='utf-8'):
    if six.PY3:
        if isinstance(s, bytes):
            s = six.text_type(s, encoding)
        else:
            s = six.text_type(s)
    else:
        s = six.text_type(bytes(s), encoding)

    return s
Exemplo n.º 2
0
def force_bytes(s, encoding='utf-8'):
    if isinstance(s, bytes):
        if encoding == 'utf-8':
            return s
        else:
            return s.decode('utf-8').encode(encoding)

    if not isinstance(s, six.string_types):
        if six.PY3:
            return six.text_type(s).encode(encoding)
        else:
            return bytes(s)
    else:
        return s.encode(encoding)
Exemplo n.º 3
0
def load_op_library(library_filename):
  """Loads a TensorFlow plugin, containing custom ops and kernels.

  Pass "library_filename" to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here.
  Expects the symbols "RegisterOps", "RegisterKernels", and "GetOpList", to be
  defined in the library.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    A python module containing the Python wrappers for Ops defined in
    the plugin.

  Raises:
    RuntimeError: when unable to load the library or get the python wrappers.
  """
  status = py_tf.TF_NewStatus()

  lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
  try:
    if py_tf.TF_GetCode(status) != 0:
      raise RuntimeError(compat.as_text(py_tf.TF_Message(status)))
  finally:
    py_tf.TF_DeleteStatus(status)

  op_list_str = py_tf.TF_GetOpList(lib_handle)
  op_list = op_def_pb2.OpList()
  op_list.ParseFromString(bytes(op_list_str))
  wrappers = py_tf.GetPythonWrappers(op_list_str, len(op_list_str))

  # Get a unique name for the module.
  module_name = hashlib.md5(wrappers).hexdigest()
  module = imp.new_module(module_name)
  # pylint: disable=exec-used
  exec(wrappers, module.__dict__)
  # Stash away the library handle for making calls into the dynamic library.
  module.LIB_HANDLE = lib_handle
  # OpDefs of the list of ops defined in the library.
  module.OP_LIST = op_list
  sys.modules[module_name] = module
  return module
Exemplo n.º 4
0
def load_op_library(library_filename):
    """Loads a TensorFlow plugin, containing custom ops and kernels.

  Pass "library_filename" to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here.
  Expects the symbols "RegisterOps", "RegisterKernels", and "GetOpList", to be
  defined in the library.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    A python module containing the Python wrappers for Ops defined in
    the plugin.

  Raises:
    RuntimeError: when unable to load the library or get the python wrappers.
  """
    status = py_tf.TF_NewStatus()

    lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
    try:
        if py_tf.TF_GetCode(status) != 0:
            raise RuntimeError(compat.as_text(py_tf.TF_Message(status)))
    finally:
        py_tf.TF_DeleteStatus(status)

    op_list_str = py_tf.TF_GetOpList(lib_handle)
    op_list = op_def_pb2.OpList()
    op_list.ParseFromString(bytes(op_list_str))
    wrappers = py_tf.GetPythonWrappers(op_list_str, len(op_list_str))

    # Get a unique name for the module.
    module_name = hashlib.md5(wrappers).hexdigest()
    module = imp.new_module(module_name)
    # pylint: disable=exec-used
    exec(wrappers, module.__dict__)
    # Stash away the library handle for making calls into the dynamic library.
    module.LIB_HANDLE = lib_handle
    # OpDefs of the list of ops defined in the library.
    module.OP_LIST = op_list
    sys.modules[module_name] = module
    return module
Exemplo n.º 5
0
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
                 output_node_names, restore_op_name, filename_tensor_name,
                 output_graph, clear_devices):
    """Converts all variables in a graph and checkpoint into constants."""

    if not gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return -1

    if input_saver and not gfile.Exists(input_saver):
        print("Input saver file '" + input_saver + "' does not exist!")
        return -1

    if not gfile.Exists(input_checkpoint):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    input_graph_def = tf.GraphDef()
    with open(input_graph, "rb") as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(bytes(f.read()), input_graph_def)
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ""
    _ = tf.import_graph_def(input_graph_def, name="")

    with tf.Session() as sess:
        if input_saver:
            with open(input_saver, "rb") as f:
                saver_def = tf.train.SaverDef()
                if input_binary:
                    saver_def.ParseFromString(f.read())
                else:
                    text_format.Merge(f.read(), saver_def)
                saver = tf.train.Saver(saver_def=saver_def)
                saver.restore(sess, input_checkpoint)
        else:
            sess.run([restore_op_name],
                     {filename_tensor_name: input_checkpoint})
        found_variables = {}
        for node in input_graph_def.node:
            if node.op == "Assign":
                variable_name = node.input[0]
                found_variables[variable_name] = sess.run(variable_name + ":0")

    # This graph only includes the nodes needed to evaluate the output nodes, and
    # removes unneeded nodes like those involved in saving and assignment.
    inference_graph = graph_util.extract_sub_graph(
        input_graph_def, output_node_names.split(","))

    output_graph_def = tf.GraphDef()
    how_many_converted = 0
    for input_node in inference_graph.node:
        output_node = tf.NodeDef()
        if input_node.name in found_variables:
            output_node.op = "Const"
            output_node.name = input_node.name
            dtype = input_node.attr["dtype"]
            data = found_variables[input_node.name]
            set_attr_dtype(output_node, "dtype", dtype)
            set_attr_tensor(output_node, "value", data, dtype.type, data.shape)
            how_many_converted += 1
        else:
            output_node.CopyFrom(input_node)
        output_graph_def.node.extend([output_node])

    with gfile.FastGFile(output_graph, "w") as f:
        f.write(output_graph_def.SerializeToString())
    print("Converted %d variables to const ops." % how_many_converted)
    print("%d ops in the final graph." % len(output_graph_def.node))
Exemplo n.º 6
0
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
                 output_node_names, restore_op_name, filename_tensor_name,
                 output_graph, clear_devices):
  """Converts all variables in a graph and checkpoint into constants."""

  if not gfile.Exists(input_graph):
    print("Input graph file '" + input_graph + "' does not exist!")
    return -1

  if input_saver and not gfile.Exists(input_saver):
    print("Input saver file '" + input_saver + "' does not exist!")
    return -1

  if not gfile.Exists(input_checkpoint):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  input_graph_def = tf.GraphDef()
  with open(input_graph, "rb") as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(bytes(f.read()), input_graph_def)
  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    for node in input_graph_def.node:
      node.device = ""
  _ = tf.import_graph_def(input_graph_def, name="")

  with tf.Session() as sess:
    if input_saver:
      with open(input_saver, "rb") as f:
        saver_def = tf.train.SaverDef()
        if input_binary:
          saver_def.ParseFromString(f.read())
        else:
          text_format.Merge(f.read(), saver_def)
        saver = tf.train.Saver(saver_def=saver_def)
        saver.restore(sess, input_checkpoint)
    else:
      sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
    found_variables = {}
    for node in input_graph_def.node:
      if node.op == "Assign":
        variable_name = node.input[0]
        found_variables[variable_name] = sess.run(variable_name + ":0")

  # This graph only includes the nodes needed to evaluate the output nodes, and
  # removes unneeded nodes like those involved in saving and assignment.
  inference_graph = graph_util.extract_sub_graph(
      input_graph_def, output_node_names.split(","))

  output_graph_def = tf.GraphDef()
  how_many_converted = 0
  for input_node in inference_graph.node:
    output_node = tf.NodeDef()
    if input_node.name in found_variables:
      output_node.op = "Const"
      output_node.name = input_node.name
      dtype = input_node.attr["dtype"]
      data = found_variables[input_node.name]
      set_attr_dtype(output_node, "dtype", dtype)
      set_attr_tensor(output_node, "value", data, dtype.type, data.shape)
      how_many_converted += 1
    else:
      output_node.CopyFrom(input_node)
    output_graph_def.node.extend([output_node])

  with gfile.FastGFile(output_graph, "w") as f:
    f.write(output_graph_def.SerializeToString())
  print("Converted %d variables to const ops." % how_many_converted)
  print("%d ops in the final graph." % len(output_graph_def.node))