Example #1
0
  def test_registered_saver_is_called_before_save_after_load(self):
    if not context.executing_eagerly():
      self.skipTest("This test must run under eager mode.")

    class RestoreClass(autotrackable.AutoTrackable):
      pass
    def save_fn(trackables, file_prefix):
      del trackables  # Unused.
      # Check that directory is empty
      files = gfile.ListDirectory(os.path.dirname(file_prefix.numpy()))
      self.assertEmpty(files)

    def restore_fn(trackables, merged_prefix):
      del merged_prefix  # Unused.
      root = next(trackables.values())
      self.assertEqual(root.v.numpy(), 123)

    registration.register_checkpoint_saver(
        name="OptionalRestore",
        predicate=lambda x: isinstance(x, RestoreClass),
        save_fn=save_fn,
        restore_fn=restore_fn)

    root = RestoreClass()
    root.v = variables.Variable(123.0)

    ckpt_path = os.path.join(self.get_temp_dir(), "ckpt")
    util.Checkpoint(root).write(ckpt_path)
    def test_registration(self):
        registration.register_checkpoint_saver(
            package="Testing",
            name="test_predicate",
            predicate=lambda x: hasattr(x, "check_attr"),
            save_fn=lambda: "save",
            restore_fn=lambda: "restore")
        x = base.Trackable()
        self.assertIsNone(registration.get_registered_saver_name(x))

        x.check_attr = 1
        saver_name = registration.get_registered_saver_name(x)
        self.assertEqual(saver_name, "Testing.test_predicate")

        self.assertEqual(registration.get_save_function(saver_name)(), "save")
        self.assertEqual(
            registration.get_restore_function(saver_name)(), "restore")

        registration.validate_restore_function(x, "Testing.test_predicate")
        with self.assertRaisesRegex(ValueError, "saver cannot be found"):
            registration.validate_restore_function(x, "Invalid.name")
        x2 = base.Trackable()
        with self.assertRaisesRegex(ValueError, "saver cannot be used"):
            registration.validate_restore_function(x2,
                                                   "Testing.test_predicate")
Example #3
0
  def test_strict_predicate(self):
    class StrictPredicateClass(autotrackable.AutoTrackable):
      pass
    registration.register_checkpoint_saver(
        name="StrictPredicate",
        predicate=lambda x: isinstance(x, StrictPredicateClass),
        save_fn=lambda **kwargs: [],
        restore_fn=lambda **kwargs: None,
        strict_predicate_restore=True)

    root = StrictPredicateClass()
    ckpt_path = os.path.join(self.get_temp_dir(), "ckpt")
    util.Checkpoint(root).write(ckpt_path)

    root2 = autotrackable.AutoTrackable()
    with self.assertRaisesRegex(ValueError, "saver cannot be used"):
      util.Checkpoint(root2).read(ckpt_path)
Example #4
0
  def test_non_strict_predicate(self):
    class NonStrictPredicateClass(autotrackable.AutoTrackable):
      pass
    registration.register_checkpoint_saver(
        name="NonStrictPredicate",
        predicate=lambda x: isinstance(x, NonStrictPredicateClass),
        save_fn=lambda **kwargs: [],
        restore_fn=lambda **kwargs: None,
        strict_predicate_restore=False)

    root = NonStrictPredicateClass()
    ckpt_path = os.path.join(self.get_temp_dir(), "ckpt")
    util.Checkpoint(root).write(ckpt_path)

    root2 = autotrackable.AutoTrackable()
    # This should run without throwing an error.
    util.Checkpoint(root2).read(ckpt_path)
Example #5
0
  tensor_names, shapes_and_slices, tensors, restored_trackables = (
      get_tensor_slices(trackables))
  dtypes = [t.dtype for t in tensors]
  restored_tensors = io_ops.restore_v2(merged_prefix, tensor_names,
                                       shapes_and_slices, dtypes)
  for trackable, restored_tensor in zip(restored_trackables, restored_tensors):
    expected_shape = trackable.value().get_shape()
    restored_tensor = array_ops.reshape(restored_tensor, expected_shape)
    parts = array_ops.unstack(restored_tensor)
    for part, restored_part in zip(trackable.parts, parts):
      part.assign(restored_part)


registration.register_checkpoint_saver(
    name="stacks",
    predicate=lambda x: isinstance(x, (Stack, Part)),
    save_fn=save_stacks_and_parts,
    restore_fn=restore_stacks_and_parts)


