Beispiel #1
0
 def testCustomCheckpointPrefix(self):
   directory = self.get_temp_dir()
   checkpoint = util.Checkpoint()
   manager = checkpoint_management.CheckpointManager(
       checkpoint, directory, max_to_keep=2, checkpoint_name="ckpt_name")
   path = manager.save(checkpoint_number=5)
   self.assertEqual(os.path.basename(path), "ckpt_name-5")
   manager = checkpoint_management.CheckpointManager(
       checkpoint, directory, max_to_keep=2)
   path = manager.save(checkpoint_number=5)
   self.assertEqual(os.path.basename(path), "ckpt-5")
 def testClockReset(self, mock_time):
     directory = self.get_temp_dir()
     mock_time.time.return_value = 10000.
     checkpoint = util.Checkpoint()
     first_manager = checkpoint_management.CheckpointManager(
         checkpoint,
         directory,
         max_to_keep=1,
         keep_checkpoint_every_n_hours=1.)
     first_path = first_manager.save()
     mock_time.time.return_value += 3600.
     second_path = first_manager.save()
     mock_time.time.return_value += 3600.
     third_path = first_manager.save()
     self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertEqual([third_path], first_manager.checkpoints)
     state = checkpoint_management.get_checkpoint_state(directory)
     self.assertEqual(13600., state.last_preserved_timestamp)
     # Set the clock back in time
     mock_time.time.return_value = 5000.
     del first_manager
     with test.mock.patch.object(logging, "warning") as mock_log:
         second_manager = checkpoint_management.CheckpointManager(
             checkpoint, directory, max_to_keep=1)
         self.assertRegexpMatches(
             str(mock_log.call_args),
             "behind the last preserved checkpoint timestamp")
     # We should err on the side of keeping checkpoints around when we're not
     # sure whether they were preserved or not due to clock funkiness.
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     # We know about the existing checkpoints, but they'll never be deleted and
     # so won't go in the CheckpointState proto on save.
     self.assertEqual(third_path, second_manager.latest_checkpoint)
     self.assertEqual([], second_manager.checkpoints)
     mock_time.time.return_value += 10.
     fourth_path = second_manager.save()
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertEqual(fourth_path, second_manager.latest_checkpoint)
     self.assertEqual([fourth_path], second_manager.checkpoints)
     mock_time.time.return_value += 10.
     fifth_path = second_manager.save()
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertEqual([fifth_path], second_manager.checkpoints)
     state = checkpoint_management.get_checkpoint_state(directory)
     self.assertEqual(5000., state.last_preserved_timestamp)
     self.assertEqual([5020.], state.all_model_checkpoint_timestamps)
    def testRestoreOrInitialize(self):
        directory = self.get_temp_dir()

        # Create a checkpoint for initializing.
        init_prefix = os.path.join(directory, "init")
        init_v = variables.Variable(2.0)
        init_ckpt = util.Checkpoint(v=init_v)
        self.evaluate(init_v.initializer)
        init_path = init_ckpt.save(init_prefix)

        # Create the checkpoint manager.
        ckpt_dir = os.path.join(directory, "ckpt")
        v = variables.Variable(1.0)
        checkpoint = util.Checkpoint(v=v)
        manager = checkpoint_management.CheckpointManager(
            checkpoint,
            ckpt_dir,
            max_to_keep=None,
            init_fn=lambda: checkpoint.restore(init_path).run_restore_ops())
        self.evaluate(v.initializer)

        # First call should call `init_fn`.
        self.assertIsNone(manager.restore_or_initialize())
        self.assertEqual(2.0, self.evaluate(v))

        # Save a checkpoint and second call should restore from the checkpoints.
        manager.save()
        self.assertIsNotNone(manager.restore_or_initialize())
 def testAgnosticUsage(self):
     """Graph/eager agnostic usage."""
     # Does create garbage when executing eagerly due to ops.Graph() creation.
     num_training_steps = 10
     checkpoint_directory = self.get_temp_dir()
     for training_continuation in range(3):
         with test_util.device(use_gpu=True):
             model = MyModel()
             optimizer = adam.AdamOptimizer(0.001)
             root = trackable_utils.Checkpoint(
                 optimizer=optimizer,
                 model=model,
                 global_step=training_util.get_or_create_global_step())
             manager = checkpoint_management.CheckpointManager(
                 root, checkpoint_directory, max_to_keep=1)
             status = root.restore(save_path=manager.latest_checkpoint)
             input_value = constant_op.constant([[3.]])
             train_fn = functools.partial(optimizer.minimize,
                                          functools.partial(
                                              model, input_value),
                                          global_step=root.global_step)
             if not context.executing_eagerly():
                 train_fn = functools.partial(self.evaluate, train_fn())
             status.initialize_or_restore()
             for _ in range(num_training_steps):
                 train_fn()
             manager.save()
             self.assertEqual(
                 (training_continuation + 1) * num_training_steps,
                 self.evaluate(root.global_step))
             self.assertEqual(training_continuation + 1,
                              self.evaluate(root.save_counter))
