Ejemplo n.º 1
0
    def back_up(self, epoch):
        """Back up the current state of training into a checkpoint file.

    Arguments:
      epoch: The current epoch information to be saved.
    """
        # pylint: disable=protected-access
        self._assert_in_multi_worker_mode()

        # Update `_ckpt_saved_epoch`.
        K.set_value(self._ckpt_saved_epoch, epoch)

        # If this is multi-worker training, and this worker should not
        # save checkpoint, we replace the filepath with a dummy filepath so
        # it writes to a file that will be removed at the end of _save_model()
        # call. This is because the SyncOnReadVariable needs to be synced across
        # all the workers in order to be read, and all workers need to initiate
        # that.
        if multi_worker_util.should_save_checkpoint():
            save_filepath = self._backup_filepath
        else:
            save_filepath = self._temp_filepath

        # Save the weights plus CKPT_SAVED_EPOCH variable.
        self._model.save_weights(save_filepath, overwrite=True)

        if not multi_worker_util.should_save_checkpoint():
            # Remove the file in multi-worker training where this worker should
            # not checkpoint. It is a dummy file previously saved for sync distributed
            # training.
            _remove_dir(self._temp_dir)
Ejemplo n.º 2
0
    def __init__(self, model, original_filepath):
        self._model = model

        # The directory and filepath that store the training state backup file.
        self._backup_dir, self._backup_filepath = _get_backup_filepath(
            original_filepath)

        # For those who should not checkpoint (e.g. non-chief worker in sync
        # training), create a temporary directory to write to (that will be
        # removed later).
        if not multi_worker_util.should_save_checkpoint():
            self._temp_dir, self._temp_filepath = _get_temp_filepath(
                original_filepath)

        # The epoch at which the checkpoint is saved. Used for fault-tolerance.
        # GPU device only has int64 dtype registered VarHandleOp.
        self._ckpt_saved_epoch = variables.Variable(
            initial_value=constant_op.constant(CKPT_SAVED_EPOCH_UNUSED_VALUE,
                                               dtype=dtypes.int64),
            name='ckpt_saved_epoch')

        # Variable initialization.
        K.set_value(self._ckpt_saved_epoch, CKPT_SAVED_EPOCH_UNUSED_VALUE)

        # Calling `AutoTrackable.__setattr__` to avoid getting added as a weight of
        # model (which is done in `Layer.__setattr__`), which breaks saving/loading
        # in hdf5 format. Once becomes an attr of `model`, _ckpt_saved_epoch gets
        # tracked and will be included in the checkpoint file when backing up.
        tracking.AutoTrackable.__setattr__(self._model, CKPT_SAVED_EPOCH,
                                           self._ckpt_saved_epoch)
Ejemplo n.º 3
0
    def _maybe_remove_file(self):
        # Remove the checkpoint directory in multi-worker training where this worker
        # should not checkpoint. It is a dummy directory previously saved for sync
        # distributed training.

        if (self.model._in_multi_worker_mode() and  # pylint: disable=protected-access
                not multi_worker_util.should_save_checkpoint()):
            file_io.delete_recursively(self._temp_file_dir)
            del self._temp_file_dir
Ejemplo n.º 4
0
    def delete_backup(self):
        """Delete the backup directories.

    Delete the backup directories which should not exist after `fit()`
    successfully finishes.
    """
        self._assert_in_multi_worker_mode()
        tracking.AutoTrackable.__delattr__(self._model, CKPT_SAVED_EPOCH)
        if multi_worker_util.should_save_checkpoint():
            _remove_dir(self._backup_dir)
        else:
            assert not file_io.file_exists(self._temp_dir)
    def delete_backup(self):
        """Delete the backup directories.

    Delete the backup directories which should not exist after `fit()`
    successfully finishes.
    """
        self._assert_in_multi_worker_mode()
        # Model may not have such attr if there was a failure before the attr was
        # added to the model
        if hasattr(self._model, CKPT_SAVED_EPOCH):
            tracking.AutoTrackable.__delattr__(self._model, CKPT_SAVED_EPOCH)
        if multi_worker_util.should_save_checkpoint():
            _remove_dir(self._backup_dir)
        else:
            assert not file_io.file_exists(self._temp_dir)
Ejemplo n.º 6
0
 def _get_file_path(self, epoch, logs):
     """Returns the file path for checkpoint."""
     # pylint: disable=protected-access
     if not self.model._in_multi_worker_mode(
     ) or multi_worker_util.should_save_checkpoint():
         return self.filepath.format(epoch=epoch, **logs)
     else:
         # If this is multi-worker training, and this worker should not
         # save checkpoint, we use a temp filepath to store a dummy checkpoint, so
         # it writes to a file that will be removed at the end of `_save_model()`
         # call. This is because the SyncOnReadVariable needs to be synced across
         # all the workers in order to be read, and all workers need to initiate
         # that.
         self._temp_file_dir = tempfile.mkdtemp()
         extension = os.path.splitext(self.filepath)[1]
         return os.path.join(self._temp_file_dir, 'temp' + extension)
Ejemplo n.º 7
0
 def _get_file_path(self, logs, epoch):
     """ Returns the file path for checkpoint. Similarly to tf.keras.callbacks.ModelCheckpoint
     """
     # noinspection PyProtectedMember
     if not self.model._in_multi_worker_mode(
     ) or multi_worker_util.should_save_checkpoint():
         try:
             return self.filepath.format(epoch=epoch + 1, **logs)
         except KeyError as e:
             raise KeyError(
                 "Failed to format this callback filepath: \"{}\". Reason: {}"
                 .format(self.filepath, e))
     else:
         self._temp_file_dir = tempfile.mkdtemp()
         extension = os.path.splitext(self.filepath)[1]
         return os.path.join(self._temp_file_dir, "temp" + extension)
Ejemplo n.º 8
0
 def _get_file_path(self, epoch, logs):
     """Returns the file path for checkpoint."""
     # pylint: disable=protected-access
     if not self.model._in_multi_worker_mode(
     ) or multi_worker_util.should_save_checkpoint():
         try:
             # `filepath` may contain placeholders such as `{epoch:02d}` and
             # `{mape:.2f}`. A mismatch between logged metrics and the path's
             # placeholders can cause formatting to fail.
             return self.filepath.format(epoch=epoch + 1, **logs)
         except KeyError as e:
             raise KeyError(
                 'Failed to format this callback filepath: "{}". '
                 'Reason: {}'.format(self.filepath, e))
     else:
         # If this is multi-worker training, and this worker should not
         # save checkpoint, we use a temp filepath to store a dummy checkpoint, so
         # it writes to a file that will be removed at the end of `_save_model()`
         # call. This is because the SyncOnReadVariable needs to be synced across
         # all the workers in order to be read, and all workers need to initiate
         # that.
         self._temp_file_dir = tempfile.mkdtemp()
         extension = os.path.splitext(self.filepath)[1]
         return os.path.join(self._temp_file_dir, 'temp' + extension)