def testObjectsCombined(self): # Currently fine to load two checkpoint objects into one Python object checkpoint_directory = self.get_temp_dir() save_root = checkpointable.Checkpointable() save_root.dep_one = checkpointable.Checkpointable() save_root.dep_two = checkpointable.Checkpointable() checkpointable_utils.add_variable( save_root.dep_one, name="var1", initializer=32., dtype=dtypes.float64) checkpointable_utils.add_variable( save_root.dep_two, name="var2", initializer=64., dtype=dtypes.float64) self.evaluate(checkpointable_utils.gather_initializers(save_root)) save_path = checkpointable_utils.CheckpointableSaver(save_root).save( os.path.join(checkpoint_directory, "ckpt")) load_root = checkpointable.Checkpointable() load_root.dep_one = checkpointable.Checkpointable() load_root.dep_two = load_root.dep_one v1 = checkpointable_utils.add_variable( load_root.dep_one, name="var1", shape=[], dtype=dtypes.float64) v2 = checkpointable_utils.add_variable( load_root.dep_one, name="var2", shape=[], dtype=dtypes.float64) status = checkpointable_utils.CheckpointableSaver(load_root).restore( save_path).assert_consumed() status.run_restore_ops() self.assertEqual(32., self.evaluate(v1)) self.assertEqual(64., self.evaluate(v2))
def testRestoreOnAssign(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_graph = ops.Graph() with save_graph.as_default(), self.test_session(save_graph): first = checkpointable.Checkpointable() first.var1 = variable_scope.get_variable( name="outside_var", initializer=0.) first.var2 = variable_scope.get_variable( name="blah", initializer=0.) self.evaluate(first.var1.assign(4.)) self.evaluate(first.var2.assign(8.)) save_path = checkpointable_utils.CheckpointableSaver(first).save( checkpoint_prefix) restore_graph = ops.Graph() with restore_graph.as_default(), self.test_session(restore_graph): second = checkpointable.Checkpointable() second.var2 = variable_scope.get_variable( name="blah", initializer=0.) status = checkpointable_utils.CheckpointableSaver( second).restore(save_path) recreated_var1 = variable_scope.get_variable( name="outside_var", initializer=0.) status.run_restore_ops() self.assertEqual(8., self.evaluate(second.var2)) self.evaluate(recreated_var1.assign(-2.)) self.assertEqual(-2., self.evaluate(recreated_var1)) second.var1 = recreated_var1 status.run_restore_ops() self.assertEqual(4., self.evaluate(recreated_var1))
def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() root = checkpointable.Checkpointable() root.var = checkpointable_utils.add_variable( root, name="var", initializer=0.) optimizer = CheckpointableAdam(0.1) if context.in_graph_mode(): train_op = optimizer.minimize(root.var) self.evaluate(variables.global_variables_initializer()) self.evaluate(train_op) else: optimizer.minimize(root.var.read_value) self.evaluate(state_ops.assign(root.var, 12.)) no_slots_path = checkpointable_utils.save( os.path.join(checkpoint_directory, "no_slots"), root) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) slots_path = checkpointable_utils.save( os.path.join(checkpoint_directory, "with_slots"), root) new_root = checkpointable.Checkpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). slot_status = checkpointable_utils.restore( slots_path, new_root) no_slot_status = checkpointable_utils.restore( no_slots_path, new_root) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = checkpointable_utils.add_variable( new_root, name="var", shape=[]) no_slot_status.assert_consumed() self.evaluate(no_slot_status.restore_ops) self.assertEqual(12., self.evaluate(new_root.var)) new_root.optimizer = CheckpointableAdam(0.1) with self.assertRaisesRegexp(AssertionError, "beta1_power"): slot_status.assert_consumed() self.assertEqual(12., self.evaluate(new_root.var)) if context.in_eager_mode(): # Slot variables are only created with restoring initializers when # executing eagerly. self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) else: self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var), None) if context.in_graph_mode(): train_op = new_root.optimizer.minimize(new_root.var) # The slot variable now exists; restore() didn't create it, but we should # now have a restore op for it. self.evaluate(slot_status.restore_ops) self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) self.evaluate(train_op) else: new_root.optimizer.minimize(new_root.var.read_value) slot_status.assert_consumed()
def testNumberedPath(self): root = checkpointable.Checkpointable() leaf = checkpointable.Checkpointable() root.leaf = leaf checkpointable_utils.add_variable(leaf, name="v", shape=[]) named_variables, _ = checkpointable_utils._serialize_object_graph(root) variable_name, = named_variables.keys() self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", variable_name)
def testLocalNameValidation(self): root = checkpointable.Checkpointable() leaf = checkpointable.Checkpointable() # Dots are escaped, which avoids conflicts with reserved names. root._track_checkpointable(leaf, name=".ATTRIBUTES") checkpointable_utils.add_variable(checkpointable=leaf, name="a", shape=[]) named_variables, _ = checkpointable_utils._serialize_object_graph(root) name, = named_variables.keys() self.assertEqual(name, "..ATTRIBUTES/a/.ATTRIBUTES/VARIABLE_VALUE")
def testNoDependency(self): root = checkpointable.Checkpointable() hasdep = checkpointable.Checkpointable() root.hasdep = hasdep nodep = checkpointable.Checkpointable() root.nodep = checkpointable.NoDependency(nodep) self.assertEqual(1, len(root._checkpoint_dependencies)) self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep) self.assertIs(root.hasdep, hasdep) self.assertIs(root.nodep, nodep)
def testMultipleAssignment(self): root = checkpointable.Checkpointable() root.leaf = checkpointable.Checkpointable() root.leaf = root.leaf duplicate_name_dep = checkpointable.Checkpointable() with self.assertRaises(ValueError): root._track_checkpointable(duplicate_name_dep, name="leaf") # No error; we're overriding __setattr__, so we can't really stop people # from doing this while maintaining backward compatibility. root.leaf = duplicate_name_dep root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True)
def testNames(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") x1 = resource_variable_ops.ResourceVariable(2.) x2 = resource_variable_ops.ResourceVariable(3.) x3 = resource_variable_ops.ResourceVariable(4.) y = resource_variable_ops.ResourceVariable(5.) slots = containers.UniqueNameTracker() slots.track(x1, "x") slots.track(x2, "x") slots.track(x3, "x_1") slots.track(y, "y") self.evaluate((x1.initializer, x2.initializer, x3.initializer, y.initializer)) save_root = checkpointable_utils.Checkpoint(slots=slots) save_path = save_root.save(checkpoint_prefix) restore_slots = checkpointable.Checkpointable() restore_root = checkpointable_utils.Checkpoint( slots=restore_slots) status = restore_root.restore(save_path) restore_slots.x = resource_variable_ops.ResourceVariable(0.) restore_slots.x_1 = resource_variable_ops.ResourceVariable(0.) restore_slots.x_1_1 = resource_variable_ops.ResourceVariable(0.) restore_slots.y = resource_variable_ops.ResourceVariable(0.) status.assert_consumed().run_restore_ops() self.assertEqual(2., self.evaluate(restore_slots.x)) self.assertEqual(3., self.evaluate(restore_slots.x_1)) self.assertEqual(4., self.evaluate(restore_slots.x_1_1)) self.assertEqual(5., self.evaluate(restore_slots.y))
def _get_checkpoint_name(self, name): root = checkpointable.Checkpointable() checkpointable_utils.add_variable( root, name=name, shape=[1, 2], dtype=dtypes.float64) named_variables, _ = checkpointable_utils._serialize_object_graph(root) checkpoint_name, = named_variables.keys() with ops.name_scope("root/" + checkpoint_name): pass # Make sure we can use this as an op name if we prefix it. return checkpoint_name
def testDependencyLoop(self): # Note: this test creates garbage during eager execution because it # purposefully creates a reference cycle. first = checkpointable.Checkpointable() second = checkpointable.Checkpointable() first.second = second second.first = first first.v = checkpointable_utils.add_variable( first, "v1", initializer=[3., 1., 4.]) second.v = checkpointable_utils.add_variable( second, "v2", initializer=[1., 1., 2., 3.]) self.evaluate(checkpointable_utils.gather_initializers(first)) checkpoint_directory = self.get_temp_dir() save_path = checkpointable_utils.CheckpointableSaver(first).save( os.path.join(checkpoint_directory, "ckpt")) # Test deferred loading first_load = checkpointable.Checkpointable() status = checkpointable_utils.CheckpointableSaver( first_load).restore(save_path) second_load = checkpointable.Checkpointable() first_load.second = second_load second_load.first = first_load with self.assertRaises(AssertionError): status.assert_consumed() first_load.v = checkpointable_utils.add_variable( first_load, "v1", shape=[3]) second_load.v = checkpointable_utils.add_variable( second_load, "v2", shape=[4]) status.assert_consumed() status.run_restore_ops() self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v)) self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v)) # Test loading when variables have already been created self.evaluate(first_load.v.assign([2., 7., 1.])) self.assertAllEqual([2., 7., 1.], self.evaluate(first_load.v)) self.evaluate(second_load.v.assign([2., 7., 1., 8.])) self.assertAllEqual([2., 7., 1., 8.], self.evaluate(second_load.v)) status = checkpointable_utils.CheckpointableSaver(first_load).restore( save_path).assert_consumed() status.run_restore_ops() self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v)) self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v))
def testOverlappingRestores(self): checkpoint_directory = self.get_temp_dir() save_root = checkpointable.Checkpointable() save_root.dep = checkpointable.Checkpointable() save_root.dep.var = checkpointable_utils.add_variable( save_root.dep, name="var", initializer=0.) self.evaluate(state_ops.assign(save_root.dep.var, 12.)) first_path = checkpointable_utils.save( os.path.join(checkpoint_directory, "first"), save_root) self.evaluate(state_ops.assign(save_root.dep.var, 13.)) second_path = checkpointable_utils.save( os.path.join(checkpoint_directory, "second"), save_root) first_root = checkpointable.Checkpointable() second_root = checkpointable.Checkpointable() first_status = checkpointable_utils.restore( first_path, first_root) second_status = checkpointable_utils.restore( second_path, second_root) load_dep = checkpointable.Checkpointable() load_dep.var = checkpointable_utils.add_variable( load_dep, name="var", shape=[]) first_root.dep = load_dep first_status.assert_consumed() self.evaluate(first_status.restore_ops) self.assertEqual([], second_status.restore_ops) self.assertEqual(12., self.evaluate(load_dep.var)) second_root.dep = load_dep second_status.assert_consumed() self.evaluate(second_status.restore_ops) self.assertEqual(13., self.evaluate(load_dep.var)) # Try again with the order of the restore() reversed. The last restore # determines the final value. first_root = checkpointable.Checkpointable() second_root = checkpointable.Checkpointable() second_status = checkpointable_utils.restore( second_path, second_root) first_status = checkpointable_utils.restore( first_path, first_root) load_dep = checkpointable.Checkpointable() load_dep.var = checkpointable_utils.add_variable( load_dep, name="var", shape=[]) first_root.dep = load_dep first_status.assert_consumed() self.assertEqual([], second_status.restore_ops) self.evaluate(first_status.restore_ops) self.assertEqual(12., self.evaluate(load_dep.var)) second_root.dep = load_dep second_status.assert_consumed() self.evaluate(second_status.restore_ops) self.assertEqual(12., self.evaluate(load_dep.var))
def testShapeDtype(self): root = checkpointable.Checkpointable() v1 = checkpointable_utils.add_variable( root, name="v1", initializer=3., dtype=dtypes.float64) self.assertEqual(dtypes.float64, v1.dtype) v2 = checkpointable_utils.add_variable( root, name="v2", shape=[3], initializer=init_ops.ones_initializer, dtype=dtypes.float64) self.assertEqual(dtypes.float64, v2.dtype) self.assertAllEqual([1., 1., 1.], self.evaluate(v2))
def testAmbiguousLoad(self): # Not OK to split one checkpoint object into two checkpoint_directory = self.get_temp_dir() save_root = checkpointable.Checkpointable() save_root.dep_one = checkpointable.Checkpointable() save_root.dep_two = checkpointable.Checkpointable() dep_three = checkpointable.Checkpointable() save_root.dep_one.dep_three = dep_three save_root.dep_two.dep_three = dep_three checkpointable_utils.add_variable(dep_three, name="var", initializer=0.) self.evaluate(checkpointable_utils.gather_initializers(save_root)) save_path = checkpointable_utils.CheckpointableSaver(save_root).save( os.path.join(checkpoint_directory, "ckpt")) load_root = checkpointable.Checkpointable() checkpointable_utils.CheckpointableSaver(load_root).restore(save_path) load_root.dep_one = checkpointable.Checkpointable() load_root.dep_two = checkpointable.Checkpointable() load_root.dep_one.dep_three = checkpointable.Checkpointable() with self.assertRaisesRegexp(AssertionError, "resolved to different objects"): load_root.dep_two.dep_three = checkpointable.Checkpointable()
def testManySavesGraph(self): """Saves after the first should not modify the graph.""" with context.graph_mode(): graph = ops.Graph() with graph.as_default(), self.test_session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") obj = checkpointable.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) self.evaluate(checkpointable_utils.gather_initializers(obj)) saver = checkpointable_utils.CheckpointableSaver(obj) saver.save(checkpoint_prefix) before_ops = graph.get_operations() saver.save(checkpoint_prefix) self.assertEqual(before_ops, graph.get_operations())
def testManyRestoresGraph(self): """Restores after the first should not modify the graph.""" with context.graph_mode(): graph = ops.Graph() with graph.as_default(), self.test_session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") obj = checkpointable.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = CheckpointableAdam(0.1) obj.opt.minimize(obj.var.read_value()) self.evaluate(variables.global_variables_initializer()) save_path = checkpointable_utils.save( checkpoint_prefix, root_checkpointable=obj) checkpointable_utils.restore( save_path, root_checkpointable=obj) before_ops = graph.get_operations() checkpointable_utils.restore( save_path, root_checkpointable=obj) self.assertEqual(before_ops, graph.get_operations())
def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() root = checkpointable.Checkpointable() root.var = checkpointable_utils.add_variable(root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) if context.executing_eagerly(): optimizer.minimize(root.var.read_value) else: train_op = optimizer.minimize(root.var) # Note that `optimizer` has not been added as a dependency of # `root`. Create a one-off grouping so that slot variables for `root.var` # get initialized too. self.evaluate( checkpointable_utils.gather_initializers( checkpointable_utils.Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) self.evaluate(state_ops.assign(root.var, 12.)) no_slots_path = checkpointable_utils.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate( state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) slots_path = checkpointable_utils.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "with_slots")) new_root = checkpointable.Checkpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). slot_status = checkpointable_utils.CheckpointableSaver( new_root).restore(slots_path) no_slot_status = checkpointable_utils.CheckpointableSaver( new_root).restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = checkpointable_utils.add_variable(new_root, name="var", shape=[]) no_slot_status.assert_consumed() no_slot_status.run_restore_ops() self.assertEqual(12., self.evaluate(new_root.var)) new_root.optimizer = adam.AdamOptimizer(0.1) with self.assertRaisesRegexp(AssertionError, "beta1_power"): slot_status.assert_consumed() self.assertEqual(12., self.evaluate(new_root.var)) if context.executing_eagerly(): # Slot variables are only created with restoring initializers when # executing eagerly. self.assertEqual( 14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) else: self.assertIs( new_root.optimizer.get_slot(name="m", var=new_root.var), None) if context.executing_eagerly(): new_root.optimizer.minimize(new_root.var.read_value) else: train_op = new_root.optimizer.minimize(new_root.var) # The slot variable now exists; restore() didn't create it, but we should # now have a restore op for it. slot_status.run_restore_ops() self.assertEqual( 14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) self.evaluate(train_op) slot_status.assert_consumed()