checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    with context.graph_mode():
      save_graph = ops.Graph()
      with save_graph.as_default(), self.test_session(
          graph=save_graph) as session:
        root = self._initialized_model()
        save_path = root.save(
            session=session, file_prefix=checkpoint_prefix)
    with context.eager_mode():
      root = self._initialized_model()
      self._set_sentinels(root)
      root.restore(save_path).assert_consumed()
      self._check_sentinels(root)

  def testSaveEagerLoadGraph(self):
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    with context.eager_mode():
      root = self._initialized_model()
      save_path = root.save(file_prefix=checkpoint_prefix)
    with context.graph_mode():
      save_graph = ops.Graph()
      with save_graph.as_default(), self.test_session(graph=save_graph):
        root = self._initialized_model()
        self._set_sentinels(root)
        root.restore(save_path).assert_consumed().run_restore_ops()
        self._check_sentinels(root)

if __name__ == "__main__":
  test.main()
Beispiel #2
0
            def _restore_from_tensors(self, restored_tensors):
                self.a.assign(restored_tensors["-a"])
                self.b.assign(restored_tensors["-b"])

        new = NewTrackable()

        # Test with the checkpoint conversion flag disabled (normal compatibility).
        saveable_compat.force_checkpoint_conversion(False)
        checkpoint_path = os.path.join(self.get_temp_dir(), "ckpt")
        checkpoint.Checkpoint(new).write(checkpoint_path)

        dep = DeprecatedTrackable()
        checkpoint.Checkpoint(dep).read(checkpoint_path).assert_consumed()
        self.assertEqual(3, self.evaluate(dep.a))
        self.assertEqual(4, self.evaluate(dep.b))

        # Now test with the checkpoint conversion flag enabled (forward compat).
        # The deprecated object will try to load from the new checkpoint.
        saveable_compat.force_checkpoint_conversion()
        checkpoint_path = os.path.join(self.get_temp_dir(), "ckpt2")
        checkpoint.Checkpoint(new).write(checkpoint_path)

        dep = DeprecatedTrackable()
        checkpoint.Checkpoint(dep).read(checkpoint_path).assert_consumed()
        self.assertEqual(3, self.evaluate(dep.a))
        self.assertEqual(4, self.evaluate(dep.b))


if __name__ == "__main__":
    test.main()