Beispiel #5
0
def save(self, path, compression=None, shard_func=None, checkpoint_args=None):
    """Implements the save function and checkpoint functionality."""
    if context.executing_eagerly() and checkpoint_args:
        save_dataset = _SaveDataset(self, path, shard_func, compression)
        save_iterator = iter(save_dataset)

        if "checkpoint" in checkpoint_args:
            raise ValueError(
                "'Invalid `checkpoint_args`. `checkpoint_args` are not allowed "
                "to include 'checkpoint'.")
        checkpoint = tracking_util.util.Checkpoint(iterator=save_iterator)
        checkpoint_args["checkpoint"] = checkpoint
        manager = checkpoint_management.CheckpointManager(**checkpoint_args)
        checkpoint.restore(manager.latest_checkpoint)

        for _ in enumerate(save_iterator):
            if "step_counter" in checkpoint_args:
                checkpoint_args["step_counter"].assign_add(delta=1)
            manager.save(check_interval=True)
    else:
        dataset, shard_func, use_shard_func, path = set_save_dataset_attributes(
            self, shard_func, path)
        ged_ops.save_dataset(
            dataset._variant_tensor,  # pylint: disable=protected-access
            path=path,
            shard_func_other_args=shard_func.captured_inputs,
            compression=compression,
            shard_func=shard_func,
            use_shard_func=use_shard_func)
