예제 #1
0
파일: save.py 프로젝트: mmilovec/tensorflow
def _process_asset(trackable_asset, asset_info, resource_map):
    """Add `trackable_asset` to `asset_info` and `resource_map`."""
    original_variable = trackable_asset.asset_path
    with context.eager_mode():
        original_path = original_variable.numpy()
    path = builder_impl.get_asset_filename_to_add(
        asset_filepath=original_path,
        asset_filename_map=asset_info.asset_filename_map)
    # TODO(andresp): Instead of mapping 1-1 between trackable asset
    # and asset in the graph def consider deduping the assets that
    # point to the same file.
    asset_path_initializer = array_ops.placeholder(
        shape=original_variable.shape,
        dtype=dtypes.string,
        name="asset_path_initializer")
    asset_variable = resource_variable_ops.ResourceVariable(
        asset_path_initializer)
    asset_info.asset_filename_map[path] = original_path
    asset_def = meta_graph_pb2.AssetFileDef()
    asset_def.filename = path
    asset_def.tensor_info.name = asset_path_initializer.name
    asset_info.asset_defs.append(asset_def)
    asset_info.asset_initializers_by_resource[original_variable.handle] = (
        asset_variable.initializer)
    asset_info.asset_index[trackable_asset] = len(asset_info.asset_defs) - 1
    resource_map[original_variable.handle] = asset_variable.handle
예제 #2
0
def _process_asset(trackable_asset, asset_info, resource_map):
  """Add `trackable_asset` to `asset_info` and `resource_map`."""
  original_path_tensor = trackable_asset.asset_path
  original_path = tensor_util.constant_value(original_path_tensor)
  try:
    original_path = str(original_path.astype(str))
  except AttributeError:
    # Already a string rather than a numpy array
    pass
  path = builder_impl.get_asset_filename_to_add(
      asset_filepath=original_path,
      asset_filename_map=asset_info.asset_filename_map)
  # TODO(andresp): Instead of mapping 1-1 between trackable asset
  # and asset in the graph def consider deduping the assets that
  # point to the same file.
  asset_path_initializer = array_ops.placeholder(
      shape=original_path_tensor.shape,
      dtype=dtypes.string,
      name="asset_path_initializer")
  asset_variable = resource_variable_ops.ResourceVariable(
      asset_path_initializer)
  asset_info.asset_filename_map[path] = original_path
  asset_def = meta_graph_pb2.AssetFileDef()
  asset_def.filename = path
  asset_def.tensor_info.name = asset_path_initializer.name
  asset_info.asset_defs.append(asset_def)
  asset_info.asset_initializers_by_resource[original_path_tensor] = (
      asset_variable.initializer)
  asset_info.asset_index[trackable_asset] = len(asset_info.asset_defs) - 1
  resource_map[original_path_tensor] = asset_variable
예제 #3
0
    def testCreationOfAssetsKeyCollectionIsDeterministic(self):
        tmp_asset_dir = os.path.join(self.get_temp_dir(), "assets")
        tf_v1.gfile.MakeDirs(tmp_asset_dir)
        filenames = [
            os.path.join(tmp_asset_dir, "file%d.txt" % n) for n in range(10)
        ]
        for filename in filenames:
            _write_string_to_file(filename, "I am file %s" % filename)

        with tf.Graph().as_default() as graph:
            assets = [
                tf.constant(f, name=os.path.basename(f)) for f in filenames
            ]
            for asset in assets:
                graph.add_to_collection(tf_v1.GraphKeys.ASSET_FILEPATHS, asset)
            saved_model_lib.add_signature("default", {},
                                          {"default": assets[0]})

        handler = saved_model_lib.SavedModelHandler()
        handler.add_graph_copy(graph)
        saved_model_proto = copy.deepcopy(handler._proto)
        export_dir = os.path.join(self.get_temp_dir(), "assets_key_test")
        saved_model_lib._make_assets_key_collection(saved_model_proto,
                                                    export_dir)

        meta_graph = list(saved_model_proto.meta_graphs)[0]
        asset_tensor_names = []
        for asset_any_proto in meta_graph.collection_def[
                tf_v1.saved_model.constants.ASSETS_KEY].any_list.value:
            asset_proto = meta_graph_pb2.AssetFileDef()
            asset_any_proto.Unpack(asset_proto)
            asset_tensor_names.append(asset_proto.tensor_info.name)
        self.assertEqual(asset_tensor_names, sorted(asset_tensor_names))
