def testDeferredRestorationUsageEager(self):
     """An idiomatic eager execution example."""
     num_training_steps = 10
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
     latest_object_graph = None  # Will be saved with the checkpoint eventually.
     for training_continuation in range(3):
         with ops.Graph().as_default():
             network = MyNetwork()
             optimizer = CheckpointableAdam(0.001)
             root = Root(optimizer=optimizer, network=network)
             checkpointable.restore(save_path=core_saver.latest_checkpoint(
                 checkpoint_directory),
                                    root_checkpointable=root,
                                    object_graph_proto=latest_object_graph)
             for _ in range(num_training_steps):
                 # TODO(allenl): Use a Dataset and serialize/checkpoint it.
                 input_value = constant_op.constant([[3.]])
                 optimizer.minimize(
                     lambda: network(input_value),  # pylint: disable=cell-var-from-loop
                     global_step=root.global_step)
             latest_object_graph, _ = checkpointable.save(
                 file_prefix=checkpoint_prefix, root_checkpointable=root)
             self.assertEqual(
                 (training_continuation + 1) * num_training_steps,
                 root.global_step.numpy())
 def testDeferredRestorationUsageEager(self):
   """An idiomatic eager execution example."""
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   latest_object_graph = None  # Will be saved with the checkpoint eventually.
   for training_continuation in range(3):
     with ops.Graph().as_default():
       network = MyNetwork()
       optimizer = CheckpointableAdam(0.001)
       root = Root(optimizer=optimizer, network=network)
       checkpointable.restore(
           save_path=core_saver.latest_checkpoint(checkpoint_directory),
           root_checkpointable=root,
           object_graph_proto=latest_object_graph)
       for _ in range(num_training_steps):
         # TODO(allenl): Use a Dataset and serialize/checkpoint it.
         input_value = constant_op.constant([[3.]])
         optimizer.minimize(
             lambda: network(input_value),  # pylint: disable=cell-var-from-loop
             global_step=root.global_step)
       latest_object_graph, _ = checkpointable.save(
           file_prefix=checkpoint_prefix,
           root_checkpointable=root)
       self.assertEqual((training_continuation + 1) * num_training_steps,
                        root.global_step.numpy())
 def testUsageGraph(self):
     """Expected usage when graph building."""
     with context.graph_mode():
         num_training_steps = 10
         checkpoint_directory = self.get_temp_dir()
         checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
         latest_object_graph = None
         for training_continuation in range(3):
             with ops.Graph().as_default():
                 network = MyNetwork()
                 optimizer = CheckpointableAdam(0.001)
                 root = Root(optimizer=optimizer, network=network)
                 input_value = constant_op.constant([[3.]])
                 train_op = optimizer.minimize(network(input_value),
                                               global_step=root.global_step)
                 init_op = variables.global_variables_initializer()
                 checkpoint_path = core_saver.latest_checkpoint(
                     checkpoint_directory)
                 with self.test_session(
                         graph=ops.get_default_graph()) as session:
                     if checkpoint_path is None:
                         self.assertEqual(0, training_continuation)
                         session.run(init_op)
                         # Another alternative would be to run initializers automatically
                         # if no checkpoint is being loaded. This would make deferred
                         # loading a bit more useful with graph execution.
                     else:
                         checkpointable.restore(
                             save_path=checkpoint_path,
                             root_checkpointable=root,
                             object_graph_proto=latest_object_graph,
                             session=session)
                     for _ in range(num_training_steps):
                         session.run(train_op)
                     latest_object_graph, _ = checkpointable.save(
                         file_prefix=checkpoint_prefix,
                         root_checkpointable=root,
                         session=session)
                     self.assertEqual(
                         (training_continuation + 1) * num_training_steps,
                         session.run(root.global_step))
 def testUsageGraph(self):
   """Expected usage when graph building."""
   with context.graph_mode():
     num_training_steps = 10
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
     latest_object_graph = None
     for training_continuation in range(3):
       with ops.Graph().as_default():
         network = MyNetwork()
         optimizer = CheckpointableAdam(0.001)
         root = Root(optimizer=optimizer, network=network)
         input_value = constant_op.constant([[3.]])
         train_op = optimizer.minimize(
             network(input_value),
             global_step=root.global_step)
         init_op = variables.global_variables_initializer()
         checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
         with self.test_session(graph=ops.get_default_graph()) as session:
           if checkpoint_path is None:
             self.assertEqual(0, training_continuation)
             session.run(init_op)
             # Another alternative would be to run initializers automatically
             # if no checkpoint is being loaded. This would make deferred
             # loading a bit more useful with graph execution.
           else:
             checkpointable.restore(
                 save_path=checkpoint_path,
                 root_checkpointable=root,
                 object_graph_proto=latest_object_graph,
                 session=session)
           for _ in range(num_training_steps):
             session.run(train_op)
           latest_object_graph, _ = checkpointable.save(
               file_prefix=checkpoint_prefix,
               root_checkpointable=root,
               session=session)
           self.assertEqual((training_continuation + 1) * num_training_steps,
                            session.run(root.global_step))
 def testSaveRestore(self):
     network = MyNetwork()
     optimizer = CheckpointableAdam(0.001)
     root_checkpointable = Root(optimizer=optimizer, network=network)
     input_value = constant_op.constant([[3.]])
     if context.in_eager_mode():
         optimizer.minimize(lambda: network(input_value),
                            global_step=root_checkpointable.global_step)
     else:
         train_op = optimizer.minimize(
             network(input_value),
             global_step=root_checkpointable.global_step)
         self.evaluate(variables.global_variables_initializer())
         self.evaluate(train_op)
     prefix = os.path.join(self.get_temp_dir(), "ckpt")
     self.evaluate(state_ops.assign(network._named.variables[1], [42.]))
     m_bias_slot = optimizer.get_slot(network._named.variables[1], "m")
     self.evaluate(state_ops.assign(m_bias_slot, [1.5]))
     serialized_graph, save_path = checkpointable.save(
         file_prefix=prefix,
         root_checkpointable=root_checkpointable,
         global_step=root_checkpointable.global_step)
     self.evaluate(state_ops.assign(network._named.variables[1], [43.]))
     self.evaluate(state_ops.assign(root_checkpointable.global_step, 3))
     optimizer_variables = self.evaluate(optimizer.variables())
     self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
     # Immediate restoration
     checkpointable.restore(save_path=save_path,
                            root_checkpointable=root_checkpointable,
                            object_graph_proto=serialized_graph)
     self.assertAllEqual([42.], self.evaluate(network._named.variables[1]))
     self.assertAllEqual(1, self.evaluate(root_checkpointable.global_step))
     self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
     with ops.Graph().as_default():
         on_create_network = MyNetwork()
         on_create_optimizer = CheckpointableAdam(0.001)
         on_create_root = Root(optimizer=on_create_optimizer,
                               network=on_create_network)
         with self.test_session(graph=ops.get_default_graph()):
             # Deferred restoration
             checkpointable.restore(save_path=save_path,
                                    root_checkpointable=on_create_root,
                                    object_graph_proto=serialized_graph)
             on_create_network(constant_op.constant([[3.]
                                                     ]))  # create variables
             self.assertAllEqual(1,
                                 self.evaluate(on_create_root.global_step))
             self.assertAllEqual([42.],
                                 self.evaluate(
                                     on_create_network._named.variables[1]))
             on_create_m_bias_slot = on_create_optimizer.get_slot(
                 on_create_network._named.variables[1], "m")
             # Optimizer slot variables are created when the original variable is
             # restored.
             self.assertAllEqual([1.5],
                                 self.evaluate(on_create_m_bias_slot))
             # beta1_power and beta2_power haven't been created yet, but everything
             # else matches.
             self.assertAllEqual(
                 optimizer_variables[2:],
                 self.evaluate(on_create_optimizer.variables()))
             on_create_optimizer._create_slots(
                 [resource_variable_ops.ResourceVariable([1.])])
             beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators(
             )
             self.assertAllEqual(optimizer_variables[0],
                                 self.evaluate(beta1_power))
             self.assertAllEqual(optimizer_variables[1],
                                 self.evaluate(beta2_power))
 def testSaveRestore(self):
   network = MyNetwork()
   optimizer = CheckpointableAdam(0.001)
   root_checkpointable = Root(optimizer=optimizer, network=network)
   input_value = constant_op.constant([[3.]])
   if context.in_eager_mode():
     optimizer.minimize(
         lambda: network(input_value),
         global_step=root_checkpointable.global_step)
   else:
     train_op = optimizer.minimize(
         network(input_value), global_step=root_checkpointable.global_step)
     self.evaluate(variables.global_variables_initializer())
     self.evaluate(train_op)
   prefix = os.path.join(self.get_temp_dir(), "ckpt")
   self.evaluate(state_ops.assign(network._named_dense.variables[1], [42.]))
   m_bias_slot = optimizer.get_slot(network._named_dense.variables[1], "m")
   self.evaluate(state_ops.assign(m_bias_slot, [1.5]))
   serialized_graph, save_path = checkpointable.save(
       file_prefix=prefix,
       root_checkpointable=root_checkpointable,
       global_step=root_checkpointable.global_step)
   self.evaluate(state_ops.assign(network._named_dense.variables[1], [43.]))
   self.evaluate(state_ops.assign(root_checkpointable.global_step, 3))
   optimizer_variables = self.evaluate(optimizer.variables())
   self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
   # Immediate restoration
   checkpointable.restore(
       save_path=save_path,
       root_checkpointable=root_checkpointable,
       object_graph_proto=serialized_graph)
   self.assertAllEqual([42.], self.evaluate(network._named_dense.variables[1]))
   self.assertAllEqual(1, self.evaluate(root_checkpointable.global_step))
   self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
   with ops.Graph().as_default():
     on_create_network = MyNetwork()
     on_create_optimizer = CheckpointableAdam(0.001)
     on_create_root = Root(
         optimizer=on_create_optimizer, network=on_create_network)
     with self.test_session(graph=ops.get_default_graph()):
       # Deferred restoration
       checkpointable.restore(
           save_path=save_path,
           root_checkpointable=on_create_root,
           object_graph_proto=serialized_graph)
       on_create_network(constant_op.constant([[3.]]))  # create variables
       self.assertAllEqual(1, self.evaluate(on_create_root.global_step))
       self.assertAllEqual([42.],
                           self.evaluate(
                               on_create_network._named_dense.variables[1]))
       on_create_m_bias_slot = on_create_optimizer.get_slot(
           on_create_network._named_dense.variables[1], "m")
       # Optimizer slot variables are created when the original variable is
       # restored.
       self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
       # beta1_power and beta2_power haven't been created yet, but everything
       # else matches.
       self.assertAllEqual(optimizer_variables[2:],
                           self.evaluate(on_create_optimizer.variables()))
       on_create_optimizer._create_slots(
           [resource_variable_ops.ResourceVariable([1.])])
       beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators()
       self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power))
       self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power))