def loadPreTrained(model: BaseNet, path: str, logger: HtmlLogger) -> dict: optimizerStateDict = None if path is not None: if exists(path): # load checkpoint checkpoint = loadModel( path, map_location=lambda storage, loc: storage.cuda()) # load weights model.loadPreTrained(checkpoint['state_dict']) # # load optimizer state dict # optimizerStateDict = checkpoint['optimizer'] # add info rows about checkpoint loggerRows = [] loggerRows.append(['Path', '{}'.format(path)]) validationAccRows = [[ 'Ratio', 'Accuracy' ]] + HtmlLogger.dictToRows(checkpoint['best_prec1'], nElementPerRow=1) loggerRows.append(['Validation accuracy', validationAccRows]) # set optimizer table row optimizerRow = HtmlLogger.dictToRows( optimizerStateDict, nElementPerRow=3 ) if optimizerStateDict else optimizerStateDict loggerRows.append(['Optimizer', optimizerRow]) logger.addInfoTable('Pre-trained model', loggerRows) else: raise ValueError( 'Failed to load pre-trained from [{}], path does not exists' .format(path)) return optimizerStateDict
def logParameters(logger, args, model): if not logger: return # log command line logger.addInfoTable( title='Command line', rows=[[' '.join(argv)], ['PID', getpid()], ['Hostname', gethostname()], ['CUDA_VISIBLE_DEVICES', environ.get('CUDA_VISIBLE_DEVICES')]]) # # calc number of permutations # permutationStr = model.nPerms # for p in [12, 9, 6, 3]: # v = model.nPerms / (10 ** p) # if v > 1: # permutationStr = '{:.3f} * 10<sup>{}</sup>'.format(v, p) # break # # # log other parameters # logger.addInfoTable('Parameters', HtmlLogger.dictToRows( # { # 'Learnable params': len([param for param in model.parameters() if param.requires_grad]), # 'Widths per layer': [layer.nWidths() for layer in model.layersList()], # 'Permutations': permutationStr # }, nElementPerRow=2)) # init args dict sorting function sortFuncsDict = {k: lambda kv: kv[-1] for k in BaseNet.keysToSortByValue()} # transform args to dictionary argsDict = vars(args) # emit model flops list from args dict modelFlopsKey = BaseNet.modelFlopsKey() modelFlops = argsDict[modelFlopsKey] del argsDict[modelFlopsKey] # log args to html logger.addInfoTable( 'args', HtmlLogger.dictToRows(argsDict, 3, lambda kv: kv[0], sortFuncsDict)) # print args print(args) # save to json saveArgsToJSON(args) # bring back model flops list argsDict[modelFlopsKey] = modelFlops
class SearchRegime(TrainRegime): # init train logger key trainLoggerKey = TrainWeights.trainLoggerKey summaryKey = TrainWeights.summaryKey # init table columns names archLossKey = 'Arch Loss' pathsListKey = 'Paths list' gradientsKey = 'Gradients' # get keys from TrainWeights batchNumKey = TrainWeights.batchNumKey epochNumKey = TrainWeights.epochNumKey forwardCountersKey = TrainWeights.forwardCountersKey timeKey = TrainWeights.timeKey trainLossKey = TrainWeights.trainLossKey trainAccKey = TrainWeights.trainAccKey validLossKey = TrainWeights.validLossKey validAccKey = TrainWeights.validAccKey widthKey = TrainWeights.widthKey lrKey = TrainWeights.lrKey validFlopsRatioKey = TrainWeights.flopsRatioKey # init table columns k = 2 alphasTableTitle = 'Alphas (top [{}])'.format(k) # init table columns names colsTrainAlphas = [ batchNumKey, archLossKey, alphasTableTitle, pathsListKey, gradientsKey ] colsMainLogger = [ epochNumKey, archLossKey, trainLossKey, trainAccKey, validLossKey, validAccKey, validFlopsRatioKey, widthKey, lrKey ] # init statistics (plots) keys template batchLossAvgTemplate = '{}_Loss_Avg_(Batch)' epochLossAvgTemplate = '{}_Loss_Avg_(Epoch)' batchLossVarianceTemplate = '{}_Loss_Variance_(Batch)' # init statistics (plots) keys entropyKey = 'Alphas_Entropy' batchAlphaDistributionKey = 'Alphas_Distribution_(Batch)' epochAlphaDistributionKey = 'Alphas_Distribution_(Epoch)' # init formats for keys formats = { archLossKey: lambda x: HtmlLogger.dictToRows(x, nElementPerRow=1) } def __init__(self, args, logger): self.lossClass = FlopsLoss super(SearchRegime, self).__init__(args, logger) # init number of epochs self.nEpochs = args.search_epochs # init main table logger.createDataTable('Search summary', self.colsMainLogger) # update max table cell length logger.setMaxTableCellLength(30) # add TrainWeights formats to self.formats self.formats.update(TrainWeights.getFormats()) # update epoch key format self.formats[ self.epochNumKey] = lambda x: '{}/{}'.format(x, self.nEpochs) # init flops loss self.flopsLoss = FlopsLoss( args, getattr(args, self.model.baselineFlopsKey())) self.flopsLoss = self.flopsLoss.cuda() # create search queue self.search_queue = self.createSearchQueue() # load model pre-trained weights TrainWeights.loadPreTrained(self.model, args.pre_trained, self.logger) # reset args.pre_trained, we don't want to load these weights anymore args.pre_trained = None # init model replications self.replicator = self.initReplicator() # create folder for jobs checkpoints self.jobsPath = '{}/jobs'.format(args.save) makedirs(self.jobsPath) # init data table row keys to replace self.rowKeysToReplace = [self.validLossKey, self.validAccKey] def buildStatsRules(self): _alphaDistributionMaxVal = 1.1 return { self.batchAlphaDistributionKey: _alphaDistributionMaxVal, self.epochAlphaDistributionKey: _alphaDistributionMaxVal } # apply defined format functions on dict values by keys def _applyFormats(self, dict): for k in dict.keys(): if k in self.formats: dict[k] = self.formats[k](dict[k]) @abstractmethod def initReplicator(self): raise NotImplementedError('subclasses must override initReplicator()!') @abstractmethod def _pathsListToRows(self, batchLossDictsList: list) -> list: raise NotImplementedError( 'subclasses must override _parsePathsList()!') @abstractmethod def _containerPerAlpha(self, model: BaseNet) -> list: raise NotImplementedError( 'subclasses must override _containerPerAlpha()!') @abstractmethod def _alphaGradTitle(self, layer: SlimLayer, alphaIdx: int): raise NotImplementedError( 'subclasses must override _alphaGradTitle()!') @abstractmethod def _calcAlphasDistribStats(self, model: BaseNet, alphaDistributionKey: str): raise NotImplementedError( 'subclasses must override _calcAlphasDistribStats()!') # add epoch loss to statistics plots @abstractmethod def _updateEpochLossStats(self, epochLossDict: dict): raise NotImplementedError( 'subclasses must override _updateEpochLossStats()!') # updates alphas gradients # updates statistics @abstractmethod def _updateAlphasGradients(self, lossDictsList: list) -> dict: raise NotImplementedError( 'subclasses must override _updateAlphasGradients()!') def _alphaPlotTitle(self, layer: SlimLayer, alphaIdx: int) -> str: return '{} ({})'.format(layer.widthRatioByIdx(alphaIdx), layer.widthByIdx(alphaIdx)) def _getNextSearchQueueDataLoader(self): if len(self.search_queue) == 0: # create search queue again, because we iterate over all samples self.search_queue = self.createSearchQueue() # get next DataLoader dataLoader = self.search_queue[0] # remove DataLoader from search_queue list del self.search_queue[0] return dataLoader def TrainWeightsClass(self): return EpochTrainWeights def _addValuesToStatistics(self, getListFunc: callable, templateStr: str, valuesDict: dict): for k, v in valuesDict.items(): self.statistics.addValue(getListFunc(templateStr.format(k)), v) def trainAlphas(self, search_queue, optimizer, epoch, loggers): print('*** trainAlphas() ***') model = self.model replicator = self.replicator # init trainingStats instance trainStats = AlphaTrainingStats(self.flopsLoss.lossKeys(), useAvg=False) def createInfoTable(dict, key, logger, rows): dict[key] = logger.createInfoTable('Show', rows) def createAlphasTable(k, rows): createInfoTable(dataRow, self.alphasTableTitle, trainLogger, rows) nBatches = len(search_queue) # update batch num key format self.formats[self.batchNumKey] = lambda x: '{}/{}'.format(x, nBatches) startTime = time() # choose nSamples paths, train them, evaluate them over search_queue # lossDictsList is a list of lists where each list contains losses of specific batch lossDictsList = replicator.loss(model, search_queue) calcTime = time() - startTime trainLogger = loggers.get(self.trainLoggerKey) if trainLogger: trainLogger.createDataTable( 'Alphas - Epoch:[{}] - Time:[{:.3f}]'.format(epoch, calcTime), self.colsTrainAlphas) for batchNum, batchLossDictsList in enumerate(lossDictsList): # reset optimizer gradients optimizer.zero_grad() # update statistics and alphas gradients based on loss lossAvgDict = self._updateAlphasGradients(batchLossDictsList) # perform optimizer step optimizer.step() # update training stats for lossName, loss in lossAvgDict.items(): trainStats.update(lossName, loss) # save alphas to csv model.saveAlphasCsv(data=[epoch, batchNum]) # update batch alphas distribution statistics (after optimizer step) self._calcAlphasDistribStats(model, self.batchAlphaDistributionKey) if trainLogger: # parse paths list to InfoTable rows pathsListRows = self._pathsListToRows(batchLossDictsList) # parse alphas gradients to InfoTable rows gradientRows = [['Layer #', self.gradientsKey]] for layerIdx, (layer, alphas) in enumerate( zip(model.layersList(), model.alphas())): gradientRows.append([ layerIdx, [[ self._alphaGradTitle(layer, idx), '{:.5f}'.format(alphas.grad[idx]) ] for idx in range(len(alphas))] ]) # init data row dataRow = { self.batchNumKey: batchNum, self.archLossKey: trainStats.batchLoss(), self.pathsListKey: trainLogger.createInfoTable('Show', pathsListRows), self.gradientsKey: trainLogger.createInfoTable('Show', gradientRows) } # add alphas distribution table model.logTopAlphas(self.k, [createAlphasTable]) # apply formats self._applyFormats(dataRow) # add row to data table trainLogger.addDataRow(dataRow) epochLossDict = trainStats.epochLoss() # log summary row summaryDataRow = { self.batchNumKey: self.summaryKey, self.archLossKey: epochLossDict } # delete batch num key format del self.formats[self.batchNumKey] # apply formats self._applyFormats(summaryDataRow) # add row to data table trainLogger.addSummaryDataRow(summaryDataRow) # update epoch alphas distribution statistics (after optimizer step) self._calcAlphasDistribStats(model, self.epochAlphaDistributionKey) # update epoch loss statistics self._updateEpochLossStats(epochLossDict) # update statistics plots self.statistics.plotData() return epochLossDict, summaryDataRow @staticmethod def _getEpochRange(nEpochs: int) -> range: return range(1, nEpochs + 1) @staticmethod def _generateTableValue(jobName, key) -> str: return {BaseNet.partitionKey(): '{}_{}'.format(jobName, key)} def _createPartitionInfoTable(self, partition): rows = [['Layer #', self.widthKey] ] + [[layerIdx, w] for layerIdx, w in enumerate(partition)] table = self.logger.createInfoTable('Show', rows) return table def _createJob(self, epoch: int, id: int, choosePathFunc: callable) -> dict: model = self.model args = self.args # clone args job = Namespace(**vars(args)) # init job name epochStr = epoch if epoch >= 10 else '0{}'.format(epoch) idStr = id if id >= 10 else '0{}'.format(id) jobName = '[{}]-[{}]-[{}]'.format(args.time, epochStr, idStr) # create job data row dataRow = { k: self._generateTableValue(jobName, k) for k in self.rowKeysToReplace } # sample path from alphas distribution choosePathFunc() # set attributes job.partition = model.currWidthRatio() job.epoch = epoch job.id = id job.jobName = jobName job.tableKeys = dataRow job.width = [0.25, 0.5, 0.75, 1.0] # init model flops key modelFlopsKey = BaseNet.modelFlopsKey() # reset model flops dict setattr(job, modelFlopsKey, None) # save job jobPath = '{}/{}.pth.tar'.format(self.jobsPath, job.jobName) saveCheckpoint(job, jobPath) # add flops ratio to data row dataRow[self.validFlopsRatioKey] = model.flopsRatio() # add path width ratio to data row dataRow[self.widthKey] = self._createPartitionInfoTable(job.partition) # add epoch number to data row dataRow[self.epochNumKey] = epoch # apply formats self._applyFormats(dataRow) return dataRow def _createEpochJobs(self, epoch: int) -> list: # init epoch data rows list epochDataRows = [] # init model path chooser function choosePathFunc = self.model.choosePathAlphasAsPartition for id in self._getEpochRange(self.args.nJobs): jobDataRow = self._createJob(epoch, id, choosePathFunc) epochDataRows.append(jobDataRow) # only 1st job should be based on alphas max, the rest should sample from alphas distribution choosePathFunc = self.model.choosePathByAlphas return epochDataRows def train(self): args = self.args model = self.model logger = self.logger epochRange = self._getEpochRange(self.nEpochs) # init optimizer optimizer = SGD(model.alphas(), args.search_learning_rate, momentum=args.search_momentum, weight_decay=args.search_weight_decay) # init scheduler scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.95, patience=args.search_patience, min_lr=args.search_learning_rate_min) for epoch in epochRange: print('========== Epoch:[{}/{}] =============='.format( epoch, self.nEpochs)) # init epoch train logger trainLogger = HtmlLogger(self.trainFolderPath, epoch) # set loggers dictionary loggersDict = {self.trainLoggerKey: trainLogger} # create epoch jobs epochDataRows = self._createEpochJobs(epoch) # add epoch data rows for jobDataRow in epochDataRows: logger.addDataRow(jobDataRow, trType='<tr bgcolor="#2CBDD6">') # train alphas # epochLossDict, alphasDataRow = self.trainAlphas(self._getNextSearchQueueDataLoader(), optimizer, epoch, loggersDict) epochLossDict, alphasDataRow = self.trainAlphas( self.valid_queue, optimizer, epoch, loggersDict) # update scheduler scheduler.step(epochLossDict.get(self.flopsLoss.totalKey())) # calc model choosePathAlphasAsPartition flops ratio model.choosePathAlphasAsPartition() # add values to alphas data row additionalData = { self.epochNumKey: epoch, self.lrKey: optimizer.param_groups[0]['lr'], self.validFlopsRatioKey: model.flopsRatio() } self._applyFormats(additionalData) # add alphas data row alphasDataRow.update(additionalData) logger.addDataRow(alphasDataRow) # save checkpoint save_checkpoint(self.trainFolderPath, model, optimizer, epochLossDict)
class TrainWeights: # init train logger key trainLoggerKey = 'train' summaryKey = 'Summary' # init tables keys trainLossKey = 'Training loss' trainAccKey = 'Training acc' validLossKey = 'Validation loss' validAccKey = 'Validation acc' flopsRatioKey = 'Flops ratio' epochNumKey = 'Epoch #' batchNumKey = 'Batch #' timeKey = 'Time' lrKey = 'Optimizer lr' widthKey = 'Width' forwardCountersKey = 'Forward counters' # init formats for keys formats = { timeKey: lambda x: '{:.3f}'.format(x), lrKey: lambda x: '{:.8f}'.format(x), trainLossKey: lambda x: HtmlLogger.dictToRows(x, nElementPerRow=1), trainAccKey: lambda x: HtmlLogger.dictToRows(x, nElementPerRow=1), validLossKey: lambda x: HtmlLogger.dictToRows(x, nElementPerRow=1), validAccKey: lambda x: HtmlLogger.dictToRows(x, nElementPerRow=1), flopsRatioKey: lambda x: '{:.3f}'.format(x) } # init tables columns colsTrainWeights = [batchNumKey, trainLossKey, trainAccKey, timeKey] colsValidation = [batchNumKey, validLossKey, validAccKey, timeKey] # def __init__(self, regime): def __init__(self, getModel, getModelParallel, getArgs, getLogger, getTrainQueue, getValidQueue, getTrainFolderPath): # init functions self.getModel = getModel self.getModelParallel = getModelParallel self.getArgs = getArgs self.getLogger = getLogger self.getTrainQueue = getTrainQueue self.getValidQueue = getValidQueue self.getTrainFolderPath = getTrainFolderPath # self.regime = regime # init cross entropy loss self.cross_entropy = CrossEntropyLoss().cuda() # load pre-trained model & optimizer self.optimizerStateDict = self.loadPreTrained( self.getModel(), self.getArgs().pre_trained, self.getLogger()) # apply defined format functions on dict values by keys def _applyFormats(self, dict): for k in dict.keys(): if k in self.formats: dict[k] = self.formats[k](dict[k]) @staticmethod def getFormats(): return TrainWeights.formats @abstractmethod def stopCondition(self, epoch): raise NotImplementedError('subclasses must override stopCondition()!') @abstractmethod # returns (widthRatio, idxList) list or generator def widthList(self): raise NotImplementedError('subclasses must override widthList()!') @abstractmethod def schedulerMetric(self, validLoss): raise NotImplementedError( 'subclasses must override schedulerMetric()!') @abstractmethod def postEpoch(self, epoch, optimizer, trainData: EpochData, validData: EpochData): raise NotImplementedError('subclasses must override postEpoch()!') @abstractmethod def postTrain(self): raise NotImplementedError('subclasses must override postTrain()!') # generic epoch flow def _genericEpoch(self, forwardFunc, data_queue, loggers, lossKey, accKey, tableTitle, tableCols, forwardCountersTitle) -> EpochData: trainStats = TrainingStats([k for k, v in self.widthList()]) trainLogger = loggers.get(self.trainLoggerKey) if trainLogger: trainLogger.createDataTable(tableTitle, tableCols) nBatches = len(data_queue) for batchNum, (input, target) in enumerate(data_queue): startTime = time() input = input.cuda().clone().detach().requires_grad_(False) target = target.cuda( async=True).clone().detach().requires_grad_(False) # do forward forwardFunc(input, target, trainStats) endTime = time() if trainLogger: dataRow = { self.batchNumKey: '{}/{}'.format(batchNum, nBatches), self.timeKey: (endTime - startTime), lossKey: trainStats.batchLoss(), accKey: trainStats.prec1() } # apply formats self._applyFormats(dataRow) # add row to data table trainLogger.addDataRow(dataRow) epochLossDict = trainStats.epochLoss() epochAccDict = trainStats.top1() # # add epoch data to statistics plots # self.statistics.addBatchData(epochLossDict, epochAccDict) # log accuracy, loss, etc. summaryData = { lossKey: epochLossDict, accKey: epochAccDict, self.batchNumKey: self.summaryKey } # apply formats self._applyFormats(summaryData) for logger in loggers.values(): if logger: logger.addSummaryDataRow(summaryData) # log forward counters. if loggerFuncs==[] then it is just resets counters func = [ lambda rows: trainLogger.addInfoTable(title=forwardCountersTitle, rows=rows) ] if trainLogger else [] self.getModel().logForwardCounters(loggerFuncs=func) return EpochData(epochLossDict, epochAccDict, summaryData) def _slimForward(self, input, target, trainStats): model = self.getModel() modelParallel = self.getModelParallel() crit = self.cross_entropy # init loss list lossList = [] # iterate & forward widths for widthRatio, idxList in self.widthList(): # set model layers current width index model.setCurrWidthIdx(idxList) # forward logits = modelParallel(input) # calc loss loss = crit(logits, target) # add to loss list lossList.append(loss) # update training stats trainStats.update(widthRatio, logits, target, loss) return lossList # performs single epoch of model weights training def weightsEpoch(self, optimizer, epoch, loggers) -> EpochData: # print('*** weightsEpoch() ***') model = self.getModel() modelParallel = self.getModelParallel() modelParallel.train() assert (model.training is True) def forwardFunc(input, target, trainStats): # optimize model weights optimizer.zero_grad() # forward lossList = self._slimForward(input, target, trainStats) # back propagate for loss in lossList: loss.backward() # update weights optimizer.step() tableTitle = 'Epoch:[{}] - Training weights'.format(epoch) forwardCountersTitle = '{} - Training'.format(self.forwardCountersKey) return self._genericEpoch(forwardFunc, self.getTrainQueue(), loggers, self.trainLossKey, self.trainAccKey, tableTitle, self.colsTrainWeights, forwardCountersTitle) # performs single epoch of model inference def inferEpoch(self, nEpoch, loggers) -> EpochData: print('*** inferEpoch() ***') model = self.getModel() modelParallel = self.getModelParallel() modelParallel.eval() assert (model.training is False) def forwardFunc(input, target, trainStats): with no_grad(): self._slimForward(input, target, trainStats) tableTitle = 'Epoch:[{}] - Validation'.format(nEpoch) forwardCountersTitle = '{} - Validation'.format( self.forwardCountersKey) return self._genericEpoch(forwardFunc, self.getValidQueue(), loggers, self.validLossKey, self.validAccKey, tableTitle, self.colsValidation, forwardCountersTitle) def _initOptimizer(self): modelParallel = self.getModelParallel() args = self.getArgs() optimizer = SGD(modelParallel.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) # load optimizer pre-trained state dict if exists if self.optimizerStateDict: optimizer.load_state_dict(self.optimizerStateDict) return optimizer def train(self, trainFolderName='init_weights_train'): args = self.getArgs() # create train folder folderPath = '{}/{}'.format(self.getTrainFolderPath(), trainFolderName) if not exists(folderPath): makedirs(folderPath) # init optimizer optimizer = self._initOptimizer() # init scheduler scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.95, patience=args.weights_patience, min_lr=args.learning_rate_min) epoch = 0 trainLoggerFlag = True while not self.stopCondition(epoch): # update epoch number epoch += 1 # init train logger trainLogger = None if trainLoggerFlag: trainLogger = HtmlLogger(folderPath, epoch) trainLogger.addInfoTable('Learning rates', [[ 'optimizer_lr', self.formats[self.lrKey]( optimizer.param_groups[0]['lr']) ]]) # update train logger condition for next epoch trainLoggerFlag = ((epoch + 1) % args.logInterval) == 0 # set loggers dictionary loggersDict = {self.trainLoggerKey: trainLogger} print('========== Epoch:[{}] =============='.format(epoch)) # train trainData = self.weightsEpoch(optimizer, epoch, loggersDict) # validation validData = self.inferEpoch(epoch, loggersDict) # update scheduler scheduler.step(self.schedulerMetric(validData.lossDict())) self.postEpoch(epoch, optimizer, trainData, validData) self.postTrain() @staticmethod def loadPreTrained(model: BaseNet, path: str, logger: HtmlLogger) -> dict: optimizerStateDict = None if path is not None: if exists(path): # load checkpoint checkpoint = loadModel( path, map_location=lambda storage, loc: storage.cuda()) # load weights model.loadPreTrained(checkpoint['state_dict']) # # load optimizer state dict # optimizerStateDict = checkpoint['optimizer'] # add info rows about checkpoint loggerRows = [] loggerRows.append(['Path', '{}'.format(path)]) validationAccRows = [[ 'Ratio', 'Accuracy' ]] + HtmlLogger.dictToRows(checkpoint['best_prec1'], nElementPerRow=1) loggerRows.append(['Validation accuracy', validationAccRows]) # set optimizer table row optimizerRow = HtmlLogger.dictToRows( optimizerStateDict, nElementPerRow=3 ) if optimizerStateDict else optimizerStateDict loggerRows.append(['Optimizer', optimizerRow]) logger.addInfoTable('Pre-trained model', loggerRows) else: raise ValueError( 'Failed to load pre-trained from [{}], path does not exists' .format(path)) return optimizerStateDict