Exemple #1
0
 def test_dedup_assets(self, cycles):
   vocab = self._make_asset("contents")
   root = tracking.AutoTrackable()
   root.asset1 = tracking.TrackableAsset(vocab)
   root.asset2 = tracking.TrackableAsset(vocab)
   imported = self.cycle(root, cycles)
   self.assertEqual(imported.asset1.asset_path.numpy(),
                    imported.asset2.asset_path.numpy())
 def __init__(self, resource_name, filename, maximum_cached_engines):
   super(TRTEngineResource, self).__init__()
   self._resource_name = resource_name
   # Track the serialized engine file in the SavedModel.
   self._filename = self._track_trackable(
       tracking.TrackableAsset(filename), "_serialized_trt_engine_filename")
   self._maximum_cached_engines = maximum_cached_engines
Exemple #3
0
 def __init__(self,
              vocab_file_path,
              oov_buckets,
              num_lines_to_ignore=0,
              num_lines_to_use=None):
   super(TextEmbeddingModel, self).__init__()
   self._vocabulary, self._pretrained_vectors = load(vocab_file_path,
                                                     parse_line,
                                                     num_lines_to_ignore,
                                                     num_lines_to_use)
   self._oov_buckets = oov_buckets
   # Make the vocabulary file a `TrackableAsset` to ensure it is saved along
   # with the model.
   self._vocabulary_file = tracking.TrackableAsset(
       write_vocabulary_file(self._vocabulary))
   self._table = lookup_ops.index_table_from_file(
       vocabulary_file=self._vocabulary_file,
       num_oov_buckets=self._oov_buckets,
       hasher_spec=lookup_ops.FastHashSpec)
   oovs = np.zeros([oov_buckets, self._pretrained_vectors.shape[1]])
   self._pretrained_vectors.resize([
       self._pretrained_vectors.shape[0] + oov_buckets,
       self._pretrained_vectors.shape[1]
   ])
   self._pretrained_vectors[self._pretrained_vectors.shape[0] -
                            oov_buckets:, :] = oovs
   self.embeddings = tf.Variable(self._pretrained_vectors)
   self.variables = [self.embeddings]
   self.trainable_variables = self.variables
  def load(self, tags):
    """Creates an object from the MetaGraph identified by `tags`."""
    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
    load_graph_returns = [None]
    wrapped = wrap_function.wrap_function(
        functools.partial(self.load_graph, load_graph_returns, meta_graph_def),
        signature=[])
    saver, = load_graph_returns
    self.restore_variables(wrapped, saver)
    with wrapped.graph.as_default():
      init_op = loader_impl.get_init_op(meta_graph_def)
    root = tracking.AutoTrackable()
    if init_op is not None:
      asset_feed_tensors = []
      asset_paths = []
      for tensor_name, value in loader_impl.get_asset_tensors(
          self._export_dir, meta_graph_def).items():
        asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name))
        asset_paths.append(tracking.TrackableAsset(value))
      init_fn = wrapped.prune(
          feeds=asset_feed_tensors,
          fetches=[wrapped.graph.as_graph_element(init_op)])
      initializer = _Initializer(init_fn, asset_paths)
      initializer.initialize()
      root.initializer = initializer
      root.asset_paths = asset_paths
    else:
      root.asset_paths = []
    signature_functions = self._extract_signatures(wrapped, meta_graph_def)

    root.signatures = signature_serialization.create_signature_map(
        signature_functions)
    root.variables = list(wrapped.graph.variables)
    return root
    def load(self, tags):
        """Creates an object from the MetaGraph identified by `tags`."""
        meta_graph_def = self.get_meta_graph_def_from_tags(tags)
        load_graph_returns = [None]
        wrapped = wrap_function.wrap_function(functools.partial(
            self.load_graph, load_graph_returns, meta_graph_def),
                                              signature=[])
        saver, = load_graph_returns
        self.restore_variables(wrapped, saver)
        with wrapped.graph.as_default():
            init_op = loader_impl.get_init_op(
                meta_graph_def
            ) or monitored_session.Scaffold.default_local_init_op()
            # Add a dummy Tensor we know we can fetch to add control dependencies to.
            init_anchor = constant_op.constant(0., name="dummy_fetch")

        root = tracking.AutoTrackable()
        asset_feed_tensors = []
        asset_paths = []
        for tensor_name, value in loader_impl.get_asset_tensors(
                self._export_dir, meta_graph_def).items():
            asset_feed_tensors.append(
                wrapped.graph.as_graph_element(tensor_name))
            asset_paths.append(tracking.TrackableAsset(value))
        init_fn = wrapped.prune(
            feeds=asset_feed_tensors,
            fetches=[init_anchor,
                     wrapped.graph.as_graph_element(init_op)])
        initializer = _Initializer(init_fn, asset_paths)
        # pylint: disable=protected-access
        local_init_op, _ = initializer._initialize()
        # pylint: enable=protected-access
        with ops.init_scope():
            if not context.executing_eagerly():
                ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS,
                                      local_init_op)
                for variable in wrapped.graph.get_collection_ref(
                        ops.GraphKeys.LOCAL_VARIABLES):
                    # pylint: disable=protected-access
                    variable._initializer_op = local_init_op
                    # pylint: enable=protected-access
        root.initializer = initializer
        root.asset_paths = asset_paths
        signature_functions = self._extract_signatures(wrapped, meta_graph_def)

        root.signatures = signature_serialization.create_signature_map(
            signature_functions)
        root.variables = list(wrapped.graph.variables)
        root.tensorflow_version = (
            meta_graph_def.meta_info_def.tensorflow_version)
        root.tensorflow_git_version = (
            meta_graph_def.meta_info_def.tensorflow_git_version)
        root.graph = wrapped.graph
        root.prune = wrapped.prune
        return root
 def test_capture_assets(self, cycles):
     root = tracking.AutoTrackable()
     root.vocab = tracking.TrackableAsset(self._make_asset("contents"))
     root.f = def_function.function(lambda: root.vocab.asset_path,
                                    input_signature=[])
     imported = self.cycle(root, cycles)
     original_output = root.f().numpy()
     imported_output = imported.f().numpy()
     self.assertNotEqual(original_output, imported_output)
     with open(imported_output, "r") as f:
         self.assertEqual("contents", f.read())
