Ejemplo n.º 1
0
#\\\ Determine processing unit:
if useGPU and torch.cuda.is_available():
    device = 'cuda:0'
    torch.cuda.empty_cache()
else:
    device = 'cpu'
# Notify:
if doPrint:
    print("Device selected: %s" % device)

#\\\ Logging options
if doLogging:
    # If logging is on, load the tensorboard visualizer and initialize it
    from Utils.visualTools import Visualizer
    logsTB = os.path.join(saveDir, 'logsTB')
    logger = Visualizer(logsTB, name='visualResults')

#\\\ Save variables during evaluation.
# We will save all the evaluations obtained for each of the trained models.
# It basically is a dictionary, containing a list. The key of the
# dictionary determines the model, then the first list index determines
# which split realization. Then, this will be converted to numpy to compute
# mean and standard deviation (across the split dimension).
accBest = {}  # Accuracy for the best model
accLast = {}  # Accuracy for the last model
for thisModel in modelList:  # Create an element for each split realization,
    accBest[thisModel] = [None] * nDataSplits
    accLast[thisModel] = [None] * nDataSplits

####################
# TRAINING OPTIONS #
Ejemplo n.º 2
0
    def train(self, data, nEpochs, batchSize, **kwargs):

        ####################################
        # ARGUMENTS (Store chosen options) #
        ####################################

        # Training Options:
        if 'doLogging' in kwargs.keys():
            doLogging = kwargs['doLogging']
        else:
            doLogging = False

        if 'doSaveVars' in kwargs.keys():
            doSaveVars = kwargs['doSaveVars']
        else:
            doSaveVars = True

        if 'printInterval' in kwargs.keys():
            printInterval = kwargs['printInterval']
            if printInterval > 0:
                doPrint = True
            else:
                doPrint = False
        else:
            doPrint = True
            printInterval = (data.nTrain // batchSize) // 5

        if 'learningRateDecayRate' in kwargs.keys() and \
            'learningRateDecayPeriod' in kwargs.keys():
            doLearningRateDecay = True
            learningRateDecayRate = kwargs['learningRateDecayRate']
            learningRateDecayPeriod = kwargs['learningRateDecayPeriod']
        else:
            doLearningRateDecay = False

        if 'validationInterval' in kwargs.keys():
            validationInterval = kwargs['validationInterval']
        else:
            validationInterval = data.nTrain // batchSize

        if 'earlyStoppingLag' in kwargs.keys():
            doEarlyStopping = True
            earlyStoppingLag = kwargs['earlyStoppingLag']
        else:
            doEarlyStopping = False
            earlyStoppingLag = 0

        # No training case:
        if nEpochs == 0:
            doSaveVars = False
            doLogging = False
            # If there's no training happening, there's nothing to report about
            # training losses and stuff.

        if doLogging:
            from Utils.visualTools import Visualizer
            logsTB = os.path.join(self.saveDir, self.name + '-logsTB')
            logger = Visualizer(logsTB, name='visualResults')

        ###########################################
        # DATA INPUT (pick up on data parameters) #
        ###########################################

        nTrain = data.nTrain  # size of the training set

        # Number of batches: If the desired number of batches does not split the
        # dataset evenly, we reduce the size of the last batch (the number of
        # samples in the last batch).
        # The variable batchSize is a list of length nBatches (number of
        # batches), where each element of the list is a number indicating the
        # size of the corresponding batch.
        if nTrain < batchSize:
            nBatches = 1
            batchSize = [nTrain]
        elif nTrain % batchSize != 0:
            nBatches = np.ceil(nTrain / batchSize).astype(np.int64)
            batchSize = [batchSize] * nBatches
            # If the sum of all batches so far is not the total number of
            # graphs, start taking away samples from the last batch (remember
            # that we used ceiling, so we are overshooting with the estimated
            # number of batches)
            while sum(batchSize) != nTrain:
                batchSize[-1] -= 1
        # If they fit evenly, then just do so.
        else:
            nBatches = np.int(nTrain / batchSize)
            batchSize = [batchSize] * nBatches
        # batchIndex is used to determine the first and last element of each
        # batch.
        # If batchSize is, for example [20,20,20] meaning that there are three
        # batches of size 20 each, then cumsum will give [20,40,60] which
        # determines the last index of each batch: up to 20, from 20 to 40, and
        # from 40 to 60. We add the 0 at the beginning so that
        # batchIndex[b]:batchIndex[b+1] gives the right samples for batch b.
        batchIndex = np.cumsum(batchSize).tolist()
        batchIndex = [0] + batchIndex

        ###################
        # SAVE ATTRIBUTES #
        ###################

        self.trainingOptions = {}
        self.trainingOptions['doLogging'] = doLogging
        self.trainingOptions['logger'] = logger
        self.trainingOptions['doSaveVars'] = doSaveVars
        self.trainingOptions['doPrint'] = printInterval
        self.trainingOptions['printInterval'] = printInterval
        self.trainingOptions['doLearningRateDecay'] = doLearningRateDecay
        if doLearningRateDecay:
            self.trainingOptions['learningRateDecayRate'] = \
                                                         learningRateDecayRate
            self.trainingOptions['learningRateDecayPeriod'] = \
                                                         learningRateDecayPeriod
        self.trainingOptions['validationInterval'] = validationInterval
        self.trainingOptions['doEarlyStopping'] = doEarlyStopping
        self.trainingOptions['earlyStoppingLag'] = earlyStoppingLag

        ##############
        # TRAINING   #
        ##############

        # Learning rate scheduler:
        if doLearningRateDecay:
            learningRateScheduler = torch.optim.lr_scheduler.StepLR(
                self.optim, learningRateDecayPeriod, learningRateDecayRate)

        # Initialize counters (since we give the possibility of early stopping,
        # we had to drop the 'for' and use a 'while' instead):
        epoch = 0  # epoch counter
        lagCount = 0  # lag counter for early stopping

        if doSaveVars:
            lossTrain = []
            evalTrain = []
            lossValid = []
            evalValid = []

        while epoch < nEpochs \
                    and (lagCount < earlyStoppingLag or (not doEarlyStopping)):
            # The condition will be zero (stop), whenever one of the items of
            # the 'and' is zero. Therefore, we want this to stop only for epoch
            # counting when we are NOT doing early stopping. This can be
            # achieved if the second element of the 'and' is always 1 (so that
            # the first element, the epoch counting, decides). In order to
            # force the second element to be one whenever there is not early
            # stopping, we have an or, and force it to one. So, when we are not
            # doing early stopping, the variable 'not doEarlyStopping' is 1,
            # and the result of the 'or' is 1 regardless of the lagCount. When
            # we do early stopping, then the variable 'not doEarlyStopping' is
            # 0, and the value 1 for the 'or' gate is determined by the lag
            # count.
            # ALTERNATIVELY, we could just keep 'and lagCount<earlyStoppingLag'
            # and be sure that lagCount can only be increased whenever
            # doEarlyStopping is True. But I somehow figured out that would be
            # harder to maintain (more parts of the code to check if we are
            # accidentally increasing lagCount).

            # Randomize dataset for each epoch
            randomPermutation = np.random.permutation(nTrain)
            # Convert a numpy.array of numpy.int into a list of actual int.
            idxEpoch = [int(i) for i in randomPermutation]

            # Learning decay
            if doLearningRateDecay:
                learningRateScheduler.step()

                if doPrint:
                    # All the optimization have the same learning rate, so just
                    # print one of them
                    # TODO: Actually, they might be different, so I will need to
                    # print all of them.
                    print("Epoch %d, learning rate = %.8f" %
                          (epoch + 1,
                           learningRateScheduler.optim.param_groups[0]['lr']))

            # Initialize counter
            batch = 0  # batch counter
            while batch < nBatches \
                        and (lagCount<earlyStoppingLag or (not doEarlyStopping)):

                # Extract the adequate batch
                thisBatchIndices = idxEpoch[batchIndex[batch]:batchIndex[batch
                                                                         + 1]]
                # Get the samples
                xTrain, yTrain = data.getSamples('train', thisBatchIndices)
                xTrain = xTrain.unsqueeze(1)  # To account for just F=1 feature

                # Set the ordering
                xTrainOrdered = xTrain[:, :, self.order]  # B x F x N

                # Reset gradients
                self.archit.zero_grad()

                # Obtain the output of the GNN
                yHatTrain = self.archit(xTrainOrdered)

                # Compute loss
                lossValueTrain = self.loss(yHatTrain, yTrain.type(torch.int64))

                # Compute gradients
                lossValueTrain.backward()

                # Optimize
                self.optim.step()

                # Compute the accuracy
                #   Note: Using yHatTrain.data creates a new tensor with the
                #   same value, but detaches it from the gradient, so that no
                #   gradient operation is taken into account here.
                #   (Alternatively, we could use a with torch.no_grad():)
                accTrain = data.get_results(yHatTrain.data, yTrain)

                # Logging values
                if doLogging:
                    lossTrainTB = lossValueTrain.item()
                    evalTrainTB = accTrain * 100
                # Save values
                if doSaveVars:
                    lossTrain += [lossValueTrain.item()]
                    evalTrain += [accTrain * 100]

                # Print:
                if doPrint:
                    if (epoch * nBatches + batch) % printInterval == 0:
                        print("(E: %2d, B: %3d) %6.4f / %6.2f%%" %
                              (epoch + 1, batch + 1, lossValueTrain.item(),
                               accTrain * 100))

                #\\\\\\\
                #\\\ TB LOGGING (for each batch)
                #\\\\\\\

                if doLogging:
                    logger.scalar_summary(mode='Training',
                                          epoch=epoch * nBatches + batch,
                                          **{
                                              'lossTrain': lossTrainTB,
                                              'evalTrain': evalTrainTB
                                          })

                #\\\\\\\
                #\\\ VALIDATION
                #\\\\\\\

                if (epoch * nBatches + batch) % validationInterval == 0:
                    # Validation:
                    xValid, yValid = data.getSamples('valid')
                    xValid = xValid.unsqueeze(1)  # Add the F dimension: BxFxN

                    # Set the ordering
                    xValidOrdered = xValid[:, :, self.order]  # BxFxN
                    # Under torch.no_grad() so that the computations carried out
                    # to obtain the validation accuracy are not taken into
                    # account to update the learnable parameters.
                    with torch.no_grad():
                        # Obtain the output of the GNN
                        yHatValid = self.archit(xValidOrdered)

                        # Compute loss
                        lossValueValid = self.loss(yHatValid,
                                                   yValid.type(torch.int64))

                        # Compute accuracy:
                        accValid = data.get_results(yHatValid, yValid)

                        # Logging values
                        if doLogging:
                            lossValidTB = lossValueValid.item()
                            evalValidTB = accValid * 100
                        # Save values
                        if doSaveVars:
                            lossValid += [lossValueValid.item()]
                            evalValid += [accValid * 100]

                    # Print:
                    if doPrint:
                        print("[VALIDATION] %6.4f / %6.2f%%" %
                              (lossValueValid.item(), accValid * 100))

                    if doLogging:
                        logger.scalar_summary(mode='Validation',
                                              epoch=epoch * nBatches + batch,
                                              **{
                                                  'lossValid': lossValidTB,
                                                  'evalValid': evalValidTB
                                              })

                    # No previous best option, so let's record the first trial
                    # as the best option
                    if epoch == 0 and batch == 0:
                        bestScore = accValid
                        bestEpoch, bestBatch = epoch, batch
                        # Save this model as the best (so far)
                        self.save(label='Best')
                        # Start the counter
                        if doEarlyStopping:
                            initialBest = True
                    else:
                        thisValidScore = accValid
                        if thisValidScore > bestScore:
                            bestScore = thisValidScore
                            bestEpoch, bestBatch = epoch, batch
                            if doPrint:
                                print("\t=> New best achieved: %.4f" % \
                                          (bestScore))
                            self.save(label='Best')
                            # Now that we have found a best that is not the
                            # initial one, we can start counting the lag (if
                            # needed)
                            initialBest = False
                            # If we achieved a new best, then we need to reset
                            # the lag count.
                            if doEarlyStopping:
                                lagCount = 0
                        # If we didn't achieve a new best, increase the lag
                        # count.
                        # Unless it was the initial best, in which case we
                        # haven't found any best yet, so we shouldn't be doing
                        # the early stopping count.
                        elif doEarlyStopping and not initialBest:
                            lagCount += 1

                #\\\\\\\
                #\\\ END OF BATCH:
                #\\\\\\\

                #\\\ Increase batch count:
                batch += 1

            #\\\\\\\
            #\\\ END OF EPOCH:
            #\\\\\\\

            #\\\ Save models:
            self.save(label='Last')

            #\\\ Increase epoch count:
            epoch += 1

        #################
        # TRAINING OVER #
        #################

        if doSaveVars:
            # We convert the lists into np.arrays to be handled by both
            # Matlab(R) and matplotlib
            self.lossTrain = np.array(lossTrain)
            self.evalTrain = np.array(evalTrain)
            self.lossValid = np.array(lossValid)
            self.evalValid = np.array(evalValid)
            # And we would like to save all the relevant information from
            # training
            saveDirVars = os.path.join(self.saveDir, self.name + '-trainVars')
            if not os.path.exists(saveDirVars):
                os.makedirs(saveDirVars)
            pathToFile = os.path.join(saveDirVars, 'trainVars.pkl')
            with open(pathToFile, 'wb') as trainVarsFile:
                pickle.dump(
                    {
                        'nEpochs': nEpochs,
                        'nBatches': nBatches,
                        'batchSize': np.array(batchSize),
                        'batchIndex': np.array(batchIndex),
                        'lossTrain': lossTrain,
                        'evalTrain': evalTrain,
                        'lossValid': lossValid,
                        'evalValid': evalValid
                    }, trainVarsFile)
            # And because of the SP background, why not save it in matlab too?
            pathToMat = os.path.join(saveDirVars, 'trainVars.mat')
            varsMatlab = {}
            varsMatlab['nEpochs'] = nEpochs
            varsMatlab['nBatches'] = nBatches
            varsMatlab['batchSize'] = np.array(batchSize)
            varsMatlab['batchIndex'] = np.array(batchIndex)
            varsMatlab['lossTrain'] = self.lossTrain
            varsMatlab['evalTrain'] = self.evalTrain
            varsMatlab['lossValid'] = self.lossValid
            varsMatlab['evalValid'] = self.evalValid
            savemat(pathToMat, varsMatlab)

        # Once the training is over, we need to check if it is due to early
        # stopping or not. Because if it is, we still need to save the last
        # (which didn't) happen at the end of epoch
        if doEarlyStopping and lagCount == earlyStoppingLag and nEpochs > 0:
            #\\\ Save models:
            self.save(label='Last')

        # Now, if we didn't do any training (i.e. nEpochs = 0), then the last is
        # also the best.
        if nEpochs == 0:
            self.save(label='Best')
            self.save(label='Last')
            if doPrint:
                print(
                    "WARNING: No training. Best and Last models are the same.")

        # After training is done, reload best model before proceeding to
        # evaluation:
        self.load(label='Best')
        self.bestBatch = bestBatch
        self.bestEpoch = bestEpoch

        #\\\ Print out best:
        if doPrint and nEpochs > 0:
            print("=> Best validation achieved (E: %d, B: %d): %.4f" %
                  (bestEpoch + 1, bestBatch + 1, bestScore))