示例#1
0
  def init(self,
           graph_def=None,
           init_op=None,
           clear_devices=False,
           default_graph_signature=None,
           named_graph_signatures=None,
           assets=None,
           assets_callback=None):
    """Initialization.

    Args:
      graph_def: A GraphDef message of the graph to be used in inference.
        GraphDef of default graph is used when None.
      init_op: Op to be used in initialization.
      clear_devices: If device info of the graph should be cleared upon export.
      default_graph_signature: Default signature of the graph.
      named_graph_signatures: Map of named input/output signatures of the graph.
      assets: A list of tuples of asset files with the first element being the
        filename (string) and the second being the Tensor.
      assets_callback: callback with a single string argument; called during
        export with the asset path.

    Raises:
      RuntimeError: if init is called more than once.
      TypeError: if init_op is not an Operation or None.
    """
    # Avoid Dangerous default value []
    if named_graph_signatures is None:
      named_graph_signatures = {}
    if assets is None:
      assets = {}

    if self._has_init:
      raise RuntimeError("init should be called only once")
    self._has_init = True

    if graph_def or clear_devices:
      copy = tf.GraphDef()
      if graph_def:
        copy.CopyFrom(graph_def)
      else:
        copy.CopyFrom(tf.get_default_graph().as_graph_def())
      if clear_devices:
        for node in copy.node:
          node.device = ""
      graph_any_buf = any_pb2.Any()
      graph_any_buf.Pack(copy)
      tf.add_to_collection(GRAPH_KEY, graph_any_buf)

    if init_op:
      if not isinstance(init_op, ops.Operation):
        raise TypeError("init_op needs to be an Operation: %s" % init_op)
      tf.add_to_collection(INIT_OP_KEY, init_op)

    signatures_proto = manifest_pb2.Signatures()
    if default_graph_signature:
      signatures_proto.default_signature.CopyFrom(default_graph_signature)
    for signature_name, signature in six.iteritems(named_graph_signatures):
      signatures_proto.named_signatures[signature_name].CopyFrom(signature)
    signatures_any_buf = any_pb2.Any()
    signatures_any_buf.Pack(signatures_proto)
    tf.add_to_collection(SIGNATURES_KEY, signatures_any_buf)

    for filename, tensor in assets:
      asset = manifest_pb2.AssetFile()
      asset.filename = filename
      asset.tensor_binding.tensor_name = tensor.name
      asset_any_buf = any_pb2.Any()
      asset_any_buf.Pack(asset)
      tf.add_to_collection(ASSETS_KEY, asset_any_buf)

    self._assets_callback = assets_callback
示例#2
0
def LoadSessionBundleFromPath(export_dir, target="", config=None):
  """Load session bundle from the given path.

  The function reads input from the export_dir, constructs the graph data to the
  default graph and restores the parameters for the session created.

  Args:
    export_dir: the directory that contains files exported by exporter.
    target: The execution engine to connect to. See target in tf.Session()
    config: A ConfigProto proto with configuration options. See config in
    tf.Session()

  Returns:
    session: a tensorflow session created from the variable files.
    meta_graph: a meta graph proto saved in the exporter directory.

  Raises:
    RuntimeError: if the required files are missing or contain unrecognizable
    fields, i.e. the exported model is invalid.
  """
  meta_graph_filename = os.path.join(export_dir,
                                     constants.META_GRAPH_DEF_FILENAME)
  if not gfile.Exists(meta_graph_filename):
    raise RuntimeError("Expected meta graph file missing %s" %
                       meta_graph_filename)
  variables_filename = os.path.join(export_dir,
                                    constants.VARIABLES_FILENAME)
  if not gfile.Exists(variables_filename):
    variables_filename = os.path.join(
        export_dir, constants.VARIABLES_FILENAME_PATTERN)
    if not gfile.Glob(variables_filename):
      raise RuntimeError("Expected variables file missing %s" %
                         variables_filename)
  assets_dir = os.path.join(export_dir, constants.ASSETS_DIRECTORY)

  # Reads meta graph file.
  meta_graph_def = meta_graph_pb2.MetaGraphDef()
  with gfile.GFile(meta_graph_filename) as f:
    meta_graph_def.ParseFromString(f.read())

  collection_def = meta_graph_def.collection_def
  graph_def = tf.GraphDef()
  if constants.GRAPH_KEY in collection_def:
    # Use serving graph_def in MetaGraphDef collection_def if exists
    graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
    if len(graph_def_any) != 1:
      raise RuntimeError(
          "Expected exactly one serving GraphDef in : %s" % meta_graph_def)
    else:
      graph_def_any[0].Unpack(graph_def)
      # Replace the graph def in meta graph proto.
      meta_graph_def.graph_def.CopyFrom(graph_def)

  tf.reset_default_graph()
  sess = tf.Session(target, graph=None, config=config)
  # Import the graph.
  saver = tf.train.import_meta_graph(meta_graph_def)
  # Restore the session.
  saver.restore(sess, variables_filename)

  init_op_tensor = None
  if constants.INIT_OP_KEY in collection_def:
    init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
    if len(init_ops) != 1:
      raise RuntimeError(
          "Expected exactly one serving init op in : %s" % meta_graph_def)
    init_op_tensor = tf.get_collection(constants.INIT_OP_KEY)[0]

  # Create asset input tensor list.
  asset_tensor_dict = {}
  if constants.ASSETS_KEY in collection_def:
    assets_any = collection_def[constants.ASSETS_KEY].any_list.value
    for asset in assets_any:
      asset_pb = manifest_pb2.AssetFile()
      asset.Unpack(asset_pb)
      asset_tensor_dict[asset_pb.tensor_binding.tensor_name] = os.path.join(
          assets_dir, asset_pb.filename)

  if init_op_tensor:
    # Run the init op.
    sess.run(fetches=[init_op_tensor], feed_dict=asset_tensor_dict)

  return sess, meta_graph_def
