Exemplo n.º 1
0
 def test_dedup_assets(self):
   vocab = self._make_asset("contents")
   root = tracking.Checkpointable()
   root.asset1 = tracking.TrackableAsset(vocab)
   root.asset2 = tracking.TrackableAsset(vocab)
   imported = self.cycle(root)
   self.assertEqual(imported.asset1.asset_path.numpy(),
                    imported.asset2.asset_path.numpy())
Exemplo n.º 2
0
    def test_assets_dedup(self):
        vocab = self._make_asset("contents")
        root = tracking.Checkpointable()
        root.f = def_function.function(
            lambda x: 2. * x,
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])

        root.asset1 = tracking.TrackableAsset(vocab)
        root.asset2 = tracking.TrackableAsset(vocab)

        imported = self.cycle(root)

        self.assertEqual(imported.asset1.asset_path.numpy(),
                         imported.asset2.asset_path.numpy())
Exemplo n.º 3
0
 def load(self, tags):
   """Creates an object from the MetaGraph identified by `tags`."""
   meta_graph_def = self.get_meta_graph_def_from_tags(tags)
   load_graph_returns = [None]
   wrapped = wrap_function.wrap_function(
       functools.partial(self.load_graph, load_graph_returns, meta_graph_def),
       signature=[])
   saver, = load_graph_returns
   self.restore_variables(wrapped, saver)
   with wrapped.graph.as_default():
     init_op = loader_impl.get_init_op(meta_graph_def)
   if init_op is not None:
     asset_feed_tensors = []
     asset_paths = []
     for tensor_name, value in loader_impl.get_asset_tensors(
         self._export_dir, meta_graph_def).items():
       asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name))
       asset_paths.append(tracking.TrackableAsset(value))
     init_fn = wrapped.prune(
         feeds=asset_feed_tensors,
         fetches=[wrapped.graph.as_graph_element(init_op)])
     init_fn(*[path.asset_path for path in asset_paths])
   signature_functions = self._extract_signatures(wrapped, meta_graph_def)
   root = tracking.AutoCheckpointable()
   root.signatures = signature_serialization.create_signature_map(
       signature_functions)
   root.variables = list(wrapped.graph.variables)
   return root
Exemplo n.º 4
0
    def test_assets_dedup(self):
        vocab = self._make_asset("contents")
        root = tracking.Checkpointable()
        root.f = def_function.function(
            lambda x: 2. * x,
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])

        root.asset1 = tracking.TrackableAsset(vocab)
        root.asset2 = tracking.TrackableAsset(vocab)

        export_dir = os.path.join(self.get_temp_dir(), "save_dir")
        save.save(root, export_dir)
        imported = load.load(export_dir)

        self.assertEqual(imported.asset1.asset_path.numpy(),
                         imported.asset2.asset_path.numpy())
