def testAssets(self):
        export_dir = os.path.join(compat.as_bytes(tf.test.get_temp_dir()),
                                  compat.as_bytes("with-assets"))
        builder = saved_model_builder.SavedModelBuilder(export_dir)

        with self.test_session(graph=tf.Graph()) as sess:
            v = tf.Variable(42, name="v")
            sess.run(tf.initialize_all_variables())
            self.assertEqual(42, v.eval())

            # Build an asset collection.
            asset_filepath = os.path.join(
                compat.as_bytes(tf.test.get_temp_dir()),
                compat.as_bytes("hello42.txt"))
            file_io.write_string_to_file(asset_filepath, "foo bar baz")
            asset_file_tensor = tf.constant(asset_filepath,
                                            name="asset_file_tensor")
            tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS,
                                 asset_file_tensor)

            ignored_filepath = os.path.join(
                compat.as_bytes(tf.test.get_temp_dir()),
                compat.as_bytes("ignored.txt"))
            file_io.write_string_to_file(ignored_filepath, "will be ignored")

            asset_collection = tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS)

            builder.add_meta_graph_and_variables(
                sess, ["foo"], assets_collection=asset_collection)

        # Save the SavedModel to disk.
        builder.save()

        with self.test_session(graph=tf.Graph()) as sess:
            foo_graph = loader.load(sess, ["foo"], export_dir)

            # Validate the assets.
            collection_def = foo_graph.collection_def
            assets_any = collection_def[constants.ASSETS_KEY].any_list.value
            self.assertEqual(len(assets_any), 1)
            asset = manifest_pb2.AssetFile()
            assets_any[0].Unpack(asset)
            assets_path = os.path.join(
                compat.as_bytes(export_dir),
                compat.as_bytes(constants.ASSETS_DIRECTORY),
                compat.as_bytes("hello42.txt"))
            asset_contents = file_io.read_file_to_string(assets_path)
            self.assertEqual("foo bar baz", compat.as_text(asset_contents))
            self.assertEqual("hello42.txt", asset.filename)
            self.assertEqual("asset_file_tensor:0",
                             asset.tensor_binding.tensor_name)
            ignored_asset_path = os.path.join(
                compat.as_bytes(export_dir),
                compat.as_bytes(constants.ASSETS_DIRECTORY),
                compat.as_bytes("ignored.txt"))
            self.assertFalse(file_io.file_exists(ignored_asset_path))
Exemple #2
0
    def _add_asset_to_collection(self, asset_filename, asset_tensor):
        """Builds an asset proto and adds it to the asset collection of the graph.

    Args:
      asset_filename: The filename of the asset to be added.
      asset_tensor: The asset tensor used to populate the tensor binding of the
          asset proto.
    """
        asset_proto = manifest_pb2.AssetFile()
        asset_proto.filename = asset_filename
        asset_proto.tensor_binding.tensor_name = asset_tensor.name

        asset_any_proto = Any()
        asset_any_proto.Pack(asset_proto)
        ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto)
