def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() root = trackable_utils.Checkpoint() root.var = trackable_utils.add_variable( root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) if context.executing_eagerly(): optimizer.minimize(root.var.read_value) else: train_op = optimizer.minimize(root.var) # Note that `optimizer` has not been added as a dependency of # `root`. Create a one-off grouping so that slot variables for `root.var` # get initialized too. self.evaluate(trackable_utils.gather_initializers( trackable_utils.Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) self.evaluate(state_ops.assign(root.var, 12.)) no_slots_path = root.save(os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) slots_path = root.save(os.path.join(checkpoint_directory, "with_slots")) new_root = trackable_utils.Checkpoint() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). slot_status = new_root.restore(slots_path) no_slot_status = new_root.restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = trackable_utils.add_variable( new_root, name="var", shape=[]) no_slot_status.assert_consumed() no_slot_status.run_restore_ops() self.assertEqual(12., self.evaluate(new_root.var)) new_root.optimizer = adam.AdamOptimizer(0.1) slot_status.assert_existing_objects_matched() with self.assertRaisesRegex(AssertionError, "beta1_power"): slot_status.assert_consumed() self.assertEqual(12., self.evaluate(new_root.var)) if context.executing_eagerly(): # Slot variables are only created with restoring initializers when # executing eagerly. self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) else: self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var), None) if context.executing_eagerly(): new_root.optimizer.minimize(new_root.var.read_value) else: train_op = new_root.optimizer.minimize(new_root.var) # The slot variable now exists; restore() didn't create it, but we should # now have a restore op for it. slot_status.run_restore_ops() self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) self.evaluate(train_op) slot_status.assert_consumed()
def __init__(self): super().__init__() self.a_variable = trackable_utils.add_variable(self, name="a_variable", shape=[])
def testDeferredSlotRestoration(self): with self.test_session(): checkpoint_directory = self.get_temp_dir() root = tf.train.Checkpoint() root.var = trackable_utils.add_variable(root, name="var", initializer=0.0) optimizer = adam.Adam(0.1) variables = [root.var] gradients = [1.0] train_op = optimizer.apply_gradients(zip(gradients, variables)) # Note that `optimizer` has not been added as a dependency of # `root`. Create a one-off grouping so that slot variables for # `root.var` get initialized too. self.evaluate( trackable_utils.gather_initializers( tf.train.Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) self.evaluate(tf.compat.v1.assign(root.var, 12.0)) no_slots_path = root.save( os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(tf.compat.v1.assign(root.var, 13.0)) self.evaluate( tf.compat.v1.assign( optimizer.get_slot(slot_name="m", var=root.var), 14.0)) slots_path = root.save( os.path.join(checkpoint_directory, "with_slots")) new_root = tf.train.Checkpoint() # Load the slot-containing checkpoint (deferred), then immediately # overwrite the non-slot variable (also deferred). slot_status = new_root.restore(slots_path) no_slot_status = new_root.restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = trackable_utils.add_variable(new_root, name="var", shape=[]) no_slot_status.assert_consumed() no_slot_status.run_restore_ops() self.assertEqual(12.0, self.evaluate(new_root.var)) new_root.optimizer = adam.Adam(0.1) slot_status.assert_existing_objects_matched() if not tf.executing_eagerly(): with self.assertRaisesRegex(AssertionError, "Unresolved object"): slot_status.assert_consumed() self.assertEqual(12.0, self.evaluate(new_root.var)) if tf.executing_eagerly(): # Slot variables are only created with restoring initializers # when executing eagerly. self.assertEqual( 14.0, self.evaluate( new_root.optimizer.get_slot(slot_name="m", var=new_root.var)), ) else: # Slot variables are not created eagerly when graph building. with self.assertRaises(KeyError): new_root.optimizer.get_slot(slot_name="m", var=new_root.var) variables = [new_root.var] gradients = [1.0] train_op = new_root.optimizer.apply_gradients( zip(gradients, variables)) # The slot variable now exists; restore() didn't create it, but we # should now have a restore op for it. slot_status.run_restore_ops() if not tf.executing_eagerly(): # The train op hasn't run when graph building, so the slot # variable has its restored value. It has run in eager, so the # value will be different. self.assertEqual( 14.0, self.evaluate( new_root.optimizer.get_slot(slot_name="m", var=new_root.var)), ) self.evaluate(train_op) slot_status.assert_consumed()