def setupCheckpoint(self, TrainDevice): LatestCheckpointDict = None AllCheckpoints = glob.glob(os.path.join(self.ExptDirPath, '*.tar')) if len(AllCheckpoints) > 0: LatestCheckpointDict = ptUtils.loadLatestPyTorchCheckpoint(self.ExptDirPath, map_location=TrainDevice) print('[ INFO ]: Loading from last checkpoint.') if LatestCheckpointDict is not None: # Make sure experiment names match if self.Config.Args.expt_name == LatestCheckpointDict['Name']: self.load_state_dict(LatestCheckpointDict['ModelStateDict']) self.StartEpoch = LatestCheckpointDict['Epoch'] self.Optimizer.load_state_dict(LatestCheckpointDict['OptimizerStateDict']) self.LossHistory = LatestCheckpointDict['LossHistory'] if 'ValLossHistory' in LatestCheckpointDict: self.ValLossHistory = LatestCheckpointDict['ValLossHistory'] else: self.ValLossHistory = self.LossHistory # Move optimizer state to GPU if needed. See https://github.com/pytorch/pytorch/issues/2830 if TrainDevice is not 'cpu': for state in self.Optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(TrainDevice) else: print('[ INFO ]: Experiment names do not match. Training from scratch.')
def loadCheckpoint(self, Path=None, Device='cpu'): if Path is None: self.ExptDirPath = os.path.join(ptUtils.expandTilde(self.Config.Args.output_dir), self.Config.Args.expt_name) print('[ INFO ]: Loading from latest checkpoint.') CheckpointDict = ptUtils.loadLatestPyTorchCheckpoint(self.ExptDirPath, map_location=Device) else: # Load latest print('[ INFO ]: Loading from checkpoint {}'.format(Path)) CheckpointDict = ptUtils.loadPyTorchCheckpoint(Path) self.load_state_dict(CheckpointDict['ModelStateDict'])