Exemple #3
0
    def doBasicsOneExportPath(self,
                              export_path,
                              clear_devices=False,
                              global_step=GLOBAL_STEP,
                              sharded=True):
        # Build a graph with 2 parameter nodes on different devices.
        tf.reset_default_graph()
        with tf.Session(target="",
                        config=config_pb2.ConfigProto(
                            device_count={"CPU": 2})) as sess:
            # v2 is an unsaved variable derived from v0 and v1.  It is used to
            # exercise the ability to run an init op when restoring a graph.
            with sess.graph.device("/cpu:0"):
                v0 = tf.Variable(10, name="v0")
            with sess.graph.device("/cpu:1"):
                v1 = tf.Variable(20, name="v1")
            v2 = tf.Variable(1, name="v2", trainable=False, collections=[])
            assign_v2 = tf.assign(v2, tf.add(v0, v1))
            init_op = tf.group(assign_v2, name="init_op")

            tf.add_to_collection("v", v0)
            tf.add_to_collection("v", v1)
            tf.add_to_collection("v", v2)

            global_step_tensor = tf.Variable(global_step, name="global_step")
            named_tensor_bindings = {
                "logical_input_A": v0,
                "logical_input_B": v1
            }
            signatures = {
                "foo":
                exporter.regression_signature(input_tensor=v0,
                                              output_tensor=v1),
                "generic":
                exporter.generic_signature(named_tensor_bindings)
            }

            asset_filepath_orig = os.path.join(tf.test.get_temp_dir(),
                                               "hello42.txt")
            asset_file = tf.constant(asset_filepath_orig, name="filename42")
            tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, asset_file)

            with gfile.FastGFile(asset_filepath_orig, "w") as f:
                f.write("your data here")
            assets_collection = tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS)

            ignored_asset = os.path.join(tf.test.get_temp_dir(), "ignored.txt")
            with gfile.FastGFile(ignored_asset, "w") as f:
                f.write("additional data here")

            tf.initialize_all_variables().run()

            # Run an export.
            save = tf.train.Saver({
                "v0": v0,
                "v1": v1
            },
                                  restore_sequentially=True,
                                  sharded=sharded)
            export = exporter.Exporter(save)
            export.init(
                sess.graph.as_graph_def(),
                init_op=init_op,
                clear_devices=clear_devices,
                default_graph_signature=exporter.classification_signature(
                    input_tensor=v0),
                named_graph_signatures=signatures,
                assets_collection=assets_collection)
            export.export(export_path,
                          global_step_tensor,
                          sess,
                          exports_to_keep=gc.largest_export_versions(2))

        # Restore graph.
        compare_def = tf.get_default_graph().as_graph_def()
        tf.reset_default_graph()
        with tf.Session(target="",
                        config=config_pb2.ConfigProto(
                            device_count={"CPU": 2})) as sess:
            save = tf.train.import_meta_graph(
                os.path.join(export_path,
                             constants.VERSION_FORMAT_SPECIFIER % global_step,
                             constants.META_GRAPH_DEF_FILENAME))
            self.assertIsNotNone(save)
            meta_graph_def = save.export_meta_graph()
            collection_def = meta_graph_def.collection_def

            # Validate custom graph_def.
            graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
            self.assertEquals(len(graph_def_any), 1)
            graph_def = tf.GraphDef()
            graph_def_any[0].Unpack(graph_def)
            if clear_devices:
                for node in compare_def.node:
                    node.device = ""
            self.assertProtoEquals(compare_def, graph_def)

            # Validate init_op.
            init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
            self.assertEquals(len(init_ops), 1)
            self.assertEquals(init_ops[0], "init_op")

            # Validate signatures.
            signatures_any = collection_def[
                constants.SIGNATURES_KEY].any_list.value
            self.assertEquals(len(signatures_any), 1)
            signatures = manifest_pb2.Signatures()
            signatures_any[0].Unpack(signatures)
            default_signature = signatures.default_signature
            self.assertEqual(
                default_signature.classification_signature.input.tensor_name,
                "v0:0")
            bindings = signatures.named_signatures[
                "generic"].generic_signature.map
            self.assertEquals(bindings["logical_input_A"].tensor_name, "v0:0")
            self.assertEquals(bindings["logical_input_B"].tensor_name, "v1:0")
            read_foo_signature = (
                signatures.named_signatures["foo"].regression_signature)
            self.assertEquals(read_foo_signature.input.tensor_name, "v0:0")
            self.assertEquals(read_foo_signature.output.tensor_name, "v1:0")

            # Validate the assets.
            assets_any = collection_def[constants.ASSETS_KEY].any_list.value
            self.assertEquals(len(assets_any), 1)
            asset = manifest_pb2.AssetFile()
            assets_any[0].Unpack(asset)
            assets_path = os.path.join(
                export_path, constants.VERSION_FORMAT_SPECIFIER % global_step,
                constants.ASSETS_DIRECTORY, "hello42.txt")
            asset_contents = gfile.GFile(assets_path).read()
            self.assertEqual(asset_contents, "your data here")
            self.assertEquals("hello42.txt", asset.filename)
            self.assertEquals("filename42:0", asset.tensor_binding.tensor_name)
            ignored_asset_path = os.path.join(
                export_path, constants.VERSION_FORMAT_SPECIFIER % global_step,
                constants.ASSETS_DIRECTORY, "ignored.txt")
            self.assertFalse(gfile.Exists(ignored_asset_path))

            # Validate graph restoration.
            if sharded:
                save.restore(
                    sess,
                    os.path.join(
                        export_path,
                        constants.VERSION_FORMAT_SPECIFIER % global_step,
                        constants.VARIABLES_FILENAME_PATTERN))
            else:
                save.restore(
                    sess,
                    os.path.join(
                        export_path,
                        constants.VERSION_FORMAT_SPECIFIER % global_step,
                        constants.VARIABLES_FILENAME))
            self.assertEqual(10, tf.get_collection("v")[0].eval())
            self.assertEqual(20, tf.get_collection("v")[1].eval())
            tf.get_collection(constants.INIT_OP_KEY)[0].run()
            self.assertEqual(30, tf.get_collection("v")[2].eval())
