Example #1
0
 def testNetworkSaveRestoreAlreadyBuilt(self):
   net = MyNetwork(name="abcd")
   with self.assertRaisesRegexp(
       ValueError, "Attempt to save the Network before it was first called"):
     network.save_network_checkpoint(net, self.get_temp_dir())
   net(constant_op.constant([[2.0]]))
   self.evaluate(net.trainable_variables[0].assign([[17.0]]))
   self._save_modify_load_network_built(net, global_step=None)
   self._save_modify_load_network_built(net, global_step=10)
Example #2
0
  def testRestoreIntoSubNetwork(self):

    class Parent(network.Network):

      def __init__(self, name=None):
        super(Parent, self).__init__(name=name)
        self.first = self.track_layer(MyNetwork())
        self.second = self.track_layer(MyNetwork())

      def call(self, x):
        return self.first(self.second(x))

    one = constant_op.constant([[3.]])
    whole_model_saver = Parent()
    whole_model_saver(one)
    self.evaluate(whole_model_saver.variables[0].assign([[15.]]))
    self.evaluate(whole_model_saver.variables[1].assign([[16.]]))
    whole_model_checkpoint = network.save_network_checkpoint(
        whole_model_saver, self.get_temp_dir())

    save_from = MyNetwork()
    save_from(one)
    self.evaluate(save_from.variables[0].assign([[5.]]))
    checkpoint = network.save_network_checkpoint(save_from, self.get_temp_dir())
    save_into_parent = Parent()
    network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint)
    network.restore_network_checkpoint(save_into_parent.first, checkpoint)
    # deferred loading multiple times is fine
    network.restore_network_checkpoint(save_into_parent.first, checkpoint)
    save_into_parent(one)  # deferred loading
    self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[0]))
    self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1]))

    # Try again with the opposite ordering, and we should get different results
    # (deferred restoration should happen the same way non-deferred happens,
    # with later restorations overwriting older ones).
    save_into_parent = Parent()
    # deferred loading multiple times is fine
    network.restore_network_checkpoint(save_into_parent.first, checkpoint)
    network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint)
    save_into_parent(one)  # deferred loading
    # We've overwritten the sub-Network restore.
    self.assertAllEqual([[15.]], self.evaluate(save_into_parent.variables[0]))
    self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1]))

    self.evaluate(save_into_parent.variables[0].assign([[3.]]))
    self.evaluate(save_into_parent.variables[1].assign([[4.]]))
    network.restore_network_checkpoint(save_into_parent.second, checkpoint)
    self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[1]))
    with self.assertRaisesRegexp(errors_impl.NotFoundError,
                                 "not found in checkpoint"):
      # The checkpoint is incompatible.
      network.restore_network_checkpoint(save_into_parent, checkpoint)
Example #3
0
  def testDefaultMapCollisionErrors(self):

    one = constant_op.constant([[1.]])
    first = core.Dense(1, name="dense", use_bias=False)
    first(one)

    class Parent(network.Network):

      def __init__(self, name=None):
        super(Parent, self).__init__(name=name)
        self.first = self.track_layer(first)
        self.second = self.track_layer(core.Dense(1, use_bias=False))

      def call(self, x):
        return self.first(self.second(x))

    make_checkpoint = Parent()
    one = constant_op.constant([[1.]])
    make_checkpoint(one)
    self.evaluate(make_checkpoint.variables[0].assign([[2.]]))
    self.evaluate(make_checkpoint.variables[1].assign([[3.]]))
    with self.assertRaisesRegexp(
        ValueError,
        ("The default checkpoint variable name mapping strategy for Network "
         "'parent' resulted in a naming conflict.")):
      network.save_network_checkpoint(make_checkpoint, self.get_temp_dir())

    class Compatible(network.Network):

      def __init__(self, name=None):
        super(Compatible, self).__init__(name=name)
        self.first = self.track_layer(core.Dense(1, use_bias=False))

      def call(self, x):
        return self.first(x)

    successful_checkpoint = Compatible()
    successful_checkpoint(one)
    self.evaluate(successful_checkpoint.variables[0].assign([[-1.]]))
    checkpoint_path = network.save_network_checkpoint(
        successful_checkpoint, self.get_temp_dir())
    load_checkpoint = Parent()
    load_checkpoint(one)
    with self.assertRaisesRegexp(
        ValueError,
        ("The default checkpoint variable name mapping strategy for Network "
         "'parent_1' resulted in a naming conflict.")):
      network.restore_network_checkpoint(load_checkpoint, checkpoint_path)
