예제 #1
0
    def loadPreTrained(model: BaseNet, path: str, logger: HtmlLogger) -> dict:
        optimizerStateDict = None

        if path is not None:
            if exists(path):
                # load checkpoint
                checkpoint = loadModel(
                    path, map_location=lambda storage, loc: storage.cuda())
                # load weights
                model.loadPreTrained(checkpoint['state_dict'])
                # # load optimizer state dict
                # optimizerStateDict = checkpoint['optimizer']
                # add info rows about checkpoint
                loggerRows = []
                loggerRows.append(['Path', '{}'.format(path)])
                validationAccRows = [[
                    'Ratio', 'Accuracy'
                ]] + HtmlLogger.dictToRows(checkpoint['best_prec1'],
                                           nElementPerRow=1)
                loggerRows.append(['Validation accuracy', validationAccRows])
                # set optimizer table row
                optimizerRow = HtmlLogger.dictToRows(
                    optimizerStateDict, nElementPerRow=3
                ) if optimizerStateDict else optimizerStateDict
                loggerRows.append(['Optimizer', optimizerRow])
                logger.addInfoTable('Pre-trained model', loggerRows)
            else:
                raise ValueError(
                    'Failed to load pre-trained from [{}], path does not exists'
                    .format(path))

        return optimizerStateDict
예제 #2
0
    def train(self, trainFolderName='init_weights_train'):
        args = self.getArgs()

        # create train folder
        folderPath = '{}/{}'.format(self.getTrainFolderPath(), trainFolderName)
        if not exists(folderPath):
            makedirs(folderPath)

        # init optimizer
        optimizer = self._initOptimizer()
        # init scheduler
        scheduler = ReduceLROnPlateau(optimizer,
                                      mode='min',
                                      factor=0.95,
                                      patience=args.weights_patience,
                                      min_lr=args.learning_rate_min)

        epoch = 0
        trainLoggerFlag = True

        while not self.stopCondition(epoch):
            # update epoch number
            epoch += 1
            # init train logger
            trainLogger = None
            if trainLoggerFlag:
                trainLogger = HtmlLogger(folderPath, epoch)
                trainLogger.addInfoTable('Learning rates', [[
                    'optimizer_lr', self.formats[self.lrKey](
                        optimizer.param_groups[0]['lr'])
                ]])

            # update train logger condition for next epoch
            trainLoggerFlag = ((epoch + 1) % args.logInterval) == 0

            # set loggers dictionary
            loggersDict = {self.trainLoggerKey: trainLogger}

            print('========== Epoch:[{}] =============='.format(epoch))
            # train
            trainData = self.weightsEpoch(optimizer, epoch, loggersDict)
            # validation
            validData = self.inferEpoch(epoch, loggersDict)

            # update scheduler
            scheduler.step(self.schedulerMetric(validData.lossDict()))

            self.postEpoch(epoch, optimizer, trainData, validData)

        self.postTrain()
예제 #3
0
    def train(self):
        args = self.args
        model = self.model
        logger = self.logger
        epochRange = self._getEpochRange(self.nEpochs)

        # init optimizer
        optimizer = SGD(model.alphas(),
                        args.search_learning_rate,
                        momentum=args.search_momentum,
                        weight_decay=args.search_weight_decay)
        # init scheduler
        scheduler = ReduceLROnPlateau(optimizer,
                                      mode='min',
                                      factor=0.95,
                                      patience=args.search_patience,
                                      min_lr=args.search_learning_rate_min)

        for epoch in epochRange:
            print('========== Epoch:[{}/{}] =============='.format(
                epoch, self.nEpochs))
            # init epoch train logger
            trainLogger = HtmlLogger(self.trainFolderPath, epoch)
            # set loggers dictionary
            loggersDict = {self.trainLoggerKey: trainLogger}

            # create epoch jobs
            epochDataRows = self._createEpochJobs(epoch)
            # add epoch data rows
            for jobDataRow in epochDataRows:
                logger.addDataRow(jobDataRow, trType='<tr bgcolor="#2CBDD6">')

            # train alphas
            # epochLossDict, alphasDataRow = self.trainAlphas(self._getNextSearchQueueDataLoader(), optimizer, epoch, loggersDict)
            epochLossDict, alphasDataRow = self.trainAlphas(
                self.valid_queue, optimizer, epoch, loggersDict)
            # update scheduler
            scheduler.step(epochLossDict.get(self.flopsLoss.totalKey()))

            # calc model choosePathAlphasAsPartition flops ratio
            model.choosePathAlphasAsPartition()
            # add values to alphas data row
            additionalData = {
                self.epochNumKey: epoch,
                self.lrKey: optimizer.param_groups[0]['lr'],
                self.validFlopsRatioKey: model.flopsRatio()
            }
            self._applyFormats(additionalData)
            # add alphas data row
            alphasDataRow.update(additionalData)
            logger.addDataRow(alphasDataRow)

            # save checkpoint
            save_checkpoint(self.trainFolderPath, model, optimizer,
                            epochLossDict)
