def validate(self, ValDataLoader, Objective, Device='cpu'): self.eval() #switch to evaluation mode ValLosses = [] Tic = ptUtils.getCurrentEpochTime() # print('Val length:', len(ValDataLoader)) for i, (Data, Targets) in enumerate(ValDataLoader, 0): # Get each batch DataTD = ptUtils.sendToDevice(Data, Device) TargetsTD = ptUtils.sendToDevice(Targets, Device) Output = self.forward(DataTD) Loss = Objective(Output, TargetsTD) ValLosses.append(Loss.item()) # Print stats Toc = ptUtils.getCurrentEpochTime() Elapsed = math.floor((Toc - Tic) * 1e-6) done = int(50 * (i + 1) / len(ValDataLoader)) sys.stdout.write( ('\r[{}>{}] val loss - {:.8f}, elapsed - {}').format( '+' * done, '-' * (50 - done), np.mean(np.asarray(ValLosses)), ptUtils.getTimeDur(Elapsed))) sys.stdout.flush() sys.stdout.write('\n') self.train() #switch back to train mode return ValLosses
def fit(self, TrainDataLoader, Optimizer=None, Objective=nn.MSELoss(), TrainDevice='cpu', ValDataLoader=None): if Optimizer is None: # Optimizer = optim.SGD(NN.parameters(), lr=Args.learning_rate) # , momentum=0.9) self.Optimizer = optim.Adam(self.parameters(), lr=self.Config.Args.learning_rate, weight_decay=1e-5) # PARAM else: self.Optimizer = Optimizer self.setupCheckpoint(TrainDevice) print('[ INFO ]: Training on {}'.format(TrainDevice)) self.to(TrainDevice) CurrLegend = ['Train loss'] AllTic = ptUtils.getCurrentEpochTime() for Epoch in range(self.Config.Args.epochs): try: EpochLosses = [] # For all batches in an epoch Tic = ptUtils.getCurrentEpochTime() for i, (Data, Targets) in enumerate(TrainDataLoader, 0): # Get each batch DataTD = ptUtils.sendToDevice(Data, TrainDevice) TargetsTD = ptUtils.sendToDevice(Targets, TrainDevice) self.Optimizer.zero_grad() # Forward, backward, optimize Output = self.forward(DataTD) Loss = Objective(Output, TargetsTD) Loss.backward() self.Optimizer.step() EpochLosses.append(Loss.item()) gc.collect() # Collect garbage after each batch # Terminate early if loss is nan isTerminateEarly = False if math.isnan(EpochLosses[-1]): print('[ WARN ]: NaN loss encountered. Terminating training and saving current model checkpoint (might be junk).') isTerminateEarly = True break # Print stats Toc = ptUtils.getCurrentEpochTime() Elapsed = math.floor((Toc - Tic) * 1e-6) TotalElapsed = math.floor((Toc - AllTic) * 1e-6) # Compute ETA TimePerBatch = (Toc - AllTic) / ((Epoch * len(TrainDataLoader)) + (i+1)) # Time per batch ETA = math.floor(TimePerBatch * self.Config.Args.epochs * len(TrainDataLoader) * 1e-6) done = int(50 * (i+1) / len(TrainDataLoader)) ProgressStr = ('\r[{}>{}] epoch - {}/{}, train loss - {:.8f} | epoch - {}, total - {} ETA - {} |').format('=' * done, '-' * (50 - done), self.StartEpoch + Epoch + 1, self.StartEpoch + self.Config.Args.epochs , np.mean(np.asarray(EpochLosses)), ptUtils.getTimeDur(Elapsed), ptUtils.getTimeDur(TotalElapsed), ptUtils.getTimeDur(ETA-TotalElapsed)) sys.stdout.write(ProgressStr.ljust(150)) sys.stdout.flush() sys.stdout.write('\n') self.LossHistory.append(np.mean(np.asarray(EpochLosses))) if ValDataLoader is not None: ValLosses = self.validate(ValDataLoader, Objective, TrainDevice) self.ValLossHistory.append(np.mean(np.asarray(ValLosses))) # print('Last epoch val loss - {:.16f}'.format(self.ValLossHistory[-1])) CurrLegend = ['Train loss', 'Val loss'] # Always save checkpoint after an epoch. Will be replaced each epoch. This is independent of requested checkpointing self.saveCheckpoint(Epoch, CurrLegend, TimeString='eot', PrintStr='~'*3) isLastLoop = (Epoch == self.Config.Args.epochs-1) and (i == len(TrainDataLoader)-1) if (Epoch + 1) % self.SaveFrequency == 0 or isTerminateEarly or isLastLoop: self.saveCheckpoint(Epoch, CurrLegend) if isTerminateEarly: break except (KeyboardInterrupt, SystemExit): print('\n[ INFO ]: KeyboardInterrupt detected. Saving checkpoint.') self.saveCheckpoint(Epoch, CurrLegend, TimeString='eot', PrintStr='$'*3) break except Exception as e: print(traceback.format_exc()) print('\n[ WARN ]: Exception detected. *NOT* saving checkpoint. {}'.format(e)) # self.saveCheckpoint(Epoch, CurrLegend, TimeString='eot', PrintStr='$'*3) break AllToc = ptUtils.getCurrentEpochTime() print('[ INFO ]: All done in {}.'.format(ptUtils.getTimeDur((AllToc - AllTic) * 1e-6)))