Example #1
0
 def testAssertConsumedWithUnusedPythonState(self):
     has_config = base.Checkpointable()
     has_config.get_config = lambda: {}
     saved = util.Checkpoint(obj=has_config)
     save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt"))
     restored = util.Checkpoint(obj=base.Checkpointable())
     restored.restore(save_path).assert_consumed()
Example #2
0
 def testAssertConsumedFailsWithUsedPythonState(self):
     has_config = base.Checkpointable()
     attributes = {
         "foo_attr":
         functools.partial(base.PythonStringStateSaveable,
                           state_callback=lambda: "",
                           restore_callback=lambda x: None)
     }
     has_config._gather_saveables_for_checkpoint = lambda: attributes
     saved = util.Checkpoint(obj=has_config)
     save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt"))
     restored = util.Checkpoint(obj=base.Checkpointable())
     status = restored.restore(save_path)
     with self.assertRaisesRegexp(AssertionError, "foo_attr"):
         status.assert_consumed()
Example #3
0
  def testNames(self):
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

    x1 = resource_variable_ops.ResourceVariable(2.)
    x2 = resource_variable_ops.ResourceVariable(3.)
    x3 = resource_variable_ops.ResourceVariable(4.)
    y = resource_variable_ops.ResourceVariable(5.)
    slots = containers.UniqueNameTracker()
    slots.track(x1, "x")
    slots.track(x2, "x")
    slots.track(x3, "x_1")
    slots.track(y, "y")
    self.evaluate((x1.initializer, x2.initializer, x3.initializer,
                   y.initializer))
    save_root = checkpointable_utils.Checkpoint(slots=slots)
    save_path = save_root.save(checkpoint_prefix)

    restore_slots = checkpointable.Checkpointable()
    restore_root = checkpointable_utils.Checkpoint(
        slots=restore_slots)
    status = restore_root.restore(save_path)
    restore_slots.x = resource_variable_ops.ResourceVariable(0.)
    restore_slots.x_1 = resource_variable_ops.ResourceVariable(0.)
    restore_slots.x_1_1 = resource_variable_ops.ResourceVariable(0.)
    restore_slots.y = resource_variable_ops.ResourceVariable(0.)
    status.assert_consumed().run_restore_ops()
    self.assertEqual(2., self.evaluate(restore_slots.x))
    self.assertEqual(3., self.evaluate(restore_slots.x_1))
    self.assertEqual(4., self.evaluate(restore_slots.x_1_1))
    self.assertEqual(5., self.evaluate(restore_slots.y))
Example #4
0
 def testOverwrite(self):
     root = base.Checkpointable()
     leaf = base.Checkpointable()
     root._track_checkpointable(leaf, name="leaf")
     (current_name, current_dependency), = root._checkpoint_dependencies
     self.assertIs(leaf, current_dependency)
     self.assertEqual("leaf", current_name)
     duplicate_name_dep = base.Checkpointable()
     with self.assertRaises(ValueError):
         root._track_checkpointable(duplicate_name_dep, name="leaf")
     root._track_checkpointable(duplicate_name_dep,
                                name="leaf",
                                overwrite=True)
     (current_name, current_dependency), = root._checkpoint_dependencies
     self.assertIs(duplicate_name_dep, current_dependency)
     self.assertEqual("leaf", current_name)
 def testManySavesGraph(self):
     """Saves after the first should not modify the graph."""
     with context.graph_mode():
         graph = ops.Graph()
         with graph.as_default(), self.test_session(graph):
             checkpoint_directory = self.get_temp_dir()
             checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
             obj = checkpointable.Checkpointable()
             obj.var = variable_scope.get_variable(name="v", initializer=0.)
             obj.opt = adam.AdamOptimizer(0.1)
             obj.opt.minimize(obj.var.read_value())
             self.evaluate(checkpointable_utils.gather_initializers(obj))
             saver = checkpointable_utils.CheckpointableSaver(obj)
             saver.save(checkpoint_prefix)
             before_ops = graph.get_operations()
             saver.save(checkpoint_prefix)
             self.assertEqual(before_ops, graph.get_operations())