예제 #4
0
def train(scriptArgs):
    # load args from file
    args = loadCheckpoint(scriptArgs.json,
                          map_location=lambda storage, loc: storage.cuda())

    # terminate if validAcc exists
    _validAccKey = TrainWeights.validAccKey
    if hasattr(args, _validAccKey):
        print('[{}] exists'.format(_validAccKey))
        exit(0)

    # no need to save model random weights
    args.saveRandomWeights = False

    # update args parameters
    args.seed = datetime.now().microsecond
    # update cudnn parameters
    random.seed(args.seed)
    set_device(scriptArgs.gpu[0])
    cudnn.benchmark = True
    torch_manual_seed(args.seed)
    cudnn.enabled = True
    cuda_manual_seed(args.seed)
    # copy scriptArgs values to args
    for k, v in vars(scriptArgs).items():
        setattr(args, k, v)

    # load model flops
    _modelFlopsPathKey = BaseNet.modelFlopsPathKey()
    modelFlopsPath = getattr(args, _modelFlopsPathKey)
    if modelFlopsPath and exists(modelFlopsPath):
        setattr(args, BaseNet.modelFlopsKey(), loadCheckpoint(modelFlopsPath))

    folderNotExists = not exists(args.save)
    if folderNotExists:
        create_exp_dir(args.save)
        # init logger
        logger = HtmlLogger(args.save, 'log')

        if scriptArgs.individual:
            args.width = []

        alphasRegime = OptimalRegime(args, logger)
        # train according to chosen regime
        alphasRegime.train()

    return folderNotExists
예제 #5
0
def logParameters(logger, args, model):
    if not logger:
        return

    # log command line
    logger.addInfoTable(
        title='Command line',
        rows=[[' '.join(argv)], ['PID', getpid()], ['Hostname',
                                                    gethostname()],
              ['CUDA_VISIBLE_DEVICES',
               environ.get('CUDA_VISIBLE_DEVICES')]])

    # # calc number of permutations
    # permutationStr = model.nPerms
    # for p in [12, 9, 6, 3]:
    #     v = model.nPerms / (10 ** p)
    #     if v > 1:
    #         permutationStr = '{:.3f} * 10<sup>{}</sup>'.format(v, p)
    #         break
    #
    # # log other parameters
    # logger.addInfoTable('Parameters', HtmlLogger.dictToRows(
    #     {
    #         'Learnable params': len([param for param in model.parameters() if param.requires_grad]),
    #         'Widths per layer': [layer.nWidths() for layer in model.layersList()],
    #         'Permutations': permutationStr
    #     }, nElementPerRow=2))

    # init args dict sorting function
    sortFuncsDict = {k: lambda kv: kv[-1] for k in BaseNet.keysToSortByValue()}
    # transform args to dictionary
    argsDict = vars(args)
    # emit model flops list from args dict
    modelFlopsKey = BaseNet.modelFlopsKey()
    modelFlops = argsDict[modelFlopsKey]
    del argsDict[modelFlopsKey]
    # log args to html
    logger.addInfoTable(
        'args',
        HtmlLogger.dictToRows(argsDict, 3, lambda kv: kv[0], sortFuncsDict))
    # print args
    print(args)
    # save to json
    saveArgsToJSON(args)
    # bring back model flops list
    argsDict[modelFlopsKey] = modelFlops