def cycle(obj, cycles, signatures=None, options=None):
  to_save = obj
  for _ in range(cycles):
    path = tempfile.mkdtemp(prefix=test.get_temp_dir())
    # If available, we'll run the save and restore preferring the GPU. This
    # just makes sure we aren't throwing errors and have enough
    # device("CPU") blocks to satisfy the placer.
    with test_util.use_gpu():
      save.save(to_save, path, signatures, options=options)
      loaded = load.load(path)
      signatures = loaded.signatures
 def test_invalid_registration(self):
     with self.assertRaisesRegex(TypeError, "must be string"):
         registration.register_checkpoint_saver(package=None,
                                                name="test",
                                                predicate=lambda: None,
                                                save_fn=lambda: None,
                                                restore_fn=lambda: None)
     with self.assertRaisesRegex(TypeError, "must be string"):
         registration.register_checkpoint_saver(name=None,
                                                predicate=lambda: None,
                                                save_fn=lambda: None,
                                                restore_fn=lambda: None)
     with self.assertRaisesRegex(ValueError,
                                 "Invalid registered checkpoint saver."):
         registration.register_checkpoint_saver(package="package",
                                                name="t/est",
                                                predicate=lambda: None,
                                                save_fn=lambda: None,
                                                restore_fn=lambda: None)
     with self.assertRaisesRegex(ValueError,
                                 "Invalid registered checkpoint saver."):
         registration.register_checkpoint_saver(package="package",
                                                name="t/est",
                                                predicate=lambda: None,
                                                save_fn=lambda: None,
                                                restore_fn=lambda: None)
     with self.assertRaisesRegex(
             TypeError,
             "The predicate registered to a checkpoint saver must be callable"
     ):
         registration.register_checkpoint_saver(name="test",
                                                predicate=None,
                                                save_fn=lambda: None,
                                                restore_fn=lambda: None)
     with self.assertRaisesRegex(TypeError, "The save_fn must be callable"):
         registration.register_checkpoint_saver(name="test",
                                                predicate=lambda: None,
                                                save_fn=None,
                                                restore_fn=lambda: None)
     with self.assertRaisesRegex(TypeError,
                                 "The restore_fn must be callable"):
         registration.register_checkpoint_saver(name="test",
                                                predicate=lambda: None,
                                                save_fn=lambda: None,
                                                restore_fn=None)
Example #7
0
from tensorflow.python.checkpoint import graph_view
from tensorflow.python.checkpoint import save_util_v1
from tensorflow.python.eager import test
from tensorflow.python.ops import variables
from tensorflow.python.saved_model import registration
from tensorflow.python.trackable import autotrackable
from tensorflow.python.util import object_identity


class TrackableWithRegisteredSaver(autotrackable.AutoTrackable):
  pass


registration.register_checkpoint_saver(
    name="RegisteredSaver",
    predicate=lambda x: isinstance(x, TrackableWithRegisteredSaver),
    save_fn=lambda trackables, file_prefix: [],
    restore_fn=lambda trackables, merged_prefix: None)


class SerializationTest(test.TestCase):

  def test_serialize_gathered_objects(self):
    root = autotrackable.AutoTrackable()
    root.v = variables.Variable(1.0)
    root.registered = TrackableWithRegisteredSaver()
    named_saveable_objects, _, _, registered_savers = (
        save_util_v1.serialize_gathered_objects(
            graph_view.ObjectGraphView(root)))

    self.assertLen(named_saveable_objects, 1)
Example #8
0
    def test_migration_backwards_compatibility(self):
        # Tests that objects migrated to using the advanced saver registration can
        # use pre-migration checkpoints.

        class NoRegisteredSaver(autotrackable.AutoTrackable):
            def __init__(self, name):
                self.name = name

            def _serialize_to_tensors(self):
                return {"name": constant_op.constant(self.name)}

        class RegisteredSaver(autotrackable.AutoTrackable):
            def __init__(self, name):
                self.name = name

        def _get_tensors(trackables, append_name=True):
            tensor_names = []
            shapes_and_slices = []
            tensors = []
            restored_trackables = []
            for obj_prefix, obj in trackables.items():
                tensor_names.append(obj_prefix +
                                    "name" if append_name else obj_prefix)
                shapes_and_slices.append("")
                tensors.append(constant_op.constant(obj.name))
                restored_trackables.append(obj)
            return tensor_names, shapes_and_slices, tensors, restored_trackables

        def save_fn(trackables, file_prefix):
            tensor_names, shapes_and_slices, tensors, _ = _get_tensors(
                trackables)
            io_ops.save_v2(file_prefix, tensor_names, shapes_and_slices,
                           tensors)
            return file_prefix

        def restore_fn(trackables, merged_prefix):
            tensor_names, shapes_and_slices, tensors, restored_trackables = (
                _get_tensors(trackables))
            dtypes = [t.dtype for t in tensors]
            try:
                restored_tensors = io_ops.restore_v2(merged_prefix,
                                                     tensor_names,
                                                     shapes_and_slices, dtypes)
            except errors_impl.NotFoundError:
                # If a NotFoundError is caught, then it means that the checkpoint
                # was written prior to the saver registration migration.
                tensor_names, shapes_and_slices, tensors, restored_trackables = (
                    _get_tensors(trackables, append_name=False))
                restored_tensors = io_ops.restore_v2(merged_prefix,
                                                     tensor_names,
                                                     shapes_and_slices, dtypes)
            for trackable, name_tensor in zip(restored_trackables,
                                              restored_tensors):
                trackable.name = name_tensor

        registration.register_checkpoint_saver(
            name="MigratedSaver",
            predicate=lambda x: isinstance(x, RegisteredSaver),
            save_fn=save_fn,
            restore_fn=restore_fn,
        )

        before = NoRegisteredSaver("before")
        after = RegisteredSaver("after")
        before_ckpt_path = os.path.join(self.get_temp_dir(), "before_ckpt")
        util.Checkpoint(before).write(before_ckpt_path)

        after_ckpt = util.Checkpoint(after)
        after_ckpt_path = os.path.join(self.get_temp_dir(), "after_ckpt")
        after_ckpt.write(after_ckpt_path)

        # Try loading the pre-migrated checkpoint to the migrated object.
        after_ckpt.read(before_ckpt_path)
        self.assertEqual(b"before", self.evaluate(after.name))