Example #4
0
 def testSaveRestoreDefaultGlobalStep(self):
   net = MyNetwork(name="abcd")
   net(constant_op.constant([[2.0]]))
   self.evaluate(net.variables[0].assign([[3.]]))
   default_global_step = training_util.get_or_create_global_step()
   self.evaluate(default_global_step.assign(4242))
   save_path = network.save_network_checkpoint(net, self.get_temp_dir())
   self.assertIn("abcd-4242", save_path)
Example #5
0
  def testCustomMapCollisionErrors(self):

    class Parent(network.Network):

      def __init__(self, name=None):
        super(Parent, self).__init__(name=name)
        self.first = self.track_layer(MyNetwork())
        self.second = self.track_layer(MyNetwork())

      def call(self, x):
        return self.first(self.second(x))

    make_checkpoint = Parent()
    one = constant_op.constant([[1.]])
    make_checkpoint(one)
    self.evaluate(make_checkpoint.variables[0].assign([[2.]]))
    self.evaluate(make_checkpoint.variables[1].assign([[3.]]))
    with self.assertRaisesRegexp(
        ValueError,
        "The map_func passed to save_network_checkpoint for the Network "
        "'parent' resulted in two variables named 'foo'"):
      network.save_network_checkpoint(
          make_checkpoint, self.get_temp_dir(), map_func=lambda n: "foo")
    checkpoint = network.save_network_checkpoint(
        network=make_checkpoint.first,
        save_path=self.get_temp_dir(),
        map_func=lambda n: "foo")
    loader = Parent()
    network.restore_network_checkpoint(
        loader, checkpoint, map_func=lambda n: "foo")
    with self.assertRaisesRegexp(
        ValueError,
        ("The map_func passed to restore_network_checkpoint for the Network"
         " 'parent_1' resulted in two variables named 'foo'")):
      loader(one)
    loader = Parent()
    loader(one)
    with self.assertRaisesRegexp(
        ValueError,
        ("The map_func passed to restore_network_checkpoint for the Network"
         " 'parent_2' resulted in two variables named 'foo'")):
      network.restore_network_checkpoint(
          loader, checkpoint, map_func=lambda n: "foo")
Example #6
0
 def testNetworkSaveAndRestoreIntoUnbuilt(self):
   save_dir = self.get_temp_dir()
   net1 = MyNetwork()
   test_input = constant_op.constant([[2.0]])
   net1(test_input)
   self.evaluate(net1.trainable_variables[0].assign([[17.0]]))
   save_path = network.save_network_checkpoint(net1, save_dir)
   # With a pre-build restore we should have the same value.
   net2 = MyNetwork()
   network.restore_network_checkpoint(net2, save_path)
   self.assertAllEqual(self.evaluate(net1(test_input)),
                       self.evaluate(net2(test_input)))
   self.assertIsNot(net1.variables[0], net2.variables[0])
   self.assertAllEqual(self.evaluate(net1.variables[0]),
                       self.evaluate(net2.variables[0]))
Example #7
0
 def _save_modify_load_network_built(self, net, global_step=None):
   checkpoint_directory = self.get_temp_dir()
   checkpoint_path = network.save_network_checkpoint(
       network=net, save_path=checkpoint_directory, global_step=global_step)
   input_value = constant_op.constant([[42.0]])
   original_output = self.evaluate(net(input_value))
   for var in net.variables:
     self.evaluate(var.assign(var + 1.))
   self.assertGreater(
       self.evaluate(net(input_value)),
       original_output)
   # Either the returned explicit checkpoint path or the directory should work.
   network.restore_network_checkpoint(net, save_path=checkpoint_directory)
   self.assertAllEqual(
       original_output,
       self.evaluate(net(input_value)))
   for var in net.variables:
     self.evaluate(var.assign(var + 2.))
   network.restore_network_checkpoint(net, save_path=checkpoint_path)
   self.assertAllEqual(
       original_output,
       self.evaluate(net(input_value)))
