def testDeferredRestorationUsageEager(self): """An idiomatic eager execution example.""" num_training_steps = 10 checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") latest_object_graph = None # Will be saved with the checkpoint eventually. for training_continuation in range(3): with ops.Graph().as_default(): network = MyNetwork() optimizer = CheckpointableAdam(0.001) root = Root(optimizer=optimizer, network=network) checkpointable.restore(save_path=core_saver.latest_checkpoint( checkpoint_directory), root_checkpointable=root, object_graph_proto=latest_object_graph) for _ in range(num_training_steps): # TODO(allenl): Use a Dataset and serialize/checkpoint it. input_value = constant_op.constant([[3.]]) optimizer.minimize( lambda: network(input_value), # pylint: disable=cell-var-from-loop global_step=root.global_step) latest_object_graph, _ = checkpointable.save( file_prefix=checkpoint_prefix, root_checkpointable=root) self.assertEqual( (training_continuation + 1) * num_training_steps, root.global_step.numpy())
def testDeferredRestorationUsageEager(self): """An idiomatic eager execution example.""" num_training_steps = 10 checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") latest_object_graph = None # Will be saved with the checkpoint eventually. for training_continuation in range(3): with ops.Graph().as_default(): network = MyNetwork() optimizer = CheckpointableAdam(0.001) root = Root(optimizer=optimizer, network=network) checkpointable.restore( save_path=core_saver.latest_checkpoint(checkpoint_directory), root_checkpointable=root, object_graph_proto=latest_object_graph) for _ in range(num_training_steps): # TODO(allenl): Use a Dataset and serialize/checkpoint it. input_value = constant_op.constant([[3.]]) optimizer.minimize( lambda: network(input_value), # pylint: disable=cell-var-from-loop global_step=root.global_step) latest_object_graph, _ = checkpointable.save( file_prefix=checkpoint_prefix, root_checkpointable=root) self.assertEqual((training_continuation + 1) * num_training_steps, root.global_step.numpy())
def testUsageGraph(self): """Expected usage when graph building.""" with context.graph_mode(): num_training_steps = 10 checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") latest_object_graph = None for training_continuation in range(3): with ops.Graph().as_default(): network = MyNetwork() optimizer = CheckpointableAdam(0.001) root = Root(optimizer=optimizer, network=network) input_value = constant_op.constant([[3.]]) train_op = optimizer.minimize(network(input_value), global_step=root.global_step) init_op = variables.global_variables_initializer() checkpoint_path = core_saver.latest_checkpoint( checkpoint_directory) with self.test_session( graph=ops.get_default_graph()) as session: if checkpoint_path is None: self.assertEqual(0, training_continuation) session.run(init_op) # Another alternative would be to run initializers automatically # if no checkpoint is being loaded. This would make deferred # loading a bit more useful with graph execution. else: checkpointable.restore( save_path=checkpoint_path, root_checkpointable=root, object_graph_proto=latest_object_graph, session=session) for _ in range(num_training_steps): session.run(train_op) latest_object_graph, _ = checkpointable.save( file_prefix=checkpoint_prefix, root_checkpointable=root, session=session) self.assertEqual( (training_continuation + 1) * num_training_steps, session.run(root.global_step))
def testUsageGraph(self): """Expected usage when graph building.""" with context.graph_mode(): num_training_steps = 10 checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") latest_object_graph = None for training_continuation in range(3): with ops.Graph().as_default(): network = MyNetwork() optimizer = CheckpointableAdam(0.001) root = Root(optimizer=optimizer, network=network) input_value = constant_op.constant([[3.]]) train_op = optimizer.minimize( network(input_value), global_step=root.global_step) init_op = variables.global_variables_initializer() checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) with self.test_session(graph=ops.get_default_graph()) as session: if checkpoint_path is None: self.assertEqual(0, training_continuation) session.run(init_op) # Another alternative would be to run initializers automatically # if no checkpoint is being loaded. This would make deferred # loading a bit more useful with graph execution. else: checkpointable.restore( save_path=checkpoint_path, root_checkpointable=root, object_graph_proto=latest_object_graph, session=session) for _ in range(num_training_steps): session.run(train_op) latest_object_graph, _ = checkpointable.save( file_prefix=checkpoint_prefix, root_checkpointable=root, session=session) self.assertEqual((training_continuation + 1) * num_training_steps, session.run(root.global_step))
def testSaveRestore(self): network = MyNetwork() optimizer = CheckpointableAdam(0.001) root_checkpointable = Root(optimizer=optimizer, network=network) input_value = constant_op.constant([[3.]]) if context.in_eager_mode(): optimizer.minimize(lambda: network(input_value), global_step=root_checkpointable.global_step) else: train_op = optimizer.minimize( network(input_value), global_step=root_checkpointable.global_step) self.evaluate(variables.global_variables_initializer()) self.evaluate(train_op) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(state_ops.assign(network._named.variables[1], [42.])) m_bias_slot = optimizer.get_slot(network._named.variables[1], "m") self.evaluate(state_ops.assign(m_bias_slot, [1.5])) serialized_graph, save_path = checkpointable.save( file_prefix=prefix, root_checkpointable=root_checkpointable, global_step=root_checkpointable.global_step) self.evaluate(state_ops.assign(network._named.variables[1], [43.])) self.evaluate(state_ops.assign(root_checkpointable.global_step, 3)) optimizer_variables = self.evaluate(optimizer.variables()) self.evaluate(state_ops.assign(m_bias_slot, [-2.])) # Immediate restoration checkpointable.restore(save_path=save_path, root_checkpointable=root_checkpointable, object_graph_proto=serialized_graph) self.assertAllEqual([42.], self.evaluate(network._named.variables[1])) self.assertAllEqual(1, self.evaluate(root_checkpointable.global_step)) self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) with ops.Graph().as_default(): on_create_network = MyNetwork() on_create_optimizer = CheckpointableAdam(0.001) on_create_root = Root(optimizer=on_create_optimizer, network=on_create_network) with self.test_session(graph=ops.get_default_graph()): # Deferred restoration checkpointable.restore(save_path=save_path, root_checkpointable=on_create_root, object_graph_proto=serialized_graph) on_create_network(constant_op.constant([[3.] ])) # create variables self.assertAllEqual(1, self.evaluate(on_create_root.global_step)) self.assertAllEqual([42.], self.evaluate( on_create_network._named.variables[1])) on_create_m_bias_slot = on_create_optimizer.get_slot( on_create_network._named.variables[1], "m") # Optimizer slot variables are created when the original variable is # restored. self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot)) # beta1_power and beta2_power haven't been created yet, but everything # else matches. self.assertAllEqual( optimizer_variables[2:], self.evaluate(on_create_optimizer.variables())) on_create_optimizer._create_slots( [resource_variable_ops.ResourceVariable([1.])]) beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators( ) self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power)) self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power))
def testSaveRestore(self): network = MyNetwork() optimizer = CheckpointableAdam(0.001) root_checkpointable = Root(optimizer=optimizer, network=network) input_value = constant_op.constant([[3.]]) if context.in_eager_mode(): optimizer.minimize( lambda: network(input_value), global_step=root_checkpointable.global_step) else: train_op = optimizer.minimize( network(input_value), global_step=root_checkpointable.global_step) self.evaluate(variables.global_variables_initializer()) self.evaluate(train_op) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(state_ops.assign(network._named_dense.variables[1], [42.])) m_bias_slot = optimizer.get_slot(network._named_dense.variables[1], "m") self.evaluate(state_ops.assign(m_bias_slot, [1.5])) serialized_graph, save_path = checkpointable.save( file_prefix=prefix, root_checkpointable=root_checkpointable, global_step=root_checkpointable.global_step) self.evaluate(state_ops.assign(network._named_dense.variables[1], [43.])) self.evaluate(state_ops.assign(root_checkpointable.global_step, 3)) optimizer_variables = self.evaluate(optimizer.variables()) self.evaluate(state_ops.assign(m_bias_slot, [-2.])) # Immediate restoration checkpointable.restore( save_path=save_path, root_checkpointable=root_checkpointable, object_graph_proto=serialized_graph) self.assertAllEqual([42.], self.evaluate(network._named_dense.variables[1])) self.assertAllEqual(1, self.evaluate(root_checkpointable.global_step)) self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) with ops.Graph().as_default(): on_create_network = MyNetwork() on_create_optimizer = CheckpointableAdam(0.001) on_create_root = Root( optimizer=on_create_optimizer, network=on_create_network) with self.test_session(graph=ops.get_default_graph()): # Deferred restoration checkpointable.restore( save_path=save_path, root_checkpointable=on_create_root, object_graph_proto=serialized_graph) on_create_network(constant_op.constant([[3.]])) # create variables self.assertAllEqual(1, self.evaluate(on_create_root.global_step)) self.assertAllEqual([42.], self.evaluate( on_create_network._named_dense.variables[1])) on_create_m_bias_slot = on_create_optimizer.get_slot( on_create_network._named_dense.variables[1], "m") # Optimizer slot variables are created when the original variable is # restored. self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot)) # beta1_power and beta2_power haven't been created yet, but everything # else matches. self.assertAllEqual(optimizer_variables[2:], self.evaluate(on_create_optimizer.variables())) on_create_optimizer._create_slots( [resource_variable_ops.ResourceVariable([1.])]) beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators() self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power)) self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power))