示例#3
0
    def doBasicsOneExportPath(self,
                              export_path,
                              clear_devices=False,
                              global_step=GLOBAL_STEP):
        # Build a graph with 2 parameter nodes on different devices.
        tf.reset_default_graph()
        with tf.Session(target="",
                        config=config_pb2.ConfigProto(
                            device_count={"CPU": 2})) as sess:
            # v2 is an unsaved variable derived from v0 and v1.  It is used to
            # exercise the ability to run an init op when restoring a graph.
            with sess.graph.device("/cpu:0"):
                v0 = tf.Variable(10, name="v0")
            with sess.graph.device("/cpu:1"):
                v1 = tf.Variable(20, name="v1")
            v2 = tf.Variable(1, name="v2", trainable=False, collections=[])
            assign_v2 = tf.assign(v2, tf.add(v0, v1))
            init_op = tf.group(assign_v2, name="init_op")

            tf.add_to_collection("v", v0)
            tf.add_to_collection("v", v1)
            tf.add_to_collection("v", v2)

            global_step_tensor = tf.Variable(global_step, name="global_step")
            named_tensor_bindings = {
                "logical_input_A": v0,
                "logical_input_B": v1
            }
            signatures = {
                "foo":
                exporter.regression_signature(input_tensor=v0,
                                              output_tensor=v1),
                "generic":
                exporter.generic_signature(named_tensor_bindings)
            }

            def write_asset(path):
                file_path = os.path.join(path, "file.txt")
                with gfile.FastGFile(file_path, "w") as f:
                    f.write("your data here")

            asset_file = tf.Variable("hello42.txt", name="filename42")
            assets = {("hello42.txt", asset_file)}

            tf.initialize_all_variables().run()

            # Run an export.
            save = tf.train.Saver({
                "v0": v0,
                "v1": v1
            },
                                  restore_sequentially=True,
                                  sharded=True)
            export = exporter.Exporter(save)
            export.init(
                sess.graph.as_graph_def(),
                init_op=init_op,
                clear_devices=clear_devices,
                default_graph_signature=exporter.classification_signature(
                    input_tensor=v0),
                named_graph_signatures=signatures,
                assets=assets,
                assets_callback=write_asset)
            export.export(export_path,
                          global_step_tensor,
                          sess,
                          exports_to_keep=gc.largest_export_versions(2))

        # Restore graph.
        compare_def = tf.get_default_graph().as_graph_def()
        tf.reset_default_graph()
        with tf.Session(target="",
                        config=config_pb2.ConfigProto(
                            device_count={"CPU": 2})) as sess:
            save = tf.train.import_meta_graph(
                os.path.join(export_path,
                             exporter.VERSION_FORMAT_SPECIFIER % global_step,
                             exporter.META_GRAPH_DEF_FILENAME))
            self.assertIsNotNone(save)
            meta_graph_def = save.export_meta_graph()
            collection_def = meta_graph_def.collection_def

            # Validate custom graph_def.
            graph_def_any = collection_def[exporter.GRAPH_KEY].any_list.value
            self.assertEquals(len(graph_def_any), 1)
            graph_def = tf.GraphDef()
            graph_def_any[0].Unpack(graph_def)
            if clear_devices:
                for node in compare_def.node:
                    node.device = ""
            self.assertProtoEquals(compare_def, graph_def)

            # Validate init_op.
            init_ops = collection_def[exporter.INIT_OP_KEY].node_list.value
            self.assertEquals(len(init_ops), 1)
            self.assertEquals(init_ops[0], "init_op")

            # Validate signatures.
            signatures_any = collection_def[
                exporter.SIGNATURES_KEY].any_list.value
            self.assertEquals(len(signatures_any), 1)
            signatures = manifest_pb2.Signatures()
            signatures_any[0].Unpack(signatures)
            default_signature = signatures.default_signature
            self.assertEqual(
                default_signature.classification_signature.input.tensor_name,
                "v0:0")
            bindings = signatures.named_signatures[
                "generic"].generic_signature.map
            self.assertEquals(bindings["logical_input_A"].tensor_name, "v0:0")
            self.assertEquals(bindings["logical_input_B"].tensor_name, "v1:0")
            read_foo_signature = (
                signatures.named_signatures["foo"].regression_signature)
            self.assertEquals(read_foo_signature.input.tensor_name, "v0:0")
            self.assertEquals(read_foo_signature.output.tensor_name, "v1:0")

            # Validate the assets.
            assets_any = collection_def[exporter.ASSETS_KEY].any_list.value
            self.assertEquals(len(assets_any), 1)
            asset = manifest_pb2.AssetFile()
            assets_any[0].Unpack(asset)
            assets_path = os.path.join(
                export_path, exporter.VERSION_FORMAT_SPECIFIER % global_step,
                exporter.ASSETS_DIRECTORY, "file.txt")
            asset_contents = gfile.GFile(assets_path).read()
            self.assertEqual(asset_contents, "your data here")
            self.assertEquals("hello42.txt", asset.filename)
            self.assertEquals("filename42:0", asset.tensor_binding.tensor_name)

            # Validate graph restoration.
            save.restore(
                sess,
                os.path.join(export_path,
                             exporter.VERSION_FORMAT_SPECIFIER % global_step,
                             exporter.VARIABLES_DIRECTORY))
            self.assertEqual(10, tf.get_collection("v")[0].eval())
            self.assertEqual(20, tf.get_collection("v")[1].eval())
            tf.get_collection(exporter.INIT_OP_KEY)[0].run()
            self.assertEqual(30, tf.get_collection("v")[2].eval())