def load_session_bundle_from_path(export_dir, target="", config=None):
    """Load session bundle from the given path.

  The function reads input from the export_dir, constructs the graph data to the
  default graph and restores the parameters for the session created.

  Args:
    export_dir: the directory that contains files exported by exporter.
    target: The execution engine to connect to. See target in tf.Session()
    config: A ConfigProto proto with configuration options. See config in
    tf.Session()

  Returns:
    session: a tensorflow session created from the variable files.
    meta_graph: a meta graph proto saved in the exporter directory.

  Raises:
    RuntimeError: if the required files are missing or contain unrecognizable
    fields, i.e. the exported model is invalid.
  """
    meta_graph_filename = os.path.join(export_dir,
                                       constants.META_GRAPH_DEF_FILENAME)
    if not file_io.file_exists(meta_graph_filename):
        raise RuntimeError("Expected meta graph file missing %s" %
                           meta_graph_filename)
    variables_filename = os.path.join(export_dir, constants.VARIABLES_FILENAME)
    if not file_io.file_exists(variables_filename):
        variables_filename = os.path.join(export_dir,
                                          constants.VARIABLES_FILENAME_PATTERN)
        if not file_io.get_matching_files(variables_filename):
            raise RuntimeError("Expected variables file missing %s" %
                               variables_filename)
    assets_dir = os.path.join(export_dir, constants.ASSETS_DIRECTORY)

    # Reads meta graph file.
    meta_graph_def = meta_graph_pb2.MetaGraphDef()
    meta_graph_def.ParseFromString(
        file_io.read_file_to_string(meta_graph_filename))

    collection_def = meta_graph_def.collection_def
    graph_def = tf.GraphDef()
    if constants.GRAPH_KEY in collection_def:
        # Use serving graph_def in MetaGraphDef collection_def if exists
        graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
        if len(graph_def_any) != 1:
            raise RuntimeError(
                "Expected exactly one serving GraphDef in : %s" %
                meta_graph_def)
        else:
            graph_def_any[0].Unpack(graph_def)
            # Replace the graph def in meta graph proto.
            meta_graph_def.graph_def.CopyFrom(graph_def)

    tf.reset_default_graph()
    sess = tf.Session(target, graph=None, config=config)
    # Import the graph.
    saver = tf.train.import_meta_graph(meta_graph_def)
    # Restore the session.
    saver.restore(sess, variables_filename)

    init_op_tensor = None
    if constants.INIT_OP_KEY in collection_def:
        init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
        if len(init_ops) != 1:
            raise RuntimeError("Expected exactly one serving init op in : %s" %
                               meta_graph_def)
        init_op_tensor = tf.get_collection(constants.INIT_OP_KEY)[0]

    # Create asset input tensor list.
    asset_tensor_dict = {}
    if constants.ASSETS_KEY in collection_def:
        assets_any = collection_def[constants.ASSETS_KEY].any_list.value
        for asset in assets_any:
            asset_pb = manifest_pb2.AssetFile()
            asset.Unpack(asset_pb)
            asset_tensor_dict[
                asset_pb.tensor_binding.tensor_name] = os.path.join(
                    assets_dir, asset_pb.filename)

    if init_op_tensor:
        # Run the init op.
        sess.run(fetches=[init_op_tensor], feed_dict=asset_tensor_dict)

    return sess, meta_graph_def