Exemplo n.º 5
0
    def test_unused_asset(self):
        root = tracking.AutoCheckpointable()
        root.f = def_function.function(
            lambda x: 2. * x,
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
        root.asset = tracking.TrackableAsset(self._vocab_path)

        export_dir = os.path.join(self.get_temp_dir(), "save_dir")
        save.save(root, export_dir)
        self.assertAllClose({"output_0": [0.2]},
                            _import_and_infer(export_dir, {"x": [0.1]}))
Exemplo n.º 6
0
 def test_capture_assets(self):
     root = tracking.AutoCheckpointable()
     root.vocab = tracking.TrackableAsset(self._make_asset("contents"))
     root.f = def_function.function(lambda: root.vocab.asset_path,
                                    input_signature=[])
     imported = self.cycle(root)
     origin_output = root.f().numpy()
     imported_output = imported.f().numpy()
     self.assertNotEqual(origin_output, imported_output)
     with open(imported_output, "r") as f:
         self.assertEquals("contents", f.read())
Exemplo n.º 7
0
  def test_assets(self):
    file1 = self._make_asset("contents 1")
    file2 = self._make_asset("contents 2")

    root = tracking.Checkpointable()
    root.asset1 = tracking.TrackableAsset(file1)
    root.asset2 = tracking.TrackableAsset(file2)

    save_dir = os.path.join(self.get_temp_dir(), "save_dir")
    save.save(root, save_dir, signatures={})

    file_io.delete_file(file1)
    file_io.delete_file(file2)
    load_dir = os.path.join(self.get_temp_dir(), "load_dir")
    file_io.rename(save_dir, load_dir)

    imported = load.load(load_dir)
    with open(imported.asset1.asset_path.numpy(), "r") as f:
      self.assertEquals("contents 1", f.read())
    with open(imported.asset2.asset_path.numpy(), "r") as f:
      self.assertEquals("contents 2", f.read())
Exemplo n.º 8
0
 def test_asset_path_returned(self):
     root = tracking.AutoCheckpointable()
     root.path = tracking.TrackableAsset(self._vocab_path)
     save_dir = os.path.join(self.get_temp_dir(), "saved_model")
     root.get_asset = def_function.function(lambda: root.path.asset_path)
     save.save(root,
               save_dir,
               signatures=root.get_asset.get_concrete_function())
     second_dir = os.path.join(self.get_temp_dir(), "second_dir")
     file_io.rename(save_dir, second_dir)
     imported_path = _import_and_infer(second_dir, {})["output_0"]
     self.assertIn(compat.as_str_any(second_dir),
                   compat.as_str_any(imported_path))
Exemplo n.º 9
0
 def test_capture_assets(self):
   root = tracking.Checkpointable()
   root.vocab = tracking.TrackableAsset(self._make_asset("contents"))
   root.f = def_function.function(
       lambda: root.vocab.asset_path,
       input_signature=[])
   save_dir = os.path.join(self.get_temp_dir(), "save_dir")
   save.save(root, save_dir)
   imported = load.load(save_dir)
   origin_output = root.f().numpy()
   imported_output = imported.f().numpy()
   self.assertNotEqual(origin_output, imported_output)
   with open(imported_output, "r") as f:
     self.assertEquals("contents", f.read())
Exemplo n.º 10
0
    def test_assets_import(self):
        file1 = self._make_asset("contents 1")
        file2 = self._make_asset("contents 2")

        root = tracking.Checkpointable()
        root.f = def_function.function(
            lambda x: 2. * x,
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
        root.asset1 = tracking.TrackableAsset(file1)
        root.asset2 = tracking.TrackableAsset(file2)

        save_dir = os.path.join(self.get_temp_dir(), "save_dir")
        save.save(root, save_dir)

        file_io.delete_file(file1)
        file_io.delete_file(file2)
        load_dir = os.path.join(self.get_temp_dir(), "load_dir")
        file_io.rename(save_dir, load_dir)

        imported = load.load(load_dir)
        with open(imported.asset1.asset_path.numpy(), "r") as f:
            self.assertEquals("contents 1", f.read())
        with open(imported.asset2.asset_path.numpy(), "r") as f:
            self.assertEquals("contents 2", f.read())
Exemplo n.º 11
0
    def test_capture_assets_in_graph(self, cycles):
        root = tracking.AutoCheckpointable()
        root.vocab = tracking.TrackableAsset(self._make_asset("contents"))
        root.f = def_function.function(lambda: root.vocab.asset_path,
                                       input_signature=[])

        original_output = root.f().numpy()

        if cycles > 1:
            root = self.cycle(root, cycles - 1)
        path = tempfile.mkdtemp(prefix=self.get_temp_dir())
        save.save(root, path)

        with ops.Graph().as_default():
            imported = load.load(path)
            imported_tensor = imported.f()
            with monitored_session.MonitoredSession() as sess:
                imported_output = sess.run(imported_tensor)
                self.assertNotEqual(original_output, imported_output)
                with open(imported_output, "r") as f:
                    self.assertEqual("contents", f.read())
Exemplo n.º 12
0
 def _recreate_asset(self, proto):
     filename = os.path.join(
         saved_model_utils.get_assets_dir(self._export_dir),
         self._asset_file_def[proto.asset_file_def_index].filename)
     return tracking.TrackableAsset(filename), setattr
Exemplo n.º 13
0
  def __init__(self,
               filename,
               key_dtype,
               key_index,
               value_dtype,
               value_index,
               vocab_size=None,
               delimiter="\t",
               name=None):
    """Constructs a table initializer object to populate from a text file.

    It generates one key-value pair per line. The type of table key and
    value are specified by `key_dtype` and `value_dtype`, respectively.
    Similarly the content of the key and value are specified by the key_index
    and value_index.

    - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
      expects data type int64.
    - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
      type string.
    - A value >=0 means use the index (starting at zero) of the split line based
      on `delimiter`.

    Args:
      filename: The filename of the text file to be used for initialization.
        The path must be accessible from wherever the graph is initialized
        (eg. trainer or eval workers). The filename may be a scalar `Tensor`.
      key_dtype: The `key` data type.
      key_index: the index that represents information of a line to get the
        table 'key' values from.
      value_dtype: The `value` data type.
      value_index: the index that represents information of a line to get the
        table 'value' values from.'
      vocab_size: The number of elements in the file, if known.
      delimiter: The delimiter to separate fields in a line.
      name: A name for the operation (optional).

    Raises:
      ValueError: when the filename is empty, or when the table key and value
      data types do not match the expected data types.
    """
    if not isinstance(filename, ops.Tensor) and not filename:
      raise ValueError("Filename required for %s." % name)

    key_dtype = dtypes.as_dtype(key_dtype)
    value_dtype = dtypes.as_dtype(value_dtype)

    if key_index < -2:
      raise ValueError("Invalid key index %s." % (key_index))

    if key_index == TextFileIndex.LINE_NUMBER and key_dtype != dtypes.int64:
      raise ValueError("Signature mismatch. Keys must be dtype %s, got %s." %
                       (dtypes.int64, key_dtype))
    if ((key_index == TextFileIndex.WHOLE_LINE) and
        (not key_dtype.is_integer) and (key_dtype != dtypes.string)):
      raise ValueError(
          "Signature mismatch. Keys must be integer or string, got %s." %
          key_dtype)
    if value_index < -2:
      raise ValueError("Invalid value index %s." % (value_index))

    if value_index == TextFileIndex.LINE_NUMBER and value_dtype != dtypes.int64:
      raise ValueError("Signature mismatch. Values must be dtype %s, got %s." %
                       (dtypes.int64, value_dtype))
    if value_index == TextFileIndex.WHOLE_LINE and value_dtype != dtypes.string:
      raise ValueError("Signature mismatch. Values must be dtype %s, got %s." %
                       (dtypes.string, value_dtype))

    if (vocab_size is not None) and (vocab_size <= 0):
      raise ValueError("Invalid vocab_size %s." % vocab_size)

    self._key_index = key_index
    self._value_index = value_index
    self._vocab_size = vocab_size
    self._delimiter = delimiter
    self._name = name
    self._filename = self._track_checkpointable(
        checkpointable.TrackableAsset(filename),
        "_filename")

    super(TextFileInitializer, self).__init__(key_dtype, value_dtype)
Exemplo n.º 14
0
 def _recreate_asset(self, proto):
     filename = os.path.join(
         saved_model_utils.get_assets_dir(self._export_dir),
         proto.relative_filename)
     return tracking.TrackableAsset(filename)