Beispiel #1
0
 def _recreate_asset(self, proto):
   filename = os.path.join(
       saved_model_utils.get_assets_dir(self._export_dir),
       self._asset_file_def[proto.asset_file_def_index].filename)
   asset = tracking.Asset(filename)
   if not context.executing_eagerly():
     ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset.asset_path)
   return asset, setattr
 def __init__(self):
     self.asset = tracking.Asset(
         test.test_src_dir_path(
             "cc/saved_model/testdata/static_hashtable_asset.txt"))
     self.table = lookup_ops.StaticHashTable(
         lookup_ops.TextFileInitializer(
             self.asset, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
             dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), -1)
Beispiel #3
0
    def test_unused_asset(self):
        root = tracking.AutoTrackable()
        root.f = def_function.function(
            lambda x: 2. * x,
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
        root.asset = tracking.Asset(self._vocab_path)

        export_dir = os.path.join(self.get_temp_dir(), "save_dir")
        save.save(root, export_dir)
        self.assertAllClose({"output_0": [0.2]},
                            _import_and_infer(export_dir, {"x": [0.1]}))
Beispiel #4
0
 def test_asset_path_returned(self):
   root = tracking.AutoTrackable()
   root.path = tracking.Asset(self._vocab_path)
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   root.get_asset = def_function.function(lambda: root.path.asset_path)
   save.save(root, save_dir, signatures=root.get_asset.get_concrete_function())
   second_dir = os.path.join(self.get_temp_dir(), "second_dir")
   file_io.rename(save_dir, second_dir)
   imported_path = _import_and_infer(second_dir, {})["output_0"]
   self.assertIn(
       compat.as_str_any(second_dir), compat.as_str_any(imported_path))
 def __init__(self,
              resource_name,
              filename,
              maximum_cached_engines,
              device="GPU"):
   super(_TRTEngineResource, self).__init__(
       device=device, deleter=_TRTEngineResourceDeleter(resource_name, device))
   self._resource_name = resource_name
   # Track the serialized engine file in the SavedModel.
   self._filename = self._track_trackable(
       tracking.Asset(filename), "_serialized_trt_resource_filename")
   self._maximum_cached_engines = maximum_cached_engines
 def _recreate_asset(self, proto):
   filename = os.path.join(
       saved_model_utils.get_assets_dir(self._export_dir),
       self._asset_file_def[proto.asset_file_def_index].filename)
   return tracking.Asset(filename), setattr
Beispiel #7
0
 def __init__(self):
   self.asset = tracking.Asset(
       test.test_src_dir_path("cc/saved_model/testdata/test_asset.txt"))
Beispiel #8
0
    def load(self, tags):
        """Creates an object from the MetaGraph identified by `tags`."""
        meta_graph_def = self.get_meta_graph_def_from_tags(tags)
        load_shared_name_suffix = "_load_{}".format(ops.uid())
        functions = function_deserialization.load_function_def_library(
            meta_graph_def.graph_def.library,
            load_shared_name_suffix=load_shared_name_suffix)
        # Replace existing functions in the MetaGraphDef with renamed functions so
        # we don't have duplicates or name collisions.
        meta_graph_def.graph_def.library.Clear()
        for function in functions.values():
            meta_graph_def.graph_def.library.function.add().CopyFrom(
                function.function_def)
        # We've renamed functions and shared names. We need the same operation on
        # the GraphDef itself for consistency.
        for node_def in meta_graph_def.graph_def.node:
            function_deserialization.fix_node_def(
                node_def,
                functions,
                load_shared_name_suffix,
                debug_name="MetaGraph import")

        load_graph_returns = [None]
        wrapped = wrap_function.wrap_function(functools.partial(
            self.load_graph, load_graph_returns, meta_graph_def),
                                              signature=[])
        saver, = load_graph_returns
        self.restore_variables(wrapped, saver)
        with wrapped.graph.as_default():
            init_op = loader_impl.get_init_op(
                meta_graph_def
            ) or monitored_session.Scaffold.default_local_init_op()
            # Add a dummy Tensor we know we can fetch to add control dependencies to.
            init_anchor = constant_op.constant(0., name="dummy_fetch")

        root = tracking.AutoTrackable()
        asset_feed_tensors = []
        asset_paths = []
        for tensor_name, value in loader_impl.get_asset_tensors(
                self._export_dir, meta_graph_def).items():
            asset_feed_tensors.append(
                wrapped.graph.as_graph_element(tensor_name))
            asset_paths.append(tracking.Asset(value))
        init_fn = wrapped.prune(
            feeds=asset_feed_tensors,
            fetches=[init_anchor,
                     wrapped.graph.as_graph_element(init_op)])
        initializer = _Initializer(init_fn, asset_paths)
        # pylint: disable=protected-access
        local_init_op, _ = initializer._initialize()
        # pylint: enable=protected-access
        with ops.init_scope():
            if not context.executing_eagerly():
                ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS,
                                      local_init_op)
                for variable in wrapped.graph.get_collection_ref(
                        ops.GraphKeys.LOCAL_VARIABLES):
                    # pylint: disable=protected-access
                    variable._initializer_op = local_init_op
                    # pylint: enable=protected-access
        root.initializer = initializer
        root.asset_paths = asset_paths
        signature_functions = self._extract_signatures(wrapped, meta_graph_def)

        root.signatures = signature_serialization.create_signature_map(
            signature_functions)
        root.variables = list(wrapped.graph.variables)
        root.tensorflow_version = (
            meta_graph_def.meta_info_def.tensorflow_version)
        root.tensorflow_git_version = (
            meta_graph_def.meta_info_def.tensorflow_git_version)
        root.graph = wrapped.graph
        root.prune = wrapped.prune
        return root