def filter_distributed_callbacks(callbacks_list): """Filter Callbacks based on the worker context when running multi-worker. Arguments: callbacks_list: A list of `Callback` instances. Returns: The list of `Callback` instances that should be run on this worker. """ if not K.in_multi_worker_mode(): raise ValueError( 'filter_distributed_callbacks() should only be called when Keras ' 'is in multi worker mode.') worker_context = dc_context.get_current_worker_context() callbacks_list = callbacks_list or [] if not [ c for c in callbacks_list if isinstance(c, callbacks.ModelCheckpoint) ]: # TODO(rchao): Consider providing a ModelCheckpoint here if the user # fails to. logging.warning('ModelCheckpoint callback is not provided. ' 'Workers will need to restart training if any fails.') # TODO(rchao): Add similar warning for restoring callback (to be designed). if callbacks_list is None or worker_context.is_chief: return callbacks_list # Some Callbacks should only run on the chief worker. return [ callback for callback in callbacks_list if not callback._chief_worker_only ] # pylint: disable=protected-access
def filter_distributed_callbacks(callbacks_list): """Filter Callbacks based on the worker context when running multi-worker. Arguments: callbacks_list: A list of `Callback` instances. Returns: The list of `Callback` instances that should be run on this worker. """ if not K.in_multi_worker_mode(): raise ValueError( 'filter_distributed_callbacks() should only be called when Keras ' 'is in multi worker mode.') callbacks_list = callbacks_list or [] if not [ c for c in callbacks_list if isinstance(c, callbacks.ModelCheckpoint) ]: # TODO(rchao): Consider providing a ModelCheckpoint here if the user # fails to (possibly with tempfile directory). logging.warning('ModelCheckpoint callback is not provided. ' 'Workers will need to restart training if any fails.') if callbacks_list is None or is_current_worker_chief(): return callbacks_list # Some Callbacks should only run on the chief worker. return [ callback for callback in callbacks_list if not callback._chief_worker_only ] # pylint: disable=protected-access
def _assert_in_multi_worker_mode(self): if not K.in_multi_worker_mode(): raise ValueError( 'MultiWorkerTrainingState is only supposed to be used ' 'in multi-worker training. This indicates some error ' 'that needs to be fixed. Please submit a bug issue to ' 'tf.keras team.')
def _maybe_remove_file(self, file_handle, filepath): # Remove the file in multi-worker training where this worker should # not checkpoint. It is a dummy file previously saved for sync distributed # training. if K.in_multi_worker_mode( ) and not dc_context.get_current_worker_context().should_checkpoint: os.close(file_handle) os.remove(filepath)
def _save_model(self, epoch, logs): """Saves the model. Arguments: epoch: the epoch this iteration is in. logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`. """ logs = logs or {} if isinstance(self.save_freq, int) or self.epochs_since_last_save >= self.period: self.epochs_since_last_save = 0 file_handle, filepath = self._get_file_handle_and_path(epoch, logs) if self.save_best_only: current = logs.get(self.monitor) if current is None: logging.warning( 'Can save best model only with %s available, ' 'skipping.', self.monitor) else: if self.monitor_op(current, self.best): if self.verbose > 0: print( '\nEpoch %05d: %s improved from %0.5f to %0.5f,' ' saving model to %s' % (epoch + 1, self.monitor, self.best, current, filepath)) self.best = current if self.save_weights_only: filepath = os.path.join(filepath, 'cp') self.model.save_weights(filepath, overwrite=True) else: self.kash_model.save(filepath) else: if self.verbose > 0: print( '\nEpoch %05d: %s did not improve from %0.5f' % (epoch + 1, self.monitor, self.best)) else: if self.verbose > 0: print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath)) if self.save_weights_only: if K.in_multi_worker_mode(): # TODO(rchao): Save to an additional training state file for FT, # instead of adding an attr to weight file. With this we can support # the cases of all combinations with `save_weights_only`, # `save_best_only`, and `save_format` parameters. # pylint: disable=protected-access self.model._ckpt_saved_epoch = epoch filepath = os.path.join(filepath, 'cp') self.model.save_weights(filepath, overwrite=True) else: self.kash_model.save(filepath) self._maybe_remove_file(file_handle, filepath)
def on_epoch_end(self, epoch, logs=None): self.epochs_since_last_save += 1 if self.save_freq == 'epoch': self._save_model(epoch=epoch, logs=logs) if K.in_multi_worker_mode(): # For multi-worker training, back up the weights and current training # state for possible future recovery. # TODO(rchao): Call `back_up` at finer period such as N steps. self._training_state.back_up(epoch)
def on_train_end(self, logs=None): if K.in_multi_worker_mode(): # In multi-worker training, on successful exit of training, delete the # training state backup file that was saved for the purpose of worker # recovery. self._training_state.delete_backup() # Restore the training state so the model is ready for next (possible) # multi worker training. del self._training_state del self.model._training_state
def on_train_begin(self, logs=None): if K.in_multi_worker_mode(): # pylint: disable=protected-access # MultiWorkerTrainingState is used to manage the training state needed # for preemption-recovery of a worker in multi-worker training. self.model._training_state = ( multi_worker_training_state.MultiWorkerTrainingState( self.model, self.filepath)) self._training_state = self.model._training_state if self._training_state.restore(): # If the training state needs to be and is successfully restored, # it is recovering from a previous failure (or preemption). In such # case, do not load the weights from user specified file path. return # If this is not multi worker training, restoring is not needed, or # restoring failed, check if it should load weights on restart. # TODO(rchao): Also restore the epoch in single-worker training when # `self.load_weights_on_restart=True`. if self.load_weights_on_restart: # In multi worker training, it only should if `experimental_should_init` # is True. # TODO(rchao): Reference `experimental_should_init` api from a util file. if not K.in_multi_worker_mode( ) or dc_context.get_current_worker_context( ).experimental_should_init: filepath_to_load = ( self._get_most_recently_modified_file_matching_pattern( self.filepath)) if filepath_to_load is not None and os.path.exists( filepath_to_load): try: # `filepath` may contain placeholders such as `{epoch:02d}`, and # thus it attempts to load the most recently modified file with file # name matching the pattern. self.model.load_weights(filepath_to_load) except (IOError, ValueError) as e: raise ValueError( 'Error loading file from {}. Reason: {}'.format( filepath_to_load, e))
def _get_file_handle_and_path(self, epoch, logs): """Returns the file handle and path.""" # TODO(rchao): Replace dc_context reference with # distributed_training_utils.should_current_worker_checkpoint() once # distributed_training_utils.py no longer depends on callbacks.py. if not K.in_multi_worker_mode( ) or dc_context.get_current_worker_context().should_checkpoint: return None, self.filepath.format(epoch=epoch + 1, **logs) else: # If this is multi-worker training, and this worker should not # save checkpoint, we replace the filepath with a dummy filepath so # it writes to a file that will be removed at the end of _save_model() # call. This is because the SyncOnReadVariable needs to be synced across # all the workers in order to be read, and all workers need to initiate # that. file_handle, temp_file_name = tempfile.mkstemp() extension = os.path.splitext(self.filepath)[1] return file_handle, temp_file_name + extension