Exemple #1
0
    def setupCheckpoint(self, TrainDevice):
        LatestCheckpointDict = None
        AllCheckpoints = glob.glob(os.path.join(self.ExptDirPath, '*.tar'))
        if len(AllCheckpoints) > 0:
            LatestCheckpointDict = ptUtils.loadLatestPyTorchCheckpoint(self.ExptDirPath, map_location=TrainDevice)
            print('[ INFO ]: Loading from last checkpoint.')

        if LatestCheckpointDict is not None:
            # Make sure experiment names match
            if self.Config.Args.expt_name == LatestCheckpointDict['Name']:
                self.load_state_dict(LatestCheckpointDict['ModelStateDict'])
                self.StartEpoch = LatestCheckpointDict['Epoch']
                self.Optimizer.load_state_dict(LatestCheckpointDict['OptimizerStateDict'])
                self.LossHistory = LatestCheckpointDict['LossHistory']
                if 'ValLossHistory' in LatestCheckpointDict:
                    self.ValLossHistory = LatestCheckpointDict['ValLossHistory']
                else:
                    self.ValLossHistory = self.LossHistory

                # Move optimizer state to GPU if needed. See https://github.com/pytorch/pytorch/issues/2830
                if TrainDevice is not 'cpu':
                    for state in self.Optimizer.state.values():
                        for k, v in state.items():
                            if isinstance(v, torch.Tensor):
                                state[k] = v.to(TrainDevice)
            else:
                print('[ INFO ]: Experiment names do not match. Training from scratch.')
Exemple #2
0
    def loadCheckpoint(self, Path=None, Device='cpu'):
        if Path is None:
            self.ExptDirPath = os.path.join(ptUtils.expandTilde(self.Config.Args.output_dir), self.Config.Args.expt_name)
            print('[ INFO ]: Loading from latest checkpoint.')
            CheckpointDict = ptUtils.loadLatestPyTorchCheckpoint(self.ExptDirPath, map_location=Device)
        else: # Load latest
            print('[ INFO ]: Loading from checkpoint {}'.format(Path))
            CheckpointDict = ptUtils.loadPyTorchCheckpoint(Path)

        self.load_state_dict(CheckpointDict['ModelStateDict'])