def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() root = checkpointable.Checkpointable() root.var = checkpointable_utils.add_variable( root, name="var", initializer=0.) optimizer = CheckpointableAdam(0.1) if context.in_graph_mode(): train_op = optimizer.minimize(root.var) self.evaluate(variables.global_variables_initializer()) self.evaluate(train_op) else: optimizer.minimize(root.var.read_value) self.evaluate(state_ops.assign(root.var, 12.)) no_slots_path = checkpointable_utils.save( os.path.join(checkpoint_directory, "no_slots"), root) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) slots_path = checkpointable_utils.save( os.path.join(checkpoint_directory, "with_slots"), root) new_root = checkpointable.Checkpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). slot_status = checkpointable_utils.restore( slots_path, new_root) no_slot_status = checkpointable_utils.restore( no_slots_path, new_root) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = checkpointable_utils.add_variable( new_root, name="var", shape=[]) no_slot_status.assert_consumed() self.evaluate(no_slot_status.restore_ops) self.assertEqual(12., self.evaluate(new_root.var)) new_root.optimizer = CheckpointableAdam(0.1) with self.assertRaisesRegexp(AssertionError, "beta1_power"): slot_status.assert_consumed() self.assertEqual(12., self.evaluate(new_root.var)) if context.in_eager_mode(): # Slot variables are only created with restoring initializers when # executing eagerly. self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) else: self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var), None) if context.in_graph_mode(): train_op = new_root.optimizer.minimize(new_root.var) # The slot variable now exists; restore() didn't create it, but we should # now have a restore op for it. self.evaluate(slot_status.restore_ops) self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) self.evaluate(train_op) else: new_root.optimizer.minimize(new_root.var.read_value) slot_status.assert_consumed()
def testNumberedPath(self): root = checkpointable.Checkpointable() leaf = checkpointable.Checkpointable() root.leaf = leaf checkpointable_utils.add_variable(leaf, name="v", shape=[]) named_variables, _ = checkpointable_utils._serialize_object_graph(root) variable_name, = named_variables.keys() self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", variable_name)
def _get_checkpoint_name(self, name): root = checkpointable.Checkpointable() checkpointable_utils.add_variable( root, name=name, shape=[1, 2], dtype=dtypes.float64) named_variables, _ = checkpointable_utils._serialize_object_graph(root) checkpoint_name, = named_variables.keys() with ops.name_scope("root/" + checkpoint_name): pass # Make sure we can use this as an op name if we prefix it. return checkpoint_name
def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() root = checkpointable.Checkpointable() root.var = checkpointable_utils.add_variable( root, name="var", initializer=0.) optimizer = CheckpointableAdam(0.1) if context.in_graph_mode(): train_op = optimizer.minimize(root.var) self.evaluate(variables.global_variables_initializer()) self.evaluate(train_op) else: optimizer.minimize(root.var.read_value) self.evaluate(state_ops.assign(root.var, 12.)) no_slots_path = checkpointable_utils.Saver(root).save( os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) slots_path = checkpointable_utils.Saver(root).save( os.path.join(checkpoint_directory, "with_slots")) new_root = checkpointable.Checkpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). slot_status = checkpointable_utils.Saver(new_root).restore(slots_path) no_slot_status = checkpointable_utils.Saver(new_root).restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = checkpointable_utils.add_variable( new_root, name="var", shape=[]) no_slot_status.assert_consumed() no_slot_status.run_restore_ops() self.assertEqual(12., self.evaluate(new_root.var)) new_root.optimizer = CheckpointableAdam(0.1) with self.assertRaisesRegexp(AssertionError, "beta1_power"): slot_status.assert_consumed() self.assertEqual(12., self.evaluate(new_root.var)) if context.in_eager_mode(): # Slot variables are only created with restoring initializers when # executing eagerly. self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) else: self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var), None) if context.in_graph_mode(): train_op = new_root.optimizer.minimize(new_root.var) # The slot variable now exists; restore() didn't create it, but we should # now have a restore op for it. slot_status.run_restore_ops() self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) self.evaluate(train_op) else: new_root.optimizer.minimize(new_root.var.read_value) slot_status.assert_consumed()
def testLocalNameValidation(self): root = checkpointable.Checkpointable() leaf = checkpointable.Checkpointable() # Dots are escaped, which avoids conflicts with reserved names. root._track_checkpointable(leaf, name=".ATTRIBUTES") checkpointable_utils.add_variable(checkpointable=leaf, name="a", shape=[]) named_variables, _ = checkpointable_utils._serialize_object_graph(root) name, = named_variables.keys() self.assertEqual(name, "..ATTRIBUTES/a/.ATTRIBUTES/VARIABLE_VALUE")
def testInitNotCalled(self): class NoInit(checkpointable.Checkpointable): def __init__(self): pass # __init__ for Checkpointable will be called implicitly. checkpointable_utils.add_variable(NoInit(), "var", shape=[])
def testOverlappingRestores(self): checkpoint_directory = self.get_temp_dir() save_root = checkpointable.Checkpointable() save_root.dep = checkpointable.Checkpointable() save_root.dep.var = checkpointable_utils.add_variable( save_root.dep, name="var", initializer=0.) self.evaluate(state_ops.assign(save_root.dep.var, 12.)) first_path = checkpointable_utils.save( os.path.join(checkpoint_directory, "first"), save_root) self.evaluate(state_ops.assign(save_root.dep.var, 13.)) second_path = checkpointable_utils.save( os.path.join(checkpoint_directory, "second"), save_root) first_root = checkpointable.Checkpointable() second_root = checkpointable.Checkpointable() first_status = checkpointable_utils.restore( first_path, first_root) second_status = checkpointable_utils.restore( second_path, second_root) load_dep = checkpointable.Checkpointable() load_dep.var = checkpointable_utils.add_variable( load_dep, name="var", shape=[]) first_root.dep = load_dep first_status.assert_consumed() self.evaluate(first_status.restore_ops) self.assertEqual([], second_status.restore_ops) self.assertEqual(12., self.evaluate(load_dep.var)) second_root.dep = load_dep second_status.assert_consumed() self.evaluate(second_status.restore_ops) self.assertEqual(13., self.evaluate(load_dep.var)) # Try again with the order of the restore() reversed. The last restore # determines the final value. first_root = checkpointable.Checkpointable() second_root = checkpointable.Checkpointable() second_status = checkpointable_utils.restore( second_path, second_root) first_status = checkpointable_utils.restore( first_path, first_root) load_dep = checkpointable.Checkpointable() load_dep.var = checkpointable_utils.add_variable( load_dep, name="var", shape=[]) first_root.dep = load_dep first_status.assert_consumed() self.assertEqual([], second_status.restore_ops) self.evaluate(first_status.restore_ops) self.assertEqual(12., self.evaluate(load_dep.var)) second_root.dep = load_dep second_status.assert_consumed() self.evaluate(second_status.restore_ops) self.assertEqual(12., self.evaluate(load_dep.var))
def testShapeDtype(self): root = checkpointable.Checkpointable() v1 = checkpointable_utils.add_variable( root, name="v1", initializer=3., dtype=dtypes.float64) self.assertEqual(dtypes.float64, v1.dtype) v2 = checkpointable_utils.add_variable( root, name="v2", shape=[3], initializer=init_ops.ones_initializer, dtype=dtypes.float64) self.assertEqual(dtypes.float64, v2.dtype) self.assertAllEqual([1., 1., 1.], self.evaluate(v2))
def testOverlappingRestores(self): checkpoint_directory = self.get_temp_dir() save_root = checkpointable.Checkpointable() save_root.dep = checkpointable.Checkpointable() save_root.dep.var = checkpointable_utils.add_variable( save_root.dep, name="var", initializer=0.) self.evaluate(state_ops.assign(save_root.dep.var, 12.)) saver = checkpointable_utils.CheckpointableSaver(save_root) first_path = saver.save(os.path.join(checkpoint_directory, "first")) self.evaluate(state_ops.assign(save_root.dep.var, 13.)) second_path = saver.save(os.path.join(checkpoint_directory, "second")) first_root = checkpointable.Checkpointable() second_root = checkpointable.Checkpointable() first_status = checkpointable_utils.CheckpointableSaver( first_root).restore(first_path) second_status = checkpointable_utils.CheckpointableSaver( second_root).restore(second_path) load_dep = checkpointable.Checkpointable() load_dep.var = checkpointable_utils.add_variable( load_dep, name="var", shape=[]) first_root.dep = load_dep first_status.assert_consumed() first_status.run_restore_ops() self.assertEqual(12., self.evaluate(load_dep.var)) second_root.dep = load_dep second_status.assert_consumed() second_status.run_restore_ops() self.assertEqual(13., self.evaluate(load_dep.var)) # Try again with the order of the restore() reversed. The last restore # determines the final value. first_root = checkpointable.Checkpointable() second_root = checkpointable.Checkpointable() second_status = checkpointable_utils.CheckpointableSaver( second_root).restore(second_path) first_status = checkpointable_utils.CheckpointableSaver( first_root).restore(first_path) load_dep = checkpointable.Checkpointable() load_dep.var = checkpointable_utils.add_variable( load_dep, name="var", shape=[]) first_root.dep = load_dep first_status.assert_consumed() first_status.run_restore_ops() self.assertEqual(12., self.evaluate(load_dep.var)) second_root.dep = load_dep second_status.assert_consumed() second_status.run_restore_ops() self.assertEqual(12., self.evaluate(load_dep.var))
def testAddVariable(self): obj = NonLayerCheckpointable() with self.assertRaisesRegexp(ValueError, "do not specify shape"): checkpointable_utils.add_variable( obj, name="shape_specified_twice", shape=[], initializer=1) constant_initializer = checkpointable_utils.add_variable( obj, name="constant_initializer", initializer=1) with variable_scope.variable_scope("some_variable_scope"): ones_initializer = checkpointable_utils.add_variable( obj, name="ones_initializer", shape=[2], initializer=init_ops.ones_initializer(dtype=dtypes.float32)) bare_initializer = checkpointable_utils.add_variable( obj, name="bare_initializer", shape=[2, 2], dtype=dtypes.float64, initializer=init_ops.zeros_initializer) # Even in graph mode, there are no naming conflicts between objects, only # naming conflicts within an object. other_duplicate = resource_variable_ops.ResourceVariable( name="duplicate", initial_value=1.) duplicate = checkpointable_utils.add_variable( obj, name="duplicate", shape=[]) with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"): checkpointable_utils.add_variable(obj, name="duplicate", shape=[]) if context.in_graph_mode(): self.evaluate(variables.global_variables_initializer()) self.assertEqual("constant_initializer:0", constant_initializer.name) self.assertEqual(1, self.evaluate(constant_initializer)) self.assertEqual("some_variable_scope/ones_initializer:0", ones_initializer.name) self.assertAllEqual([1, 1], self.evaluate(ones_initializer)) self.assertAllEqual([[0., 0.], [0., 0.]], self.evaluate(bare_initializer)) self.assertEqual("a_variable:0", obj.a_variable.name) self.assertEqual("duplicate:0", other_duplicate.name) if context.in_graph_mode(): # The .name attribute may be globally influenced, but the checkpoint name # won't be (tested below). self.assertEqual("duplicate_1:0", duplicate.name) else: # When executing eagerly, there's no uniquification of variable names. The # checkpoint name will be the same. self.assertEqual("duplicate:0", duplicate.name) named_variables, _ = checkpointable_utils._serialize_object_graph(obj) expected_checkpoint_names = ( "a_variable/.ATTRIBUTES/VARIABLE_VALUE", "bare_initializer/.ATTRIBUTES/VARIABLE_VALUE", "constant_initializer/.ATTRIBUTES/VARIABLE_VALUE", "duplicate/.ATTRIBUTES/VARIABLE_VALUE", "ones_initializer/.ATTRIBUTES/VARIABLE_VALUE", ) six.assertCountEqual( self, expected_checkpoint_names, named_variables.keys())
def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() root = checkpointable.Checkpointable() root.var = checkpointable_utils.add_variable( root, name="var", initializer=0.) optimizer = CheckpointableAdam(0.1) if context.in_graph_mode(): train_op = optimizer.minimize(root.var) self.evaluate(variables.global_variables_initializer()) self.evaluate(train_op) else: optimizer.minimize(root.var.read_value) self.evaluate(state_ops.assign(root.var, 12.)) no_slots_path = checkpointable_utils.save( os.path.join(checkpoint_directory, "no_slots"), root) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) slots_path = checkpointable_utils.save( os.path.join(checkpoint_directory, "with_slots"), root) new_root = checkpointable.Checkpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). slot_status = checkpointable_utils.restore( slots_path, new_root) no_slot_status = checkpointable_utils.restore( no_slots_path, new_root) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = checkpointable_utils.add_variable( new_root, name="var", shape=[]) self.assertEqual(12., self.evaluate(new_root.var)) no_slot_status.assert_consumed() new_root.optimizer = CheckpointableAdam(0.1) with self.assertRaisesRegexp(AssertionError, "beta1_power"): slot_status.assert_consumed() self.assertEqual(12., self.evaluate(new_root.var)) self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) if context.in_graph_mode(): train_op = new_root.optimizer.minimize(new_root.var) self.evaluate(train_op) else: new_root.optimizer.minimize(new_root.var.read_value) slot_status.assert_consumed()
def testDependencyLoop(self): # Note: this test creates garbage during eager execution because it # purposefully creates a reference cycle. first = checkpointable.Checkpointable() second = checkpointable.Checkpointable() first.second = second second.first = first first.v = checkpointable_utils.add_variable( first, "v1", initializer=[3., 1., 4.]) second.v = checkpointable_utils.add_variable( second, "v2", initializer=[1., 1., 2., 3.]) self.evaluate(checkpointable_utils.gather_initializers(first)) checkpoint_directory = self.get_temp_dir() save_path = checkpointable_utils.CheckpointableSaver(first).save( os.path.join(checkpoint_directory, "ckpt")) # Test deferred loading first_load = checkpointable.Checkpointable() status = checkpointable_utils.CheckpointableSaver( first_load).restore(save_path) second_load = checkpointable.Checkpointable() first_load.second = second_load second_load.first = first_load with self.assertRaises(AssertionError): status.assert_consumed() first_load.v = checkpointable_utils.add_variable( first_load, "v1", shape=[3]) second_load.v = checkpointable_utils.add_variable( second_load, "v2", shape=[4]) status.assert_consumed() status.run_restore_ops() self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v)) self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v)) # Test loading when variables have already been created self.evaluate(first_load.v.assign([2., 7., 1.])) self.assertAllEqual([2., 7., 1.], self.evaluate(first_load.v)) self.evaluate(second_load.v.assign([2., 7., 1., 8.])) self.assertAllEqual([2., 7., 1., 8.], self.evaluate(second_load.v)) status = checkpointable_utils.CheckpointableSaver(first_load).restore( save_path).assert_consumed() status.run_restore_ops() self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v)) self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v))
def testAmbiguousLoad(self): # Not OK to split one checkpoint object into two checkpoint_directory = self.get_temp_dir() save_root = checkpointable.Checkpointable() save_root.dep_one = checkpointable.Checkpointable() save_root.dep_two = checkpointable.Checkpointable() dep_three = checkpointable.Checkpointable() save_root.dep_one.dep_three = dep_three save_root.dep_two.dep_three = dep_three checkpointable_utils.add_variable(dep_three, name="var", initializer=0.) self.evaluate(checkpointable_utils.gather_initializers(save_root)) save_path = checkpointable_utils.CheckpointableSaver(save_root).save( os.path.join(checkpoint_directory, "ckpt")) load_root = checkpointable.Checkpointable() checkpointable_utils.CheckpointableSaver(load_root).restore(save_path) load_root.dep_one = checkpointable.Checkpointable() load_root.dep_two = checkpointable.Checkpointable() load_root.dep_one.dep_three = checkpointable.Checkpointable() with self.assertRaisesRegexp(AssertionError, "resolved to different objects"): load_root.dep_two.dep_three = checkpointable.Checkpointable()
def testObjectsCombined(self): # Currently fine to load two checkpoint objects into one Python object checkpoint_directory = self.get_temp_dir() save_root = checkpointable.Checkpointable() save_root.dep_one = checkpointable.Checkpointable() save_root.dep_two = checkpointable.Checkpointable() checkpointable_utils.add_variable( save_root.dep_one, name="var1", initializer=32., dtype=dtypes.float64) checkpointable_utils.add_variable( save_root.dep_two, name="var2", initializer=64., dtype=dtypes.float64) self.evaluate(checkpointable_utils.gather_initializers(save_root)) save_path = checkpointable_utils.CheckpointableSaver(save_root).save( os.path.join(checkpoint_directory, "ckpt")) load_root = checkpointable.Checkpointable() load_root.dep_one = checkpointable.Checkpointable() load_root.dep_two = load_root.dep_one v1 = checkpointable_utils.add_variable( load_root.dep_one, name="var1", shape=[], dtype=dtypes.float64) v2 = checkpointable_utils.add_variable( load_root.dep_one, name="var2", shape=[], dtype=dtypes.float64) status = checkpointable_utils.CheckpointableSaver(load_root).restore( save_path).assert_consumed() status.run_restore_ops() self.assertEqual(32., self.evaluate(v1)) self.assertEqual(64., self.evaluate(v2))
def save_counter(self): """An integer variable which starts at zero and is incremented on save. Used to number checkpoints. Returns: The save counter variable. """ if self._save_counter is None: # Initialized to 0 and incremented before saving. self._save_counter = checkpointable_utils.add_variable( self, name="save_counter", initializer=0, dtype=dtypes.int64) return self._save_counter
def __init__(self): super(NonLayerCheckpointable, self).__init__() self.a_variable = checkpointable_utils.add_variable( self, name="a_variable", shape=[])
def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() root = checkpointable.Checkpointable() root.var = checkpointable_utils.add_variable( root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) if context.in_graph_mode(): train_op = optimizer.minimize(root.var) # Note that `optimizer` has not been added as a dependency of # `root`. Create a one-off grouping so that slot variables for `root.var` # get initialized too. self.evaluate(checkpointable_utils.gather_initializers( checkpointable_utils.Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) else: optimizer.minimize(root.var.read_value) self.evaluate(state_ops.assign(root.var, 12.)) no_slots_path = checkpointable_utils.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) slots_path = checkpointable_utils.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "with_slots")) new_root = checkpointable.Checkpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). slot_status = checkpointable_utils.CheckpointableSaver( new_root).restore(slots_path) no_slot_status = checkpointable_utils.CheckpointableSaver( new_root).restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = checkpointable_utils.add_variable( new_root, name="var", shape=[]) no_slot_status.assert_consumed() no_slot_status.run_restore_ops() self.assertEqual(12., self.evaluate(new_root.var)) new_root.optimizer = adam.AdamOptimizer(0.1) with self.assertRaisesRegexp(AssertionError, "beta1_power"): slot_status.assert_consumed() self.assertEqual(12., self.evaluate(new_root.var)) if context.in_eager_mode(): # Slot variables are only created with restoring initializers when # executing eagerly. self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) else: self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var), None) if context.in_graph_mode(): train_op = new_root.optimizer.minimize(new_root.var) # The slot variable now exists; restore() didn't create it, but we should # now have a restore op for it. slot_status.run_restore_ops() self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) self.evaluate(train_op) else: new_root.optimizer.minimize(new_root.var.read_value) slot_status.assert_consumed()
def build(self): self.var = checkpointable_utils.add_variable( self, "var", initializer=0.)
def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() root = checkpointable.Checkpointable() root.var = checkpointable_utils.add_variable( root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) if context.executing_eagerly(): optimizer.minimize(root.var.read_value) else: train_op = optimizer.minimize(root.var) # Note that `optimizer` has not been added as a dependency of # `root`. Create a one-off grouping so that slot variables for `root.var` # get initialized too. self.evaluate(checkpointable_utils.gather_initializers( checkpointable_utils.Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) self.evaluate(state_ops.assign(root.var, 12.)) no_slots_path = checkpointable_utils.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) slots_path = checkpointable_utils.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "with_slots")) new_root = checkpointable.Checkpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). slot_status = checkpointable_utils.CheckpointableSaver( new_root).restore(slots_path) no_slot_status = checkpointable_utils.CheckpointableSaver( new_root).restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = checkpointable_utils.add_variable( new_root, name="var", shape=[]) no_slot_status.assert_consumed() no_slot_status.run_restore_ops() self.assertEqual(12., self.evaluate(new_root.var)) new_root.optimizer = adam.AdamOptimizer(0.1) with self.assertRaisesRegexp(AssertionError, "beta1_power"): slot_status.assert_consumed() self.assertEqual(12., self.evaluate(new_root.var)) if context.executing_eagerly(): # Slot variables are only created with restoring initializers when # executing eagerly. self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) else: self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var), None) if context.executing_eagerly(): new_root.optimizer.minimize(new_root.var.read_value) else: train_op = new_root.optimizer.minimize(new_root.var) # The slot variable now exists; restore() didn't create it, but we should # now have a restore op for it. slot_status.run_restore_ops() self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) self.evaluate(train_op) slot_status.assert_consumed()