Beispiel #6
0
 def testAgnosticUsage(self):
   """Graph/eager agnostic usage."""
   # Does create garbage when executing eagerly due to ops.Graph() creation.
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   def _train_fn(model, input_value):
     with backprop.GradientTape() as tape:
       loss = model(input_value)
     variables = model.trainable_variables
     gradients = tape.gradient(loss, variables)
     return optimizer.apply_gradients(zip(gradients, variables))
   for training_continuation in range(3):
     with test_util.device(use_gpu=True):
       model = MyModel()
       optimizer = adam.Adam(0.001)
       root = trackable_utils.Checkpoint(
           optimizer=optimizer, model=model)
       manager = checkpoint_management.CheckpointManager(
           root, checkpoint_directory, max_to_keep=1)
       status = root.restore(save_path=manager.latest_checkpoint)
       input_value = constant_op.constant([[3.]])
       train_fn = functools.partial(_train_fn, model, input_value)
       if not context.executing_eagerly():
         train_fn = functools.partial(self.evaluate, train_fn())
       status.initialize_or_restore()
       for _ in range(num_training_steps):
         train_fn()
       manager.save()
       self.assertEqual((training_continuation + 1) * num_training_steps,
                        self.evaluate(root.optimizer.iterations))
       self.assertEqual(training_continuation + 1,
                        self.evaluate(root.save_counter))
 def testContinueFromUnmanaged(self):
   directory = self.get_temp_dir()
   prefix = os.path.join(directory, "unusual_prefix")
   checkpoint = util.Checkpoint()
   first_path = checkpoint.save(prefix)
   second_path = checkpoint.save(prefix)
   del checkpoint
   checkpoint = util.Checkpoint()
   manager = checkpoint_management.CheckpointManager(
       checkpoint, directory, max_to_keep=2)
   checkpoint.restore(manager.latest_checkpoint).run_restore_ops()
   self.assertEqual(2, self.evaluate(checkpoint.save_counter))
   third_path = manager.save()
   self.assertEqual([third_path], manager.checkpoints)
   fourth_path = manager.save()
   self.assertEqual([third_path, fourth_path],
                    manager.checkpoints)
   fifth_path = manager.save()
   self.assertEqual([fourth_path, fifth_path],
                    manager.checkpoints)
   self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
   self.assertFalse(checkpoint_management.checkpoint_exists(third_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
Beispiel #8
0
    def test_training_loop(self):
        for _ in range(5):
            layer = _Dense(5)
            checkpoint = tracking.Checkpoint(layer=layer)
            manager = checkpoint_management.CheckpointManager(
                checkpoint, directory=self.get_temp_dir(), max_to_keep=5)
            manager.restore_or_initialize()

            for _ in range(10):
                x = self.device.pack([
                    constant_op.constant([[-0.5]]),
                    constant_op.constant([[0.5]])
                ])
                with self.device:
                    with backprop.GradientTape() as tape:
                        y = layer(x)
                        loss = (y - math_ops.range(5.))**2.
                    parameters = layer.trainable_variables
                    unreduced_gradients = tape.gradient(loss, parameters)
                    reduced_gradients = _collective_sum(
                        unreduced_gradients,
                        num_replicas=len(self.device.components))
                    for grad, param in zip(reduced_gradients, parameters):
                        param.assign_sub(0.01 * grad)

                manager.save()
Beispiel #9
0
 def _assertNotCheckpointable(self, dataset):
   iterator = iter(dataset)
   ckpt = trackable_utils.Checkpoint(
       step=variables.Variable(0), iterator=iterator)
   manager = checkpoint_management.CheckpointManager(
       ckpt, self.get_temp_dir(), max_to_keep=3)
   with self.assertRaises(errors.FailedPreconditionError):
     manager.save()
    def testLatestCheckpointFSpathDirectory(self):
        directory = pathlib.Path(self.get_temp_dir())
        checkpoint = util.Checkpoint()
        manager = checkpoint_management.CheckpointManager(
            checkpoint, directory, max_to_keep=2, checkpoint_name="ckpt_name")
        manager.save()

        cp_dir = checkpoint_management.latest_checkpoint(directory)
        self.assertEqual(str(directory / "ckpt_name-1"), cp_dir)
 def testKeepAll(self):
     checkpoint = util.Checkpoint()
     directory = os.path.join(
         self.get_temp_dir(),
         # Avoid sharing directories between eager and graph
         # TODO(allenl): stop run_in_graph_and_eager_modes reusing directories
         str(context.executing_eagerly()))
     manager = checkpoint_management.CheckpointManager(checkpoint,
                                                       directory,
                                                       max_to_keep=None)
     first_path = manager.save()
     second_path = manager.save()
     third_path = manager.save()
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
     self.assertEqual(third_path, manager.latest_checkpoint)
     self.assertEqual([first_path, second_path, third_path],
                      manager.checkpoints)
     del manager
     manager = checkpoint_management.CheckpointManager(checkpoint,
                                                       directory,
                                                       max_to_keep=None)
     fourth_path = manager.save()
     self.assertEqual([first_path, second_path, third_path, fourth_path],
                      manager.checkpoints)
     del manager
     manager = checkpoint_management.CheckpointManager(checkpoint,
                                                       directory,
                                                       max_to_keep=3)
     self.assertEqual([first_path, second_path, third_path, fourth_path],
                      manager.checkpoints)
     self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
     fifth_path = manager.save()
     self.assertEqual([third_path, fourth_path, fifth_path],
                      manager.checkpoints)
     self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertFalse(checkpoint_management.checkpoint_exists(second_path))
     self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
Beispiel #12
0
 def testCheckpointLargeBatches(self):
   # Batches of size 512M
   dataset = dataset_ops.Dataset.from_tensors(
       array_ops.ones((64, 1024, 1024), dtype=dtypes.float32)).repeat()
   dataset = dataset.batch(2, num_parallel_calls=5)
   iterator = iter(dataset)
   next(iterator)  # request an element to fill the buffer
   ckpt = trackable_utils.Checkpoint(iterator=iterator)
   manager = checkpoint_management.CheckpointManager(
       ckpt, self.get_temp_dir(), max_to_keep=1)
   manager.save()
    def __init__(self, model, checkpoint_dir):
        self._model = model

        # The epoch at which the checkpoint is saved. Used for fault-tolerance.
        # GPU device only has int64 dtype registered VarHandleOp.
        self._ckpt_saved_epoch = variables.Variable(
            initial_value=constant_op.constant(CKPT_SAVED_EPOCH_UNUSED_VALUE,
                                               dtype=dtypes.int64),
            name='ckpt_saved_epoch')

        # Variable initialization.
        backend.set_value(self._ckpt_saved_epoch,
                          CKPT_SAVED_EPOCH_UNUSED_VALUE)

        # _ckpt_saved_epoch gets tracked and is included in the checkpoint file
        # when backing up.
        checkpoint = trackable_util.Checkpoint(
            model=self._model, ckpt_saved_epoch=self._ckpt_saved_epoch)

        # If this is single-worker training, checkpoint_dir are the same for
        # write_checkpoint_manager and read_checkpoint_manager.
        #
        # If this is multi-worker training, and this worker should not
        # save checkpoint, we replace the write_checkpoint_manager's checkpoint_dir
        # with a temp filepath, so it writes to a file that will be removed at the
        # end of back_up() call. This is necessary because the SyncOnReadVariable
        # needs to be synced across all the workers in order to be read, and all
        # workers need to perform `save()`.
        # But all workers should restore from the same checkpoint_dir as passed in
        # read_checkpoint_manager.
        self.read_checkpoint_manager = checkpoint_management.CheckpointManager(
            checkpoint,
            directory=os.path.join(checkpoint_dir, 'chief'),
            max_to_keep=1)
        write_checkpoint_dir = distributed_file_utils.write_dirpath(
            checkpoint_dir, self._model.distribute_strategy)
        if self._model.distribute_strategy.extended.should_checkpoint:
            self.write_checkpoint_manager = self.read_checkpoint_manager
        else:
            self.write_checkpoint_manager = checkpoint_management.CheckpointManager(
                checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
 def testDeletion(self):
   checkpoint = util.Checkpoint()
   manager = checkpoint_management.CheckpointManager(
       checkpoint, self.get_temp_dir(), max_to_keep=3)
   first_path = manager.save()
   second_path = manager.save()
   third_path = manager.save()
   fourth_path = manager.save()
   self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
   self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
    def testSidecarEvaluatorOutputsSummary(self):
        # Create a model with synthetic data, and fit for one epoch.
        model = keras.models.Sequential([keras.layers.Dense(10)])
        model.compile(gradient_descent.SGD(),
                      loss='mse',
                      metrics=keras.metrics.CategoricalAccuracy())
        data = np.random.random((1000, 32))
        labels = np.random.random((1000, 10))
        dataset = dataset_ops.Dataset.from_tensor_slices((data, labels))
        dataset = dataset.batch(32)
        model.fit(dataset, epochs=1)

        # Save a checkpoint.
        checkpoint_dir = os.path.join(self.get_temp_dir(), 'ckpt')
        log_dir = os.path.join(self.get_temp_dir(), 'summary')
        logging.info('checkpoint_dir = %s, log_dir = %s', checkpoint_dir,
                     log_dir)
        checkpoint = tracking_util.Checkpoint(model=model,
                                              optimizer=model.optimizer)
        checkpoint_manager = checkpoint_management.CheckpointManager(
            checkpoint, checkpoint_dir, max_to_keep=2)
        logging.info('Checkpoint manager saved to: %s',
                     checkpoint_manager.save())

        # Have an sidecar_evaluator evaluate once.
        sidecar_evaluator_lib.SidecarEvaluator(model,
                                               data=dataset,
                                               checkpoint_dir=checkpoint_dir,
                                               log_dir=log_dir,
                                               max_evaluations=1).start()

        # Asserts summary files do get written when log_dir is provided.
        summary_files = file_io.list_directory_v2(log_dir)
        self.assertNotEmpty(
            file_io.list_directory_v2(checkpoint_dir),
            'Checkpoint should have been written and '
            'checkpoint_dir should not be empty.')
        self.assertNotEmpty(
            summary_files, 'Summary should have been written and '
            'log_dir should not be empty.')

        # Asserts the content of the summary file.
        event_pb_written = False
        for event_pb in summary_iterator.summary_iterator(
                os.path.join(log_dir, summary_files[0])):
            if event_pb.step > 0:
                self.assertEqual(event_pb.step, 32)
                self.assertEqual(event_pb.summary.value[0].tag,
                                 'categorical_accuracy')
                event_pb_written = True

        # Verifying at least one non-zeroth step is written to summary.
        self.assertTrue(event_pb_written)
 def testCheckpointLargeShuffleBuffer(self):
   # Tensor of size 512M
   dataset = dataset_ops.Dataset.from_tensors(
       array_ops.ones((128, 1024, 1024), dtype=dtypes.float32))
   dataset = dataset.repeat()
   # Set shuffle buffer size to 5 to exceed the 2GB protobuf limit.
   dataset = dataset.shuffle(5)
   iterator = iter(dataset)
   next(iterator)  # request an element to fill the shuffle buffer
   ckpt = trackable_utils.Checkpoint(iterator=iterator)
   manager = checkpoint_management.CheckpointManager(
       ckpt, self.get_temp_dir(), max_to_keep=1)
   manager.save()
    def testCheckpointManagerFSpathDirectory(self):
        directory = pathlib.Path(self.get_temp_dir())
        v = variables.Variable(0.0)
        checkpoint = util.Checkpoint(v=v)
        self.evaluate(v.initializer)
        manager = checkpoint_management.CheckpointManager(
            checkpoint, directory, max_to_keep=2, checkpoint_name="ckpt_name")
        save_path = manager.save()
        expected = str(directory / "ckpt_name-1")
        self.assertEqual(expected, save_path)

        restore_path = manager.restore_or_initialize()
        self.assertEqual(str(directory / "ckpt_name-1"), restore_path)
Beispiel #18
0
  def testSaveRestoreModifiedDataset(self):
    ckpt_dir = self.get_temp_dir()
    dataset = dataset_ops.Dataset.range(10)
    iterator = iter(dataset)
    ckpt = trackable_utils.Checkpoint(iterator=iterator)
    manager = checkpoint_management.CheckpointManager(
        ckpt, ckpt_dir, max_to_keep=3)

    for _ in range(5):
      next(iterator)
    manager.save()

    # Define a different dataset and try to restore into its iterator.
    dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
    iterator = iter(dataset)
    ckpt = trackable_utils.Checkpoint(iterator=iterator)
    manager = checkpoint_management.CheckpointManager(
        ckpt, ckpt_dir, max_to_keep=3)
    with self.assertRaisesRegex(
        errors.NotFoundError,
        "Make sure the dataset definition has not changed"):
      ckpt.restore(manager.latest_checkpoint)
Beispiel #19
0
 def testCheckpointLargeBatches(self):
   if pywrap_sanitizers.is_tsan_enabled():
     self.skipTest('Creating a large buffer causes OOM when using tsan.')
   # Batches of size 512M
   dataset = dataset_ops.Dataset.from_tensors(
       array_ops.ones((64, 1024, 1024), dtype=dtypes.float32)).repeat()
   dataset = dataset.batch(2, num_parallel_calls=5)
   iterator = iter(dataset)
   next(iterator)  # request an element to fill the buffer
   ckpt = trackable_utils.Checkpoint(iterator=iterator)
   manager = checkpoint_management.CheckpointManager(
       ckpt, self.get_temp_dir(), max_to_keep=1)
   manager.save()
  def testModelNotBuiltRaiseError(self, model_type):
    model = _test_model_builder(
        model_type=model_type, compile_model=False, build_model=False)

    checkpoint_dir = self.get_temp_dir()
    checkpoint = tracking_util.Checkpoint(model=model)
    checkpoint_manager = checkpoint_management.CheckpointManager(
        checkpoint, checkpoint_dir, max_to_keep=2)
    checkpoint_manager.save()

    sidecar_evaluator = sidecar_evaluator_lib.SidecarEvaluator(
        model, data=None, checkpoint_dir=checkpoint_dir)
    with self.assertRaisesRegex(AssertionError, 'Nothing to load.'):
      sidecar_evaluator.start()
Beispiel #21
0
  def testCheckpointLargeShuffleBuffer(self):
    # Tensor of size 100M
    dataset = dataset_ops.Dataset.from_tensors(
        array_ops.ones((25, 1000, 1000), dtype=dtypes.float32))
    dataset = dataset.repeat()
    # Shuffle 25 tensors to exceed the 2GB protocol buffer limit
    dataset = dataset.shuffle(25)

    iterator = iter(dataset)
    next(iterator)  # request an element to fill the shuffle buffer
    ckpt = trackable_utils.Checkpoint(iterator=iterator)
    manager = checkpoint_management.CheckpointManager(
        ckpt, self.get_temp_dir(), max_to_keep=1)
    with self.assertRaisesRegex(errors.UnknownError, "Failed to serialize"):
      manager.save()
Beispiel #22
0
  def testCheckpointFinishedCache(self):
    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)
    ds = ds.cache()

    iterator = iter(ds)
    for i in range(num_elements):
      self.assertEqual(next(iterator).numpy(), i)
    ckpt = trackable_utils.Checkpoint(iterator=iterator)
    manager = checkpoint_management.CheckpointManager(
        ckpt, self.get_temp_dir(), max_to_keep=1)
    manager.save()
    manager.restore_or_initialize()
    with self.assertRaises(StopIteration):
      next(iterator)
  def testIterationsNotSavedWillRaiseError(self):
    model = self.createTestModel(compile_model=False)

    checkpoint_dir = self.get_temp_dir()
    checkpoint = tracking_util.Checkpoint(model=model)
    checkpoint_manager = checkpoint_management.CheckpointManager(
        checkpoint, checkpoint_dir, max_to_keep=2)
    checkpoint_manager.save()

    sidecar_evaluator = sidecar_evaluator_lib.SidecarEvaluator(
        model, data=None, checkpoint_dir=checkpoint_dir, log_dir=None)
    with self.assertRaisesRegexp(
        RuntimeError, '`iterations` cannot be loaded '
        'from the checkpoint file.'):
      sidecar_evaluator.start()
Beispiel #24
0
  def testCheckpointLargeCache(self):
    # Tensor of size 100M
    dataset = dataset_ops.Dataset.from_tensors(
        array_ops.ones((25, 1000, 1000), dtype=dtypes.float32))
    # Repeat 25 times to exceed the 2G proto limit
    dataset = dataset.repeat(25)
    dataset = dataset.cache()

    # Iterate to fill the cache.
    iterator = iter(dataset)
    for _ in range(23):
      next(iterator)
    ckpt = trackable_utils.Checkpoint(iterator=iterator)
    manager = checkpoint_management.CheckpointManager(
        ckpt, self.get_temp_dir(), max_to_keep=1)
    manager.save()
Beispiel #25
0
  def testSaveRestoreReshuffleDataset(self):
    dataset = dataset_ops.Dataset.range(10)
    dataset = dataset.shuffle(10, reshuffle_each_iteration=True)
    iterator = iter(dataset)
    ckpt = trackable_utils.Checkpoint(
        step=variables.Variable(0), iterator=iterator)
    manager = checkpoint_management.CheckpointManager(
        ckpt, self.get_temp_dir(), max_to_keep=3)

    iter1 = [next(iterator).numpy() for _ in range(5)]

    manager.save()
    iter2 = [next(iterator).numpy() for _ in range(5)]

    ckpt.restore(manager.latest_checkpoint)
    iter3 = [next(iterator).numpy() for _ in range(5)]

    self.assertNotEqual(iter1, iter2)
    self.assertCountEqual(iter2, iter3)
  def testSidecarEvaluatorOutputsSummary(self, model_type, build_model):
    # Create a model with synthetic data, and fit for one epoch.
    model = _test_model_builder(
        model_type=model_type, compile_model=True, build_model=False)
    data = np.random.random((1000, 32))
    labels = np.random.random((1000, 10))
    dataset = dataset_ops.Dataset.from_tensor_slices((data, labels))
    dataset = dataset.batch(32)
    model.fit(dataset, epochs=1)

    # Save a checkpoint.
    checkpoint_dir = os.path.join(self.get_temp_dir(), 'ckpt')
    log_dir = os.path.join(self.get_temp_dir(), 'summary')
    logging.info('checkpoint_dir = %s, log_dir = %s', checkpoint_dir, log_dir)
    checkpoint = tracking_util.Checkpoint(
        model=model, optimizer=model.optimizer)
    checkpoint_manager = checkpoint_management.CheckpointManager(
        checkpoint, checkpoint_dir, max_to_keep=2)
    logging.info('Checkpoint manager saved to: %s', checkpoint_manager.save())
    self.assertNotEmpty(
        file_io.list_directory_v2(checkpoint_dir),
        'Checkpoint should have been written and '
        'checkpoint_dir should not be empty.')

    # Create a new model used for evaluation.
    eval_model = _test_model_builder(
        model_type=model_type, compile_model=True, build_model=build_model)
    # Have a sidecar_evaluator evaluate once.
    sidecar_evaluator = sidecar_evaluator_lib.SidecarEvaluator(
        eval_model,
        data=dataset,
        checkpoint_dir=checkpoint_dir,
        max_evaluations=1,
        callbacks=[keras.callbacks.TensorBoard(log_dir=log_dir)])
    sidecar_evaluator.start()
    # Eval model has been restored to the same state as the original model, so
    # their weights should match. If not, restoration of the model didn't
    # work.
    self.assertModelsSameVariables(model, eval_model)

    self.assertSummaryEventsWritten(os.path.join(log_dir, 'validation'))
    def testCheckpointInterval(self):
        v = variables.Variable(1.0)
        step_counter = variables.Variable(0)
        self.evaluate([v.initializer, step_counter.initializer])
        checkpoint = util.Checkpoint(v=v)
        manager = checkpoint_management.CheckpointManager(
            checkpoint,
            self.get_temp_dir(),
            max_to_keep=None,
            step_counter=step_counter,
            checkpoint_interval=2)

        # step_counter: 0, save an initial checkpoint.
        path = manager.save(check_interval=True)
        self.assertTrue(checkpoint_management.checkpoint_exists(path))

        # step_counter: 1, no checkpoint saved.
        self.evaluate(step_counter.assign_add(1))
        path = manager.save(check_interval=True)
        self.assertIsNone(path)

        # step_counter: 2, checkpoint saved.
        self.evaluate(step_counter.assign_add(1))
        path = manager.save(check_interval=True)
        self.assertTrue(checkpoint_management.checkpoint_exists(path))

        # no checkpoint saved when calling `save` with the same step counter.
        path = manager.save(check_interval=True)
        self.assertIsNone(path)

        # step_counter: 3, no checkpoint saved.
        self.evaluate(step_counter.assign_add(1))
        path = manager.save(check_interval=True)
        self.assertIsNone(path)

        # Always save the checkpoint.
        path = manager.save(check_interval=False)
        self.assertTrue(checkpoint_management.checkpoint_exists(path))
Beispiel #28
0
 def testDeferredRestorationUsageEager(self):
   """An idiomatic eager execution example."""
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   for training_continuation in range(3):
     with self.test_scope():
       model = Subclassed()
       optimizer = adam.Adam(0.001)
       root = checkpointable_utils.Checkpoint(
           optimizer=optimizer, model=model)
       manager = checkpoint_management.CheckpointManager(
           root, checkpoint_directory, max_to_keep=2)
       root.restore(manager.latest_checkpoint)
       for _ in range(num_training_steps):
         input_value = constant_op.constant([[3.]])
         with backprop.GradientTape() as tape:
           loss = model(input_value)
         variables = model.trainable_variables
         gradients = tape.gradient(loss, variables)
         optimizer.apply_gradients(zip(gradients, variables))
       manager.save()
       self.assertEqual((training_continuation + 1) * num_training_steps,
                        root.optimizer.iterations.numpy())
Beispiel #29
0
 def testCustomNumbering(self):
   directory = self.get_temp_dir()
   step = variables.Variable(0, dtype=dtypes.int64)
   checkpoint = util.Checkpoint(step=step)
   manager = checkpoint_management.CheckpointManager(
       checkpoint, directory, max_to_keep=2)
   self.evaluate(step.initializer)
   for i in range(5):
     path = manager.save(checkpoint_number=step)
     expected_suffix = "-%d" % (2 * i,)
     if not path.endswith(expected_suffix):
       self.fail("%s should have suffix %s" % (path, expected_suffix))
     self.evaluate(step.assign_add(2))
   self.assertEqual(5, self.evaluate(checkpoint.save_counter))
   # Test regular integers
   last_path = manager.save(checkpoint_number=32)
   self.assertIn("-32", last_path)
   self.assertEqual(last_path, manager.latest_checkpoint)
   self.assertEqual(
       last_path, checkpoint_management.latest_checkpoint(directory))
   state = checkpoint_management.get_checkpoint_state(directory)
   # Only the most recent two checkpoints are saved
   self.assertEqual([path, last_path], state.all_model_checkpoint_paths)
    def testCheckpointIntervalWithRestore(self):
        directory = self.get_temp_dir()
        v = variables.Variable(1.0)
        step_counter = variables.Variable(0)
        self.evaluate([v.initializer, step_counter.initializer])

        # Prepare a checkpoint.
        checkpoint = util.Checkpoint(v=v)
        checkpoint.save(os.path.join(directory, "ckpt"))

        manager = checkpoint_management.CheckpointManager(
            checkpoint,
            directory,
            max_to_keep=None,
            step_counter=step_counter,
            checkpoint_interval=2)

        # Restore from the checkpoint.
        self.assertIsNotNone(manager.restore_or_initialize())

        # step_counter: 0, no checkpoint saved because it is restored from the
        # checkpoint with the same step.
        path = manager.save()
        self.assertIsNone(path)