示例#1
0
  def test_checkpointable_save_restore(self):

    def _templated():
      v = variable_scope.get_variable(
          "v", shape=[1], initializer=init_ops.zeros_initializer())
      v2 = variable_scope.get_variable(
          "v2", shape=[1], initializer=init_ops.zeros_initializer())
      return v, v + 1., v2

    save_template = template.make_template("s1", _templated)
    save_root = checkpointable_utils.Checkpoint(my_template=save_template)
    v1_save, _, v2_save = save_template()
    self.evaluate(v1_save.assign([12.]))
    self.evaluate(v2_save.assign([14.]))
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    save_path = save_root.save(checkpoint_prefix)

    load_template = template.make_template("s2", _templated)
    load_root = checkpointable_utils.Checkpoint(my_template=load_template)
    status = load_root.restore(save_path)
    var, var_plus_one, var2 = load_template()
    self.assertEqual(2, len(load_template._checkpoint_dependencies))
    self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
    self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
    status.assert_consumed().run_restore_ops()
    self.assertAllEqual([12.], self.evaluate(var))
    self.assertAllEqual([13.], self.evaluate(var_plus_one))
    self.assertAllEqual([14.], self.evaluate(var2))
 def testSaveRestore(self):
   network = MyNetwork()
   optimizer = adam.AdamOptimizer(0.001)
   root_checkpointable = checkpointable_utils.Checkpoint(
       optimizer=optimizer, network=network)
   input_value = constant_op.constant([[3.]])
   if context.in_eager_mode():
     optimizer.minimize(
         lambda: network(input_value))
   else:
     train_op = optimizer.minimize(network(input_value))
     # TODO(allenl): Make initialization more pleasant when graph building.
     root_checkpointable.save_counter  # pylint: disable=pointless-statement
     self.evaluate(checkpointable_utils.gather_initializers(
         root_checkpointable))
     self.evaluate(train_op)
   prefix = os.path.join(self.get_temp_dir(), "ckpt")
   self.evaluate(state_ops.assign(network._named_dense.variables[1], [42.]))
   m_bias_slot = optimizer.get_slot(network._named_dense.variables[1], "m")
   self.evaluate(state_ops.assign(m_bias_slot, [1.5]))
   save_path = root_checkpointable.save(file_prefix=prefix)
   self.evaluate(state_ops.assign(network._named_dense.variables[1], [43.]))
   self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3))
   optimizer_variables = self.evaluate(optimizer.variables())
   self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
   # Immediate restoration
   status = root_checkpointable.restore(save_path=save_path).assert_consumed()
   status.run_restore_ops()
   self.assertAllEqual([42.], self.evaluate(network._named_dense.variables[1]))
   self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter))
   self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
   if context.in_graph_mode():
     return  # Restore-on-create is only supported when executing eagerly
   on_create_network = MyNetwork()
   on_create_optimizer = adam.AdamOptimizer(0.001)
   on_create_root = checkpointable_utils.Checkpoint(
       optimizer=on_create_optimizer, network=on_create_network)
   # Deferred restoration
   status = on_create_root.restore(save_path=save_path)
   on_create_network(constant_op.constant([[3.]]))  # create variables
   self.assertAllEqual(1, self.evaluate(on_create_root.save_counter))
   self.assertAllEqual([42.],
                       self.evaluate(
                           on_create_network._named_dense.variables[1]))
   on_create_m_bias_slot = on_create_optimizer.get_slot(
       on_create_network._named_dense.variables[1], "m")
   # Optimizer slot variables are created when the original variable is
   # restored.
   self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
   self.assertAllEqual(optimizer_variables[2:],
                       self.evaluate(on_create_optimizer.variables()))
   on_create_optimizer._create_slots(
       [resource_variable_ops.ResourceVariable([1.])])
   status.assert_consumed()
   beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators()
   self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power))
   self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power))
  def testMultipleGraphsNonSlotVariables(self):
    with context.graph_mode():
      checkpoint_directory = self.get_temp_dir()
      checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
      optimizer = adam.AdamOptimizer(0.001)
      # Construct a model in one graph
      first_graph = ops.Graph()
      first_session = session_lib.Session(graph=first_graph)
      with first_graph.as_default(), first_session.as_default():
        first_variable = resource_variable_ops.ResourceVariable([1.])
        first_root_checkpointable = checkpointable_utils.Checkpoint(
            optimizer=optimizer, variable=first_variable)
        train_op = optimizer.minimize(first_variable.read_value)
        self.evaluate(checkpointable_utils.gather_initializers(
            first_root_checkpointable))
        self.evaluate(train_op)
        self.evaluate(first_variable.assign([1.]))
        self.evaluate(optimizer.get_slot(
            var=first_variable, name="m").assign([2.]))
        beta1_power, _ = optimizer._get_beta_accumulators()
        self.evaluate(beta1_power.assign(3.))

      # Save and load in a second graph
      second_graph = ops.Graph()
      with second_graph.as_default(), session_lib.Session(graph=second_graph):
        second_variable = resource_variable_ops.ResourceVariable([1.])
        second_root_checkpointable = checkpointable_utils.Checkpoint(
            optimizer=optimizer, variable=second_variable)
        train_op = optimizer.minimize(second_variable.read_value)
        second_root_checkpointable.restore(None).initialize_or_restore()
        self.evaluate(train_op)
        self.evaluate(second_variable.assign([4.]))
        self.evaluate(optimizer.get_slot(
            var=second_variable, name="m").assign([5.]))
        beta1_power, _ = optimizer._get_beta_accumulators()
        self.evaluate(beta1_power.assign(6.))
        save_path = second_root_checkpointable.save(checkpoint_prefix)
        self.evaluate(second_variable.assign([7.]))
        self.evaluate(optimizer.get_slot(
            var=second_variable, name="m").assign([8.]))
        beta1_power, _ = optimizer._get_beta_accumulators()
        self.assertAllEqual(6., self.evaluate(beta1_power))
        status = second_root_checkpointable.restore(save_path)
        status.assert_consumed().run_restore_ops()
        self.assertAllEqual([4.], self.evaluate(second_variable))
        self.assertAllEqual([5.], self.evaluate(optimizer.get_slot(
            var=second_variable, name="m")))
        beta1_power, _ = optimizer._get_beta_accumulators()
        self.assertAllEqual(6., self.evaluate(beta1_power))

      # Check that the first graph is unmolested
      with first_graph.as_default(), first_session.as_default():
        self.assertAllEqual([1.], self.evaluate(first_variable))
        self.assertAllEqual([2.], self.evaluate(optimizer.get_slot(
            var=first_variable, name="m")))
        beta1_power, _ = optimizer._get_beta_accumulators()
        self.assertAllEqual(3., self.evaluate(beta1_power))
