def testCustomCheckpointPrefix(self): directory = self.get_temp_dir() checkpoint = util.Checkpoint() manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2, checkpoint_name="ckpt_name") path = manager.save(checkpoint_number=5) self.assertEqual(os.path.basename(path), "ckpt_name-5") manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2) path = manager.save(checkpoint_number=5) self.assertEqual(os.path.basename(path), "ckpt-5")
def testClockReset(self, mock_time): directory = self.get_temp_dir() mock_time.time.return_value = 10000. checkpoint = util.Checkpoint() first_manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=1, keep_checkpoint_every_n_hours=1.) first_path = first_manager.save() mock_time.time.return_value += 3600. second_path = first_manager.save() mock_time.time.return_value += 3600. third_path = first_manager.save() self.assertFalse(checkpoint_management.checkpoint_exists(first_path)) self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertEqual([third_path], first_manager.checkpoints) state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(13600., state.last_preserved_timestamp) # Set the clock back in time mock_time.time.return_value = 5000. del first_manager with test.mock.patch.object(logging, "warning") as mock_log: second_manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=1) self.assertRegexpMatches( str(mock_log.call_args), "behind the last preserved checkpoint timestamp") # We should err on the side of keeping checkpoints around when we're not # sure whether they were preserved or not due to clock funkiness. self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) # We know about the existing checkpoints, but they'll never be deleted and # so won't go in the CheckpointState proto on save. self.assertEqual(third_path, second_manager.latest_checkpoint) self.assertEqual([], second_manager.checkpoints) mock_time.time.return_value += 10. fourth_path = second_manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertEqual(fourth_path, second_manager.latest_checkpoint) self.assertEqual([fourth_path], second_manager.checkpoints) mock_time.time.return_value += 10. fifth_path = second_manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertEqual([fifth_path], second_manager.checkpoints) state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(5000., state.last_preserved_timestamp) self.assertEqual([5020.], state.all_model_checkpoint_timestamps)
def testRestoreOrInitialize(self): directory = self.get_temp_dir() # Create a checkpoint for initializing. init_prefix = os.path.join(directory, "init") init_v = variables.Variable(2.0) init_ckpt = util.Checkpoint(v=init_v) self.evaluate(init_v.initializer) init_path = init_ckpt.save(init_prefix) # Create the checkpoint manager. ckpt_dir = os.path.join(directory, "ckpt") v = variables.Variable(1.0) checkpoint = util.Checkpoint(v=v) manager = checkpoint_management.CheckpointManager( checkpoint, ckpt_dir, max_to_keep=None, init_fn=lambda: checkpoint.restore(init_path).run_restore_ops()) self.evaluate(v.initializer) # First call should call `init_fn`. self.assertIsNone(manager.restore_or_initialize()) self.assertEqual(2.0, self.evaluate(v)) # Save a checkpoint and second call should restore from the checkpoints. manager.save() self.assertIsNotNone(manager.restore_or_initialize())
def testAgnosticUsage(self): """Graph/eager agnostic usage.""" # Does create garbage when executing eagerly due to ops.Graph() creation. num_training_steps = 10 checkpoint_directory = self.get_temp_dir() for training_continuation in range(3): with test_util.device(use_gpu=True): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root = trackable_utils.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) manager = checkpoint_management.CheckpointManager( root, checkpoint_directory, max_to_keep=1) status = root.restore(save_path=manager.latest_checkpoint) input_value = constant_op.constant([[3.]]) train_fn = functools.partial(optimizer.minimize, functools.partial( model, input_value), global_step=root.global_step) if not context.executing_eagerly(): train_fn = functools.partial(self.evaluate, train_fn()) status.initialize_or_restore() for _ in range(num_training_steps): train_fn() manager.save() self.assertEqual( (training_continuation + 1) * num_training_steps, self.evaluate(root.global_step)) self.assertEqual(training_continuation + 1, self.evaluate(root.save_counter))
def save(self, path, compression=None, shard_func=None, checkpoint_args=None): """Implements the save function and checkpoint functionality.""" if context.executing_eagerly() and checkpoint_args: save_dataset = _SaveDataset(self, path, shard_func, compression) save_iterator = iter(save_dataset) if "checkpoint" in checkpoint_args: raise ValueError( "'Invalid `checkpoint_args`. `checkpoint_args` are not allowed " "to include 'checkpoint'.") checkpoint = tracking_util.util.Checkpoint(iterator=save_iterator) checkpoint_args["checkpoint"] = checkpoint manager = checkpoint_management.CheckpointManager(**checkpoint_args) checkpoint.restore(manager.latest_checkpoint) for _ in enumerate(save_iterator): if "step_counter" in checkpoint_args: checkpoint_args["step_counter"].assign_add(delta=1) manager.save(check_interval=True) else: dataset, shard_func, use_shard_func, path = set_save_dataset_attributes( self, shard_func, path) ged_ops.save_dataset( dataset._variant_tensor, # pylint: disable=protected-access path=path, shard_func_other_args=shard_func.captured_inputs, compression=compression, shard_func=shard_func, use_shard_func=use_shard_func)
def testAgnosticUsage(self): """Graph/eager agnostic usage.""" # Does create garbage when executing eagerly due to ops.Graph() creation. num_training_steps = 10 checkpoint_directory = self.get_temp_dir() def _train_fn(model, input_value): with backprop.GradientTape() as tape: loss = model(input_value) variables = model.trainable_variables gradients = tape.gradient(loss, variables) return optimizer.apply_gradients(zip(gradients, variables)) for training_continuation in range(3): with test_util.device(use_gpu=True): model = MyModel() optimizer = adam.Adam(0.001) root = trackable_utils.Checkpoint( optimizer=optimizer, model=model) manager = checkpoint_management.CheckpointManager( root, checkpoint_directory, max_to_keep=1) status = root.restore(save_path=manager.latest_checkpoint) input_value = constant_op.constant([[3.]]) train_fn = functools.partial(_train_fn, model, input_value) if not context.executing_eagerly(): train_fn = functools.partial(self.evaluate, train_fn()) status.initialize_or_restore() for _ in range(num_training_steps): train_fn() manager.save() self.assertEqual((training_continuation + 1) * num_training_steps, self.evaluate(root.optimizer.iterations)) self.assertEqual(training_continuation + 1, self.evaluate(root.save_counter))
def testContinueFromUnmanaged(self): directory = self.get_temp_dir() prefix = os.path.join(directory, "unusual_prefix") checkpoint = util.Checkpoint() first_path = checkpoint.save(prefix) second_path = checkpoint.save(prefix) del checkpoint checkpoint = util.Checkpoint() manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2) checkpoint.restore(manager.latest_checkpoint).run_restore_ops() self.assertEqual(2, self.evaluate(checkpoint.save_counter)) third_path = manager.save() self.assertEqual([third_path], manager.checkpoints) fourth_path = manager.save() self.assertEqual([third_path, fourth_path], manager.checkpoints) fifth_path = manager.save() self.assertEqual([fourth_path, fifth_path], manager.checkpoints) self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) self.assertFalse(checkpoint_management.checkpoint_exists(third_path)) self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
def test_training_loop(self): for _ in range(5): layer = _Dense(5) checkpoint = tracking.Checkpoint(layer=layer) manager = checkpoint_management.CheckpointManager( checkpoint, directory=self.get_temp_dir(), max_to_keep=5) manager.restore_or_initialize() for _ in range(10): x = self.device.pack([ constant_op.constant([[-0.5]]), constant_op.constant([[0.5]]) ]) with self.device: with backprop.GradientTape() as tape: y = layer(x) loss = (y - math_ops.range(5.))**2. parameters = layer.trainable_variables unreduced_gradients = tape.gradient(loss, parameters) reduced_gradients = _collective_sum( unreduced_gradients, num_replicas=len(self.device.components)) for grad, param in zip(reduced_gradients, parameters): param.assign_sub(0.01 * grad) manager.save()
def _assertNotCheckpointable(self, dataset): iterator = iter(dataset) ckpt = trackable_utils.Checkpoint( step=variables.Variable(0), iterator=iterator) manager = checkpoint_management.CheckpointManager( ckpt, self.get_temp_dir(), max_to_keep=3) with self.assertRaises(errors.FailedPreconditionError): manager.save()
def testLatestCheckpointFSpathDirectory(self): directory = pathlib.Path(self.get_temp_dir()) checkpoint = util.Checkpoint() manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2, checkpoint_name="ckpt_name") manager.save() cp_dir = checkpoint_management.latest_checkpoint(directory) self.assertEqual(str(directory / "ckpt_name-1"), cp_dir)
def testKeepAll(self): checkpoint = util.Checkpoint() directory = os.path.join( self.get_temp_dir(), # Avoid sharing directories between eager and graph # TODO(allenl): stop run_in_graph_and_eager_modes reusing directories str(context.executing_eagerly())) manager = checkpoint_management.CheckpointManager(checkpoint, directory, max_to_keep=None) first_path = manager.save() second_path = manager.save() third_path = manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) self.assertEqual(third_path, manager.latest_checkpoint) self.assertEqual([first_path, second_path, third_path], manager.checkpoints) del manager manager = checkpoint_management.CheckpointManager(checkpoint, directory, max_to_keep=None) fourth_path = manager.save() self.assertEqual([first_path, second_path, third_path, fourth_path], manager.checkpoints) del manager manager = checkpoint_management.CheckpointManager(checkpoint, directory, max_to_keep=3) self.assertEqual([first_path, second_path, third_path, fourth_path], manager.checkpoints) self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) fifth_path = manager.save() self.assertEqual([third_path, fourth_path, fifth_path], manager.checkpoints) self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path)) self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertFalse(checkpoint_management.checkpoint_exists(second_path)) self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
def testCheckpointLargeBatches(self): # Batches of size 512M dataset = dataset_ops.Dataset.from_tensors( array_ops.ones((64, 1024, 1024), dtype=dtypes.float32)).repeat() dataset = dataset.batch(2, num_parallel_calls=5) iterator = iter(dataset) next(iterator) # request an element to fill the buffer ckpt = trackable_utils.Checkpoint(iterator=iterator) manager = checkpoint_management.CheckpointManager( ckpt, self.get_temp_dir(), max_to_keep=1) manager.save()
def __init__(self, model, checkpoint_dir): self._model = model # The epoch at which the checkpoint is saved. Used for fault-tolerance. # GPU device only has int64 dtype registered VarHandleOp. self._ckpt_saved_epoch = variables.Variable( initial_value=constant_op.constant(CKPT_SAVED_EPOCH_UNUSED_VALUE, dtype=dtypes.int64), name='ckpt_saved_epoch') # Variable initialization. backend.set_value(self._ckpt_saved_epoch, CKPT_SAVED_EPOCH_UNUSED_VALUE) # _ckpt_saved_epoch gets tracked and is included in the checkpoint file # when backing up. checkpoint = trackable_util.Checkpoint( model=self._model, ckpt_saved_epoch=self._ckpt_saved_epoch) # If this is single-worker training, checkpoint_dir are the same for # write_checkpoint_manager and read_checkpoint_manager. # # If this is multi-worker training, and this worker should not # save checkpoint, we replace the write_checkpoint_manager's checkpoint_dir # with a temp filepath, so it writes to a file that will be removed at the # end of back_up() call. This is necessary because the SyncOnReadVariable # needs to be synced across all the workers in order to be read, and all # workers need to perform `save()`. # But all workers should restore from the same checkpoint_dir as passed in # read_checkpoint_manager. self.read_checkpoint_manager = checkpoint_management.CheckpointManager( checkpoint, directory=os.path.join(checkpoint_dir, 'chief'), max_to_keep=1) write_checkpoint_dir = distributed_file_utils.write_dirpath( checkpoint_dir, self._model.distribute_strategy) if self._model.distribute_strategy.extended.should_checkpoint: self.write_checkpoint_manager = self.read_checkpoint_manager else: self.write_checkpoint_manager = checkpoint_management.CheckpointManager( checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
def testDeletion(self): checkpoint = util.Checkpoint() manager = checkpoint_management.CheckpointManager( checkpoint, self.get_temp_dir(), max_to_keep=3) first_path = manager.save() second_path = manager.save() third_path = manager.save() fourth_path = manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
def testSidecarEvaluatorOutputsSummary(self): # Create a model with synthetic data, and fit for one epoch. model = keras.models.Sequential([keras.layers.Dense(10)]) model.compile(gradient_descent.SGD(), loss='mse', metrics=keras.metrics.CategoricalAccuracy()) data = np.random.random((1000, 32)) labels = np.random.random((1000, 10)) dataset = dataset_ops.Dataset.from_tensor_slices((data, labels)) dataset = dataset.batch(32) model.fit(dataset, epochs=1) # Save a checkpoint. checkpoint_dir = os.path.join(self.get_temp_dir(), 'ckpt') log_dir = os.path.join(self.get_temp_dir(), 'summary') logging.info('checkpoint_dir = %s, log_dir = %s', checkpoint_dir, log_dir) checkpoint = tracking_util.Checkpoint(model=model, optimizer=model.optimizer) checkpoint_manager = checkpoint_management.CheckpointManager( checkpoint, checkpoint_dir, max_to_keep=2) logging.info('Checkpoint manager saved to: %s', checkpoint_manager.save()) # Have an sidecar_evaluator evaluate once. sidecar_evaluator_lib.SidecarEvaluator(model, data=dataset, checkpoint_dir=checkpoint_dir, log_dir=log_dir, max_evaluations=1).start() # Asserts summary files do get written when log_dir is provided. summary_files = file_io.list_directory_v2(log_dir) self.assertNotEmpty( file_io.list_directory_v2(checkpoint_dir), 'Checkpoint should have been written and ' 'checkpoint_dir should not be empty.') self.assertNotEmpty( summary_files, 'Summary should have been written and ' 'log_dir should not be empty.') # Asserts the content of the summary file. event_pb_written = False for event_pb in summary_iterator.summary_iterator( os.path.join(log_dir, summary_files[0])): if event_pb.step > 0: self.assertEqual(event_pb.step, 32) self.assertEqual(event_pb.summary.value[0].tag, 'categorical_accuracy') event_pb_written = True # Verifying at least one non-zeroth step is written to summary. self.assertTrue(event_pb_written)
def testCheckpointLargeShuffleBuffer(self): # Tensor of size 512M dataset = dataset_ops.Dataset.from_tensors( array_ops.ones((128, 1024, 1024), dtype=dtypes.float32)) dataset = dataset.repeat() # Set shuffle buffer size to 5 to exceed the 2GB protobuf limit. dataset = dataset.shuffle(5) iterator = iter(dataset) next(iterator) # request an element to fill the shuffle buffer ckpt = trackable_utils.Checkpoint(iterator=iterator) manager = checkpoint_management.CheckpointManager( ckpt, self.get_temp_dir(), max_to_keep=1) manager.save()
def testCheckpointManagerFSpathDirectory(self): directory = pathlib.Path(self.get_temp_dir()) v = variables.Variable(0.0) checkpoint = util.Checkpoint(v=v) self.evaluate(v.initializer) manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2, checkpoint_name="ckpt_name") save_path = manager.save() expected = str(directory / "ckpt_name-1") self.assertEqual(expected, save_path) restore_path = manager.restore_or_initialize() self.assertEqual(str(directory / "ckpt_name-1"), restore_path)
def testSaveRestoreModifiedDataset(self): ckpt_dir = self.get_temp_dir() dataset = dataset_ops.Dataset.range(10) iterator = iter(dataset) ckpt = trackable_utils.Checkpoint(iterator=iterator) manager = checkpoint_management.CheckpointManager( ckpt, ckpt_dir, max_to_keep=3) for _ in range(5): next(iterator) manager.save() # Define a different dataset and try to restore into its iterator. dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) iterator = iter(dataset) ckpt = trackable_utils.Checkpoint(iterator=iterator) manager = checkpoint_management.CheckpointManager( ckpt, ckpt_dir, max_to_keep=3) with self.assertRaisesRegex( errors.NotFoundError, "Make sure the dataset definition has not changed"): ckpt.restore(manager.latest_checkpoint)
def testCheckpointLargeBatches(self): if pywrap_sanitizers.is_tsan_enabled(): self.skipTest('Creating a large buffer causes OOM when using tsan.') # Batches of size 512M dataset = dataset_ops.Dataset.from_tensors( array_ops.ones((64, 1024, 1024), dtype=dtypes.float32)).repeat() dataset = dataset.batch(2, num_parallel_calls=5) iterator = iter(dataset) next(iterator) # request an element to fill the buffer ckpt = trackable_utils.Checkpoint(iterator=iterator) manager = checkpoint_management.CheckpointManager( ckpt, self.get_temp_dir(), max_to_keep=1) manager.save()
def testModelNotBuiltRaiseError(self, model_type): model = _test_model_builder( model_type=model_type, compile_model=False, build_model=False) checkpoint_dir = self.get_temp_dir() checkpoint = tracking_util.Checkpoint(model=model) checkpoint_manager = checkpoint_management.CheckpointManager( checkpoint, checkpoint_dir, max_to_keep=2) checkpoint_manager.save() sidecar_evaluator = sidecar_evaluator_lib.SidecarEvaluator( model, data=None, checkpoint_dir=checkpoint_dir) with self.assertRaisesRegex(AssertionError, 'Nothing to load.'): sidecar_evaluator.start()
def testCheckpointLargeShuffleBuffer(self): # Tensor of size 100M dataset = dataset_ops.Dataset.from_tensors( array_ops.ones((25, 1000, 1000), dtype=dtypes.float32)) dataset = dataset.repeat() # Shuffle 25 tensors to exceed the 2GB protocol buffer limit dataset = dataset.shuffle(25) iterator = iter(dataset) next(iterator) # request an element to fill the shuffle buffer ckpt = trackable_utils.Checkpoint(iterator=iterator) manager = checkpoint_management.CheckpointManager( ckpt, self.get_temp_dir(), max_to_keep=1) with self.assertRaisesRegex(errors.UnknownError, "Failed to serialize"): manager.save()
def testCheckpointFinishedCache(self): num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) ds = ds.cache() iterator = iter(ds) for i in range(num_elements): self.assertEqual(next(iterator).numpy(), i) ckpt = trackable_utils.Checkpoint(iterator=iterator) manager = checkpoint_management.CheckpointManager( ckpt, self.get_temp_dir(), max_to_keep=1) manager.save() manager.restore_or_initialize() with self.assertRaises(StopIteration): next(iterator)
def testIterationsNotSavedWillRaiseError(self): model = self.createTestModel(compile_model=False) checkpoint_dir = self.get_temp_dir() checkpoint = tracking_util.Checkpoint(model=model) checkpoint_manager = checkpoint_management.CheckpointManager( checkpoint, checkpoint_dir, max_to_keep=2) checkpoint_manager.save() sidecar_evaluator = sidecar_evaluator_lib.SidecarEvaluator( model, data=None, checkpoint_dir=checkpoint_dir, log_dir=None) with self.assertRaisesRegexp( RuntimeError, '`iterations` cannot be loaded ' 'from the checkpoint file.'): sidecar_evaluator.start()
def testCheckpointLargeCache(self): # Tensor of size 100M dataset = dataset_ops.Dataset.from_tensors( array_ops.ones((25, 1000, 1000), dtype=dtypes.float32)) # Repeat 25 times to exceed the 2G proto limit dataset = dataset.repeat(25) dataset = dataset.cache() # Iterate to fill the cache. iterator = iter(dataset) for _ in range(23): next(iterator) ckpt = trackable_utils.Checkpoint(iterator=iterator) manager = checkpoint_management.CheckpointManager( ckpt, self.get_temp_dir(), max_to_keep=1) manager.save()
def testSaveRestoreReshuffleDataset(self): dataset = dataset_ops.Dataset.range(10) dataset = dataset.shuffle(10, reshuffle_each_iteration=True) iterator = iter(dataset) ckpt = trackable_utils.Checkpoint( step=variables.Variable(0), iterator=iterator) manager = checkpoint_management.CheckpointManager( ckpt, self.get_temp_dir(), max_to_keep=3) iter1 = [next(iterator).numpy() for _ in range(5)] manager.save() iter2 = [next(iterator).numpy() for _ in range(5)] ckpt.restore(manager.latest_checkpoint) iter3 = [next(iterator).numpy() for _ in range(5)] self.assertNotEqual(iter1, iter2) self.assertCountEqual(iter2, iter3)
def testSidecarEvaluatorOutputsSummary(self, model_type, build_model): # Create a model with synthetic data, and fit for one epoch. model = _test_model_builder( model_type=model_type, compile_model=True, build_model=False) data = np.random.random((1000, 32)) labels = np.random.random((1000, 10)) dataset = dataset_ops.Dataset.from_tensor_slices((data, labels)) dataset = dataset.batch(32) model.fit(dataset, epochs=1) # Save a checkpoint. checkpoint_dir = os.path.join(self.get_temp_dir(), 'ckpt') log_dir = os.path.join(self.get_temp_dir(), 'summary') logging.info('checkpoint_dir = %s, log_dir = %s', checkpoint_dir, log_dir) checkpoint = tracking_util.Checkpoint( model=model, optimizer=model.optimizer) checkpoint_manager = checkpoint_management.CheckpointManager( checkpoint, checkpoint_dir, max_to_keep=2) logging.info('Checkpoint manager saved to: %s', checkpoint_manager.save()) self.assertNotEmpty( file_io.list_directory_v2(checkpoint_dir), 'Checkpoint should have been written and ' 'checkpoint_dir should not be empty.') # Create a new model used for evaluation. eval_model = _test_model_builder( model_type=model_type, compile_model=True, build_model=build_model) # Have a sidecar_evaluator evaluate once. sidecar_evaluator = sidecar_evaluator_lib.SidecarEvaluator( eval_model, data=dataset, checkpoint_dir=checkpoint_dir, max_evaluations=1, callbacks=[keras.callbacks.TensorBoard(log_dir=log_dir)]) sidecar_evaluator.start() # Eval model has been restored to the same state as the original model, so # their weights should match. If not, restoration of the model didn't # work. self.assertModelsSameVariables(model, eval_model) self.assertSummaryEventsWritten(os.path.join(log_dir, 'validation'))
def testCheckpointInterval(self): v = variables.Variable(1.0) step_counter = variables.Variable(0) self.evaluate([v.initializer, step_counter.initializer]) checkpoint = util.Checkpoint(v=v) manager = checkpoint_management.CheckpointManager( checkpoint, self.get_temp_dir(), max_to_keep=None, step_counter=step_counter, checkpoint_interval=2) # step_counter: 0, save an initial checkpoint. path = manager.save(check_interval=True) self.assertTrue(checkpoint_management.checkpoint_exists(path)) # step_counter: 1, no checkpoint saved. self.evaluate(step_counter.assign_add(1)) path = manager.save(check_interval=True) self.assertIsNone(path) # step_counter: 2, checkpoint saved. self.evaluate(step_counter.assign_add(1)) path = manager.save(check_interval=True) self.assertTrue(checkpoint_management.checkpoint_exists(path)) # no checkpoint saved when calling `save` with the same step counter. path = manager.save(check_interval=True) self.assertIsNone(path) # step_counter: 3, no checkpoint saved. self.evaluate(step_counter.assign_add(1)) path = manager.save(check_interval=True) self.assertIsNone(path) # Always save the checkpoint. path = manager.save(check_interval=False) self.assertTrue(checkpoint_management.checkpoint_exists(path))
def testDeferredRestorationUsageEager(self): """An idiomatic eager execution example.""" num_training_steps = 10 checkpoint_directory = self.get_temp_dir() for training_continuation in range(3): with self.test_scope(): model = Subclassed() optimizer = adam.Adam(0.001) root = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model) manager = checkpoint_management.CheckpointManager( root, checkpoint_directory, max_to_keep=2) root.restore(manager.latest_checkpoint) for _ in range(num_training_steps): input_value = constant_op.constant([[3.]]) with backprop.GradientTape() as tape: loss = model(input_value) variables = model.trainable_variables gradients = tape.gradient(loss, variables) optimizer.apply_gradients(zip(gradients, variables)) manager.save() self.assertEqual((training_continuation + 1) * num_training_steps, root.optimizer.iterations.numpy())
def testCustomNumbering(self): directory = self.get_temp_dir() step = variables.Variable(0, dtype=dtypes.int64) checkpoint = util.Checkpoint(step=step) manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2) self.evaluate(step.initializer) for i in range(5): path = manager.save(checkpoint_number=step) expected_suffix = "-%d" % (2 * i,) if not path.endswith(expected_suffix): self.fail("%s should have suffix %s" % (path, expected_suffix)) self.evaluate(step.assign_add(2)) self.assertEqual(5, self.evaluate(checkpoint.save_counter)) # Test regular integers last_path = manager.save(checkpoint_number=32) self.assertIn("-32", last_path) self.assertEqual(last_path, manager.latest_checkpoint) self.assertEqual( last_path, checkpoint_management.latest_checkpoint(directory)) state = checkpoint_management.get_checkpoint_state(directory) # Only the most recent two checkpoints are saved self.assertEqual([path, last_path], state.all_model_checkpoint_paths)
def testCheckpointIntervalWithRestore(self): directory = self.get_temp_dir() v = variables.Variable(1.0) step_counter = variables.Variable(0) self.evaluate([v.initializer, step_counter.initializer]) # Prepare a checkpoint. checkpoint = util.Checkpoint(v=v) checkpoint.save(os.path.join(directory, "ckpt")) manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=None, step_counter=step_counter, checkpoint_interval=2) # Restore from the checkpoint. self.assertIsNotNone(manager.restore_or_initialize()) # step_counter: 0, no checkpoint saved because it is restored from the # checkpoint with the same step. path = manager.save() self.assertIsNone(path)