def testNetworkSaveRestoreAlreadyBuilt(self): net = MyNetwork(name="abcd") with self.assertRaisesRegexp( ValueError, "Attempt to save the Network before it was first called"): network.save_network_checkpoint(net, self.get_temp_dir()) net(constant_op.constant([[2.0]])) self.evaluate(net.trainable_variables[0].assign([[17.0]])) self._save_modify_load_network_built(net, global_step=None) self._save_modify_load_network_built(net, global_step=10)
def testRestoreIntoSubNetwork(self): class Parent(network.Network): def __init__(self, name=None): super(Parent, self).__init__(name=name) self.first = self.track_layer(MyNetwork()) self.second = self.track_layer(MyNetwork()) def call(self, x): return self.first(self.second(x)) one = constant_op.constant([[3.]]) whole_model_saver = Parent() whole_model_saver(one) self.evaluate(whole_model_saver.variables[0].assign([[15.]])) self.evaluate(whole_model_saver.variables[1].assign([[16.]])) whole_model_checkpoint = network.save_network_checkpoint( whole_model_saver, self.get_temp_dir()) save_from = MyNetwork() save_from(one) self.evaluate(save_from.variables[0].assign([[5.]])) checkpoint = network.save_network_checkpoint(save_from, self.get_temp_dir()) save_into_parent = Parent() network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint) network.restore_network_checkpoint(save_into_parent.first, checkpoint) # deferred loading multiple times is fine network.restore_network_checkpoint(save_into_parent.first, checkpoint) save_into_parent(one) # deferred loading self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[0])) self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1])) # Try again with the opposite ordering, and we should get different results # (deferred restoration should happen the same way non-deferred happens, # with later restorations overwriting older ones). save_into_parent = Parent() # deferred loading multiple times is fine network.restore_network_checkpoint(save_into_parent.first, checkpoint) network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint) save_into_parent(one) # deferred loading # We've overwritten the sub-Network restore. self.assertAllEqual([[15.]], self.evaluate(save_into_parent.variables[0])) self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1])) self.evaluate(save_into_parent.variables[0].assign([[3.]])) self.evaluate(save_into_parent.variables[1].assign([[4.]])) network.restore_network_checkpoint(save_into_parent.second, checkpoint) self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[1])) with self.assertRaisesRegexp(errors_impl.NotFoundError, "not found in checkpoint"): # The checkpoint is incompatible. network.restore_network_checkpoint(save_into_parent, checkpoint)
def testDefaultMapCollisionErrors(self): one = constant_op.constant([[1.]]) first = core.Dense(1, name="dense", use_bias=False) first(one) class Parent(network.Network): def __init__(self, name=None): super(Parent, self).__init__(name=name) self.first = self.track_layer(first) self.second = self.track_layer(core.Dense(1, use_bias=False)) def call(self, x): return self.first(self.second(x)) make_checkpoint = Parent() one = constant_op.constant([[1.]]) make_checkpoint(one) self.evaluate(make_checkpoint.variables[0].assign([[2.]])) self.evaluate(make_checkpoint.variables[1].assign([[3.]])) with self.assertRaisesRegexp( ValueError, ("The default checkpoint variable name mapping strategy for Network " "'parent' resulted in a naming conflict.")): network.save_network_checkpoint(make_checkpoint, self.get_temp_dir()) class Compatible(network.Network): def __init__(self, name=None): super(Compatible, self).__init__(name=name) self.first = self.track_layer(core.Dense(1, use_bias=False)) def call(self, x): return self.first(x) successful_checkpoint = Compatible() successful_checkpoint(one) self.evaluate(successful_checkpoint.variables[0].assign([[-1.]])) checkpoint_path = network.save_network_checkpoint( successful_checkpoint, self.get_temp_dir()) load_checkpoint = Parent() load_checkpoint(one) with self.assertRaisesRegexp( ValueError, ("The default checkpoint variable name mapping strategy for Network " "'parent_1' resulted in a naming conflict.")): network.restore_network_checkpoint(load_checkpoint, checkpoint_path)
def testSaveRestoreDefaultGlobalStep(self): net = MyNetwork(name="abcd") net(constant_op.constant([[2.0]])) self.evaluate(net.variables[0].assign([[3.]])) default_global_step = training_util.get_or_create_global_step() self.evaluate(default_global_step.assign(4242)) save_path = network.save_network_checkpoint(net, self.get_temp_dir()) self.assertIn("abcd-4242", save_path)
def testCustomMapCollisionErrors(self): class Parent(network.Network): def __init__(self, name=None): super(Parent, self).__init__(name=name) self.first = self.track_layer(MyNetwork()) self.second = self.track_layer(MyNetwork()) def call(self, x): return self.first(self.second(x)) make_checkpoint = Parent() one = constant_op.constant([[1.]]) make_checkpoint(one) self.evaluate(make_checkpoint.variables[0].assign([[2.]])) self.evaluate(make_checkpoint.variables[1].assign([[3.]])) with self.assertRaisesRegexp( ValueError, "The map_func passed to save_network_checkpoint for the Network " "'parent' resulted in two variables named 'foo'"): network.save_network_checkpoint( make_checkpoint, self.get_temp_dir(), map_func=lambda n: "foo") checkpoint = network.save_network_checkpoint( network=make_checkpoint.first, save_path=self.get_temp_dir(), map_func=lambda n: "foo") loader = Parent() network.restore_network_checkpoint( loader, checkpoint, map_func=lambda n: "foo") with self.assertRaisesRegexp( ValueError, ("The map_func passed to restore_network_checkpoint for the Network" " 'parent_1' resulted in two variables named 'foo'")): loader(one) loader = Parent() loader(one) with self.assertRaisesRegexp( ValueError, ("The map_func passed to restore_network_checkpoint for the Network" " 'parent_2' resulted in two variables named 'foo'")): network.restore_network_checkpoint( loader, checkpoint, map_func=lambda n: "foo")
def testNetworkSaveAndRestoreIntoUnbuilt(self): save_dir = self.get_temp_dir() net1 = MyNetwork() test_input = constant_op.constant([[2.0]]) net1(test_input) self.evaluate(net1.trainable_variables[0].assign([[17.0]])) save_path = network.save_network_checkpoint(net1, save_dir) # With a pre-build restore we should have the same value. net2 = MyNetwork() network.restore_network_checkpoint(net2, save_path) self.assertAllEqual(self.evaluate(net1(test_input)), self.evaluate(net2(test_input))) self.assertIsNot(net1.variables[0], net2.variables[0]) self.assertAllEqual(self.evaluate(net1.variables[0]), self.evaluate(net2.variables[0]))
def _save_modify_load_network_built(self, net, global_step=None): checkpoint_directory = self.get_temp_dir() checkpoint_path = network.save_network_checkpoint( network=net, save_path=checkpoint_directory, global_step=global_step) input_value = constant_op.constant([[42.0]]) original_output = self.evaluate(net(input_value)) for var in net.variables: self.evaluate(var.assign(var + 1.)) self.assertGreater( self.evaluate(net(input_value)), original_output) # Either the returned explicit checkpoint path or the directory should work. network.restore_network_checkpoint(net, save_path=checkpoint_directory) self.assertAllEqual( original_output, self.evaluate(net(input_value))) for var in net.variables: self.evaluate(var.assign(var + 2.)) network.restore_network_checkpoint(net, save_path=checkpoint_path) self.assertAllEqual( original_output, self.evaluate(net(input_value)))
def testVariableScopeStripping(self): with variable_scope.variable_scope("scope1"): with variable_scope.variable_scope("scope2"): net = MyNetwork() net(constant_op.constant([[2.0]])) self.evaluate(net.variables[0].assign([[42.]])) self.assertEqual(net.name, "scope1/scope2/my_network") self.assertStartsWith( expected_start="scope1/scope2/my_network/dense/", actual=net.trainable_weights[0].name) save_path = network.save_network_checkpoint(net, self.get_temp_dir()) self.assertIn("scope1_scope2_my_network", save_path) restore_net = MyNetwork() # Delayed restoration network.restore_network_checkpoint(restore_net, save_path) restore_net(constant_op.constant([[1.0]])) self.assertAllEqual([[42.]], self.evaluate(restore_net.variables[0])) self.evaluate(restore_net.variables[0].assign([[-1.]])) # Immediate restoration network.restore_network_checkpoint(restore_net, save_path) self.assertAllEqual([[42.]], self.evaluate(restore_net.variables[0]))
def testLoadIntoUnbuiltSharedLayer(self): class Owner(network.Network): def __init__(self, name=None): super(Owner, self).__init__(name=name) self.first = self.track_layer(core.Dense( 1, name="first_layer", use_bias=False)) def call(self, x): return self.first(x) first_owner = Owner() class User(network.Network): def __init__(self, use_layer, name=None): super(User, self).__init__(name=name) self.first = self.track_layer(use_layer) self.second = self.track_layer(core.Dense( 1, name="second_layer", use_bias=False)) def call(self, x): return self.second(self.first(x)) class LikeUserButNotSharing(network.Network): def __init__(self, name=None): super(LikeUserButNotSharing, self).__init__(name=name) self.first = self.track_layer(core.Dense( 1, name="first_layer", use_bias=False)) self.second = self.track_layer(core.Dense( 1, name="second_layer", use_bias=False)) def call(self, x): return self.second(self.first(x)) checkpoint_creator = LikeUserButNotSharing(name="checkpoint_creator") one = constant_op.constant([[1.0]]) checkpoint_creator(one) self.assertEqual(2, len(checkpoint_creator.variables)) self.evaluate(checkpoint_creator.variables[0].assign([[5.]])) self.evaluate(checkpoint_creator.variables[1].assign([[6.]])) # Re-map the variable names so that with default restore mapping we'll # attempt to restore into the unbuilt Layer. name_mapping = { "checkpoint_creator/first_layer/kernel": "owner/first_layer/kernel", "checkpoint_creator/second_layer/kernel": "second_layer/kernel", } save_path = network.save_network_checkpoint( checkpoint_creator, self.get_temp_dir(), map_func=lambda full_name: name_mapping[full_name]) load_into = User(use_layer=first_owner.first) network.restore_network_checkpoint(load_into, save_path) self.assertEqual(0, len(first_owner.variables)) self.assertAllEqual(self.evaluate(checkpoint_creator(one)), self.evaluate(load_into(one))) self.assertEqual(1, len(first_owner.variables)) self.assertAllEqual([[5.]], self.evaluate(load_into.variables[0])) self.assertAllEqual([[6.]], self.evaluate(load_into.variables[1])) first_owner(one) self.assertAllEqual([[5.]], self.evaluate(first_owner.variables[0])) # Try again with a garbage collected parent. first_owner = Owner() load_into = User(use_layer=first_owner.first) del first_owner gc.collect() def _restore_map_func(original_name): if original_name.startswith("owner/"): return original_name.replace("owner/", "owner_1/") else: return "user_1/" + original_name with self.assertRaisesRegexp(ValueError, "garbage collected"): network.restore_network_checkpoint( load_into, save_path, map_func=_restore_map_func)