예제 #6
0
class SearchRegime(TrainRegime):
    # init train logger key
    trainLoggerKey = TrainWeights.trainLoggerKey
    summaryKey = TrainWeights.summaryKey
    # init table columns names
    archLossKey = 'Arch Loss'
    pathsListKey = 'Paths list'
    gradientsKey = 'Gradients'
    # get keys from TrainWeights
    batchNumKey = TrainWeights.batchNumKey
    epochNumKey = TrainWeights.epochNumKey
    forwardCountersKey = TrainWeights.forwardCountersKey
    timeKey = TrainWeights.timeKey
    trainLossKey = TrainWeights.trainLossKey
    trainAccKey = TrainWeights.trainAccKey
    validLossKey = TrainWeights.validLossKey
    validAccKey = TrainWeights.validAccKey
    widthKey = TrainWeights.widthKey
    lrKey = TrainWeights.lrKey
    validFlopsRatioKey = TrainWeights.flopsRatioKey

    # init table columns
    k = 2
    alphasTableTitle = 'Alphas (top [{}])'.format(k)
    # init table columns names
    colsTrainAlphas = [
        batchNumKey, archLossKey, alphasTableTitle, pathsListKey, gradientsKey
    ]
    colsMainLogger = [
        epochNumKey, archLossKey, trainLossKey, trainAccKey, validLossKey,
        validAccKey, validFlopsRatioKey, widthKey, lrKey
    ]

    # init statistics (plots) keys template
    batchLossAvgTemplate = '{}_Loss_Avg_(Batch)'
    epochLossAvgTemplate = '{}_Loss_Avg_(Epoch)'
    batchLossVarianceTemplate = '{}_Loss_Variance_(Batch)'
    # init statistics (plots) keys
    entropyKey = 'Alphas_Entropy'
    batchAlphaDistributionKey = 'Alphas_Distribution_(Batch)'
    epochAlphaDistributionKey = 'Alphas_Distribution_(Epoch)'

    # init formats for keys
    formats = {
        archLossKey: lambda x: HtmlLogger.dictToRows(x, nElementPerRow=1)
    }

    def __init__(self, args, logger):
        self.lossClass = FlopsLoss
        super(SearchRegime, self).__init__(args, logger)

        # init number of epochs
        self.nEpochs = args.search_epochs
        # init main table
        logger.createDataTable('Search summary', self.colsMainLogger)
        # update max table cell length
        logger.setMaxTableCellLength(30)

        # add TrainWeights formats to self.formats
        self.formats.update(TrainWeights.getFormats())
        # update epoch key format
        self.formats[
            self.epochNumKey] = lambda x: '{}/{}'.format(x, self.nEpochs)

        # init flops loss
        self.flopsLoss = FlopsLoss(
            args, getattr(args, self.model.baselineFlopsKey()))
        self.flopsLoss = self.flopsLoss.cuda()

        # create search queue
        self.search_queue = self.createSearchQueue()

        # load model pre-trained weights
        TrainWeights.loadPreTrained(self.model, args.pre_trained, self.logger)
        # reset args.pre_trained, we don't want to load these weights anymore
        args.pre_trained = None
        # init model replications
        self.replicator = self.initReplicator()

        # create folder for jobs checkpoints
        self.jobsPath = '{}/jobs'.format(args.save)
        makedirs(self.jobsPath)
        # init data table row keys to replace
        self.rowKeysToReplace = [self.validLossKey, self.validAccKey]

    def buildStatsRules(self):
        _alphaDistributionMaxVal = 1.1
        return {
            self.batchAlphaDistributionKey: _alphaDistributionMaxVal,
            self.epochAlphaDistributionKey: _alphaDistributionMaxVal
        }

    # apply defined format functions on dict values by keys
    def _applyFormats(self, dict):
        for k in dict.keys():
            if k in self.formats:
                dict[k] = self.formats[k](dict[k])

    @abstractmethod
    def initReplicator(self):
        raise NotImplementedError('subclasses must override initReplicator()!')

    @abstractmethod
    def _pathsListToRows(self, batchLossDictsList: list) -> list:
        raise NotImplementedError(
            'subclasses must override _parsePathsList()!')

    @abstractmethod
    def _containerPerAlpha(self, model: BaseNet) -> list:
        raise NotImplementedError(
            'subclasses must override _containerPerAlpha()!')

    @abstractmethod
    def _alphaGradTitle(self, layer: SlimLayer, alphaIdx: int):
        raise NotImplementedError(
            'subclasses must override _alphaGradTitle()!')

    @abstractmethod
    def _calcAlphasDistribStats(self, model: BaseNet,
                                alphaDistributionKey: str):
        raise NotImplementedError(
            'subclasses must override _calcAlphasDistribStats()!')

    # add epoch loss to statistics plots
    @abstractmethod
    def _updateEpochLossStats(self, epochLossDict: dict):
        raise NotImplementedError(
            'subclasses must override _updateEpochLossStats()!')

    # updates alphas gradients
    # updates statistics
    @abstractmethod
    def _updateAlphasGradients(self, lossDictsList: list) -> dict:
        raise NotImplementedError(
            'subclasses must override _updateAlphasGradients()!')

    def _alphaPlotTitle(self, layer: SlimLayer, alphaIdx: int) -> str:
        return '{} ({})'.format(layer.widthRatioByIdx(alphaIdx),
                                layer.widthByIdx(alphaIdx))

    def _getNextSearchQueueDataLoader(self):
        if len(self.search_queue) == 0:
            # create search queue again, because we iterate over all samples
            self.search_queue = self.createSearchQueue()
        # get next DataLoader
        dataLoader = self.search_queue[0]
        # remove DataLoader from search_queue list
        del self.search_queue[0]

        return dataLoader

    def TrainWeightsClass(self):
        return EpochTrainWeights

    def _addValuesToStatistics(self, getListFunc: callable, templateStr: str,
                               valuesDict: dict):
        for k, v in valuesDict.items():
            self.statistics.addValue(getListFunc(templateStr.format(k)), v)

    def trainAlphas(self, search_queue, optimizer, epoch, loggers):
        print('*** trainAlphas() ***')
        model = self.model
        replicator = self.replicator
        # init trainingStats instance
        trainStats = AlphaTrainingStats(self.flopsLoss.lossKeys(),
                                        useAvg=False)

        def createInfoTable(dict, key, logger, rows):
            dict[key] = logger.createInfoTable('Show', rows)

        def createAlphasTable(k, rows):
            createInfoTable(dataRow, self.alphasTableTitle, trainLogger, rows)

        nBatches = len(search_queue)
        # update batch num key format
        self.formats[self.batchNumKey] = lambda x: '{}/{}'.format(x, nBatches)

        startTime = time()
        # choose nSamples paths, train them, evaluate them over search_queue
        # lossDictsList is a list of lists where each list contains losses of specific batch
        lossDictsList = replicator.loss(model, search_queue)
        calcTime = time() - startTime

        trainLogger = loggers.get(self.trainLoggerKey)
        if trainLogger:
            trainLogger.createDataTable(
                'Alphas - Epoch:[{}] - Time:[{:.3f}]'.format(epoch, calcTime),
                self.colsTrainAlphas)

        for batchNum, batchLossDictsList in enumerate(lossDictsList):
            # reset optimizer gradients
            optimizer.zero_grad()
            # update statistics and alphas gradients based on loss
            lossAvgDict = self._updateAlphasGradients(batchLossDictsList)
            # perform optimizer step
            optimizer.step()

            # update training stats
            for lossName, loss in lossAvgDict.items():
                trainStats.update(lossName, loss)
            # save alphas to csv
            model.saveAlphasCsv(data=[epoch, batchNum])
            # update batch alphas distribution statistics (after optimizer step)
            self._calcAlphasDistribStats(model, self.batchAlphaDistributionKey)

            if trainLogger:
                # parse paths list to InfoTable rows
                pathsListRows = self._pathsListToRows(batchLossDictsList)
                # parse alphas gradients to InfoTable rows
                gradientRows = [['Layer #', self.gradientsKey]]
                for layerIdx, (layer, alphas) in enumerate(
                        zip(model.layersList(), model.alphas())):
                    gradientRows.append([
                        layerIdx,
                        [[
                            self._alphaGradTitle(layer, idx),
                            '{:.5f}'.format(alphas.grad[idx])
                        ] for idx in range(len(alphas))]
                    ])
                # init data row
                dataRow = {
                    self.batchNumKey:
                    batchNum,
                    self.archLossKey:
                    trainStats.batchLoss(),
                    self.pathsListKey:
                    trainLogger.createInfoTable('Show', pathsListRows),
                    self.gradientsKey:
                    trainLogger.createInfoTable('Show', gradientRows)
                }
                # add alphas distribution table
                model.logTopAlphas(self.k, [createAlphasTable])
                # apply formats
                self._applyFormats(dataRow)
                # add row to data table
                trainLogger.addDataRow(dataRow)

        epochLossDict = trainStats.epochLoss()
        # log summary row
        summaryDataRow = {
            self.batchNumKey: self.summaryKey,
            self.archLossKey: epochLossDict
        }
        # delete batch num key format
        del self.formats[self.batchNumKey]
        # apply formats
        self._applyFormats(summaryDataRow)
        # add row to data table
        trainLogger.addSummaryDataRow(summaryDataRow)

        # update epoch alphas distribution statistics (after optimizer step)
        self._calcAlphasDistribStats(model, self.epochAlphaDistributionKey)
        # update epoch loss statistics
        self._updateEpochLossStats(epochLossDict)
        # update statistics plots
        self.statistics.plotData()

        return epochLossDict, summaryDataRow

    @staticmethod
    def _getEpochRange(nEpochs: int) -> range:
        return range(1, nEpochs + 1)

    @staticmethod
    def _generateTableValue(jobName, key) -> str:
        return {BaseNet.partitionKey(): '{}_{}'.format(jobName, key)}

    def _createPartitionInfoTable(self, partition):
        rows = [['Layer #', self.widthKey]
                ] + [[layerIdx, w] for layerIdx, w in enumerate(partition)]
        table = self.logger.createInfoTable('Show', rows)
        return table

    def _createJob(self, epoch: int, id: int,
                   choosePathFunc: callable) -> dict:
        model = self.model
        args = self.args
        # clone args
        job = Namespace(**vars(args))
        # init job name
        epochStr = epoch if epoch >= 10 else '0{}'.format(epoch)
        idStr = id if id >= 10 else '0{}'.format(id)
        jobName = '[{}]-[{}]-[{}]'.format(args.time, epochStr, idStr)
        # create job data row
        dataRow = {
            k: self._generateTableValue(jobName, k)
            for k in self.rowKeysToReplace
        }
        # sample path from alphas distribution
        choosePathFunc()
        # set attributes
        job.partition = model.currWidthRatio()
        job.epoch = epoch
        job.id = id
        job.jobName = jobName
        job.tableKeys = dataRow
        job.width = [0.25, 0.5, 0.75, 1.0]
        # init model flops key
        modelFlopsKey = BaseNet.modelFlopsKey()
        # reset model flops dict
        setattr(job, modelFlopsKey, None)
        # save job
        jobPath = '{}/{}.pth.tar'.format(self.jobsPath, job.jobName)
        saveCheckpoint(job, jobPath)

        # add flops ratio to data row
        dataRow[self.validFlopsRatioKey] = model.flopsRatio()
        # add path width ratio to data row
        dataRow[self.widthKey] = self._createPartitionInfoTable(job.partition)
        # add epoch number to data row
        dataRow[self.epochNumKey] = epoch
        # apply formats
        self._applyFormats(dataRow)

        return dataRow

    def _createEpochJobs(self, epoch: int) -> list:
        # init epoch data rows list
        epochDataRows = []
        # init model path chooser function
        choosePathFunc = self.model.choosePathAlphasAsPartition
        for id in self._getEpochRange(self.args.nJobs):
            jobDataRow = self._createJob(epoch, id, choosePathFunc)
            epochDataRows.append(jobDataRow)
            # only 1st job should be based on alphas max, the rest should sample from alphas distribution
            choosePathFunc = self.model.choosePathByAlphas

        return epochDataRows

    def train(self):
        args = self.args
        model = self.model
        logger = self.logger
        epochRange = self._getEpochRange(self.nEpochs)

        # init optimizer
        optimizer = SGD(model.alphas(),
                        args.search_learning_rate,
                        momentum=args.search_momentum,
                        weight_decay=args.search_weight_decay)
        # init scheduler
        scheduler = ReduceLROnPlateau(optimizer,
                                      mode='min',
                                      factor=0.95,
                                      patience=args.search_patience,
                                      min_lr=args.search_learning_rate_min)

        for epoch in epochRange:
            print('========== Epoch:[{}/{}] =============='.format(
                epoch, self.nEpochs))
            # init epoch train logger
            trainLogger = HtmlLogger(self.trainFolderPath, epoch)
            # set loggers dictionary
            loggersDict = {self.trainLoggerKey: trainLogger}

            # create epoch jobs
            epochDataRows = self._createEpochJobs(epoch)
            # add epoch data rows
            for jobDataRow in epochDataRows:
                logger.addDataRow(jobDataRow, trType='<tr bgcolor="#2CBDD6">')

            # train alphas
            # epochLossDict, alphasDataRow = self.trainAlphas(self._getNextSearchQueueDataLoader(), optimizer, epoch, loggersDict)
            epochLossDict, alphasDataRow = self.trainAlphas(
                self.valid_queue, optimizer, epoch, loggersDict)
            # update scheduler
            scheduler.step(epochLossDict.get(self.flopsLoss.totalKey()))

            # calc model choosePathAlphasAsPartition flops ratio
            model.choosePathAlphasAsPartition()
            # add values to alphas data row
            additionalData = {
                self.epochNumKey: epoch,
                self.lrKey: optimizer.param_groups[0]['lr'],
                self.validFlopsRatioKey: model.flopsRatio()
            }
            self._applyFormats(additionalData)
            # add alphas data row
            alphasDataRow.update(additionalData)
            logger.addDataRow(alphasDataRow)

            # save checkpoint
            save_checkpoint(self.trainFolderPath, model, optimizer,
                            epochLossDict)