Exemple #5
0
  def init(self,
           graph_def=None,
           init_op=None,
           clear_devices=False,
           default_graph_signature=None,
           named_graph_signatures=None,
           assets_collection=None,
           assets_callback=gfile_copy_callback):
    """Initialization.

    Args:
      graph_def: A GraphDef message of the graph to be used in inference.
        GraphDef of default graph is used when None.
      init_op: Op to be used in initialization.
      clear_devices: If device info of the graph should be cleared upon export.
      default_graph_signature: Default signature of the graph.
      named_graph_signatures: Map of named input/output signatures of the graph.
      assets_collection: A collection of constant asset filepath tensors. If set
        the assets will be exported into the asset directory.
      assets_callback: callback with two argument called during export with the
        list of files to copy and the asset path.
    Raises:
      RuntimeError: if init is called more than once.
      TypeError: if init_op is not an Operation or None.
      ValueError: if asset file path tensors are not non-empty constant string
        scalar tensors.
    """
    # Avoid Dangerous default value []
    if named_graph_signatures is None:
      named_graph_signatures = {}
    assets = []
    if assets_collection:
      for asset_tensor in assets_collection:
        asset_filepath = self._file_path_value(asset_tensor)
        if not asset_filepath:
          raise ValueError("invalid asset filepath tensor %s" % asset_tensor)
        basename = os.path.basename(asset_filepath)
        assets.append((basename, asset_tensor))
        self._assets_to_copy[asset_filepath] = basename

    if self._has_init:
      raise RuntimeError("init should be called only once")
    self._has_init = True

    if graph_def or clear_devices:
      copy = graph_pb2.GraphDef()
      if graph_def:
        copy.CopyFrom(graph_def)
      else:
        copy.CopyFrom(ops.get_default_graph().as_graph_def())
      if clear_devices:
        for node in copy.node:
          node.device = ""
      graph_any_buf = Any()
      graph_any_buf.Pack(copy)
      ops.add_to_collection(constants.GRAPH_KEY, graph_any_buf)

    if init_op:
      if not isinstance(init_op, ops.Operation):
        raise TypeError("init_op needs to be an Operation: %s" % init_op)
      ops.add_to_collection(constants.INIT_OP_KEY, init_op)

    signatures_proto = manifest_pb2.Signatures()
    if default_graph_signature:
      signatures_proto.default_signature.CopyFrom(default_graph_signature)
    for signature_name, signature in six.iteritems(named_graph_signatures):
      signatures_proto.named_signatures[signature_name].CopyFrom(signature)
    signatures_any_buf = Any()
    signatures_any_buf.Pack(signatures_proto)
    ops.add_to_collection(constants.SIGNATURES_KEY, signatures_any_buf)

    for filename, tensor in assets:
      asset = manifest_pb2.AssetFile()
      asset.filename = filename
      asset.tensor_binding.tensor_name = tensor.name
      asset_any_buf = Any()
      asset_any_buf.Pack(asset)
      ops.add_to_collection(constants.ASSETS_KEY, asset_any_buf)

    self._assets_callback = assets_callback