示例#4
0
  def init(self,
           graph_def=None,
           init_op=None,
           clear_devices=False,
           default_graph_signature=None,
           named_graph_signatures=None,
           assets_collection=None,
           assets_callback=gfile_copy_callback):
    """Initialization.

    Args:
      graph_def: A GraphDef message of the graph to be used in inference.
        GraphDef of default graph is used when None.
      init_op: Op to be used in initialization.
      clear_devices: If device info of the graph should be cleared upon export.
      default_graph_signature: Default signature of the graph.
      named_graph_signatures: Map of named input/output signatures of the graph.
      assets_collection: A collection of constant asset filepath tensors. If set
        the assets will be exported into the asset directory.
      assets_callback: callback with two argument called during export with the
        list of files to copy and the asset path.
    Raises:
      RuntimeError: if init is called more than once.
      TypeError: if init_op is not an Operation or None.
      ValueError: if asset file path tensors are not non-empty constant string
        scalar tensors.
    """
    # Avoid Dangerous default value []
    if named_graph_signatures is None:
      named_graph_signatures = {}
    assets = []
    if assets_collection:
      for asset_tensor in assets_collection:
        asset_filepath = self._file_path_value(asset_tensor)
        if not asset_filepath:
          raise ValueError("invalid asset filepath tensor %s" % asset_tensor)
        basename = os.path.basename(asset_filepath)
        assets.append((basename, asset_tensor))
        self._assets_to_copy[asset_filepath] = basename

    if self._has_init:
      raise RuntimeError("init should be called only once")
    self._has_init = True

    if graph_def or clear_devices:
      copy = tf.GraphDef()
      if graph_def:
        copy.CopyFrom(graph_def)
      else:
        copy.CopyFrom(tf.get_default_graph().as_graph_def())
      if clear_devices:
        for node in copy.node:
          node.device = ""
      graph_any_buf = any_pb2.Any()
      graph_any_buf.Pack(copy)
      tf.add_to_collection(constants.GRAPH_KEY, graph_any_buf)

    if init_op:
      if not isinstance(init_op, ops.Operation):
        raise TypeError("init_op needs to be an Operation: %s" % init_op)
      tf.add_to_collection(constants.INIT_OP_KEY, init_op)

    signatures_proto = manifest_pb2.Signatures()
    if default_graph_signature:
      signatures_proto.default_signature.CopyFrom(default_graph_signature)
    for signature_name, signature in six.iteritems(named_graph_signatures):
      signatures_proto.named_signatures[signature_name].CopyFrom(signature)
    signatures_any_buf = any_pb2.Any()
    signatures_any_buf.Pack(signatures_proto)
    tf.add_to_collection(constants.SIGNATURES_KEY, signatures_any_buf)

    for filename, tensor in assets:
      asset = manifest_pb2.AssetFile()
      asset.filename = filename
      asset.tensor_binding.tensor_name = tensor.name
      asset_any_buf = any_pb2.Any()
      asset_any_buf.Pack(asset)
      tf.add_to_collection(constants.ASSETS_KEY, asset_any_buf)

    self._assets_callback = assets_callback