예제 #7
0
class TrainWeights:
    # init train logger key
    trainLoggerKey = 'train'
    summaryKey = 'Summary'
    # init tables keys
    trainLossKey = 'Training loss'
    trainAccKey = 'Training acc'
    validLossKey = 'Validation loss'
    validAccKey = 'Validation acc'
    flopsRatioKey = 'Flops ratio'
    epochNumKey = 'Epoch #'
    batchNumKey = 'Batch #'
    timeKey = 'Time'
    lrKey = 'Optimizer lr'
    widthKey = 'Width'
    forwardCountersKey = 'Forward counters'

    # init formats for keys
    formats = {
        timeKey: lambda x: '{:.3f}'.format(x),
        lrKey: lambda x: '{:.8f}'.format(x),
        trainLossKey: lambda x: HtmlLogger.dictToRows(x, nElementPerRow=1),
        trainAccKey: lambda x: HtmlLogger.dictToRows(x, nElementPerRow=1),
        validLossKey: lambda x: HtmlLogger.dictToRows(x, nElementPerRow=1),
        validAccKey: lambda x: HtmlLogger.dictToRows(x, nElementPerRow=1),
        flopsRatioKey: lambda x: '{:.3f}'.format(x)
    }

    # init tables columns
    colsTrainWeights = [batchNumKey, trainLossKey, trainAccKey, timeKey]
    colsValidation = [batchNumKey, validLossKey, validAccKey, timeKey]

    # def __init__(self, regime):
    def __init__(self, getModel, getModelParallel, getArgs, getLogger,
                 getTrainQueue, getValidQueue, getTrainFolderPath):
        # init functions
        self.getModel = getModel
        self.getModelParallel = getModelParallel
        self.getArgs = getArgs
        self.getLogger = getLogger
        self.getTrainQueue = getTrainQueue
        self.getValidQueue = getValidQueue
        self.getTrainFolderPath = getTrainFolderPath

        # self.regime = regime
        # init cross entropy loss
        self.cross_entropy = CrossEntropyLoss().cuda()

        # load pre-trained model & optimizer
        self.optimizerStateDict = self.loadPreTrained(
            self.getModel(),
            self.getArgs().pre_trained, self.getLogger())

    # apply defined format functions on dict values by keys
    def _applyFormats(self, dict):
        for k in dict.keys():
            if k in self.formats:
                dict[k] = self.formats[k](dict[k])

    @staticmethod
    def getFormats():
        return TrainWeights.formats

    @abstractmethod
    def stopCondition(self, epoch):
        raise NotImplementedError('subclasses must override stopCondition()!')

    @abstractmethod
    # returns (widthRatio, idxList) list or generator
    def widthList(self):
        raise NotImplementedError('subclasses must override widthList()!')

    @abstractmethod
    def schedulerMetric(self, validLoss):
        raise NotImplementedError(
            'subclasses must override schedulerMetric()!')

    @abstractmethod
    def postEpoch(self, epoch, optimizer, trainData: EpochData,
                  validData: EpochData):
        raise NotImplementedError('subclasses must override postEpoch()!')

    @abstractmethod
    def postTrain(self):
        raise NotImplementedError('subclasses must override postTrain()!')

    # generic epoch flow
    def _genericEpoch(self, forwardFunc, data_queue, loggers, lossKey, accKey,
                      tableTitle, tableCols,
                      forwardCountersTitle) -> EpochData:
        trainStats = TrainingStats([k for k, v in self.widthList()])

        trainLogger = loggers.get(self.trainLoggerKey)
        if trainLogger:
            trainLogger.createDataTable(tableTitle, tableCols)

        nBatches = len(data_queue)

        for batchNum, (input, target) in enumerate(data_queue):
            startTime = time()

            input = input.cuda().clone().detach().requires_grad_(False)
            target = target.cuda(
                async=True).clone().detach().requires_grad_(False)

            # do forward
            forwardFunc(input, target, trainStats)

            endTime = time()

            if trainLogger:
                dataRow = {
                    self.batchNumKey: '{}/{}'.format(batchNum, nBatches),
                    self.timeKey: (endTime - startTime),
                    lossKey: trainStats.batchLoss(),
                    accKey: trainStats.prec1()
                }
                # apply formats
                self._applyFormats(dataRow)
                # add row to data table
                trainLogger.addDataRow(dataRow)

        epochLossDict = trainStats.epochLoss()
        epochAccDict = trainStats.top1()
        # # add epoch data to statistics plots
        # self.statistics.addBatchData(epochLossDict, epochAccDict)
        # log accuracy, loss, etc.
        summaryData = {
            lossKey: epochLossDict,
            accKey: epochAccDict,
            self.batchNumKey: self.summaryKey
        }
        # apply formats
        self._applyFormats(summaryData)

        for logger in loggers.values():
            if logger:
                logger.addSummaryDataRow(summaryData)

        # log forward counters. if loggerFuncs==[] then it is just resets counters
        func = [
            lambda rows: trainLogger.addInfoTable(title=forwardCountersTitle,
                                                  rows=rows)
        ] if trainLogger else []
        self.getModel().logForwardCounters(loggerFuncs=func)

        return EpochData(epochLossDict, epochAccDict, summaryData)

    def _slimForward(self, input, target, trainStats):
        model = self.getModel()
        modelParallel = self.getModelParallel()
        crit = self.cross_entropy
        # init loss list
        lossList = []
        # iterate & forward widths
        for widthRatio, idxList in self.widthList():
            # set model layers current width index
            model.setCurrWidthIdx(idxList)
            # forward
            logits = modelParallel(input)
            # calc loss
            loss = crit(logits, target)
            # add to loss list
            lossList.append(loss)
            # update training stats
            trainStats.update(widthRatio, logits, target, loss)

        return lossList

    # performs single epoch of model weights training
    def weightsEpoch(self, optimizer, epoch, loggers) -> EpochData:
        # print('*** weightsEpoch() ***')
        model = self.getModel()
        modelParallel = self.getModelParallel()

        modelParallel.train()
        assert (model.training is True)

        def forwardFunc(input, target, trainStats):
            # optimize model weights
            optimizer.zero_grad()
            # forward
            lossList = self._slimForward(input, target, trainStats)
            # back propagate
            for loss in lossList:
                loss.backward()
            # update weights
            optimizer.step()

        tableTitle = 'Epoch:[{}] - Training weights'.format(epoch)
        forwardCountersTitle = '{} - Training'.format(self.forwardCountersKey)
        return self._genericEpoch(forwardFunc, self.getTrainQueue(), loggers,
                                  self.trainLossKey, self.trainAccKey,
                                  tableTitle, self.colsTrainWeights,
                                  forwardCountersTitle)

    # performs single epoch of model inference
    def inferEpoch(self, nEpoch, loggers) -> EpochData:
        print('*** inferEpoch() ***')
        model = self.getModel()
        modelParallel = self.getModelParallel()

        modelParallel.eval()
        assert (model.training is False)

        def forwardFunc(input, target, trainStats):
            with no_grad():
                self._slimForward(input, target, trainStats)

        tableTitle = 'Epoch:[{}] - Validation'.format(nEpoch)
        forwardCountersTitle = '{} - Validation'.format(
            self.forwardCountersKey)
        return self._genericEpoch(forwardFunc, self.getValidQueue(), loggers,
                                  self.validLossKey, self.validAccKey,
                                  tableTitle, self.colsValidation,
                                  forwardCountersTitle)

    def _initOptimizer(self):
        modelParallel = self.getModelParallel()
        args = self.getArgs()

        optimizer = SGD(modelParallel.parameters(),
                        args.learning_rate,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)
        # load optimizer pre-trained state dict if exists
        if self.optimizerStateDict:
            optimizer.load_state_dict(self.optimizerStateDict)

        return optimizer

    def train(self, trainFolderName='init_weights_train'):
        args = self.getArgs()

        # create train folder
        folderPath = '{}/{}'.format(self.getTrainFolderPath(), trainFolderName)
        if not exists(folderPath):
            makedirs(folderPath)

        # init optimizer
        optimizer = self._initOptimizer()
        # init scheduler
        scheduler = ReduceLROnPlateau(optimizer,
                                      mode='min',
                                      factor=0.95,
                                      patience=args.weights_patience,
                                      min_lr=args.learning_rate_min)

        epoch = 0
        trainLoggerFlag = True

        while not self.stopCondition(epoch):
            # update epoch number
            epoch += 1
            # init train logger
            trainLogger = None
            if trainLoggerFlag:
                trainLogger = HtmlLogger(folderPath, epoch)
                trainLogger.addInfoTable('Learning rates', [[
                    'optimizer_lr', self.formats[self.lrKey](
                        optimizer.param_groups[0]['lr'])
                ]])

            # update train logger condition for next epoch
            trainLoggerFlag = ((epoch + 1) % args.logInterval) == 0

            # set loggers dictionary
            loggersDict = {self.trainLoggerKey: trainLogger}

            print('========== Epoch:[{}] =============='.format(epoch))
            # train
            trainData = self.weightsEpoch(optimizer, epoch, loggersDict)
            # validation
            validData = self.inferEpoch(epoch, loggersDict)

            # update scheduler
            scheduler.step(self.schedulerMetric(validData.lossDict()))

            self.postEpoch(epoch, optimizer, trainData, validData)

        self.postTrain()

    @staticmethod
    def loadPreTrained(model: BaseNet, path: str, logger: HtmlLogger) -> dict:
        optimizerStateDict = None

        if path is not None:
            if exists(path):
                # load checkpoint
                checkpoint = loadModel(
                    path, map_location=lambda storage, loc: storage.cuda())
                # load weights
                model.loadPreTrained(checkpoint['state_dict'])
                # # load optimizer state dict
                # optimizerStateDict = checkpoint['optimizer']
                # add info rows about checkpoint
                loggerRows = []
                loggerRows.append(['Path', '{}'.format(path)])
                validationAccRows = [[
                    'Ratio', 'Accuracy'
                ]] + HtmlLogger.dictToRows(checkpoint['best_prec1'],
                                           nElementPerRow=1)
                loggerRows.append(['Validation accuracy', validationAccRows])
                # set optimizer table row
                optimizerRow = HtmlLogger.dictToRows(
                    optimizerStateDict, nElementPerRow=3
                ) if optimizerStateDict else optimizerStateDict
                loggerRows.append(['Optimizer', optimizerRow])
                logger.addInfoTable('Pre-trained model', loggerRows)
            else:
                raise ValueError(
                    'Failed to load pre-trained from [{}], path does not exists'
                    .format(path))

        return optimizerStateDict
