def filter_distributed_callbacks(callbacks_list):
  """Filter Callbacks based on the worker context when running multi-worker.

  Arguments:
    callbacks_list: A list of `Callback` instances.

  Returns:
    The list of `Callback` instances that should be run on this worker.
  """

  if not K.in_multi_worker_mode():
    raise ValueError(
        'filter_distributed_callbacks() should only be called when Keras '
        'is in multi worker mode.')

  worker_context = dc_context.get_current_worker_context()
  callbacks_list = callbacks_list or []
  if not [
      c for c in callbacks_list if isinstance(c, callbacks.ModelCheckpoint)
  ]:
    # TODO(rchao): Consider providing a ModelCheckpoint here if the user
    # fails to.
    logging.warning('ModelCheckpoint callback is not provided. '
                    'Workers will need to restart training if any fails.')
  # TODO(rchao): Add similar warning for restoring callback (to be designed).

  if callbacks_list is None or worker_context.is_chief:
    return callbacks_list

  # Some Callbacks should only run on the chief worker.
  return [
      callback for callback in callbacks_list if not callback._chief_worker_only
  ]  # pylint: disable=protected-access
Exemple #2
0
def filter_distributed_callbacks(callbacks_list):
    """Filter Callbacks based on the worker context when running multi-worker.

  Arguments:
    callbacks_list: A list of `Callback` instances.

  Returns:
    The list of `Callback` instances that should be run on this worker.
  """

    if not K.in_multi_worker_mode():
        raise ValueError(
            'filter_distributed_callbacks() should only be called when Keras '
            'is in multi worker mode.')

    callbacks_list = callbacks_list or []
    if not [
            c
            for c in callbacks_list if isinstance(c, callbacks.ModelCheckpoint)
    ]:
        # TODO(rchao): Consider providing a ModelCheckpoint here if the user
        # fails to (possibly with tempfile directory).
        logging.warning('ModelCheckpoint callback is not provided. '
                        'Workers will need to restart training if any fails.')

    if callbacks_list is None or is_current_worker_chief():
        return callbacks_list

    # Some Callbacks should only run on the chief worker.
    return [
        callback for callback in callbacks_list
        if not callback._chief_worker_only
    ]  # pylint: disable=protected-access
 def _assert_in_multi_worker_mode(self):
     if not K.in_multi_worker_mode():
         raise ValueError(
             'MultiWorkerTrainingState is only supposed to be used '
             'in multi-worker training. This indicates some error '
             'that needs to be fixed. Please submit a bug issue to '
             'tf.keras team.')
 def _maybe_remove_file(self, file_handle, filepath):
     # Remove the file in multi-worker training where this worker should
     # not checkpoint. It is a dummy file previously saved for sync distributed
     # training.
     if K.in_multi_worker_mode(
     ) and not dc_context.get_current_worker_context().should_checkpoint:
         os.close(file_handle)
         os.remove(filepath)
