def callableForTestBackupModelNotRemovedIfInterrupted(model, test_obj,
                                                        train_ds, num_epoch,
                                                        steps, strategy,
                                                        saving_filepath,
                                                        **kwargs):

    # `barrier` object needs to be passed in from parent
    # thread so both threads refer to the same object.
    barrier = kwargs['barrier']

    num_epoch = 4

    # Testing the backup filepath `multi_worker_training_state` uses.
    _, backup_filepath = training_state._get_backup_filepath(saving_filepath)

    # The backup_filepath shouldn't exist at the beginning.
    test_obj.assertFalse(training_state.checkpoint_exists(backup_filepath))

    # Callback to interrupt in the middle of training.
    class InterruptingCallback(callbacks.Callback):

      def on_epoch_begin(self, epoch, logs=None):
        if epoch == 2:
          raise RuntimeError('Interrupting!')

    try:
      model.fit(
          x=train_ds,
          epochs=num_epoch,
          steps_per_epoch=steps,
          callbacks=[
              callbacks.ModelCheckpoint(
                  filepath=saving_filepath, save_weights_only=True),
              InterruptingCallback()
          ])
    except RuntimeError as e:
      if 'Interrupting!' not in e.message:
        raise

    # Sync on the two threads.
    barrier.wait()

    # The back up file should exist after interruption of `model.fit()`.
    test_obj.assertTrue(training_state.checkpoint_exists(backup_filepath))
Beispiel #2
0
    def callableForTestBackupModelRemoved(model, test_obj, train_ds, num_epoch,
                                          steps, strategy, saving_filepath,
                                          **kwargs):

        # `barrier` object needs to be passed in from parent
        # thread so both threads refer to the same object.
        barrier = kwargs['barrier']

        num_epoch = 3

        # Testing the backup filepath `multi_worker_training_state` uses.
        _, backup_filepath = training_state._get_backup_filepath(
            saving_filepath)

        # The backup_filepath shouldn't exist at the beginning.
        test_obj.assertFalse(training_state.checkpoint_exists(backup_filepath))

        # Callback to verify that the backup file exists in the middle of training.
        class BackupFilepathVerifyingCallback(callbacks.Callback):
            def on_epoch_begin(self, epoch, logs=None):
                if epoch > 1:
                    # Asserting that after the first two epochs, the backup file should
                    # exist.
                    test_obj.assertTrue(
                        training_state.checkpoint_exists(backup_filepath))

        model.fit(x=train_ds,
                  epochs=num_epoch,
                  steps_per_epoch=steps,
                  callbacks=[
                      callbacks.ModelCheckpoint(filepath=saving_filepath,
                                                save_weights_only=True),
                      BackupFilepathVerifyingCallback()
                  ])

        # Sync on the two threads so we make sure the backup file is removed before
        # we move on.
        barrier.wait()

        # The back up file should not exist at successful exit of `model.fit()`.
        test_obj.assertFalse(training_state.checkpoint_exists(backup_filepath))