def on_train_begin(self, logs=None): # pylint: disable=protected-access if self.model._in_multi_worker_mode(): # MultiWorkerTrainingState is used to manage the training state needed # for preemption-recovery of a worker in multi-worker training. self.model._training_state = ( training_state.MultiWorkerTrainingState( self.model, self.filepath)) self._training_state = self.model._training_state if self._training_state.restore(): # If the training state needs to be and is successfully restored, # it is recovering from a previous failure (or preemption). In such # case, do not load the weights from user specified file path. return # If this is not multi worker training, restoring is not needed, or # restoring failed, check if it should load weights on restart. if self.load_weights_on_restart: if (not self.model._in_multi_worker_mode() or multi_worker_util.should_load_checkpoint()): filepath_to_load = ( self._get_most_recently_modified_file_matching_pattern( self.filepath)) if (filepath_to_load is not None and training_state.checkpoint_exists(filepath_to_load)): try: # `filepath` may contain placeholders such as `{epoch:02d}`, and # thus it attempts to load the most recently modified file with file # name matching the pattern. self.model.load_weights(filepath_to_load) except (IOError, ValueError) as e: raise ValueError( 'Error loading file from {}. Reason: {}'.format( filepath_to_load, e))
def init_restore_or_wait_for_variables(): """Initialize or restore variables or wait for variables to be initialized.""" session = K._get_session() # pylint: disable=protected-access if not multi_worker_util.has_worker_context( ) or multi_worker_util.should_load_checkpoint(): # TODO(yuefengz): if checkpoints exist, restore from checkpoint. K._initialize_variables(session) # pylint: disable=protected-access else: _wait_for_variable_initialization(session)
def restore(self): """Restore the training state from the backed up checkpoint file. Returns: True if the training state is successfully restored. False if the training state doesn't need to be restored, or error occurred so it can't. """ # For multi-worker training, it should not restore a model in certain # worker setting (e.g. non-chief worker in ParameterServerStrategy). # pylint: disable=protected-access if self._model._in_multi_worker_mode() and not multi_worker_util.should_load_checkpoint(): return self.read_checkpoint_manager.restore_or_initialize()
def restore(self): """Restore the training state from the backed up checkpoint file. Returns: True if the training state is successfully restored. False if the training state doesn't need to be restored, or error occurred so it can't. """ self._assert_in_multi_worker_mode() if not multi_worker_util.should_load_checkpoint(): # For multi-worker training, it should not restore a model in certain # worker setting (e.g. non-chief worker in ParameterServerStrategy). return False if file_io.file_exists(self._backup_dir): try: # Load the weights plus CKPT_SAVED_EPOCH variable. self._model.load_weights(self._backup_filepath) return True except (IOError, ValueError) as e: raise ValueError('Error loading file from {}. Reason: {}'.format( self._backup_filepath, e)) return False
def on_train_begin(self, logs=None): # pylint: disable=protected-access if self.model._in_multi_worker_mode(): # MultiWorkerTrainingState is used to manage the training state needed # for preemption-recovery of a worker in multi-worker training. self.model._training_state = ( training_state.MultiWorkerTrainingState( self.model, self.filepath)) self._training_state = self.model._training_state if self._training_state.restore(): # If the training state needs to be and is successfully restored, # it is recovering from a previous failure (or preemption). In such # case, do not load the weights from user specified file path. return # If this is not multi worker training, restoring is not needed, or # restoring failed, check if it should load weights on restart. if (not self.model._in_multi_worker_mode() or multi_worker_util.should_load_checkpoint()): filepath_to_load = self.manager.latest_checkpoint if (filepath_to_load is not None and training_state.checkpoint_exists(filepath_to_load)): try: # `filepath` may contain placeholders such as `{epoch:02d}`, and # thus it attempts to load the most recently modified file with file # name matching the pattern. if not self.save_best_only: self.ckpt.restore(filepath_to_load).expect_partial() self.start_epoch[0] = self.ckpt.step.numpy() + 1 logging.info( f"Restored checkpoint from {filepath_to_load}. " f"Start from epoch {self.start_epoch[0]}.") else: # Try to restore best metric self._load_metric(filepath_to_load) except (IOError, ValueError) as e: raise ValueError( 'Error loading file from {}. Reason: {}'.format( filepath_to_load, e))