Exemple #6
0
def load_session_bundle_from_path(export_dir,
                                  target="",
                                  config=None,
                                  meta_graph_def=None):
    """Load session bundle from the given path.

  The function reads input from the export_dir, constructs the graph data to the
  default graph and restores the parameters for the session created.

  Args:
    export_dir: the directory that contains files exported by exporter.
    target: The execution engine to connect to. See target in tf.Session()
    config: A ConfigProto proto with configuration options. See config in
    tf.Session()
    meta_graph_def: optional object of type MetaGraphDef. If this object is
    present, then it is used instead of parsing MetaGraphDef from export_dir.

  Returns:
    session: a tensorflow session created from the variable files.
    meta_graph: a meta graph proto saved in the exporter directory.

  Raises:
    RuntimeError: if the required files are missing or contain unrecognizable
    fields, i.e. the exported model is invalid.
  """
    if not meta_graph_def:
        meta_graph_filename = os.path.join(export_dir,
                                           constants.META_GRAPH_DEF_FILENAME)
        if not file_io.file_exists(meta_graph_filename):
            raise RuntimeError("Expected meta graph file missing %s" %
                               meta_graph_filename)
        # Reads meta graph file.
        meta_graph_def = meta_graph_pb2.MetaGraphDef()
        meta_graph_def.ParseFromString(
            file_io.read_file_to_string(meta_graph_filename, binary_mode=True))

    variables_filename = ""
    variables_filename_list = []
    checkpoint_sharded = False

    variables_index_filename = os.path.join(
        export_dir, constants.VARIABLES_INDEX_FILENAME_V2)
    checkpoint_v2 = file_io.file_exists(variables_index_filename)

    # Find matching checkpoint files.
    if checkpoint_v2:
        # The checkpoint is in v2 format.
        variables_filename_pattern = os.path.join(
            export_dir, constants.VARIABLES_FILENAME_PATTERN_V2)
        variables_filename_list = file_io.get_matching_files(
            variables_filename_pattern)
        checkpoint_sharded = True
    else:
        variables_filename = os.path.join(export_dir,
                                          constants.VARIABLES_FILENAME)
        if file_io.file_exists(variables_filename):
            variables_filename_list = [variables_filename]
        else:
            variables_filename = os.path.join(
                export_dir, constants.VARIABLES_FILENAME_PATTERN)
            variables_filename_list = file_io.get_matching_files(
                variables_filename)
            checkpoint_sharded = True

    # Prepare the files to restore a session.
    if not variables_filename_list:
        restore_files = ""
    elif checkpoint_v2 or not checkpoint_sharded:
        # For checkpoint v2 or v1 with non-sharded files, use "export" to restore
        # the session.
        restore_files = constants.VARIABLES_FILENAME
    else:
        restore_files = constants.VARIABLES_FILENAME_PATTERN

    assets_dir = os.path.join(export_dir, constants.ASSETS_DIRECTORY)

    collection_def = meta_graph_def.collection_def
    graph_def = graph_pb2.GraphDef()
    if constants.GRAPH_KEY in collection_def:
        # Use serving graph_def in MetaGraphDef collection_def if exists
        graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
        if len(graph_def_any) != 1:
            raise RuntimeError(
                "Expected exactly one serving GraphDef in : %s" %
                meta_graph_def)
        else:
            graph_def_any[0].Unpack(graph_def)
            # Replace the graph def in meta graph proto.
            meta_graph_def.graph_def.CopyFrom(graph_def)

    ops.reset_default_graph()
    sess = session.Session(target, graph=None, config=config)
    # Import the graph.
    saver = saver_lib.import_meta_graph(meta_graph_def)
    # Restore the session.
    if restore_files:
        saver.restore(sess, os.path.join(export_dir, restore_files))

    init_op_tensor = None
    if constants.INIT_OP_KEY in collection_def:
        init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
        if len(init_ops) != 1:
            raise RuntimeError("Expected exactly one serving init op in : %s" %
                               meta_graph_def)
        init_op_tensor = ops.get_collection(constants.INIT_OP_KEY)[0]

    # Create asset input tensor list.
    asset_tensor_dict = {}
    if constants.ASSETS_KEY in collection_def:
        assets_any = collection_def[constants.ASSETS_KEY].any_list.value
        for asset in assets_any:
            asset_pb = manifest_pb2.AssetFile()
            asset.Unpack(asset_pb)
            asset_tensor_dict[
                asset_pb.tensor_binding.tensor_name] = os.path.join(
                    assets_dir, asset_pb.filename)

    if init_op_tensor:
        # Run the init op.
        sess.run(fetches=[init_op_tensor], feed_dict=asset_tensor_dict)

    return sess, meta_graph_def