예제 #4
0
def _merge_assets_key_collection(saved_model_proto, path):
    """Merges the ASSETS_KEY collection into the GraphDefs in saved_model_proto.

  Removes the ASSETS_KEY collection from the GraphDefs in the SavedModel and
  modifies nodes with the assets filenames to point to the assets in `path`.
  After this transformation, the SavedModel GraphDefs can be used without
  feeding asset tensors.

  Args:
    saved_model_proto: SavedModel proto to be modified.
    path: path where the SavedModel is being loaded from.
  """
    for meta_graph in saved_model_proto.meta_graphs:
        node_asset_map = {}
        if tf_v1.saved_model.constants.ASSETS_KEY in meta_graph.collection_def:
            assets_any_proto = meta_graph.collection_def[
                tf_v1.saved_model.constants.ASSETS_KEY].any_list.value
            for asset_any_proto in assets_any_proto:
                asset_proto = meta_graph_pb2.AssetFileDef()
                asset_any_proto.Unpack(asset_proto)
                asset_filename = _get_asset_filename(path,
                                                     asset_proto.filename)
                node_asset_map[_get_node_name_from_tensor(
                    asset_proto.tensor_info.name)] = asset_filename
            del meta_graph.collection_def[
                tf_v1.saved_model.constants.ASSETS_KEY]

        for node in meta_graph.graph_def.node:
            asset_filepath = node_asset_map.get(node.name)
            if asset_filepath:
                _check_asset_node_def(node)
                node.attr["value"].tensor.string_val[0] = asset_filepath
예제 #5
0
def _get_asset_tensors(export_dir, meta_graph_def_to_load):
    """Gets the asset tensors, if defined in the meta graph def to load.

  Args:
    export_dir: Directory where the SavedModel is located.
    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.

  Returns:
    A dictionary of asset tensors, keyed by the name of the asset tensor. The
    value in the map corresponds to the absolute path of the asset file.
  """
    # Collection-def that may contain the assets key.
    collection_def = meta_graph_def_to_load.collection_def

    asset_tensor_dict = {}
    if constants.ASSETS_KEY in collection_def:
        # Location of the assets for SavedModel.
        assets_directory = os.path.join(
            compat.as_bytes(export_dir),
            compat.as_bytes(constants.ASSETS_DIRECTORY))
        assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value
        # Process each asset and add it to the asset tensor dictionary.
        for asset_any_proto in assets_any_proto:
            asset_proto = meta_graph_pb2.AssetFileDef()
            asset_any_proto.Unpack(asset_proto)
            asset_tensor_dict[asset_proto.tensor_info.name] = os.path.join(
                compat.as_bytes(assets_directory),
                compat.as_bytes(asset_proto.filename))
    return asset_tensor_dict
예제 #6
0
def get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
    """Gets the asset tensors, if defined in the meta graph def to load.

  Args:
    export_dir: Directory where the SavedModel is located.
    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
    import_scope: Optional `string` -- if specified, prepend this followed by
        '/' to all returned asset tensor names.

  Returns:
    A dictionary of asset tensors, keyed by the name of the asset tensor. The
    value in the map corresponds to the absolute path of the asset file.
  """
    # Collection-def that may contain the assets key.
    collection_def = meta_graph_def_to_load.collection_def

    asset_tensor_dict = {}
    if constants.ASSETS_KEY in collection_def:
        # Location of the assets for SavedModel.
        assets_directory = os.path.join(
            compat.as_bytes(export_dir),
            compat.as_bytes(constants.ASSETS_DIRECTORY))
        assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value
        # Process each asset and add it to the asset tensor dictionary.
        for asset_any_proto in assets_any_proto:
            asset_proto = meta_graph_pb2.AssetFileDef()
            asset_any_proto.Unpack(asset_proto)
            tensor_name = asset_proto.tensor_info.name
            if import_scope:
                tensor_name = "%s/%s" % (import_scope, tensor_name)
            asset_tensor_dict[tensor_name] = os.path.join(
                compat.as_bytes(assets_directory),
                compat.as_bytes(asset_proto.filename))
    return asset_tensor_dict