예제 #8
0
import torch.backends.cudnn as cudnn
from torch.cuda import is_available, set_device
from torch.cuda import manual_seed as cuda_manual_seed
from torch import manual_seed as torch_manual_seed

from trainRegimes.PreTrainedRegime import PreTrainedRegime
from utils.HtmlLogger import HtmlLogger
from utils.emails import sendEmail
from utils.args import parseArgs

if __name__ == '__main__':
    # load command line arguments
    args = parseArgs()
    # init main logger
    logger = HtmlLogger(args.save, 'log')

    if not is_available():
        print('no gpu device available')
        exit(1)

    args.seed = datetime.now().microsecond
    nprandom.seed(args.seed)
    set_device(args.gpu[0])
    cudnn.benchmark = True
    torch_manual_seed(args.seed)
    cudnn.enabled = True
    cuda_manual_seed(args.seed)

    try:
        # build regime for alphas optimization
예제 #9
0
    def printToFile(self, saveFolder):
        fileName = 'model'

        logger = HtmlLogger(saveFolder, fileName)
        if exists(logger.fullPath):
            return

        logger.setMaxTableCellLength(1000)

        layerIdxKey = 'Layer#'
        nFiltersKey = 'Filters#'
        widthsKey = 'Width'
        layerArchKey = 'Layer Architecture'

        logger.createDataTable(
            'Model architecture',
            [layerIdxKey, nFiltersKey, widthsKey, layerArchKey])
        for layerIdx, layer in enumerate(self._layers.flops()):
            widths = layer.widthList()

            dataRow = {
                layerIdxKey: layerIdx,
                nFiltersKey: layer.outputChannels(),
                widthsKey: [widths],
                layerArchKey: layer
            }
            logger.addDataRow(dataRow)

        layerIdx += 1
        # log additional layers, like Linear, MaxPool2d, AvgPool2d
        for layer in self.additionalLayersToLog():
            dataRow = {layerIdxKey: layerIdx, layerArchKey: layer}
            logger.addDataRow(dataRow)
            layerIdx += 1

        # log layers alphas distribution
        self.logTopAlphas(len(widths),
                          loggerFuncs=[
                              lambda k, rows: logger.addInfoTable(
                                  self._alphasDistributionKey, rows)
                          ],
                          logLayer=True)
        # reset table max cell length
        logger.resetMaxTableCellLength()