Esempio n. 1
0
                print("Start evaluate: ", PATH_TO_DATA_SET_CATELOG,
                      ", and find the best threshold...")
                loss, frameAccuracy, threshold, videoAccuracy = evaluator.Evaluate(
                    session, currentEpoch_=0, threshold_=None)
                endEvaluateTime = time.time()
                PrintResults(loss_=loss,
                             frameAccuracy_=frameAccuracy,
                             isThresholdOptimized_=True,
                             threshold_=threshold,
                             videoAccuracy_=videoAccuracy,
                             duration_=(endEvaluateTime - startEvaluateTime))

            else:
                threshold = int(sys.argv[2])
                print("Start evaluate: ", PATH_TO_DATA_SET_CATELOG,
                      ", with threshold : ", threshold)
                loss, frameAccuracy, threshold, videoAccuracy = evaluator.Evaluate(
                    session, currentEpoch_=0, threshold_=threshold)
                endEvaluateTime = time.time()
                PrintResults(loss_=loss,
                             frameAccuracy_=frameAccuracy,
                             isThresholdOptimized_=False,
                             threshold_=threshold,
                             videoAccuracy_=videoAccuracy,
                             duration_=(endEvaluateTime - startEvaluateTime))

        evaluator.Release()

    else:
        PrintHelp()