예제 #7
0
    def testAssets(self):
        export_dir = os.path.join(compat.as_bytes(tf.test.get_temp_dir()),
                                  compat.as_bytes("with-assets"))
        builder = saved_model_builder.SavedModelBuilder(export_dir)

        with self.test_session(graph=tf.Graph()) as sess:
            v = tf.Variable(42, name="v")
            sess.run(tf.initialize_all_variables())
            self.assertEqual(42, v.eval())

            # Build an asset collection.
            asset_filepath = os.path.join(
                compat.as_bytes(tf.test.get_temp_dir()),
                compat.as_bytes("hello42.txt"))
            file_io.write_string_to_file(asset_filepath, "foo bar baz")
            asset_file_tensor = tf.constant(asset_filepath,
                                            name="asset_file_tensor")
            tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS,
                                 asset_file_tensor)

            ignored_filepath = os.path.join(
                compat.as_bytes(tf.test.get_temp_dir()),
                compat.as_bytes("ignored.txt"))
            file_io.write_string_to_file(ignored_filepath, "will be ignored")

            asset_collection = tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS)

            builder.add_meta_graph_and_variables(
                sess, ["foo"], assets_collection=asset_collection)

        # Save the SavedModel to disk.
        builder.save()

        with self.test_session(graph=tf.Graph()) as sess:
            foo_graph = loader.load(sess, ["foo"], export_dir)

            # Validate the assets.
            collection_def = foo_graph.collection_def
            assets_any = collection_def[constants.ASSETS_KEY].any_list.value
            self.assertEqual(len(assets_any), 1)
            asset = meta_graph_pb2.AssetFileDef()
            assets_any[0].Unpack(asset)
            assets_path = os.path.join(
                compat.as_bytes(export_dir),
                compat.as_bytes(constants.ASSETS_DIRECTORY),
                compat.as_bytes("hello42.txt"))
            asset_contents = file_io.read_file_to_string(assets_path)
            self.assertEqual("foo bar baz", compat.as_text(asset_contents))
            self.assertEqual("hello42.txt", asset.filename)
            self.assertEqual("asset_file_tensor:0", asset.tensor_info.name)
            ignored_asset_path = os.path.join(
                compat.as_bytes(export_dir),
                compat.as_bytes(constants.ASSETS_DIRECTORY),
                compat.as_bytes("ignored.txt"))
            self.assertFalse(file_io.file_exists(ignored_asset_path))
예제 #8
0
def _add_asset_to_collection(asset_filename, asset_tensor):
  """Builds an asset proto and adds it to the asset collection of the graph.

  Args:
    asset_filename: The filename of the asset to be added.
    asset_tensor: The asset tensor used to populate the tensor info of the
        asset proto.
  """
  asset_proto = meta_graph_pb2.AssetFileDef()
  asset_proto.filename = asset_filename
  asset_proto.tensor_info.name = asset_tensor.name

  asset_any_proto = Any()
  asset_any_proto.Pack(asset_proto)
  ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto)
예제 #9
0
 def _validate_asset_collection(self, export_dir, graph_collection_def,
                                expected_asset_file_name,
                                expected_asset_file_contents,
                                expected_asset_tensor_name):
     assets_any = graph_collection_def[constants.ASSETS_KEY].any_list.value
     asset = meta_graph_pb2.AssetFileDef()
     assets_any[0].Unpack(asset)
     assets_path = os.path.join(compat.as_bytes(export_dir),
                                compat.as_bytes(constants.ASSETS_DIRECTORY),
                                compat.as_bytes(expected_asset_file_name))
     actual_asset_contents = file_io.read_file_to_string(assets_path)
     self.assertEqual(expected_asset_file_contents,
                      compat.as_text(actual_asset_contents))
     self.assertEqual(expected_asset_file_name, asset.filename)
     self.assertEqual(expected_asset_tensor_name, asset.tensor_info.name)