Example #6
0
 def testAddVariableOverwrite(self):
     root = base.Checkpointable()
     a = root._add_variable_with_custom_getter(
         name="v", shape=[], getter=variable_scope.get_variable)
     self.assertEqual([root, a], util.list_objects(root))
     with ops.Graph().as_default():
         b = root._add_variable_with_custom_getter(
             name="v",
             shape=[],
             overwrite=True,
             getter=variable_scope.get_variable)
         self.assertEqual([root, b], util.list_objects(root))
     with ops.Graph().as_default():
         with self.assertRaisesRegexp(ValueError,
                                      "already declared as a dependency"):
             root._add_variable_with_custom_getter(
                 name="v",
                 shape=[],
                 overwrite=False,
                 getter=variable_scope.get_variable)
    def testDeferredSlotRestoration(self):
        checkpoint_directory = self.get_temp_dir()

        root = checkpointable.Checkpointable()
        root.var = checkpointable_utils.add_variable(root,
                                                     name="var",
                                                     initializer=0.)
        optimizer = adam.AdamOptimizer(0.1)
        if context.executing_eagerly():
            optimizer.minimize(root.var.read_value)
        else:
            train_op = optimizer.minimize(root.var)
            # Note that `optimizer` has not been added as a dependency of
            # `root`. Create a one-off grouping so that slot variables for `root.var`
            # get initialized too.
            self.evaluate(
                checkpointable_utils.gather_initializers(
                    checkpointable_utils.Checkpoint(root=root,
                                                    optimizer=optimizer)))
            self.evaluate(train_op)
        self.evaluate(state_ops.assign(root.var, 12.))
        no_slots_path = checkpointable_utils.CheckpointableSaver(root).save(
            os.path.join(checkpoint_directory, "no_slots"))
        root.optimizer = optimizer
        self.evaluate(state_ops.assign(root.var, 13.))
        self.evaluate(
            state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.))
        slots_path = checkpointable_utils.CheckpointableSaver(root).save(
            os.path.join(checkpoint_directory, "with_slots"))
        new_root = checkpointable.Checkpointable()
        # Load the slot-containing checkpoint (deferred), then immediately overwrite
        # the non-slot variable (also deferred).
        slot_status = checkpointable_utils.CheckpointableSaver(
            new_root).restore(slots_path)
        no_slot_status = checkpointable_utils.CheckpointableSaver(
            new_root).restore(no_slots_path)
        with self.assertRaises(AssertionError):
            no_slot_status.assert_consumed()
        new_root.var = checkpointable_utils.add_variable(new_root,
                                                         name="var",
                                                         shape=[])
        no_slot_status.assert_consumed()
        no_slot_status.run_restore_ops()
        self.assertEqual(12., self.evaluate(new_root.var))
        new_root.optimizer = adam.AdamOptimizer(0.1)
        with self.assertRaisesRegexp(AssertionError, "beta1_power"):
            slot_status.assert_consumed()
        self.assertEqual(12., self.evaluate(new_root.var))
        if context.executing_eagerly():
            # Slot variables are only created with restoring initializers when
            # executing eagerly.
            self.assertEqual(
                14.,
                self.evaluate(
                    new_root.optimizer.get_slot(name="m", var=new_root.var)))
        else:
            self.assertIs(
                new_root.optimizer.get_slot(name="m", var=new_root.var), None)
        if context.executing_eagerly():
            new_root.optimizer.minimize(new_root.var.read_value)
        else:
            train_op = new_root.optimizer.minimize(new_root.var)
            # The slot variable now exists; restore() didn't create it, but we should
            # now have a restore op for it.
            slot_status.run_restore_ops()
            self.assertEqual(
                14.,
                self.evaluate(
                    new_root.optimizer.get_slot(name="m", var=new_root.var)))
            self.evaluate(train_op)
        slot_status.assert_consumed()