예제 #1
0
    def on_train_begin(self, logs=None):
        # pylint: disable=protected-access
        if self.model._in_multi_worker_mode():
            # MultiWorkerTrainingState is used to manage the training state needed
            # for preemption-recovery of a worker in multi-worker training.
            self.model._training_state = (
                training_state.MultiWorkerTrainingState(
                    self.model, self.filepath))
            self._training_state = self.model._training_state
            if self._training_state.restore():
                # If the training state needs to be and is successfully restored,
                # it is recovering from a previous failure (or preemption). In such
                # case, do not load the weights from user specified file path.
                return

        # If this is not multi worker training, restoring is not needed, or
        # restoring failed, check if it should load weights on restart.
        if self.load_weights_on_restart:
            if (not self.model._in_multi_worker_mode()
                    or multi_worker_util.should_load_checkpoint()):
                filepath_to_load = (
                    self._get_most_recently_modified_file_matching_pattern(
                        self.filepath))
                if (filepath_to_load is not None and
                        training_state.checkpoint_exists(filepath_to_load)):
                    try:
                        # `filepath` may contain placeholders such as `{epoch:02d}`, and
                        # thus it attempts to load the most recently modified file with file
                        # name matching the pattern.
                        self.model.load_weights(filepath_to_load)
                    except (IOError, ValueError) as e:
                        raise ValueError(
                            'Error loading file from {}. Reason: {}'.format(
                                filepath_to_load, e))
예제 #2
0
def init_restore_or_wait_for_variables():
  """Initialize or restore variables or wait for variables to be initialized."""
  session = K._get_session()  # pylint: disable=protected-access
  if not multi_worker_util.has_worker_context(
  ) or multi_worker_util.should_load_checkpoint():
    # TODO(yuefengz): if checkpoints exist, restore from checkpoint.
    K._initialize_variables(session)  # pylint: disable=protected-access
  else:
    _wait_for_variable_initialization(session)
예제 #3
0
 def restore(self):
     """Restore the training state from the backed up checkpoint file.
     Returns:
       True if the training state is successfully restored. False if the training
       state doesn't need to be restored, or error occurred so it can't.
     """
     # For multi-worker training, it should not restore a model in certain
     # worker setting (e.g. non-chief worker in ParameterServerStrategy).
     # pylint: disable=protected-access
     if self._model._in_multi_worker_mode() and not multi_worker_util.should_load_checkpoint():
         return
     self.read_checkpoint_manager.restore_or_initialize()
  def restore(self):
    """Restore the training state from the backed up checkpoint file.

    Returns:
      True if the training state is successfully restored. False if the training
      state doesn't need to be restored, or error occurred so it can't.
    """
    self._assert_in_multi_worker_mode()
    if not multi_worker_util.should_load_checkpoint():
      # For multi-worker training, it should not restore a model in certain
      # worker setting (e.g. non-chief worker in ParameterServerStrategy).
      return False
    if file_io.file_exists(self._backup_dir):
      try:
        # Load the weights plus CKPT_SAVED_EPOCH variable.
        self._model.load_weights(self._backup_filepath)
        return True

      except (IOError, ValueError) as e:
        raise ValueError('Error loading file from {}. Reason: {}'.format(
            self._backup_filepath, e))
    return False
예제 #5
0
    def on_train_begin(self, logs=None):
        # pylint: disable=protected-access
        if self.model._in_multi_worker_mode():
            # MultiWorkerTrainingState is used to manage the training state needed
            # for preemption-recovery of a worker in multi-worker training.
            self.model._training_state = (
                training_state.MultiWorkerTrainingState(
                    self.model, self.filepath))
            self._training_state = self.model._training_state
            if self._training_state.restore():
                # If the training state needs to be and is successfully restored,
                # it is recovering from a previous failure (or preemption). In such
                # case, do not load the weights from user specified file path.
                return

        # If this is not multi worker training, restoring is not needed, or
        # restoring failed, check if it should load weights on restart.
        if (not self.model._in_multi_worker_mode()
                or multi_worker_util.should_load_checkpoint()):
            filepath_to_load = self.manager.latest_checkpoint
            if (filepath_to_load is not None
                    and training_state.checkpoint_exists(filepath_to_load)):
                try:
                    # `filepath` may contain placeholders such as `{epoch:02d}`, and
                    # thus it attempts to load the most recently modified file with file
                    # name matching the pattern.
                    if not self.save_best_only:
                        self.ckpt.restore(filepath_to_load).expect_partial()
                        self.start_epoch[0] = self.ckpt.step.numpy() + 1
                        logging.info(
                            f"Restored checkpoint from {filepath_to_load}. "
                            f"Start from epoch {self.start_epoch[0]}.")
                    else:
                        # Try to restore best metric
                        self._load_metric(filepath_to_load)
                except (IOError, ValueError) as e:
                    raise ValueError(
                        'Error loading file from {}. Reason: {}'.format(
                            filepath_to_load, e))