示例#4
0
  def testSaveRestoreMultipleIterator(self):
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
    dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
    dataset = dataset.map(math_ops.square).batch(2)
    iterator_1 = datasets.Iterator(dataset)
    iterator_2 = datasets.Iterator(dataset)
    dataset_2 = Dataset.range(10)
    iterator_3 = datasets.Iterator(dataset_2)

    checkpoint = checkpointable_utils.Checkpoint(
        iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
    self.assertAllEqual([1, 4], iterator_1.get_next().numpy())
    self.assertEqual(0, iterator_3.get_next().numpy())
    self.assertEqual(1, iterator_3.get_next().numpy())
    self.assertEqual(2, iterator_3.get_next().numpy())

    save_path = checkpoint.save(checkpoint_prefix)
    self.assertAllEqual([1, 4], iterator_2.get_next().numpy())
    self.assertAllEqual([9, 16], iterator_2.get_next().numpy())
    self.assertEqual(3, iterator_3.get_next().numpy())
    checkpoint.restore(save_path)
    self.assertAllEqual([9, 16], iterator_1.get_next().numpy())
    self.assertAllEqual([1, 4], iterator_2.get_next().numpy())
    self.assertEqual(3, iterator_3.get_next().numpy())
    def testAnonymousVarsInInit(self):
        class Model(training.Model):
            def __init__(self):
                super(Model, self).__init__()
                self.w = resource_variable_ops.ResourceVariable(0.0)
                self.b = resource_variable_ops.ResourceVariable(0.0)
                self.vars = [self.w, self.b]

            def call(self, x):
                return x * self.w + self.b

        with context.eager_mode():
            model = Model()
            optimizer = adam.AdamOptimizer(learning_rate=0.05)
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            checkpoint = checkpointable_utils.Checkpoint(model=model,
                                                         optimizer=optimizer)
            for _ in range(2):
                checkpoint.save(checkpoint_prefix)
                with backprop.GradientTape() as tape:
                    loss = (constant_op.constant(1.) -
                            model(constant_op.constant(1.)))**2
                grad = tape.gradient(loss, model.vars)
                optimizer.apply_gradients([(g, v)
                                           for g, v in zip(grad, model.vars)])
 def testAgnosticUsage(self):
   """Graph/eager agnostic usage."""
   # Does create garbage when executing eagerly due to ops.Graph() creation.
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   for training_continuation in range(3):
     with ops.Graph().as_default(), self.test_session(
         graph=ops.get_default_graph()):
       network = MyNetwork()
       optimizer = adam.AdamOptimizer(0.001)
       root = checkpointable_utils.Checkpoint(
           optimizer=optimizer, network=network,
           global_step=training_util.get_or_create_global_step())
       checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
       status = root.restore(save_path=checkpoint_path)
       input_value = constant_op.constant([[3.]])
       train_fn = functools.partial(
           optimizer.minimize,
           functools.partial(network, input_value),
           global_step=root.global_step)
       if context.in_graph_mode():
         train_fn = functools.partial(self.evaluate, train_fn())
       status.initialize_or_restore()
       for _ in range(num_training_steps):
         train_fn()
       root.save(file_prefix=checkpoint_prefix)
       self.assertEqual((training_continuation + 1) * num_training_steps,
                        self.evaluate(root.global_step))
       self.assertEqual(training_continuation + 1,
                        self.evaluate(root.save_counter))
 def testUsageGraph(self):
   """Expected usage when graph building."""
   with context.graph_mode():
     num_training_steps = 10
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
     for training_continuation in range(3):
       with ops.Graph().as_default():
         network = MyNetwork()
         optimizer = adam.AdamOptimizer(0.001)
         root = checkpointable_utils.Checkpoint(
             optimizer=optimizer, network=network,
             global_step=training_util.get_or_create_global_step())
         input_value = constant_op.constant([[3.]])
         train_op = optimizer.minimize(
             network(input_value),
             global_step=root.global_step)
         checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
         with self.test_session(graph=ops.get_default_graph()) as session:
           status = root.restore(save_path=checkpoint_path)
           status.initialize_or_restore(session=session)
           if checkpoint_path is None:
             self.assertEqual(0, training_continuation)
             with self.assertRaises(AssertionError):
               status.assert_consumed()
           else:
             status.assert_consumed()
           for _ in range(num_training_steps):
             session.run(train_op)
           root.save(file_prefix=checkpoint_prefix, session=session)
           self.assertEqual((training_continuation + 1) * num_training_steps,
                            session.run(root.global_step))
           self.assertEqual(training_continuation + 1,
                            session.run(root.save_counter))
示例#8
0
  def test_checkpointable_save_restore_nested(self):

    def _inner_template():
      v = variable_scope.get_variable(
          "v", shape=[1], initializer=init_ops.zeros_initializer())
      return v

    def _outer_template():
      first_inner = template.make_template("i1", _inner_template)
      second_inner = template.make_template("i2", _inner_template)
      v1 = first_inner()
      v2 = second_inner()
      v3 = second_inner()
      return (first_inner, second_inner), (v1, v2, v3)

    with variable_scope.variable_scope("ignored"):
      save_template = template.make_template("s1", _outer_template)
      save_root = checkpointable_utils.Checkpoint(my_template=save_template)
      (inner_template_one, inner_template_two), _ = save_template()
    self.evaluate(inner_template_one.variables[0].assign([20.]))
    self.evaluate(inner_template_two.variables[0].assign([25.]))
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    save_path = save_root.save(checkpoint_prefix)

    load_template = template.make_template("s2", _outer_template)
    load_root = checkpointable_utils.Checkpoint(my_template=load_template)
    status = load_root.restore(save_path)
    (inner_template_one, inner_template_two), (v1, v2, v3) = load_template()
    outer_template_dependencies = load_root.my_template._checkpoint_dependencies
    self.assertEqual(2, len(outer_template_dependencies))
    self.assertEqual("i1", outer_template_dependencies[0].name)
    self.assertIs(inner_template_one, outer_template_dependencies[0].ref)
    self.assertEqual("i2", outer_template_dependencies[1].name)
    self.assertIs(inner_template_two, outer_template_dependencies[1].ref)
    self.assertEqual(1, len(inner_template_one._checkpoint_dependencies))
    self.assertEqual("v", inner_template_one._checkpoint_dependencies[0].name)
    self.assertEqual(1, len(inner_template_two._checkpoint_dependencies))
    self.assertEqual("v", inner_template_two._checkpoint_dependencies[0].name)
    status.assert_consumed().run_restore_ops()
    self.assertAllEqual([20.], self.evaluate(v1))
    self.assertAllEqual([25.], self.evaluate(v2))
    self.assertAllEqual([25.], self.evaluate(v3))
示例#9
0
  def testRestoreExhaustedIterator(self):
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
    dataset = Dataset.range(3)
    iterator = datasets.Iterator(dataset)

    checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
    self.assertEqual(0, iterator.get_next().numpy())
    self.assertEqual(1, iterator.get_next().numpy())
    save_path = checkpoint.save(checkpoint_prefix)
    self.assertEqual(2, iterator.get_next().numpy())
    checkpoint.restore(save_path)
    self.assertEqual(2, iterator.get_next().numpy())
示例#10
0
 def testSaveRestore(self):
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
   dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
   dataset = dataset.map(math_ops.square).batch(2)
   iterator = datasets.Iterator(dataset)
   checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
   self.assertAllEqual([1, 4], iterator.get_next().numpy())
   save_path = checkpoint.save(checkpoint_prefix)
   self.assertAllEqual([9, 16], iterator.get_next().numpy())
   self.assertAllEqual([25, 36], iterator.get_next().numpy())
   checkpoint.restore(save_path)
   self.assertAllEqual([9, 16], iterator.get_next().numpy())
   self.assertAllEqual([25, 36], iterator.get_next().numpy())
    def testSaveRestore(self):
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        mean = metrics.Mean()
        checkpoint = checkpointable_utils.Checkpoint(mean=mean)
        mean.build()
        mean._built = True
        self.evaluate(mean.init_variables())
        self.evaluate(mean(100.))
        self.evaluate(mean(200.))
        save_path = checkpoint.save(checkpoint_prefix)
        self.evaluate(mean(1000.))
        checkpoint.restore(save_path).assert_consumed().run_restore_ops()
        self.evaluate(mean(300.))
        self.assertAllEqual(200., self.evaluate(mean.value()))

        restore_mean = metrics.Mean()
        restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean)
        status = restore_checkpoint.restore(save_path)
        restore_update = restore_mean(300.)
        status.assert_consumed().run_restore_ops()
        self.evaluate(restore_update)
        self.assertAllEqual(200., self.evaluate(restore_mean.value()))
        self.assertEqual(3, self.evaluate(restore_mean.denom))
    def testWithDefun(self):
        num_training_steps = 2
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        for training_continuation in range(3):
            with ops.Graph().as_default(), self.test_session(
                    graph=ops.get_default_graph()), test_util.device(
                        use_gpu=True):
                model = MyModel()
                # Don't actually train so we can test variable values
                optimizer = adam.AdamOptimizer(0.)
                root = checkpointable_utils.Checkpoint(
                    optimizer=optimizer,
                    model=model,
                    global_step=training_util.get_or_create_global_step())
                checkpoint_path = core_saver.latest_checkpoint(
                    checkpoint_directory)
                status = root.restore(save_path=checkpoint_path)

                def train_fn():
                    @function.defun
                    def _call_model(x):
                        return model(x)

                    with backprop.GradientTape() as tape:
                        loss = _call_model(constant_op.constant([[3.]]))
                    gradients = tape.gradient(loss, model.variables)
                    return optimizer.apply_gradients(
                        zip(gradients, model.variables),
                        global_step=root.global_step)

                if not context.executing_eagerly():
                    train_fn = functools.partial(self.evaluate, train_fn())
                status.initialize_or_restore()
                for _ in range(num_training_steps):
                    train_fn()
                if training_continuation > 0:
                    status.assert_consumed()
                    self.assertAllClose([[42.]],
                                        self.evaluate(model.variables[0]))
                else:
                    self.evaluate(model.variables[0].assign([[42.]]))
                root.save(file_prefix=checkpoint_prefix)
                self.assertEqual(
                    (training_continuation + 1) * num_training_steps,
                    self.evaluate(root.global_step))
                self.assertEqual(training_continuation + 1,
                                 self.evaluate(root.save_counter))
 def _initialized_model(self):
   input_value = constant_op.constant([[3.]])
   network = MyNetwork()
   optimizer = adam.AdamOptimizer(0.001)
   optimizer_step = training_util.get_or_create_global_step()
   root_checkpointable = checkpointable_utils.Checkpoint(
       optimizer=optimizer, network=network, optimizer_step=optimizer_step)
   train_op = optimizer.minimize(
       functools.partial(network, input_value),
       global_step=optimizer_step)
   self.evaluate(checkpointable_utils.gather_initializers(
       root_checkpointable))
   self.evaluate(train_op)
   # A regular variable, a slot variable, and a non-slot Optimizer variable
   # with known values to check when loading.
   self.evaluate(network._named_dense.bias.assign([1.]))
   self.evaluate(optimizer.get_slot(
       var=network._named_dense.bias, name="m").assign([2.]))
   beta1_power, _ = optimizer._get_beta_accumulators()
   self.evaluate(beta1_power.assign(3.))
   return root_checkpointable
 def testDeferredRestorationUsageEager(self):
   """An idiomatic eager execution example."""
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   for training_continuation in range(3):
     network = MyNetwork()
     optimizer = adam.AdamOptimizer(0.001)
     root = checkpointable_utils.Checkpoint(
         optimizer=optimizer, network=network,
         optimizer_step=training_util.get_or_create_global_step())
     root.restore(core_saver.latest_checkpoint(checkpoint_directory))
     for _ in range(num_training_steps):
       # TODO(allenl): Use a Dataset and serialize/checkpoint it.
       input_value = constant_op.constant([[3.]])
       optimizer.minimize(
           lambda: network(input_value),  # pylint: disable=cell-var-from-loop
           global_step=root.optimizer_step)
     root.save(file_prefix=checkpoint_prefix)
     self.assertEqual((training_continuation + 1) * num_training_steps,
                      root.optimizer_step.numpy())
 def testNamingWithOptimizer(self):
   input_value = constant_op.constant([[3.]])
   network = MyNetwork()
   # A nuisance Network using the same optimizer. Its slot variables should not
   # go in the checkpoint, since it is never depended on.
   other_network = MyNetwork()
   optimizer = adam.AdamOptimizer(0.001)
   optimizer_step = training_util.get_or_create_global_step()
   root_checkpointable = checkpointable_utils.Checkpoint(
       optimizer=optimizer, network=network, optimizer_step=optimizer_step)
   if context.in_eager_mode():
     optimizer.minimize(
         lambda: network(input_value),
         global_step=optimizer_step)
     optimizer.minimize(
         lambda: other_network(input_value),
         global_step=optimizer_step)
   else:
     train_op = optimizer.minimize(
         network(input_value), global_step=optimizer_step)
     optimizer.minimize(
         other_network(input_value),
         global_step=optimizer_step)
     self.evaluate(checkpointable_utils.gather_initializers(
         root_checkpointable))
     self.evaluate(train_op)
   named_variables, serialized_graph = (
       checkpointable_utils._serialize_object_graph(root_checkpointable))
   expected_checkpoint_names = (
       # Created in the root node, so no prefix.
       "optimizer_step",
       # No name provided to track_checkpointable(), so the position is used
       # instead (one-based).
       "network/via_track_layer/kernel",
       # track_checkpointable() with a name provided, so that's used
       "network/_named_dense/kernel",
       "network/_named_dense/bias",
       # non-Layer dependency of the network
       "network/_non_layer/a_variable",
       # The optimizer creates two non-slot variables
       "optimizer/beta1_power",
       "optimizer/beta2_power",
       # Slot variables
       "network/via_track_layer/kernel/.OPTIMIZER_SLOT/optimizer/m",
       "network/via_track_layer/kernel/.OPTIMIZER_SLOT/optimizer/v",
       "network/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m",
       "network/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v",
       "network/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m",
       "network/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v",
   )
   suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
   expected_checkpoint_names = [
       name + suffix for name in expected_checkpoint_names]
   six.assertCountEqual(self, expected_checkpoint_names,
                        named_variables.keys())
   # Check that we've mapped to the right variable objects (not exhaustive)
   self.assertEqual(
       "global_step:0",
       named_variables["optimizer_step" + suffix].name)
   self.assertEqual(
       "my_network/dense_1/kernel:0",
       named_variables["network/via_track_layer/kernel" + suffix].name)
   self.assertEqual(
       "my_network/dense/kernel:0",
       named_variables["network/_named_dense/kernel" + suffix].name)
   self.assertEqual(
       "beta1_power:0",
       named_variables["optimizer/beta1_power" + suffix].name)
   self.assertEqual(
       "beta2_power:0",
       named_variables["optimizer/beta2_power" + suffix].name)
   # Spot check the generated protocol buffers.
   self.assertEqual("optimizer",
                    serialized_graph.nodes[0].children[1].local_name)
   optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[
       1].node_id]
   self.assertEqual("beta1_power",
                    optimizer_node.children[0].local_name)
   self.assertEqual("beta1_power",
                    serialized_graph.nodes[optimizer_node.children[0].node_id]
                    .attributes[0].full_name)
   self.assertEqual(
       "my_network/dense/kernel",
       serialized_graph.nodes[optimizer_node.slot_variables[0]
                              .original_variable_node_id]
       .attributes[0].full_name)
   # We strip off the :0 suffix, as variable.name-based saving does.
   self.assertEqual(
       "my_network/dense/kernel/Adam",
       serialized_graph.nodes[optimizer_node.slot_variables[0]
                              .slot_variable_node_id]
       .attributes[0].full_name)
   self.assertEqual(
       "my_network/dense/kernel/Adam:0",
       optimizer.get_slot(
           var=named_variables["network/_named_dense/kernel" + suffix],
           name="m").name)
   self.assertEqual(
       "network/_named_dense/kernel" + suffix,
       serialized_graph.nodes[
           optimizer_node.slot_variables[0]
           .original_variable_node_id].attributes[0].checkpoint_key)
   self.assertEqual("m", optimizer_node.slot_variables[0].slot_name)
   self.assertEqual(
       "network/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m" + suffix,
       serialized_graph.nodes[
           optimizer_node.slot_variables[0]
           .slot_variable_node_id].attributes[0].checkpoint_key)
 def testSaveRestore(self):
     model = MyModel()
     optimizer = adam.AdamOptimizer(0.001)
     root_checkpointable = checkpointable_utils.Checkpoint(
         optimizer=optimizer, model=model)
     input_value = constant_op.constant([[3.]])
     if context.executing_eagerly():
         optimizer.minimize(lambda: model(input_value))
     else:
         train_op = optimizer.minimize(model(input_value))
         # TODO(allenl): Make initialization more pleasant when graph building.
         root_checkpointable.save_counter  # pylint: disable=pointless-statement
         self.evaluate(
             checkpointable_utils.gather_initializers(root_checkpointable))
         self.evaluate(train_op)
     prefix = os.path.join(self.get_temp_dir(), "ckpt")
     self.evaluate(state_ops.assign(model._named_dense.variables[1], [42.]))
     m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m")
     self.evaluate(state_ops.assign(m_bias_slot, [1.5]))
     save_path = root_checkpointable.save(file_prefix=prefix)
     self.evaluate(state_ops.assign(model._named_dense.variables[1], [43.]))
     self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3))
     optimizer_variables = self.evaluate(optimizer.variables())
     self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
     # Immediate restoration
     status = root_checkpointable.restore(
         save_path=save_path).assert_consumed()
     status.run_restore_ops()
     self.assertAllEqual([42.],
                         self.evaluate(model._named_dense.variables[1]))
     self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter))
     self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
     if not context.executing_eagerly():
         return  # Restore-on-create is only supported when executing eagerly
     on_create_model = MyModel()
     on_create_optimizer = adam.AdamOptimizer(
         0.001,
         # Preserve beta1_power and beta2_power when appying gradients so we can
         # test that they've been restored correctly.
         beta1=1.0,
         beta2=1.0)
     on_create_root = checkpointable_utils.Checkpoint(
         optimizer=on_create_optimizer, model=on_create_model)
     # Deferred restoration
     status = on_create_root.restore(save_path=save_path)
     on_create_model(constant_op.constant([[3.]]))  # create variables
     self.assertAllEqual(1, self.evaluate(on_create_root.save_counter))
     self.assertAllEqual([42.],
                         self.evaluate(
                             on_create_model._named_dense.variables[1]))
     on_create_m_bias_slot = on_create_optimizer.get_slot(
         on_create_model._named_dense.variables[1], "m")
     # Optimizer slot variables are created when the original variable is
     # restored.
     self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
     self.assertAllEqual(optimizer_variables[2:],
                         self.evaluate(on_create_optimizer.variables()))
     dummy_var = resource_variable_ops.ResourceVariable([1.])
     on_create_optimizer.minimize(loss=dummy_var.read_value)
     status.assert_consumed()
     beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators()
     self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power))
     self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power))
  def testDeferredSlotRestoration(self):
    checkpoint_directory = self.get_temp_dir()

    root = checkpointable.Checkpointable()
    root.var = checkpointable_utils.add_variable(
        root, name="var", initializer=0.)
    optimizer = adam.AdamOptimizer(0.1)
    if context.in_graph_mode():
      train_op = optimizer.minimize(root.var)
      # Note that `optimizer` has not been added as a dependency of
      # `root`. Create a one-off grouping so that slot variables for `root.var`
      # get initialized too.
      self.evaluate(checkpointable_utils.gather_initializers(
          checkpointable_utils.Checkpoint(root=root, optimizer=optimizer)))
      self.evaluate(train_op)
    else:
      optimizer.minimize(root.var.read_value)
    self.evaluate(state_ops.assign(root.var, 12.))
    no_slots_path = checkpointable_utils.CheckpointableSaver(root).save(
        os.path.join(checkpoint_directory, "no_slots"))
    root.optimizer = optimizer
    self.evaluate(state_ops.assign(root.var, 13.))
    self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
                                   14.))
    slots_path = checkpointable_utils.CheckpointableSaver(root).save(
        os.path.join(checkpoint_directory, "with_slots"))
    new_root = checkpointable.Checkpointable()
    # Load the slot-containing checkpoint (deferred), then immediately overwrite
    # the non-slot variable (also deferred).
    slot_status = checkpointable_utils.CheckpointableSaver(
        new_root).restore(slots_path)
    no_slot_status = checkpointable_utils.CheckpointableSaver(
        new_root).restore(no_slots_path)
    with self.assertRaises(AssertionError):
      no_slot_status.assert_consumed()
    new_root.var = checkpointable_utils.add_variable(
        new_root, name="var", shape=[])
    no_slot_status.assert_consumed()
    no_slot_status.run_restore_ops()
    self.assertEqual(12., self.evaluate(new_root.var))
    new_root.optimizer = adam.AdamOptimizer(0.1)
    with self.assertRaisesRegexp(AssertionError, "beta1_power"):
      slot_status.assert_consumed()
    self.assertEqual(12., self.evaluate(new_root.var))
    if context.in_eager_mode():
      # Slot variables are only created with restoring initializers when
      # executing eagerly.
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
    else:
      self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
                    None)
    if context.in_graph_mode():
      train_op = new_root.optimizer.minimize(new_root.var)
      # The slot variable now exists; restore() didn't create it, but we should
      # now have a restore op for it.
      slot_status.run_restore_ops()
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
      self.evaluate(train_op)
    else:
      new_root.optimizer.minimize(new_root.var.read_value)
    slot_status.assert_consumed()