def resume(self, train_info: TrainingInfo, model: Model) -> dict: """ Resume learning process and return loaded hidden state dictionary """ last_epoch = train_info.start_epoch_idx model.load_state_dict(torch.load(self.checkpoint_filename(last_epoch))) hidden_state = torch.load(self.checkpoint_hidden_filename(last_epoch)) self.checkpoint_strategy.restore(hidden_state) train_info.restore(hidden_state) return hidden_state
def load(self, train_info: TrainingInfo) -> (dict, dict): """ Resume learning process and return loaded hidden state dictionary """ last_epoch = train_info.start_epoch_idx model_state = torch.load(self.checkpoint_filename(last_epoch)) hidden_state = torch.load(self.checkpoint_hidden_filename(last_epoch)) self.checkpoint_strategy.restore(hidden_state) train_info.restore(hidden_state) return model_state, hidden_state
def resume_training(self, learner, optimizer, callbacks, metrics) -> TrainingInfo: """ Possibly resume training from a saved state from the storage """ if self.model_config.reset: start_epoch, hidden_state = 0, {} else: start_epoch, hidden_state = self.storage.resume_learning(learner.model) training_info = TrainingInfo(start_epoch_idx=start_epoch, metrics=metrics, callbacks=callbacks) if start_epoch > 0: self.restore_state(hidden_state, optimizer, callbacks) training_info.restore(hidden_state) return training_info
def resume_training(self, learner, callbacks, metrics) -> (TrainingInfo, dict): """ Possibly resume training from a saved state from the storage """ if self.model_config.reset: start_epoch, hidden_state = 0, {} else: start_epoch, hidden_state = self.storage.resume_learning( learner.model) training_info = TrainingInfo(start_epoch_idx=start_epoch, metrics=metrics, callbacks=callbacks) if start_epoch > 0: for callback in callbacks: callback.load_state_dict(hidden_state) training_info.restore(hidden_state) return training_info, hidden_state