예제 #10
0
파일: save.py 프로젝트: vic-yes/tensorflow
def _process_asset(trackable_asset, asset_info, resource_map):
    """Add `trackable_asset` to `asset_info` and `resource_map`."""
    original_variable = trackable_asset.asset_path
    with context.eager_mode():
        original_path = original_variable.numpy()
    path = builder_impl.get_asset_filename_to_add(
        asset_filepath=original_path,
        asset_filename_map=asset_info.asset_filename_map)
    asset_variable = asset_info.asset_filename_map.get(path, None)
    if asset_variable is None:
        asset_path_initializer = array_ops.placeholder(
            shape=original_variable.shape,
            dtype=dtypes.string,
            name="asset_path_initializer")
        asset_variable = resource_variable_ops.ResourceVariable(
            asset_path_initializer)
        asset_info.asset_filename_map[path] = original_path
        asset_def = meta_graph_pb2.AssetFileDef()
        asset_def.filename = path
        asset_def.tensor_info.name = asset_path_initializer.name
        asset_info.asset_defs.append(asset_def)
        asset_info.asset_initializers_by_resource[original_variable.handle] = (
            asset_variable.initializer)
    resource_map[original_variable.handle] = asset_variable.handle
예제 #11
0
def _make_assets_key_collection(saved_model_proto, export_path):
    """Creates an ASSETS_KEY collection in the GraphDefs in saved_model_proto.

  Adds an ASSETS_KEY collection to the GraphDefs in the SavedModel and returns
  a map from original asset filename to filename when exporting the SavedModel
  to `export_path`.

  This is roughly the inverse operation of `_merge_assets_key_collection`.

  Args:
    saved_model_proto: SavedModel proto to be modified.
    export_path: string with path where the saved_model_proto will be exported.

  Returns:
    A map from original asset filename to asset filename when exporting the
    SavedModel to path.

  Raises:
    ValueError: on unsuported/unexpected SavedModel.
  """
    asset_filenames = {}
    used_asset_filenames = set()

    def _make_asset_filename(original_filename):
        """Returns the asset filename to use for the file."""
        if original_filename in asset_filenames:
            return asset_filenames[original_filename]

        basename = os.path.basename(original_filename)
        suggestion = basename
        index = 0
        while suggestion in used_asset_filenames:
            suggestion = tf.compat.as_bytes(basename) + tf.compat.as_bytes(
                str(index))
            index += 1
        asset_filenames[original_filename] = suggestion
        used_asset_filenames.add(suggestion)
        return suggestion

    for meta_graph in saved_model_proto.meta_graphs:
        collection_def = meta_graph.collection_def.get(
            tf_v1.GraphKeys.ASSET_FILEPATHS)

        if collection_def is None:
            continue
        if collection_def.WhichOneof("kind") != "node_list":
            raise ValueError(
                "MetaGraph collection ASSET_FILEPATHS is not a list of tensors."
            )

        for tensor in collection_def.node_list.value:
            if not tensor.endswith(":0"):
                raise ValueError(
                    "Unexpected tensor in ASSET_FILEPATHS collection.")

        asset_nodes = set([
            _get_node_name_from_tensor(tensor)
            for tensor in collection_def.node_list.value
        ])

        tensor_filename_map = {}
        for node in meta_graph.graph_def.node:
            if node.name in asset_nodes:
                _check_asset_node_def(node)
                filename = node.attr["value"].tensor.string_val[0]
                tensor_filename_map[node.name + ":0"] = filename
                logging.debug("Found asset node %s pointing to %s", node.name,
                              filename)
                # Clear value to avoid leaking the original path.
                node.attr["value"].tensor.string_val[0] = (
                    tf.compat.as_bytes("SAVEDMODEL-ASSET"))

        if tensor_filename_map:
            assets_key_collection = meta_graph.collection_def[
                tf_v1.saved_model.constants.ASSETS_KEY]

            for tensor, filename in sorted(tensor_filename_map.items()):
                asset_proto = meta_graph_pb2.AssetFileDef()
                asset_proto.filename = _make_asset_filename(filename)
                asset_proto.tensor_info.name = tensor
                assets_key_collection.any_list.value.add().Pack(asset_proto)

    return {
        original_filename: _get_asset_filename(export_path, asset_filename)
        for original_filename, asset_filename in asset_filenames.items()
    }