Exemple #7
0
    def test_unused_asset(self):
        root = tracking.AutoTrackable()
        root.f = def_function.function(
            lambda x: 2. * x,
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
        root.asset = tracking.TrackableAsset(self._vocab_path)

        export_dir = os.path.join(self.get_temp_dir(), "save_dir")
        save.save(root, export_dir)
        self.assertAllClose({"output_0": [0.2]},
                            _import_and_infer(export_dir, {"x": [0.1]}))
Exemple #8
0
 def test_asset_path_returned(self):
   root = tracking.AutoTrackable()
   root.path = tracking.TrackableAsset(self._vocab_path)
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   root.get_asset = def_function.function(lambda: root.path.asset_path)
   save.save(root, save_dir, signatures=root.get_asset.get_concrete_function())
   second_dir = os.path.join(self.get_temp_dir(), "second_dir")
   file_io.rename(save_dir, second_dir)
   imported_path = _import_and_infer(second_dir, {})["output_0"]
   self.assertIn(compat.as_str_any(second_dir),
                 compat.as_str_any(imported_path))
Exemple #9
0
  def test_assets(self, cycles):
    file1 = self._make_asset("contents 1")
    file2 = self._make_asset("contents 2")

    root = tracking.AutoTrackable()
    root.asset1 = tracking.TrackableAsset(file1)
    root.asset2 = tracking.TrackableAsset(file2)

    save_dir = os.path.join(self.get_temp_dir(), "save_dir")
    save.save(root, save_dir)

    file_io.delete_file(file1)
    file_io.delete_file(file2)
    load_dir = os.path.join(self.get_temp_dir(), "load_dir")
    file_io.rename(save_dir, load_dir)

    imported = load.load(load_dir)
    with open(self.evaluate(imported.asset1.asset_path), "r") as f:
      self.assertEqual("contents 1", f.read())
    with open(self.evaluate(imported.asset2.asset_path), "r") as f:
      self.assertEqual("contents 2", f.read())
Exemple #10
0
 def __init__(self,
              resource_name,
              filename,
              maximum_cached_engines,
              device="GPU"):
   super(_TRTEngineResource, self).__init__(
       device=device, deleter=_TRTEngineResourceDeleter(resource_name, device))
   self._resource_name = resource_name
   # Track the serialized engine file in the SavedModel.
   self._filename = self._track_trackable(
       tracking.TrackableAsset(filename), "_serialized_trt_resource_filename")
   self._maximum_cached_engines = maximum_cached_engines
Exemple #11
0
 def __init__(self, vocabulary, emb_dim, oov_buckets):
   super(TextEmbeddingModel, self).__init__()
   self._oov_buckets = oov_buckets
   self._vocabulary_file = tracking.TrackableAsset(
       write_vocabulary_file(vocabulary))
   self._total_size = len(vocabulary) + oov_buckets
   self._table = lookup_ops.index_table_from_file(
       vocabulary_file=self._vocabulary_file,
       num_oov_buckets=self._oov_buckets,
       hasher_spec=lookup_ops.FastHashSpec)
   self.embeddings = tf.Variable(
       tf.random.uniform(shape=[self._total_size, emb_dim]))
   self.variables = [self.embeddings]
   self.trainable_variables = self.variables
Exemple #12
0
 def __init__(self, vocab, emb_dim, buckets, state_size,n_layers):
   super(ULMFiTModule, self).__init__()
   self._buckets = buckets
   self._vocab_size = len(vocab)
   self.emb_row_size = self._vocab_size+self._buckets
   #self._embeddings = tf.Variable(tf.random.uniform(shape=[self.emb_row_size, emb_dim]))
   self._state_size = state_size
   self.model = LanguageModelEncoder(self.emb_row_size,emb_dim,state_size,n_layers)
   self._vocabulary_file = tracking.TrackableAsset(write_vocabulary_file(vocab)) 
   self.w2i_table = lookup_ops.index_table_from_file(
                   vocabulary_file= self._vocabulary_file,
                   num_oov_buckets=self._buckets,
                   hasher_spec=lookup_ops.FastHashSpec)
   self.i2w_table = lookup_ops.index_to_string_table_from_file(
                   vocabulary_file=self._vocabulary_file, 
                   delimiter = '\n',
                   default_value="UNKNOWN")
   self._logit_layer = tf.keras.layers.Dense(self.emb_row_size)
   self.optimizer = tf.keras.optimizers.Adam()
    def test_capture_assets_in_graph(self, cycles):
        root = tracking.AutoTrackable()
        root.vocab = tracking.TrackableAsset(self._make_asset("contents"))
        root.f = def_function.function(lambda: root.vocab.asset_path,
                                       input_signature=[])

        original_output = root.f().numpy()

        if cycles > 1:
            root = self.cycle(root, cycles - 1)
        path = tempfile.mkdtemp(prefix=self.get_temp_dir())
        save.save(root, path)

        with ops.Graph().as_default():
            imported = load.load(path)
            imported_tensor = imported.f()
            with monitored_session.MonitoredSession() as sess:
                imported_output = sess.run(imported_tensor)
                self.assertNotEqual(original_output, imported_output)
                with open(imported_output, "r") as f:
                    self.assertEqual("contents", f.read())
Exemple #14
0
 def _recreate_asset(self, proto):
     filename = os.path.join(
         saved_model_utils.get_assets_dir(self._export_dir),
         self._asset_file_def[proto.asset_file_def_index].filename)
     return tracking.TrackableAsset(filename), setattr
Exemple #15
0
 def __init__(self, filename, name=None):
     self._name = name
     self._filename_arg = filename
     self._filename = self._track_trackable(
         trackable.TrackableAsset(filename), "_filename")
    def load(self, tags):
        """Creates an object from the MetaGraph identified by `tags`."""
        meta_graph_def = self.get_meta_graph_def_from_tags(tags)
        load_shared_name_suffix = "_load_{}".format(ops.uid())
        functions = function_deserialization.load_function_def_library(
            meta_graph_def.graph_def.library,
            load_shared_name_suffix=load_shared_name_suffix)
        # Replace existing functions in the MetaGraphDef with renamed functions so
        # we don't have duplicates or name collisions.
        meta_graph_def.graph_def.library.Clear()
        for function in functions.values():
            meta_graph_def.graph_def.library.function.add().CopyFrom(
                function.function_def)
        # We've renamed functions and shared names. We need the same operation on
        # the GraphDef itself for consistency.
        for node_def in meta_graph_def.graph_def.node:
            function_deserialization.fix_node_def(
                node_def,
                functions,
                load_shared_name_suffix,
                debug_name="MetaGraph import")

        load_graph_returns = [None]
        wrapped = wrap_function.wrap_function(functools.partial(
            self.load_graph, load_graph_returns, meta_graph_def),
                                              signature=[])
        saver, = load_graph_returns
        self.restore_variables(wrapped, saver)
        with wrapped.graph.as_default():
            init_op = loader_impl.get_init_op(
                meta_graph_def
            ) or monitored_session.Scaffold.default_local_init_op()
            # Add a dummy Tensor we know we can fetch to add control dependencies to.
            init_anchor = constant_op.constant(0., name="dummy_fetch")

        root = tracking.AutoTrackable()
        asset_feed_tensors = []
        asset_paths = []
        for tensor_name, value in loader_impl.get_asset_tensors(
                self._export_dir, meta_graph_def).items():
            asset_feed_tensors.append(
                wrapped.graph.as_graph_element(tensor_name))
            asset_paths.append(tracking.TrackableAsset(value))
        init_fn = wrapped.prune(
            feeds=asset_feed_tensors,
            fetches=[init_anchor,
                     wrapped.graph.as_graph_element(init_op)])
        initializer = _Initializer(init_fn, asset_paths)
        # pylint: disable=protected-access
        local_init_op, _ = initializer._initialize()
        # pylint: enable=protected-access
        with ops.init_scope():
            if not context.executing_eagerly():
                ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS,
                                      local_init_op)
                for variable in wrapped.graph.get_collection_ref(
                        ops.GraphKeys.LOCAL_VARIABLES):
                    # pylint: disable=protected-access
                    variable._initializer_op = local_init_op
                    # pylint: enable=protected-access
        root.initializer = initializer
        root.asset_paths = asset_paths
        signature_functions = self._extract_signatures(wrapped, meta_graph_def)

        root.signatures = signature_serialization.create_signature_map(
            signature_functions)
        root.variables = list(wrapped.graph.variables)
        root.tensorflow_version = (
            meta_graph_def.meta_info_def.tensorflow_version)
        root.tensorflow_git_version = (
            meta_graph_def.meta_info_def.tensorflow_git_version)
        root.graph = wrapped.graph
        root.prune = wrapped.prune
        return root