Example #8
0
 def testVariableScopeStripping(self):
   with variable_scope.variable_scope("scope1"):
     with variable_scope.variable_scope("scope2"):
       net = MyNetwork()
   net(constant_op.constant([[2.0]]))
   self.evaluate(net.variables[0].assign([[42.]]))
   self.assertEqual(net.name, "scope1/scope2/my_network")
   self.assertStartsWith(
       expected_start="scope1/scope2/my_network/dense/",
       actual=net.trainable_weights[0].name)
   save_path = network.save_network_checkpoint(net, self.get_temp_dir())
   self.assertIn("scope1_scope2_my_network", save_path)
   restore_net = MyNetwork()
   # Delayed restoration
   network.restore_network_checkpoint(restore_net, save_path)
   restore_net(constant_op.constant([[1.0]]))
   self.assertAllEqual([[42.]],
                       self.evaluate(restore_net.variables[0]))
   self.evaluate(restore_net.variables[0].assign([[-1.]]))
   # Immediate restoration
   network.restore_network_checkpoint(restore_net, save_path)
   self.assertAllEqual([[42.]],
                       self.evaluate(restore_net.variables[0]))
Example #9
0
  def testLoadIntoUnbuiltSharedLayer(self):

    class Owner(network.Network):

      def __init__(self, name=None):
        super(Owner, self).__init__(name=name)
        self.first = self.track_layer(core.Dense(
            1, name="first_layer", use_bias=False))

      def call(self, x):
        return self.first(x)

    first_owner = Owner()

    class User(network.Network):

      def __init__(self, use_layer, name=None):
        super(User, self).__init__(name=name)
        self.first = self.track_layer(use_layer)
        self.second = self.track_layer(core.Dense(
            1, name="second_layer", use_bias=False))

      def call(self, x):
        return self.second(self.first(x))

    class LikeUserButNotSharing(network.Network):

      def __init__(self, name=None):
        super(LikeUserButNotSharing, self).__init__(name=name)
        self.first = self.track_layer(core.Dense(
            1, name="first_layer", use_bias=False))
        self.second = self.track_layer(core.Dense(
            1, name="second_layer", use_bias=False))

      def call(self, x):
        return self.second(self.first(x))

    checkpoint_creator = LikeUserButNotSharing(name="checkpoint_creator")
    one = constant_op.constant([[1.0]])
    checkpoint_creator(one)
    self.assertEqual(2, len(checkpoint_creator.variables))
    self.evaluate(checkpoint_creator.variables[0].assign([[5.]]))
    self.evaluate(checkpoint_creator.variables[1].assign([[6.]]))
    # Re-map the variable names so that with default restore mapping we'll
    # attempt to restore into the unbuilt Layer.
    name_mapping = {
        "checkpoint_creator/first_layer/kernel": "owner/first_layer/kernel",
        "checkpoint_creator/second_layer/kernel": "second_layer/kernel",
    }
    save_path = network.save_network_checkpoint(
        checkpoint_creator,
        self.get_temp_dir(),
        map_func=lambda full_name: name_mapping[full_name])
    load_into = User(use_layer=first_owner.first)
    network.restore_network_checkpoint(load_into, save_path)
    self.assertEqual(0, len(first_owner.variables))
    self.assertAllEqual(self.evaluate(checkpoint_creator(one)),
                        self.evaluate(load_into(one)))
    self.assertEqual(1, len(first_owner.variables))
    self.assertAllEqual([[5.]], self.evaluate(load_into.variables[0]))
    self.assertAllEqual([[6.]], self.evaluate(load_into.variables[1]))
    first_owner(one)
    self.assertAllEqual([[5.]], self.evaluate(first_owner.variables[0]))

    # Try again with a garbage collected parent.
    first_owner = Owner()
    load_into = User(use_layer=first_owner.first)
    del first_owner
    gc.collect()
    def _restore_map_func(original_name):
      if original_name.startswith("owner/"):
        return original_name.replace("owner/", "owner_1/")
      else:
        return "user_1/" + original_name
    with self.assertRaisesRegexp(ValueError, "garbage collected"):
      network.restore_network_checkpoint(
          load_into, save_path, map_func=_restore_map_func)