def _recreate_asset(self, proto): filename = os.path.join( saved_model_utils.get_assets_dir(self._export_dir), self._asset_file_def[proto.asset_file_def_index].filename) asset = tracking.Asset(filename) if not context.executing_eagerly(): ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset.asset_path) return asset, setattr
def __init__(self): self.asset = tracking.Asset( test.test_src_dir_path( "cc/saved_model/testdata/static_hashtable_asset.txt")) self.table = lookup_ops.StaticHashTable( lookup_ops.TextFileInitializer( self.asset, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), -1)
def test_unused_asset(self): root = tracking.AutoTrackable() root.f = def_function.function( lambda x: 2. * x, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) root.asset = tracking.Asset(self._vocab_path) export_dir = os.path.join(self.get_temp_dir(), "save_dir") save.save(root, export_dir) self.assertAllClose({"output_0": [0.2]}, _import_and_infer(export_dir, {"x": [0.1]}))
def test_asset_path_returned(self): root = tracking.AutoTrackable() root.path = tracking.Asset(self._vocab_path) save_dir = os.path.join(self.get_temp_dir(), "saved_model") root.get_asset = def_function.function(lambda: root.path.asset_path) save.save(root, save_dir, signatures=root.get_asset.get_concrete_function()) second_dir = os.path.join(self.get_temp_dir(), "second_dir") file_io.rename(save_dir, second_dir) imported_path = _import_and_infer(second_dir, {})["output_0"] self.assertIn( compat.as_str_any(second_dir), compat.as_str_any(imported_path))
def __init__(self, resource_name, filename, maximum_cached_engines, device="GPU"): super(_TRTEngineResource, self).__init__( device=device, deleter=_TRTEngineResourceDeleter(resource_name, device)) self._resource_name = resource_name # Track the serialized engine file in the SavedModel. self._filename = self._track_trackable( tracking.Asset(filename), "_serialized_trt_resource_filename") self._maximum_cached_engines = maximum_cached_engines
def _recreate_asset(self, proto): filename = os.path.join( saved_model_utils.get_assets_dir(self._export_dir), self._asset_file_def[proto.asset_file_def_index].filename) return tracking.Asset(filename), setattr
def __init__(self): self.asset = tracking.Asset( test.test_src_dir_path("cc/saved_model/testdata/test_asset.txt"))
def load(self, tags): """Creates an object from the MetaGraph identified by `tags`.""" meta_graph_def = self.get_meta_graph_def_from_tags(tags) load_shared_name_suffix = "_load_{}".format(ops.uid()) functions = function_deserialization.load_function_def_library( meta_graph_def.graph_def.library, load_shared_name_suffix=load_shared_name_suffix) # Replace existing functions in the MetaGraphDef with renamed functions so # we don't have duplicates or name collisions. meta_graph_def.graph_def.library.Clear() for function in functions.values(): meta_graph_def.graph_def.library.function.add().CopyFrom( function.function_def) # We've renamed functions and shared names. We need the same operation on # the GraphDef itself for consistency. for node_def in meta_graph_def.graph_def.node: function_deserialization.fix_node_def( node_def, functions, load_shared_name_suffix, debug_name="MetaGraph import") load_graph_returns = [None] wrapped = wrap_function.wrap_function(functools.partial( self.load_graph, load_graph_returns, meta_graph_def), signature=[]) saver, = load_graph_returns self.restore_variables(wrapped, saver) with wrapped.graph.as_default(): init_op = loader_impl.get_init_op( meta_graph_def ) or monitored_session.Scaffold.default_local_init_op() # Add a dummy Tensor we know we can fetch to add control dependencies to. init_anchor = constant_op.constant(0., name="dummy_fetch") root = tracking.AutoTrackable() asset_feed_tensors = [] asset_paths = [] for tensor_name, value in loader_impl.get_asset_tensors( self._export_dir, meta_graph_def).items(): asset_feed_tensors.append( wrapped.graph.as_graph_element(tensor_name)) asset_paths.append(tracking.Asset(value)) init_fn = wrapped.prune( feeds=asset_feed_tensors, fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)]) initializer = _Initializer(init_fn, asset_paths) # pylint: disable=protected-access local_init_op, _ = initializer._initialize() # pylint: enable=protected-access with ops.init_scope(): if not context.executing_eagerly(): ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, local_init_op) for variable in wrapped.graph.get_collection_ref( ops.GraphKeys.LOCAL_VARIABLES): # pylint: disable=protected-access variable._initializer_op = local_init_op # pylint: enable=protected-access root.initializer = initializer root.asset_paths = asset_paths signature_functions = self._extract_signatures(wrapped, meta_graph_def) root.signatures = signature_serialization.create_signature_map( signature_functions) root.variables = list(wrapped.graph.variables) root.tensorflow_version = ( meta_graph_def.meta_info_def.tensorflow_version) root.tensorflow_git_version = ( meta_graph_def.meta_info_def.tensorflow_git_version) root.graph = wrapped.graph root.prune = wrapped.prune return root