예제 #12
0
def _make_asset_file_def_any(node_name):
    asset_file_def = meta_graph_pb2.AssetFileDef()
    asset_file_def.tensor_info.name = node_name
    any_message = Any()
    any_message.Pack(asset_file_def)
    return any_message
예제 #13
0
    def _build_model(self):
        """
    load graph from model_path and create session for this graph
    """

        model_path = self._model_path
        self._graph = tf.Graph()
        gpu_options = tf.GPUOptions(allow_growth=True)
        session_config = tf.ConfigProto(gpu_options=gpu_options,
                                        allow_soft_placement=True,
                                        log_device_placement=False)
        self._session = tf.Session(config=session_config, graph=self._graph)

        with self._graph.as_default():
            with self._session.as_default():
                # load model
                _, ext = os.path.splitext(model_path)
                tf.logging.info('loading model from %s' % model_path)
                if tf.gfile.IsDirectory(model_path):
                    model_path = self.search_pb(model_path)
                    logging.info('model find in %s' % model_path)
                    assert tf.saved_model.loader.maybe_saved_model_directory(model_path), \
                      'saved model does not exists in %s' % model_path
                    self._is_saved_model = True
                    meta_graph_def = tf.saved_model.loader.load(
                        self._session, [tf.saved_model.tag_constants.SERVING],
                        model_path)
                    # parse signature
                    signature_def = meta_graph_def.signature_def[
                        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
                    inputs = signature_def.inputs
                    for name, tensor in inputs.items():
                        logging.info('Load input binding: %s -> %s' %
                                     (name, tensor.name))
                        self._inputs_map[
                            name] = self._graph.get_tensor_by_name(tensor.name)
                    outputs = signature_def.outputs
                    for name, tensor in outputs.items():
                        logging.info('Load output binding: %s -> %s' %
                                     (name, tensor.name))
                        self._outputs_map[
                            name] = self._graph.get_tensor_by_name(tensor.name)

                    # get assets
                    self._assets = {}
                    asset_files = tf.get_collection(constants.ASSETS_KEY)
                    for any_proto in asset_files:
                        asset_file = meta_graph_pb2.AssetFileDef()
                        any_proto.Unpack(asset_file)
                        type_name = asset_file.tensor_info.name.split(':')[0]
                        asset_path = os.path.join(model_path,
                                                  constants.ASSETS_DIRECTORY,
                                                  asset_file.filename)
                        assert tf.gfile.Exists(
                            asset_path
                        ), '%s is missing in saved model' % asset_path
                        self._assets[type_name] = asset_path
                    logging.info(self._assets)

                    # get export config
                    self._export_config = {}
                    self._use_bgr = False
                    export_config_collection = tf.get_collection(
                        'EV_EXPORT_CONFIG')
                    if len(export_config_collection) > 0:
                        self._export_config = json.loads(
                            export_config_collection[0])
                        logging.info('load export config info %s' %
                                     export_config_collection[0])
                        self._use_bgr = self._export_config.get(
                            'color_format', 'rgb').lower() == 'bgr'
                        if self._use_bgr:
                            logging.info(
                                'prediction will use image in bgr order')
                        else:
                            logging.info(
                                'prediction will use image in rgb order')

                else:
                    raise ValueError('saved model is not found in %s' %
                                     self._model_path)