Exemple #5
0
    def _save_model(self, epoch, logs):
        """Saves the model.

            Arguments:
                epoch: the epoch this iteration is in.
                logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`.
            """
        logs = logs or {}

        if isinstance(self.save_freq,
                      int) or self.epochs_since_last_save >= self.period:
            self.epochs_since_last_save = 0
            file_handle, filepath = self._get_file_handle_and_path(epoch, logs)

            if self.save_best_only:
                current = logs.get(self.monitor)
                if current is None:
                    logging.warning(
                        'Can save best model only with %s available, '
                        'skipping.', self.monitor)
                else:
                    if self.monitor_op(current, self.best):
                        if self.verbose > 0:
                            print(
                                '\nEpoch %05d: %s improved from %0.5f to %0.5f,'
                                ' saving model to %s' %
                                (epoch + 1, self.monitor, self.best, current,
                                 filepath))
                        self.best = current
                        if self.save_weights_only:
                            filepath = os.path.join(filepath, 'cp')
                            self.model.save_weights(filepath, overwrite=True)
                        else:
                            self.kash_model.save(filepath)
                    else:
                        if self.verbose > 0:
                            print(
                                '\nEpoch %05d: %s did not improve from %0.5f' %
                                (epoch + 1, self.monitor, self.best))
            else:
                if self.verbose > 0:
                    print('\nEpoch %05d: saving model to %s' %
                          (epoch + 1, filepath))
                if self.save_weights_only:
                    if K.in_multi_worker_mode():
                        # TODO(rchao): Save to an additional training state file for FT,
                        # instead of adding an attr to weight file. With this we can support
                        # the cases of all combinations with `save_weights_only`,
                        # `save_best_only`, and `save_format` parameters.
                        # pylint: disable=protected-access
                        self.model._ckpt_saved_epoch = epoch
                    filepath = os.path.join(filepath, 'cp')
                    self.model.save_weights(filepath, overwrite=True)
                else:
                    self.kash_model.save(filepath)

            self._maybe_remove_file(file_handle, filepath)
 def on_epoch_end(self, epoch, logs=None):
     self.epochs_since_last_save += 1
     if self.save_freq == 'epoch':
         self._save_model(epoch=epoch, logs=logs)
     if K.in_multi_worker_mode():
         # For multi-worker training, back up the weights and current training
         # state for possible future recovery.
         # TODO(rchao): Call `back_up` at finer period such as N steps.
         self._training_state.back_up(epoch)
 def on_train_end(self, logs=None):
     if K.in_multi_worker_mode():
         # In multi-worker training, on successful exit of training, delete the
         # training state backup file that was saved for the purpose of worker
         # recovery.
         self._training_state.delete_backup()
         # Restore the training state so the model is ready for next (possible)
         # multi worker training.
         del self._training_state
         del self.model._training_state
    def on_train_begin(self, logs=None):
        if K.in_multi_worker_mode():
            # pylint: disable=protected-access
            # MultiWorkerTrainingState is used to manage the training state needed
            # for preemption-recovery of a worker in multi-worker training.
            self.model._training_state = (
                multi_worker_training_state.MultiWorkerTrainingState(
                    self.model, self.filepath))
            self._training_state = self.model._training_state
            if self._training_state.restore():
                # If the training state needs to be and is successfully restored,
                # it is recovering from a previous failure (or preemption). In such
                # case, do not load the weights from user specified file path.
                return

        # If this is not multi worker training, restoring is not needed, or
        # restoring failed, check if it should load weights on restart.
        # TODO(rchao): Also restore the epoch in single-worker training when
        # `self.load_weights_on_restart=True`.
        if self.load_weights_on_restart:
            # In multi worker training, it only should if `experimental_should_init`
            # is True.
            # TODO(rchao): Reference `experimental_should_init` api from a util file.
            if not K.in_multi_worker_mode(
            ) or dc_context.get_current_worker_context(
            ).experimental_should_init:
                filepath_to_load = (
                    self._get_most_recently_modified_file_matching_pattern(
                        self.filepath))
                if filepath_to_load is not None and os.path.exists(
                        filepath_to_load):
                    try:
                        # `filepath` may contain placeholders such as `{epoch:02d}`, and
                        # thus it attempts to load the most recently modified file with file
                        # name matching the pattern.
                        self.model.load_weights(filepath_to_load)
                    except (IOError, ValueError) as e:
                        raise ValueError(
                            'Error loading file from {}. Reason: {}'.format(
                                filepath_to_load, e))
 def _get_file_handle_and_path(self, epoch, logs):
     """Returns the file handle and path."""
     # TODO(rchao): Replace dc_context reference with
     # distributed_training_utils.should_current_worker_checkpoint() once
     # distributed_training_utils.py no longer depends on callbacks.py.
     if not K.in_multi_worker_mode(
     ) or dc_context.get_current_worker_context().should_checkpoint:
         return None, self.filepath.format(epoch=epoch + 1, **logs)
     else:
         # If this is multi-worker training, and this worker should not
         # save checkpoint, we replace the filepath with a dummy filepath so
         # it writes to a file that will be removed at the end of _save_model()
         # call. This is because the SyncOnReadVariable needs to be synced across
         # all the workers in order to be read, and all workers need to initiate
         # that.
         file_handle, temp_file_name = tempfile.mkstemp()
         extension = os.path.splitext(self.filepath)[1]
         return file_handle, temp_file_name + extension