def _initialized_model(self): input_value = constant_op.constant([[3.]]) model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() root_checkpointable = util.Checkpoint(optimizer=optimizer, model=model, optimizer_step=optimizer_step) train_op = optimizer.minimize(functools.partial(model, input_value), global_step=optimizer_step) self.evaluate(util.gather_initializers(root_checkpointable)) self.evaluate(train_op) # A regular variable, a slot variable, and a non-slot Optimizer variable # with known values to check when loading. self.evaluate(model._named_dense.bias.assign([1.])) self.evaluate( optimizer.get_slot(var=model._named_dense.bias, name="m").assign([2.])) beta_1_power, _ = optimizer._get_beta_accumulators() self.evaluate(beta_1_power.assign(3.)) return root_checkpointable
def testDictionariesBasic(self): a = training.Model() b = training.Model() a.attribute = {"b": b} c = training.Model() a.attribute["c"] = [] a.attribute["c"].append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) self.assertIs(b, a.attribute["b"]) six.assertCountEqual( self, ["b", "c"], [dep.name for dep in a.attribute._checkpoint_dependencies]) self.assertEqual([b, c], a.layers) self.assertEqual([b, c], a.attribute.layers) self.assertEqual([c], a.attribute["c"].layers) checkpoint = util.Checkpoint(a=a) save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) checkpoint.restore(save_path).assert_consumed()
def testDeferredRestorationUsageEager(self): """An idiomatic eager execution example.""" num_training_steps = 10 checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") for training_continuation in range(3): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model, optimizer_step=training_util.get_or_create_global_step()) root.restore(core_saver.latest_checkpoint(checkpoint_directory)) for _ in range(num_training_steps): # TODO(allenl): Use a Dataset and serialize/checkpoint it. input_value = constant_op.constant([[3.]]) optimizer.minimize( lambda: model(input_value), # pylint: disable=cell-var-from-loop global_step=root.optimizer_step) root.save(file_prefix=checkpoint_prefix) self.assertEqual((training_continuation + 1) * num_training_steps, root.optimizer_step.numpy())
def testWithDefun(self): num_training_steps = 2 checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") for training_continuation in range(3): with test_util.device(use_gpu=True): model = MyModel() # Don't actually train so we can test variable values optimizer = adam.AdamOptimizer(0.) root = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) checkpoint_path = checkpoint_management.latest_checkpoint( checkpoint_directory) status = root.restore(save_path=checkpoint_path) def train_fn(): @def_function.function def _call_model(x): return model(x) with backprop.GradientTape() as tape: loss = _call_model(constant_op.constant([[3.]])) gradients = tape.gradient(loss, model.variables) return optimizer.apply_gradients(zip(gradients, model.variables), global_step=root.global_step) if not context.executing_eagerly(): train_fn = functools.partial( self.evaluate, train_fn()) status.initialize_or_restore() for _ in range(num_training_steps): train_fn() if training_continuation > 0: status.assert_consumed() self.assertAllClose([[42.]], self.evaluate(model.variables[0])) else: self.evaluate(model.variables[0].assign([[42.]])) root.save(file_prefix=checkpoint_prefix) self.assertEqual((training_continuation + 1) * num_training_steps, self.evaluate(root.global_step)) self.assertEqual(training_continuation + 1, self.evaluate(root.save_counter))
def test_table(self, cycles): # TODO(b/123408779): Handle generic TrackableResources and enable this test self.skipTest("Need to handle generic TrackableResources") vocab_path = self._make_asset("alpha\nbeta\ngamma\n") initializer = lookup_ops.TextFileInitializer( vocab_path, key_dtype=dtypes.string, key_index=lookup_ops.TextFileIndex.WHOLE_LINE, value_dtype=dtypes.int64, value_index=lookup_ops.TextFileIndex.LINE_NUMBER) root = util.Checkpoint( table=lookup_ops.HashTable(initializer, default_value=-1)) root.table_user = def_function.function( root.table.lookup, input_signature=[tensor_spec.TensorSpec(None, dtypes.string)]) self.assertEqual( 2, root.table_user(constant_op.constant("gamma")).numpy()) imported = self.cycle(root, cycles) self.assertEqual( 2, imported.table_user(constant_op.constant("gamma")).numpy())
def testUsageGraph(self): """Expected usage when graph building.""" with context.graph_mode(): num_training_steps = 10 checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") for training_continuation in range(3): with ops.Graph().as_default(): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) input_value = constant_op.constant([[3.]]) train_op = optimizer.minimize(model(input_value), global_step=root.global_step) checkpoint_path = core_saver.latest_checkpoint( checkpoint_directory) with self.test_session( graph=ops.get_default_graph()) as session: status = root.restore(save_path=checkpoint_path) status.initialize_or_restore(session=session) if checkpoint_path is None: self.assertEqual(0, training_continuation) with self.assertRaises(AssertionError): status.assert_consumed() else: status.assert_consumed() for _ in range(num_training_steps): session.run(train_op) root.save(file_prefix=checkpoint_prefix, session=session) self.assertEqual( (training_continuation + 1) * num_training_steps, session.run(root.global_step)) self.assertEqual(training_continuation + 1, session.run(root.save_counter))
def testCustomNumbering(self): directory = self.get_temp_dir() step = variables.Variable(0, dtype=dtypes.int64) checkpoint = util.Checkpoint(step=step) manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2) self.evaluate(step.initializer) for i in range(5): path = manager.save(checkpoint_number=step) expected_suffix = "-%d" % (2 * i,) if not path.endswith(expected_suffix): self.fail("%s should have suffix %s" % (path, expected_suffix)) self.evaluate(step.assign_add(2)) self.assertEqual(5, self.evaluate(checkpoint.save_counter)) # Test regular integers last_path = manager.save(checkpoint_number=32) self.assertIn("-32", last_path) self.assertEqual(last_path, manager.latest_checkpoint) self.assertEqual( last_path, checkpoint_management.latest_checkpoint(directory)) state = checkpoint_management.get_checkpoint_state(directory) # Only the most recent two checkpoints are saved self.assertEqual([path, last_path], state.all_model_checkpoint_paths)
def testDeferredRestorationUsageEager(self): """An idiomatic eager execution example.""" num_training_steps = 10 checkpoint_directory = self.get_temp_dir() for training_continuation in range(3): with self.test_scope(): model = Subclassed() optimizer = adam.Adam(0.001) root = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model) manager = checkpoint_management.CheckpointManager( root, checkpoint_directory, max_to_keep=2) root.restore(manager.latest_checkpoint) for _ in range(num_training_steps): input_value = constant_op.constant([[3.]]) with backprop.GradientTape() as tape: loss = model(input_value) variables = model.trainable_variables gradients = tape.gradient(loss, variables) optimizer.apply_gradients(zip(gradients, variables)) manager.save() self.assertEqual((training_continuation + 1) * num_training_steps, root.optimizer.iterations.numpy())
def testGraphDistributionStrategy(self): self.skipTest("b/121381184") num_training_steps = 10 checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") def _train_fn(optimizer, model): input_value = constant_op.constant([[3.]]) return optimizer.minimize(functools.partial(model, input_value), global_step=root.optimizer_step) for training_continuation in range(3): with ops.Graph().as_default(): strategy = mirrored_strategy.MirroredStrategy() with strategy.scope(): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model, optimizer_step=training_util.get_or_create_global_step( )) status = root.restore( checkpoint_management.latest_checkpoint( checkpoint_directory)) train_op = strategy.extended.call_for_each_replica( functools.partial(_train_fn, optimizer, model)) with self.session() as session: if training_continuation > 0: status.assert_consumed() status.initialize_or_restore() for _ in range(num_training_steps): session.run(train_op) root.save(file_prefix=checkpoint_prefix) self.assertEqual( (training_continuation + 1) * num_training_steps, root.optimizer_step.numpy())
def testMultipleGraphsNonSlotVariables(self): with context.graph_mode(): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") optimizer = adam.AdamOptimizer(0.001) # Construct a model in one graph first_graph = ops.Graph() first_session = session_lib.Session(graph=first_graph) with first_graph.as_default(), first_session.as_default(): first_variable = resource_variable_ops.ResourceVariable([1.]) first_root_checkpointable = checkpointable_utils.Checkpoint( optimizer=optimizer, variable=first_variable) train_op = optimizer.minimize(first_variable.read_value) self.evaluate( checkpointable_utils.gather_initializers( first_root_checkpointable)) self.evaluate(train_op) self.evaluate(first_variable.assign([1.])) self.evaluate( optimizer.get_slot(var=first_variable, name="m").assign([2.])) beta1_power, _ = optimizer._get_beta_accumulators() self.evaluate(beta1_power.assign(3.)) # Save and load in a second graph second_graph = ops.Graph() with second_graph.as_default(), session_lib.Session( graph=second_graph): second_variable = resource_variable_ops.ResourceVariable([1.]) second_root_checkpointable = checkpointable_utils.Checkpoint( optimizer=optimizer, variable=second_variable) train_op = optimizer.minimize(second_variable.read_value) second_root_checkpointable.restore( None).initialize_or_restore() self.evaluate(train_op) self.evaluate(second_variable.assign([4.])) self.evaluate( optimizer.get_slot(var=second_variable, name="m").assign([5.])) beta1_power, _ = optimizer._get_beta_accumulators() self.evaluate(beta1_power.assign(6.)) save_path = second_root_checkpointable.save(checkpoint_prefix) self.evaluate(second_variable.assign([7.])) self.evaluate( optimizer.get_slot(var=second_variable, name="m").assign([8.])) beta1_power, _ = optimizer._get_beta_accumulators() self.assertAllEqual(6., self.evaluate(beta1_power)) status = second_root_checkpointable.restore(save_path) status.assert_consumed().run_restore_ops() self.assertAllEqual([4.], self.evaluate(second_variable)) self.assertAllEqual([5.], self.evaluate( optimizer.get_slot(var=second_variable, name="m"))) beta1_power, _ = optimizer._get_beta_accumulators() self.assertAllEqual(6., self.evaluate(beta1_power)) # Check that the first graph is unmolested with first_graph.as_default(), first_session.as_default(): self.assertAllEqual([1.], self.evaluate(first_variable)) self.assertAllEqual([2.], self.evaluate( optimizer.get_slot(var=first_variable, name="m"))) beta1_power, _ = optimizer._get_beta_accumulators() self.assertAllEqual(3., self.evaluate(beta1_power))
def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() root = checkpointable.Checkpointable() root.var = checkpointable_utils.add_variable(root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) if context.executing_eagerly(): optimizer.minimize(root.var.read_value) else: train_op = optimizer.minimize(root.var) # Note that `optimizer` has not been added as a dependency of # `root`. Create a one-off grouping so that slot variables for `root.var` # get initialized too. self.evaluate( checkpointable_utils.gather_initializers( checkpointable_utils.Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) self.evaluate(state_ops.assign(root.var, 12.)) no_slots_path = checkpointable_utils.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate( state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) slots_path = checkpointable_utils.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "with_slots")) new_root = checkpointable.Checkpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). slot_status = checkpointable_utils.CheckpointableSaver( new_root).restore(slots_path) no_slot_status = checkpointable_utils.CheckpointableSaver( new_root).restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = checkpointable_utils.add_variable(new_root, name="var", shape=[]) no_slot_status.assert_consumed() no_slot_status.run_restore_ops() self.assertEqual(12., self.evaluate(new_root.var)) new_root.optimizer = adam.AdamOptimizer(0.1) with self.assertRaisesRegexp(AssertionError, "beta1_power"): slot_status.assert_consumed() self.assertEqual(12., self.evaluate(new_root.var)) if context.executing_eagerly(): # Slot variables are only created with restoring initializers when # executing eagerly. self.assertEqual( 14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) else: self.assertIs( new_root.optimizer.get_slot(name="m", var=new_root.var), None) if context.executing_eagerly(): new_root.optimizer.minimize(new_root.var.read_value) else: train_op = new_root.optimizer.minimize(new_root.var) # The slot variable now exists; restore() didn't create it, but we should # now have a restore op for it. slot_status.run_restore_ops() self.assertEqual( 14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) self.evaluate(train_op) slot_status.assert_consumed()
def testSaveRestore(self): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root_checkpointable = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model) input_value = constant_op.constant([[3.]]) if context.executing_eagerly(): optimizer.minimize(lambda: model(input_value)) else: train_op = optimizer.minimize(model(input_value)) # TODO(allenl): Make initialization more pleasant when graph building. root_checkpointable.save_counter # pylint: disable=pointless-statement self.evaluate( checkpointable_utils.gather_initializers(root_checkpointable)) self.evaluate(train_op) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(state_ops.assign(model._named_dense.variables[1], [42.])) m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m") self.evaluate(state_ops.assign(m_bias_slot, [1.5])) save_path = root_checkpointable.save(file_prefix=prefix) self.evaluate(state_ops.assign(model._named_dense.variables[1], [43.])) self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3)) optimizer_variables = self.evaluate(optimizer.variables()) self.evaluate(state_ops.assign(m_bias_slot, [-2.])) # Immediate restoration status = root_checkpointable.restore( save_path=save_path).assert_consumed() status.run_restore_ops() self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1])) self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter)) self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) if not context.executing_eagerly(): return # Restore-on-create is only supported when executing eagerly on_create_model = MyModel() on_create_optimizer = adam.AdamOptimizer( 0.001, # Preserve beta1_power and beta2_power when appying gradients so we can # test that they've been restored correctly. beta1=1.0, beta2=1.0) on_create_root = checkpointable_utils.Checkpoint( optimizer=on_create_optimizer, model=on_create_model) # Deferred restoration status = on_create_root.restore(save_path=save_path) on_create_model(constant_op.constant([[3.]])) # create variables self.assertAllEqual(1, self.evaluate(on_create_root.save_counter)) self.assertAllEqual([42.], self.evaluate( on_create_model._named_dense.variables[1])) on_create_m_bias_slot = on_create_optimizer.get_slot( on_create_model._named_dense.variables[1], "m") # Optimizer slot variables are created when the original variable is # restored. self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot)) self.assertAllEqual(optimizer_variables[2:], self.evaluate(on_create_optimizer.variables())) dummy_var = resource_variable_ops.ResourceVariable([1.]) on_create_optimizer.minimize(loss=dummy_var.read_value) status.assert_consumed() beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators() self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power)) self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power))
def testNamingWithOptimizer(self): input_value = constant_op.constant([[3.]]) model = MyModel() # A nuisance Model using the same optimizer. Its slot variables should not # go in the checkpoint, since it is never depended on. other_model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() root_checkpointable = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) if context.executing_eagerly(): optimizer.minimize(lambda: model(input_value), global_step=optimizer_step) optimizer.minimize(lambda: other_model(input_value), global_step=optimizer_step) else: train_op = optimizer.minimize(model(input_value), global_step=optimizer_step) optimizer.minimize(other_model(input_value), global_step=optimizer_step) self.evaluate( checkpointable_utils.gather_initializers(root_checkpointable)) self.evaluate(train_op) named_variables, serialized_graph, _ = ( checkpointable_utils._serialize_object_graph(root_checkpointable, saveables_cache=None)) expected_checkpoint_names = ( # Created in the root node, so no prefix. "optimizer_step", "model/_second/kernel", "model/_named_dense/kernel", "model/_named_dense/bias", # non-Layer dependency of the model "model/_non_layer/a_variable", # The optimizer creates two non-slot variables "optimizer/beta1_power", "optimizer/beta2_power", # Slot variables "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m", "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v", "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m", "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v", "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m", "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v", ) suffix = "/.ATTRIBUTES/VARIABLE_VALUE" expected_checkpoint_names = [ name + suffix for name in expected_checkpoint_names ] # The Dense layers also save get_config() JSON expected_checkpoint_names.extend([ "model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON", "model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON" ]) named_variables = {v.name: v for v in named_variables} six.assertCountEqual(self, expected_checkpoint_names, named_variables.keys()) # Check that we've mapped to the right variable objects (not exhaustive) self.assertEqual("global_step", named_variables["optimizer_step" + suffix].full_name) self.assertEqual( "my_model/dense_1/kernel", named_variables["model/_second/kernel" + suffix].full_name) self.assertEqual( "my_model/dense/kernel", named_variables["model/_named_dense/kernel" + suffix].full_name) self.assertEqual( "beta1_power", named_variables["optimizer/beta1_power" + suffix].full_name) self.assertEqual( "beta2_power", named_variables["optimizer/beta2_power" + suffix].full_name) # Spot check the generated protocol buffers. self.assertEqual("optimizer", serialized_graph.nodes[0].children[1].local_name) optimizer_node = serialized_graph.nodes[ serialized_graph.nodes[0].children[1].node_id] self.assertEqual("beta1_power", optimizer_node.children[0].local_name) self.assertEqual( "beta1_power", serialized_graph.nodes[ optimizer_node.children[0].node_id].attributes[0].full_name) self.assertEqual( "my_model/dense/kernel", serialized_graph.nodes[optimizer_node.slot_variables[ 0].original_variable_node_id].attributes[0].full_name) # We strip off the :0 suffix, as variable.name-based saving does. self.assertEqual( "my_model/dense/kernel/Adam", serialized_graph.nodes[optimizer_node.slot_variables[ 0].slot_variable_node_id].attributes[0].full_name) self.assertEqual( "my_model/dense/kernel/Adam:0", optimizer.get_slot(var=model._named_dense.kernel, name="m").name) self.assertEqual( "model/_named_dense/kernel" + suffix, serialized_graph.nodes[optimizer_node.slot_variables[ 0].original_variable_node_id].attributes[0].checkpoint_key) self.assertEqual("m", optimizer_node.slot_variables[0].slot_name) self.assertEqual( "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m" + suffix, serialized_graph.nodes[optimizer_node.slot_variables[ 0].slot_variable_node_id].attributes[0].checkpoint_key)
def testSaveRestoreState(self, mock_time): directory = self.get_temp_dir() mock_time.time.return_value = 3. checkpoint = util.Checkpoint() first_manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2) first_time = 10000. first_name = os.path.join(directory, "ckpt-1") mock_time.time.return_value = first_time first_manager.save() state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual([first_time], state.all_model_checkpoint_timestamps) self.assertEqual(3., state.last_preserved_timestamp) second_time = first_time + 3610. second_name = os.path.join(directory, "ckpt-2") mock_time.time.return_value = second_time first_manager.save() state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual([first_time, second_time], state.all_model_checkpoint_timestamps) self.assertEqual(3., state.last_preserved_timestamp) self.assertEqual([first_name, second_name], first_manager.checkpoints) self.assertEqual(second_name, first_manager.latest_checkpoint) del first_manager second_manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2, keep_checkpoint_every_n_hours=1.5) self.assertEqual([first_name, second_name], second_manager.checkpoints) self.assertEqual(second_name, second_manager.latest_checkpoint) third_name = os.path.join(directory, "ckpt-3") third_time = second_time + 3600. * 0.2 mock_time.time.return_value = third_time second_manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) self.assertTrue(checkpoint_management.checkpoint_exists(second_name)) self.assertEqual([second_name, third_name], second_manager.checkpoints) state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(first_time, state.last_preserved_timestamp) fourth_time = third_time + 3600. * 0.5 mock_time.time.return_value = fourth_time fourth_name = os.path.join(directory, "ckpt-4") second_manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) self.assertFalse(checkpoint_management.checkpoint_exists(second_name)) self.assertEqual([third_name, fourth_name], second_manager.checkpoints) fifth_time = fourth_time + 3600. * 0.5 mock_time.time.return_value = fifth_time fifth_name = os.path.join(directory, "ckpt-5") second_manager.save() self.assertEqual([fourth_name, fifth_name], second_manager.checkpoints) state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(first_time, state.last_preserved_timestamp) del second_manager third_manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2, keep_checkpoint_every_n_hours=1.5) self.assertEqual(fifth_name, third_manager.latest_checkpoint) mock_time.time.return_value += 10. third_manager.save() sixth_name = os.path.join(directory, "ckpt-6") state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(fourth_time, state.last_preserved_timestamp) self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name)) self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name)) self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name)) self.assertFalse(checkpoint_management.checkpoint_exists(second_name)) self.assertFalse(checkpoint_management.checkpoint_exists(third_name)) self.assertEqual([fifth_name, sixth_name], third_manager.checkpoints)
def test_initialize_if_not_restoring(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") optimizer_only_prefix = os.path.join(checkpoint_directory, "opt") with test_util.device(use_gpu=True): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root = checkpointable_utils.Checkpoint( model=model, # Do not save the optimizer with the checkpoint. global_step=training_util.get_or_create_global_step()) optimizer_checkpoint = checkpointable_utils.Checkpoint( optimizer=optimizer) checkpoint_path = checkpoint_management.latest_checkpoint( checkpoint_directory) status = root.restore(save_path=checkpoint_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial( optimizer.minimize, functools.partial(model, input_value), global_step=root.global_step) if not context.executing_eagerly(): train_fn = functools.partial(self.evaluate, train_fn()) status.initialize_or_restore() self.evaluate([v.initializer for v in optimizer.variables()]) train_fn() model_save_path = root.save(file_prefix=checkpoint_prefix) self.evaluate(optimizer.variables()[0].assign(42.)) optimizer_save_path = optimizer_checkpoint.save(optimizer_only_prefix) # Restore into a graph with the optimizer with test_util.device(use_gpu=True): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) status = root.restore(save_path=model_save_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial( optimizer.minimize, functools.partial(model, input_value), global_step=root.global_step) if not context.executing_eagerly(): train_fn = functools.partial(self.evaluate, train_fn()) status.initialize_or_restore() train_fn() with self.assertRaises(AssertionError): status.assert_existing_objects_matched() with self.assertRaises(AssertionError): status.assert_consumed() # Make sure initialization doesn't clobber later restores with test_util.device(use_gpu=True): model = MyModel() optimizer = adam.AdamOptimizer(0.001, beta1=1.0) root = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) opt_root = checkpointable_utils.Checkpoint( optimizer=optimizer) status = root.restore(save_path=model_save_path) init_only_optimizer_status = opt_root.restore(save_path=None) optimizer_status = opt_root.restore(save_path=optimizer_save_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial( optimizer.minimize, functools.partial(model, input_value), global_step=root.global_step) if not context.executing_eagerly(): train_fn = functools.partial(self.evaluate, train_fn()) optimizer_status.run_restore_ops() status.initialize_or_restore() init_only_optimizer_status.initialize_or_restore() train_fn() self.assertEqual(42., self.evaluate(optimizer.variables()[0]))
def test_checkpointing_not_implemented(self): checkpoint_directory = self.get_temp_dir() checkpoint = checkpointable_utils.Checkpoint(net=MyNetwork()) with self.assertRaises(NotImplementedError): checkpoint.save(checkpoint_directory)