Esempio n. 2
0
class Main:
    def __init__(self):
        classifier = Classifier()
        classifier.Build()

        # Trainer, Evaluator
        print("Reading Training set...")
        self.trainer = Trainer(classifier)
        self.trainEvaluator = Evaluator("train",
                                        dataSettings.PATH_TO_TRAIN_SET_CATELOG,
                                        classifier)
        print("\t Done.\n")

        print("Reading Validation set...")
        self.validationEvaluator = Evaluator(
            "validation", dataSettings.PATH_TO_VAL_SET_CATELOG, classifier)
        print("\t Done.\n")

        print("Reading Test set...")
        self.testEvaluator = Evaluator("test",
                                       dataSettings.PATH_TO_TEST_SET_CATELOG,
                                       classifier)
        print("\t Done.\n")

        # Summary
        summaryOp = tf.summary.merge_all()
        self.trainer.SetMergedSummaryOp(summaryOp)
        self.trainEvaluator.SetMergedSummaryOp(summaryOp)
        self.validationEvaluator.SetMergedSummaryOp(summaryOp)
        self.bestThreshold = None
        self.testEvaluator.SetMergedSummaryOp(summaryOp)

        # Time
        self._startTrainEpochTime = time.time()
        self._trainCountInOneEpoch = 0

        # Saver
        self.modelSaver = tf.train.Saver(
            max_to_keep=trainSettings.MAX_TRAINING_SAVE_MODEL)

        # Session
        self.session = tf.Session()
        init = tf.global_variables_initializer()
        self.session.run(init)

        self.trainer.SetGraph(self.session.graph)
        self.validationEvaluator.SetGraph(self.session.graph)

    def __del__(self):
        self.session.close()

    def Run(self):
        self.recoverFromPretrainModelIfRequired()

        self.calculateValidationBeforeTraining()
        self.resetTimeMeasureVariables()

        print("Path to save mode: ", trainSettings.PATH_TO_SAVE_MODEL)
        print("\nStart Training...\n")

        while self.trainer.currentEpoch < trainSettings.MAX_TRAINING_EPOCH:
            self.trainer.PrepareNewBatchData()
            self.trainer.Train(self.session)
            self._trainCountInOneEpoch += 1

            if self.trainer.isNewEpoch:
                print(
                    "Epoch:", self.trainer.currentEpoch,
                    "======================================" +
                    "======================================" +
                    "======================================")

                self.printTimeMeasurement()
                self.trainer.PauseDataLoading()

                self.evaluateValidationSetAndPrint(self.trainer.currentEpoch)
                self.evaluateTrainingSetAndPrint(self.trainer.currentEpoch)

                if trainSettings.PERFORM_DATA_AUGMENTATION:
                    # Preload TrainBatch while evaluate the TestSet
                    self.trainer.ContinueDataLoading()

                self.evaluateTestSetAndPrint(self.trainer.currentEpoch)

                self.trainer.ContinueDataLoading()

                self.resetTimeMeasureVariables()

                if self.trainer.currentEpoch >= trainSettings.EPOCHS_TO_START_SAVE_MODEL:
                    self.saveCheckpoint(self.trainer.currentEpoch)
        print("Optimization finished.")
        self.trainer.Release()
        self.trainEvaluator.Release()
        self.validationEvaluator.Release()
        self.testEvaluator.Release()

    def recoverFromPretrainModelIfRequired(self):
        if trainSettings.PRETRAIN_MODEL_PATH_NAME != "":
            print("Load Pretrain model from: " +
                  trainSettings.PRETRAIN_MODEL_PATH_NAME)
            listOfAllVariables = tf.get_collection(
                tf.GraphKeys.GLOBAL_VARIABLES)
            variablesToBeRecovered = [ eachVariable for eachVariable in listOfAllVariables \
                  if eachVariable.name.split('/')[0] not in \
                  trainSettings.NAME_SCOPES_NOT_TO_RECOVER_FROM_CHECKPOINT ]
            modelLoader = tf.train.Saver(variablesToBeRecovered)
            modelLoader.restore(self.session,
                                trainSettings.PRETRAIN_MODEL_PATH_NAME)

    def evaluateTrainingSetAndPrint(self, currentEpoch_):
        '''
		    Since the BATCH_SIZE may be small (= 4 in my case), its BatchLoss or BatchAccuracy
		    may be fluctuated.  Calculate the whole Training Loss instead.
		    Note: If one want to calculate the BatchLoss ONLY, use Trainer.EvaluateTrainLoss().
		'''
        startEvaluateTime = time.time()
        loss, frameAccuracy, threshold, videoAccuracy = self.trainEvaluator.Evaluate(
            self.session,
            currentEpoch_=currentEpoch_,
            threshold_=self.bestThreshold)
        endEvaluateTime = time.time()

        self.printCalculationResults(jobType_='train',
                                     loss_=loss,
                                     frameAccuracy_=frameAccuracy,
                                     isThresholdOptimized_=False,
                                     threshold_=threshold,
                                     videoAccuracy_=videoAccuracy,
                                     duration_=(endEvaluateTime -
                                                startEvaluateTime))

    def calculateValidationBeforeTraining(self):
        if trainSettings.PRETRAIN_MODEL_PATH_NAME != "":
            print(
                "Validation before Training ",
                "=============================" +
                "======================================" +
                "======================================")
            self.evaluateValidationSetAndPrint(currentEpoch_=0)

    def evaluateValidationSetAndPrint(self, currentEpoch_):
        startEvaluateTime = time.time()
        loss, frameAccuracy, threshold, videoAccuracy = self.validationEvaluator.Evaluate(
            self.session, currentEpoch_=currentEpoch_, threshold_=None)
        endEvaluateTime = time.time()

        self.bestThreshold = threshold
        self.printCalculationResults(jobType_='validation',
                                     loss_=loss,
                                     frameAccuracy_=frameAccuracy,
                                     isThresholdOptimized_=True,
                                     threshold_=threshold,
                                     videoAccuracy_=videoAccuracy,
                                     duration_=(endEvaluateTime -
                                                startEvaluateTime))

    def evaluateTestSetAndPrint(self, currentEpoch_):
        startEvaluateTime = time.time()
        loss, frameAccuracy, threshold, videoAccuracy = self.testEvaluator.Evaluate(
            self.session,
            currentEpoch_=currentEpoch_,
            threshold_=self.bestThreshold)
        endEvaluateTime = time.time()

        self.printCalculationResults(jobType_='test',
                                     loss_=loss,
                                     frameAccuracy_=frameAccuracy,
                                     isThresholdOptimized_=False,
                                     threshold_=threshold,
                                     videoAccuracy_=videoAccuracy,
                                     duration_=(endEvaluateTime -
                                                startEvaluateTime))

    def printTimeMeasurement(self):
        timeForTrainOneEpoch = time.time() - self._startTrainEpochTime
        print("\t Back Propergation time measurement:")
        print("\t\t duration: ", "{0:.2f}".format(timeForTrainOneEpoch),
              "s/epoch")
        averagedTrainTime = timeForTrainOneEpoch / self._trainCountInOneEpoch
        print("\t\t average: ", "{0:.2f}".format(averagedTrainTime), "s/batch")
        print()

    def resetTimeMeasureVariables(self):
        self._startTrainEpochTime = time.time()
        self._trainCountInOneEpoch = 0

    def printCalculationResults(self, jobType_, loss_, frameAccuracy_,
                                isThresholdOptimized_, threshold_,
                                videoAccuracy_, duration_):
        floatPrecision = "{0:.4f}"
        print("\t " + jobType_ + ":")
        if isThresholdOptimized_:
            print("\t     loss:",
                  floatPrecision.format(loss_), "     frame accuracy:",
                  floatPrecision.format(frameAccuracy_),
                  "     best frame threshold:",
                  threshold_, "     video accuracy:",
                  floatPrecision.format(videoAccuracy_), "     duration:",
                  "{0:.2f}".format(duration_) + "(s)\n")
        else:
            print("\t     loss:",
                  floatPrecision.format(loss_), "     frame accuracy:",
                  floatPrecision.format(frameAccuracy_),
                  "     given frame threshold:",
                  threshold_, "     video accuracy:",
                  floatPrecision.format(videoAccuracy_), "     duration:",
                  "{0:.2f}".format(duration_) + "(s)\n")

    def saveCheckpoint(self, currentEpoch_):
        pathToSaveCheckpoint = os.path.join(trainSettings.PATH_TO_SAVE_MODEL,
                                            "save_epoch_" + str(currentEpoch_))
        checkpointPathFileName = os.path.join(pathToSaveCheckpoint,
                                              "ViolenceNet.ckpt")
        self.modelSaver.save(self.session, checkpointPathFileName)