def testDepAfterVar(self): class Dependency(checkpointable.Checkpointable): def build(self): self.var = checkpointable_utils.add_variable( self, "var", initializer=0.) class DepAfterVar(checkpointable.Checkpointable): def add_dep(self): dep = Dependency() dep.build() self.dep = dep dep_after_var = DepAfterVar() dep_after_var.add_dep() self.evaluate(state_ops.assign(dep_after_var.dep.var, -14.)) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_path = checkpointable_utils.CheckpointableSaver(dep_after_var).save( checkpoint_prefix) loaded_dep_after_var = DepAfterVar() status = checkpointable_utils.CheckpointableSaver( loaded_dep_after_var).restore(save_path) loaded_dep_after_var.add_dep() status.assert_consumed() status.run_restore_ops() self.assertEqual(-14., self.evaluate(loaded_dep_after_var.dep.var))
def testRestoreOnAssign(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_graph = ops.Graph() with save_graph.as_default(), self.test_session(save_graph): first = checkpointable.Checkpointable() first.var1 = variable_scope.get_variable( name="outside_var", initializer=0.) first.var2 = variable_scope.get_variable( name="blah", initializer=0.) self.evaluate(first.var1.assign(4.)) self.evaluate(first.var2.assign(8.)) save_path = checkpointable_utils.CheckpointableSaver(first).save( checkpoint_prefix) restore_graph = ops.Graph() with restore_graph.as_default(), self.test_session(restore_graph): second = checkpointable.Checkpointable() second.var2 = variable_scope.get_variable( name="blah", initializer=0.) status = checkpointable_utils.CheckpointableSaver( second).restore(save_path) recreated_var1 = variable_scope.get_variable( name="outside_var", initializer=0.) status.run_restore_ops() self.assertEqual(8., self.evaluate(second.var2)) self.evaluate(recreated_var1.assign(-2.)) self.assertEqual(-2., self.evaluate(recreated_var1)) second.var1 = recreated_var1 status.run_restore_ops() self.assertEqual(4., self.evaluate(recreated_var1))
def testObjectsCombined(self): # Currently fine to load two checkpoint objects into one Python object checkpoint_directory = self.get_temp_dir() save_root = checkpointable.Checkpointable() save_root.dep_one = checkpointable.Checkpointable() save_root.dep_two = checkpointable.Checkpointable() checkpointable_utils.add_variable( save_root.dep_one, name="var1", initializer=32., dtype=dtypes.float64) checkpointable_utils.add_variable( save_root.dep_two, name="var2", initializer=64., dtype=dtypes.float64) self.evaluate(checkpointable_utils.gather_initializers(save_root)) save_path = checkpointable_utils.CheckpointableSaver(save_root).save( os.path.join(checkpoint_directory, "ckpt")) load_root = checkpointable.Checkpointable() load_root.dep_one = checkpointable.Checkpointable() load_root.dep_two = load_root.dep_one v1 = checkpointable_utils.add_variable( load_root.dep_one, name="var1", shape=[], dtype=dtypes.float64) v2 = checkpointable_utils.add_variable( load_root.dep_one, name="var2", shape=[], dtype=dtypes.float64) status = checkpointable_utils.CheckpointableSaver(load_root).restore( save_path).assert_consumed() status.run_restore_ops() self.assertEqual(32., self.evaluate(v1)) self.assertEqual(64., self.evaluate(v2))
def testLateDependencyTracking(self): class Dependency(checkpointable.Checkpointable): def build(self): self.var = checkpointable_utils.add_variable( self, "var", initializer=0.) class LateDependencies(checkpointable.Checkpointable): def add_dep(self): self.dep = Dependency() self.dep.build() original = LateDependencies() original.add_dep() self.evaluate(state_ops.assign(original.dep.var, 123.)) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_path = checkpointable_utils.CheckpointableSaver( original).save(checkpoint_prefix) load_into = LateDependencies() status = checkpointable_utils.CheckpointableSaver( load_into).restore(save_path) with self.assertRaises(AssertionError): status.assert_consumed() load_into.add_dep() status.assert_consumed() status.run_restore_ops() self.assertEqual(123., self.evaluate(load_into.dep.var))
def testOverlappingRestores(self): checkpoint_directory = self.get_temp_dir() save_root = checkpointable.Checkpointable() save_root.dep = checkpointable.Checkpointable() save_root.dep.var = checkpointable_utils.add_variable( save_root.dep, name="var", initializer=0.) self.evaluate(state_ops.assign(save_root.dep.var, 12.)) saver = checkpointable_utils.CheckpointableSaver(save_root) first_path = saver.save(os.path.join(checkpoint_directory, "first")) self.evaluate(state_ops.assign(save_root.dep.var, 13.)) second_path = saver.save(os.path.join(checkpoint_directory, "second")) first_root = checkpointable.Checkpointable() second_root = checkpointable.Checkpointable() first_status = checkpointable_utils.CheckpointableSaver( first_root).restore(first_path) second_status = checkpointable_utils.CheckpointableSaver( second_root).restore(second_path) load_dep = checkpointable.Checkpointable() load_dep.var = checkpointable_utils.add_variable( load_dep, name="var", shape=[]) first_root.dep = load_dep first_status.assert_consumed() first_status.run_restore_ops() self.assertEqual(12., self.evaluate(load_dep.var)) second_root.dep = load_dep second_status.assert_consumed() second_status.run_restore_ops() self.assertEqual(13., self.evaluate(load_dep.var)) # Try again with the order of the restore() reversed. The last restore # determines the final value. first_root = checkpointable.Checkpointable() second_root = checkpointable.Checkpointable() second_status = checkpointable_utils.CheckpointableSaver( second_root).restore(second_path) first_status = checkpointable_utils.CheckpointableSaver( first_root).restore(first_path) load_dep = checkpointable.Checkpointable() load_dep.var = checkpointable_utils.add_variable( load_dep, name="var", shape=[]) first_root.dep = load_dep first_status.assert_consumed() first_status.run_restore_ops() self.assertEqual(12., self.evaluate(load_dep.var)) second_root.dep = load_dep second_status.assert_consumed() second_status.run_restore_ops() self.assertEqual(12., self.evaluate(load_dep.var))
def testDependencyLoop(self): # Note: this test creates garbage during eager execution because it # purposefully creates a reference cycle. first = checkpointable.Checkpointable() second = checkpointable.Checkpointable() first.second = second second.first = first first.v = checkpointable_utils.add_variable( first, "v1", initializer=[3., 1., 4.]) second.v = checkpointable_utils.add_variable( second, "v2", initializer=[1., 1., 2., 3.]) self.evaluate(checkpointable_utils.gather_initializers(first)) checkpoint_directory = self.get_temp_dir() save_path = checkpointable_utils.CheckpointableSaver(first).save( os.path.join(checkpoint_directory, "ckpt")) # Test deferred loading first_load = checkpointable.Checkpointable() status = checkpointable_utils.CheckpointableSaver( first_load).restore(save_path) second_load = checkpointable.Checkpointable() first_load.second = second_load second_load.first = first_load with self.assertRaises(AssertionError): status.assert_consumed() first_load.v = checkpointable_utils.add_variable( first_load, "v1", shape=[3]) second_load.v = checkpointable_utils.add_variable( second_load, "v2", shape=[4]) status.assert_consumed() status.run_restore_ops() self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v)) self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v)) # Test loading when variables have already been created self.evaluate(first_load.v.assign([2., 7., 1.])) self.assertAllEqual([2., 7., 1.], self.evaluate(first_load.v)) self.evaluate(second_load.v.assign([2., 7., 1., 8.])) self.assertAllEqual([2., 7., 1., 8.], self.evaluate(second_load.v)) status = checkpointable_utils.CheckpointableSaver(first_load).restore( save_path).assert_consumed() status.run_restore_ops() self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v)) self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v))
def testAmbiguousLoad(self): # Not OK to split one checkpoint object into two checkpoint_directory = self.get_temp_dir() save_root = checkpointable.Checkpointable() save_root.dep_one = checkpointable.Checkpointable() save_root.dep_two = checkpointable.Checkpointable() dep_three = checkpointable.Checkpointable() save_root.dep_one.dep_three = dep_three save_root.dep_two.dep_three = dep_three checkpointable_utils.add_variable(dep_three, name="var", initializer=0.) self.evaluate(checkpointable_utils.gather_initializers(save_root)) save_path = checkpointable_utils.CheckpointableSaver(save_root).save( os.path.join(checkpoint_directory, "ckpt")) load_root = checkpointable.Checkpointable() checkpointable_utils.CheckpointableSaver(load_root).restore(save_path) load_root.dep_one = checkpointable.Checkpointable() load_root.dep_two = checkpointable.Checkpointable() load_root.dep_one.dep_three = checkpointable.Checkpointable() with self.assertRaisesRegexp(AssertionError, "resolved to different objects"): load_root.dep_two.dep_three = checkpointable.Checkpointable()
def testSaveEagerLoadGraph(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") with context.eager_mode(): root = self._initialized_model() object_saver = checkpointable_utils.CheckpointableSaver(root) save_path = object_saver.save(file_prefix=checkpoint_prefix) with context.graph_mode(): save_graph = ops.Graph() with save_graph.as_default(), self.test_session(graph=save_graph): root = self._initialized_model() self._set_sentinels(root) root.restore(save_path).assert_consumed().run_restore_ops() self._check_sentinels(root)
def testLoadFromNameBasedSaver(self): """Save a name-based checkpoint, load it using the object-based API.""" save_path = self._write_name_based_checkpoint() root = self._initialized_model() self._set_sentinels(root) with self.assertRaises(AssertionError): self._check_sentinels(root) object_saver = checkpointable_utils.CheckpointableSaver(root) status = object_saver.restore(save_path) with self.assertRaises(AssertionError): status.assert_consumed() status.run_restore_ops() self._check_sentinels(root) self._set_sentinels(root) status.initialize_or_restore() self._check_sentinels(root)
def testManySavesGraph(self): """Saves after the first should not modify the graph.""" with context.graph_mode(): graph = ops.Graph() with graph.as_default(), self.test_session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") obj = checkpointable.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) self.evaluate(checkpointable_utils.gather_initializers(obj)) saver = checkpointable_utils.CheckpointableSaver(obj) saver.save(checkpoint_prefix) before_ops = graph.get_operations() saver.save(checkpoint_prefix) self.assertEqual(before_ops, graph.get_operations())
def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() root = checkpointable.Checkpointable() root.var = checkpointable_utils.add_variable( root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) if context.in_graph_mode(): train_op = optimizer.minimize(root.var) # Note that `optimizer` has not been added as a dependency of # `root`. Create a one-off grouping so that slot variables for `root.var` # get initialized too. self.evaluate(checkpointable_utils.gather_initializers( checkpointable_utils.Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) else: optimizer.minimize(root.var.read_value) self.evaluate(state_ops.assign(root.var, 12.)) no_slots_path = checkpointable_utils.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) slots_path = checkpointable_utils.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "with_slots")) new_root = checkpointable.Checkpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). slot_status = checkpointable_utils.CheckpointableSaver( new_root).restore(slots_path) no_slot_status = checkpointable_utils.CheckpointableSaver( new_root).restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = checkpointable_utils.add_variable( new_root, name="var", shape=[]) no_slot_status.assert_consumed() no_slot_status.run_restore_ops() self.assertEqual(12., self.evaluate(new_root.var)) new_root.optimizer = adam.AdamOptimizer(0.1) with self.assertRaisesRegexp(AssertionError, "beta1_power"): slot_status.assert_consumed() self.assertEqual(12., self.evaluate(new_root.var)) if context.in_eager_mode(): # Slot variables are only created with restoring initializers when # executing eagerly. self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) else: self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var), None) if context.in_graph_mode(): train_op = new_root.optimizer.minimize(new_root.var) # The slot variable now exists; restore() didn't create it, but we should # now have a restore op for it. slot_status.run_restore_ops() self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) self.evaluate(train_op) else: new_root.optimizer.minimize(new_root.var.read_value) slot_status.assert_consumed()