def back_up(self, epoch): """Back up the current state of training into a checkpoint file. Arguments: epoch: The current epoch information to be saved. """ # pylint: disable=protected-access self._assert_in_multi_worker_mode() # Update `_ckpt_saved_epoch`. K.set_value(self._ckpt_saved_epoch, epoch) # If this is multi-worker training, and this worker should not # save checkpoint, we replace the filepath with a dummy filepath so # it writes to a file that will be removed at the end of _save_model() # call. This is because the SyncOnReadVariable needs to be synced across # all the workers in order to be read, and all workers need to initiate # that. if multi_worker_util.should_save_checkpoint(): save_filepath = self._backup_filepath else: save_filepath = self._temp_filepath # Save the weights plus CKPT_SAVED_EPOCH variable. self._model.save_weights(save_filepath, overwrite=True) if not multi_worker_util.should_save_checkpoint(): # Remove the file in multi-worker training where this worker should # not checkpoint. It is a dummy file previously saved for sync distributed # training. _remove_dir(self._temp_dir)
def __init__(self, model, original_filepath): self._model = model # The directory and filepath that store the training state backup file. self._backup_dir, self._backup_filepath = _get_backup_filepath( original_filepath) # For those who should not checkpoint (e.g. non-chief worker in sync # training), create a temporary directory to write to (that will be # removed later). if not multi_worker_util.should_save_checkpoint(): self._temp_dir, self._temp_filepath = _get_temp_filepath( original_filepath) # The epoch at which the checkpoint is saved. Used for fault-tolerance. # GPU device only has int64 dtype registered VarHandleOp. self._ckpt_saved_epoch = variables.Variable( initial_value=constant_op.constant(CKPT_SAVED_EPOCH_UNUSED_VALUE, dtype=dtypes.int64), name='ckpt_saved_epoch') # Variable initialization. K.set_value(self._ckpt_saved_epoch, CKPT_SAVED_EPOCH_UNUSED_VALUE) # Calling `AutoTrackable.__setattr__` to avoid getting added as a weight of # model (which is done in `Layer.__setattr__`), which breaks saving/loading # in hdf5 format. Once becomes an attr of `model`, _ckpt_saved_epoch gets # tracked and will be included in the checkpoint file when backing up. tracking.AutoTrackable.__setattr__(self._model, CKPT_SAVED_EPOCH, self._ckpt_saved_epoch)
def _maybe_remove_file(self): # Remove the checkpoint directory in multi-worker training where this worker # should not checkpoint. It is a dummy directory previously saved for sync # distributed training. if (self.model._in_multi_worker_mode() and # pylint: disable=protected-access not multi_worker_util.should_save_checkpoint()): file_io.delete_recursively(self._temp_file_dir) del self._temp_file_dir
def delete_backup(self): """Delete the backup directories. Delete the backup directories which should not exist after `fit()` successfully finishes. """ self._assert_in_multi_worker_mode() tracking.AutoTrackable.__delattr__(self._model, CKPT_SAVED_EPOCH) if multi_worker_util.should_save_checkpoint(): _remove_dir(self._backup_dir) else: assert not file_io.file_exists(self._temp_dir)
def delete_backup(self): """Delete the backup directories. Delete the backup directories which should not exist after `fit()` successfully finishes. """ self._assert_in_multi_worker_mode() # Model may not have such attr if there was a failure before the attr was # added to the model if hasattr(self._model, CKPT_SAVED_EPOCH): tracking.AutoTrackable.__delattr__(self._model, CKPT_SAVED_EPOCH) if multi_worker_util.should_save_checkpoint(): _remove_dir(self._backup_dir) else: assert not file_io.file_exists(self._temp_dir)
def _get_file_path(self, epoch, logs): """Returns the file path for checkpoint.""" # pylint: disable=protected-access if not self.model._in_multi_worker_mode( ) or multi_worker_util.should_save_checkpoint(): return self.filepath.format(epoch=epoch, **logs) else: # If this is multi-worker training, and this worker should not # save checkpoint, we use a temp filepath to store a dummy checkpoint, so # it writes to a file that will be removed at the end of `_save_model()` # call. This is because the SyncOnReadVariable needs to be synced across # all the workers in order to be read, and all workers need to initiate # that. self._temp_file_dir = tempfile.mkdtemp() extension = os.path.splitext(self.filepath)[1] return os.path.join(self._temp_file_dir, 'temp' + extension)
def _get_file_path(self, logs, epoch): """ Returns the file path for checkpoint. Similarly to tf.keras.callbacks.ModelCheckpoint """ # noinspection PyProtectedMember if not self.model._in_multi_worker_mode( ) or multi_worker_util.should_save_checkpoint(): try: return self.filepath.format(epoch=epoch + 1, **logs) except KeyError as e: raise KeyError( "Failed to format this callback filepath: \"{}\". Reason: {}" .format(self.filepath, e)) else: self._temp_file_dir = tempfile.mkdtemp() extension = os.path.splitext(self.filepath)[1] return os.path.join(self._temp_file_dir, "temp" + extension)
def _get_file_path(self, epoch, logs): """Returns the file path for checkpoint.""" # pylint: disable=protected-access if not self.model._in_multi_worker_mode( ) or multi_worker_util.should_save_checkpoint(): try: # `filepath` may contain placeholders such as `{epoch:02d}` and # `{mape:.2f}`. A mismatch between logged metrics and the path's # placeholders can cause formatting to fail. return self.filepath.format(epoch=epoch + 1, **logs) except KeyError as e: raise KeyError( 'Failed to format this callback filepath: "{}". ' 'Reason: {}'.format(self.filepath, e)) else: # If this is multi-worker training, and this worker should not # save checkpoint, we use a temp filepath to store a dummy checkpoint, so # it writes to a file that will be removed at the end of `_save_model()` # call. This is because the SyncOnReadVariable needs to be synced across # all the workers in order to be read, and all workers need to initiate # that. self._temp_file_dir = tempfile.mkdtemp() extension = os.path.splitext(self.filepath)[1] return os.path.join(self._temp_file_dir, 'temp' + extension)