Exemple #1
0
 def run_session(self, *args):
     (sess_device,
      model_params,) = args
     
     graphTf = tf.Graph()
     
     with graphTf.as_default():
         with graphTf.device(sess_device): # Throws an error if GPU is specified but not available.
             self._log.print3("=========== Making the CNN graph... ===============")
             cnn3d = Cnn3d()
             with tf.variable_scope("net"):
                 cnn3d.make_cnn_model( *model_params.get_args_for_arch() ) # Creates the network's graph (without optimizer).
                 
         self._log.print3("=========== Compiling the Testing Function ============")
         self._log.print3("=======================================================\n")
         
         cnn3d.setup_ops_n_feeds_to_test( self._log,
                                          self._params.indices_fms_per_pathtype_per_layer_to_save )
         # Create the saver
         saver_all = tf.train.Saver() # saver_net would suffice
         
     with tf.Session( graph=graphTf, config=tf.ConfigProto(log_device_placement=False, device_count={'CPU':999, 'GPU':99}) ) as sessionTf:
         file_to_load_params_from = self._params.get_path_to_load_model_from()
         if file_to_load_params_from is not None: # Load params
             self._log.print3("=========== Loading parameters from specified saved model ===============")
             chkpt_fname = tf.train.latest_checkpoint( file_to_load_params_from ) if os.path.isdir( file_to_load_params_from ) else file_to_load_params_from
             self._log.print3("Loading parameters from:" + str(chkpt_fname))
             try:
                 saver_all.restore(sessionTf, chkpt_fname)
                 self._log.print3("Parameters were loaded.")
             except Exception as e: handle_exception_tf_restore(self._log, e)
             
         else:
             self._ask_user_if_test_with_random() # Asks user whether to continue with randomly initialized model. It exits if no is given.
             self._log.print3("")
             self._log.print3("=========== Initializing network variables  ===============")
             tf.variables_initializer( var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="net") ).run()
             self._log.print3("Model variables were initialized.")
             
             
         self._log.print3("")
         self._log.print3("======================================================")
         self._log.print3("=========== Testing with the CNN model ===============")
         self._log.print3("======================================================\n")
         
         performInferenceOnWholeVolumes( *( [sessionTf, cnn3d] + self._params.get_args_for_testing() ) )
     
     self._log.print3("")
     self._log.print3("======================================================")
     self._log.print3("=========== Testing session finished =================")
     self._log.print3("======================================================")
Exemple #2
0
def do_training(myLogger,
                fileToSaveTrainedCnnModelTo,
                cnn3dInstance,
                
                performValidationOnSamplesDuringTrainingProcessBool, #REQUIRED FOR AUTO SCHEDULE.
                savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation,
                
                listOfNamesToGiveToPredictionsValidationIfSavingWhenEvalDice,
                
                listOfFilepathsToEachChannelOfEachPatientTraining,
                listOfFilepathsToEachChannelOfEachPatientValidation,
                
                listOfFilepathsToGtLabelsOfEachPatientTraining,
                providedGtForValidationBool,
                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                
                providedWeightMapsToSampleForEachCategoryTraining,
                forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining,
                providedWeightMapsToSampleForEachCategoryValidation,
                forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation,
                
                providedRoiMaskForTrainingBool,
                listOfFilepathsToRoiMaskOfEachPatientTraining, # Also needed for normalization-augmentation
                providedRoiMaskForValidationBool,
                listOfFilepathsToRoiMaskOfEachPatientValidation,
                
                borrowFlag,
                n_epochs, # Every epoch the CNN model is saved.
                number_of_subepochs, # per epoch. Every subepoch Accuracy is reported
                maxNumSubjectsLoadedPerSubepoch,  # Max num of cases loaded every subepoch for segments extraction. The more, the longer loading.
                imagePartsLoadedInGpuPerSubepoch,
                imagePartsLoadedInGpuPerSubepochValidation,
                
                #-------Sampling Type---------
                samplingTypeInstanceTraining, # Instance of the deepmedic/samplingType.SamplingType class for training and validation
                samplingTypeInstanceValidation,
                
                #-------Preprocessing-----------
                padInputImagesBool,
                smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage,
                #-------Data Augmentation-------
                normAugmNone0OnImages1OrSegments2AlreadyNormalized1SubtrUpToPropOfStdAndDivideWithUpToPerc,
                reflectImageWithHalfProbDuringTraining,
                
                useSameSubChannelsAsSingleScale,
                
                listOfFilepathsToEachSubsampledChannelOfEachPatientTraining, # deprecated, not supported
                listOfFilepathsToEachSubsampledChannelOfEachPatientValidation, # deprecated, not supported
                
                #Learning Rate Schedule:
                lowerLrByStable0orAuto1orPredefined2orExponential3Schedule,
                minIncreaseInValidationAccuracyConsideredForLrSchedule,
                numEpochsToWaitBeforeLowerLR,
                divideLrBy,
                lowerLrAtTheEndOfTheseEpochsPredefinedScheduleList,
                exponentialScheduleForLrAndMom,
                
                #Weighting Classes differently in the CNN's cost function during training:
                numberOfEpochsToWeightTheClassesInTheCostFunction,
                
                performFullInferenceOnValidationImagesEveryFewEpochsBool, #Even if not providedGtForValidationBool, inference will be performed if this == True, to save the results, eg for visual.
                everyThatManyEpochsComputeDiceOnTheFullValidationImages=1, # Should not be == 0, except if performFullInferenceOnValidationImagesEveryFewEpochsBool == False
                
                #--------For FM visualisation---------
                saveIndividualFmImagesForVisualisation=False,
                saveMultidimensionalImageWithAllFms=False,
                indicesOfFmsToVisualisePerPathwayTypeAndPerLayer="placeholder",
                listOfNamesToGiveToFmVisualisationsIfSaving="placeholder"
                ):
    
    start_training_time = time.clock()
    
    # Used because I cannot pass cnn3dInstance to the sampling function.
    #This is because the parallel process then loads theano again. And creates problems in the GPU when cnmem is used.
    cnn3dWrapper = CnnWrapperForSampling(cnn3dInstance) 
    
    #---------To run PARALLEL the extraction of parts for the next subepoch---
    ppservers = () # tuple of all parallel python servers to connect with
    job_server = pp.Server(ncpus=1, ppservers=ppservers) # Creates jobserver with automatically detected number of workers
    
    tupleWithParametersForTraining = (myLogger,
                                    0,
                                    cnn3dWrapper,
                                    maxNumSubjectsLoadedPerSubepoch,
                                    
                                    imagePartsLoadedInGpuPerSubepoch,
                                    samplingTypeInstanceTraining,
                                    
                                    listOfFilepathsToEachChannelOfEachPatientTraining,
                                    
                                    listOfFilepathsToGtLabelsOfEachPatientTraining,
                                    
                                    providedRoiMaskForTrainingBool,
                                    listOfFilepathsToRoiMaskOfEachPatientTraining,
                                    
                                    providedWeightMapsToSampleForEachCategoryTraining,
                                    forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining,
                                    
                                    useSameSubChannelsAsSingleScale,
                                    
                                    listOfFilepathsToEachSubsampledChannelOfEachPatientTraining,
                                    
                                    padInputImagesBool,
                                    smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage,
                                    normAugmNone0OnImages1OrSegments2AlreadyNormalized1SubtrUpToPropOfStdAndDivideWithUpToPerc,
                                    reflectImageWithHalfProbDuringTraining
                                    )
    tupleWithParametersForValidation = (myLogger,
                                    1,
                                    cnn3dWrapper,
                                    maxNumSubjectsLoadedPerSubepoch,
                                    
                                    imagePartsLoadedInGpuPerSubepochValidation,
                                    samplingTypeInstanceValidation,
                                    
                                    listOfFilepathsToEachChannelOfEachPatientValidation,
                                    
                                    listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                                    
                                    providedRoiMaskForValidationBool,
                                    listOfFilepathsToRoiMaskOfEachPatientValidation,
                                    
                                    providedWeightMapsToSampleForEachCategoryValidation,
                                    forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation,
                                    
                                    useSameSubChannelsAsSingleScale,
                                    
                                    listOfFilepathsToEachSubsampledChannelOfEachPatientValidation,
                                    
                                    padInputImagesBool,
                                    smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage,
                                    [0, -1,-1,-1], #don't perform intensity-augmentation during validation.
                                    [0,0,0] #don't perform reflection-augmentation during validation.
                                    )
    tupleWithLocalFunctionsThatWillBeCalledByTheMainJob = ( )
    tupleWithModulesToImportWhichAreUsedByTheJobFunctions = ( "from __future__ import absolute_import, print_function, division", "from six.moves import xrange",
                "time", "numpy as np", "from deepmedic.dataManagement.sampling import *" )
    boolItIsTheVeryFirstSubepochOfThisProcess = True #to know so that in the very first I sequencially load the data for it.
    #------End for parallel------
    
    while cnn3dInstance.numberOfEpochsTrained < n_epochs :
        epoch = cnn3dInstance.numberOfEpochsTrained
        
        trainingAccuracyMonitorForEpoch = AccuracyOfEpochMonitorSegmentation(myLogger, 0, cnn3dInstance.numberOfEpochsTrained, cnn3dInstance.numberOfOutputClasses, number_of_subepochs)
        validationAccuracyMonitorForEpoch = None if not performValidationOnSamplesDuringTrainingProcessBool else \
                                        AccuracyOfEpochMonitorSegmentation(myLogger, 1, cnn3dInstance.numberOfEpochsTrained, cnn3dInstance.numberOfOutputClasses, number_of_subepochs ) 
                                        
        myLogger.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
        myLogger.print3("~~~~~~~~~~~~~~~~~~~~Starting new Epoch! Epoch #"+str(epoch)+"/"+str(n_epochs)+" ~~~~~~~~~~~~~~~~~~~~~~~~~")
        myLogger.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
        start_epoch_time = time.clock()
        
        for subepoch in xrange(number_of_subepochs): #per subepoch I randomly load some images in the gpu. Random order.
            myLogger.print3("**************************************************************************************************")
            myLogger.print3("************* Starting new Subepoch: #"+str(subepoch)+"/"+str(number_of_subepochs)+" *************")
            myLogger.print3("**************************************************************************************************")
            
            #-------------------------GET DATA FOR THIS SUBEPOCH's VALIDATION---------------------------------
            
            if performValidationOnSamplesDuringTrainingProcessBool :
                if boolItIsTheVeryFirstSubepochOfThisProcess :
                    [channsOfSegmentsForSubepPerPathwayVal,
                    labelsForCentralOfSegmentsForSubepVal] = getSampledDataAndLabelsForSubepoch(myLogger,
                                                                        1,
                                                                        cnn3dWrapper,
                                                                        maxNumSubjectsLoadedPerSubepoch,
                                                                        imagePartsLoadedInGpuPerSubepochValidation,
                                                                        samplingTypeInstanceValidation,
                                                                        
                                                                        listOfFilepathsToEachChannelOfEachPatientValidation,
                                                                        
                                                                        listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                                                                        
                                                                        providedRoiMaskForValidationBool,
                                                                        listOfFilepathsToRoiMaskOfEachPatientValidation,
                                                                        
                                                                        providedWeightMapsToSampleForEachCategoryValidation,
                                                                        forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation,
                                                                        
                                                                        useSameSubChannelsAsSingleScale,
                                                                        
                                                                        listOfFilepathsToEachSubsampledChannelOfEachPatientValidation,
                                                                        
                                                                        padInputImagesBool,
                                                                        smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage,
                                                                        normAugmNone0OnImages1OrSegments2AlreadyNormalized1SubtrUpToPropOfStdAndDivideWithUpToPerc=[0,-1,-1,-1],
                                                                        reflectImageWithHalfProbDuringTraining = [0,0,0]
                                                                        )
                    boolItIsTheVeryFirstSubepochOfThisProcess = False
                else : #It was done in parallel with the training of the previous epoch, just grab the results...
                    [channsOfSegmentsForSubepPerPathwayVal,
                    labelsForCentralOfSegmentsForSubepVal] = parallelJobToGetDataForNextValidation() #fromParallelProcessing that had started from last loop when it was submitted.
                    
                #------------------------------LOAD DATA FOR VALIDATION----------------------
                myLogger.print3("Loading Validation data for subepoch #"+str(subepoch)+" on shared variable...")
                start_loadingToGpu_time = time.clock()
                
                numberOfBatchesValidation = len(channsOfSegmentsForSubepPerPathwayVal[0]) // cnn3dInstance.batchSizeValidation #Computed with number of extracted samples, in case I dont manage to extract as many as I wanted initially.
                
                myLogger.print3("DEBUG: For Validation, loading to shared variable that many Segments: " + str(len(channsOfSegmentsForSubepPerPathwayVal[0])))
                
                cnn3dInstance.sharedInpXVal.set_value(channsOfSegmentsForSubepPerPathwayVal[0], borrow=borrowFlag) # Primary pathway
                for index in xrange(len(channsOfSegmentsForSubepPerPathwayVal[1:])) :
                    cnn3dInstance.sharedInpXPerSubsListVal[index].set_value(channsOfSegmentsForSubepPerPathwayVal[1+index], borrow=borrowFlag)
                cnn3dInstance.sharedLabelsYVal.set_value(labelsForCentralOfSegmentsForSubepVal, borrow=borrowFlag)
                channsOfSegmentsForSubepPerPathwayVal = ""
                labelsForCentralOfSegmentsForSubepVal = ""
                
                end_loadingToGpu_time = time.clock()
                myLogger.print3("TIMING: Loading sharedVariables for Validation in epoch|subepoch="+str(epoch)+"|"+str(subepoch)+" took time: "+str(end_loadingToGpu_time-start_loadingToGpu_time)+"(s)")
                
                
                #------------------------SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING-----------------
                #submit the parallel job
                myLogger.print3("PARALLEL: Before Validation in subepoch #" +str(subepoch) + ", the parallel job for extracting Segments for the next Training is submitted.")
                parallelJobToGetDataForNextTraining = job_server.submit(getSampledDataAndLabelsForSubepoch, #local function to call and execute in parallel.
                                                                        tupleWithParametersForTraining, #tuple with the arguments required
                                                                        tupleWithLocalFunctionsThatWillBeCalledByTheMainJob, #tuple of local functions that I need to call
                                                                        tupleWithModulesToImportWhichAreUsedByTheJobFunctions) #tuple of the external modules that I need, of which I am calling functions (not the mods of the ext-functions).
                
                #------------------------------------DO VALIDATION--------------------------------
                myLogger.print3("-V-V-V-V-V- Now Validating for this subepoch before commencing the training iterations... -V-V-V-V-V-")
                start_validationForSubepoch_time = time.clock()
                
                train0orValidation1 = 1 #validation
                vectorWithWeightsOfTheClassesForCostFunctionOfTraining = 'placeholder' #only used in training
                
                doTrainOrValidationOnBatchesAndReturnMeanAccuraciesOfSubepoch(myLogger,
                                                                            train0orValidation1,
                                                                            numberOfBatchesValidation, # Computed by the number of extracted samples. So, adapts.
                                                                            cnn3dInstance,
                                                                            vectorWithWeightsOfTheClassesForCostFunctionOfTraining,
                                                                            subepoch,
                                                                            validationAccuracyMonitorForEpoch)
                cnn3dInstance.freeGpuValidationData()
                
                end_validationForSubepoch_time = time.clock()
                myLogger.print3("TIMING: Validating on the batches of this subepoch #" + str(subepoch) + " took time: "+str(end_validationForSubepoch_time-start_validationForSubepoch_time)+"(s)")
                
                #Update cnn's top achieved validation accuracy if needed: (for the autoReduction of Learning Rate.)
                cnn3dInstance.checkMeanValidationAccOfLastEpochAndUpdateCnnsTopAccAchievedIfNeeded(myLogger,
                                                                                    validationAccuracyMonitorForEpoch.getMeanEmpiricalAccuracyOfEpoch(),
                                                                                    minIncreaseInValidationAccuracyConsideredForLrSchedule)
            #-------------------END OF THE VALIDATION-DURING-TRAINING-LOOP-------------------------
            
            
            #-------------------------GET DATA FOR THIS SUBEPOCH's TRAINING---------------------------------
            if (not performValidationOnSamplesDuringTrainingProcessBool) and boolItIsTheVeryFirstSubepochOfThisProcess :                    
                [channsOfSegmentsForSubepPerPathwayTrain,
                labelsForCentralOfSegmentsForSubepTrain] = getSampledDataAndLabelsForSubepoch(myLogger,
                                                                        0,
                                                                        cnn3dWrapper,
                                                                        maxNumSubjectsLoadedPerSubepoch,
                                                                        imagePartsLoadedInGpuPerSubepoch,
                                                                        samplingTypeInstanceTraining,
                                                                        
                                                                        listOfFilepathsToEachChannelOfEachPatientTraining,
                                                                        
                                                                        listOfFilepathsToGtLabelsOfEachPatientTraining,
                                                                        
                                                                        providedRoiMaskForTrainingBool,
                                                                        listOfFilepathsToRoiMaskOfEachPatientTraining,
                                                                        
                                                                        providedWeightMapsToSampleForEachCategoryTraining,
                                                                        forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining,
                                                                        
                                                                        useSameSubChannelsAsSingleScale,
                                                                        
                                                                        listOfFilepathsToEachSubsampledChannelOfEachPatientTraining,
                                                                        
                                                                        padInputImagesBool,
                                                                        smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage,
                                                                        normAugmNone0OnImages1OrSegments2AlreadyNormalized1SubtrUpToPropOfStdAndDivideWithUpToPerc,
                                                                        reflectImageWithHalfProbDuringTraining
                                                                        )
                boolItIsTheVeryFirstSubepochOfThisProcess = False
            else :
                #It was done in parallel with the validation (or with previous training iteration, in case I am not performing validation).
                [channsOfSegmentsForSubepPerPathwayTrain,
                labelsForCentralOfSegmentsForSubepTrain] = parallelJobToGetDataForNextTraining() #fromParallelProcessing that had started from last loop when it was submitted.
                
            #-------------------------COMPUTE CLASS-WEIGHTS, TO WEIGHT COST FUNCTION AND COUNTER CLASS IMBALANCE----------------------
            #Do it for only few epochs, until I get to an ok local minima neighbourhood.
            if cnn3dInstance.numberOfEpochsTrained < numberOfEpochsToWeightTheClassesInTheCostFunction :
                numOfPatchesInTheSubepoch_notParts = np.prod(labelsForCentralOfSegmentsForSubepTrain.shape)
                actualNumOfPatchesPerClassInTheSubepoch_notParts = np.bincount(np.ravel(labelsForCentralOfSegmentsForSubepTrain).astype(int))
                # yx - y1 = (x - x1) * (y2 - y1)/(x2 - x1)
                # yx = the multiplier I currently want, y1 = the multiplier at the begining, y2 = the multiplier at the end
                # x = current epoch, x1 = epoch where linear decrease starts, x2 = epoch where linear decrease ends
                y1 = (1./(actualNumOfPatchesPerClassInTheSubepoch_notParts+TINY_FLOAT)) * (numOfPatchesInTheSubepoch_notParts*1.0/cnn3dInstance.numberOfOutputClasses)
                y2 = 1.
                x1 = 0. * number_of_subepochs # linear decrease starts from epoch=0
                x2 = numberOfEpochsToWeightTheClassesInTheCostFunction * number_of_subepochs
                x = cnn3dInstance.numberOfEpochsTrained * number_of_subepochs + subepoch
                yx = (x - x1) * (y2 - y1)/(x2 - x1) + y1
                vectorWithWeightsOfTheClassesForCostFunctionOfTraining = np.asarray(yx, dtype="float32")
                myLogger.print3("UPDATE: [Weight of Classes] Setting the weights of the classes in the cost function to: " +str(vectorWithWeightsOfTheClassesForCostFunctionOfTraining))
            else :
                vectorWithWeightsOfTheClassesForCostFunctionOfTraining = np.ones(cnn3dInstance.numberOfOutputClasses, dtype='float32')
                
            #------------------- Learning Rate Schedule ------------------------
            # I must make a learning-rate-manager to encapsulate all these... Very ugly currently... All othere LR schedules are at the outer loop, per epoch.
            if (lowerLrByStable0orAuto1orPredefined2orExponential3Schedule == 4) :
                myLogger.print3("DEBUG: Going to change Learning Rate according to POLY schedule:")
                #newLearningRate = initLr * ( 1 - iter/max_iter) ^ power. Power = 0.9 in parsenet, which we validated to behave ok.
                currentIteration = cnn3dInstance.numberOfEpochsTrained * number_of_subepochs + subepoch
                max_iterations = n_epochs * number_of_subepochs
                newLearningRate = cnn3dInstance.initialLearningRate * pow( 1.0 - 1.0*currentIteration/max_iterations , 0.9)
                myLogger.print3("DEBUG: new learning rate was calculated: " +str(newLearningRate))
                cnn3dInstance.change_learning_rate_of_a_cnn(newLearningRate, myLogger)
                
            #----------------------------------LOAD TRAINING DATA ON GPU-------------------------------
            myLogger.print3("Loading Training data for subepoch #"+str(subepoch)+" on shared variable...")
            start_loadingToGpu_time = time.clock()
            
            numberOfBatchesTraining = len(channsOfSegmentsForSubepPerPathwayTrain[0]) // cnn3dInstance.batchSize #Computed with number of extracted samples, in case I dont manage to extract as many as I wanted initially.
            
            cnn3dInstance.sharedInpXTrain.set_value(channsOfSegmentsForSubepPerPathwayTrain[0], borrow=borrowFlag) # Primary pathway
            for index in xrange(len(channsOfSegmentsForSubepPerPathwayTrain[1:])) :
                cnn3dInstance.sharedInpXPerSubsListTrain[index].set_value(channsOfSegmentsForSubepPerPathwayTrain[1+index], borrow=borrowFlag)
            cnn3dInstance.sharedLabelsYTrain.set_value(labelsForCentralOfSegmentsForSubepTrain, borrow=borrowFlag)
            channsOfSegmentsForSubepPerPathwayTrain = ""
            labelsForCentralOfSegmentsForSubepTrain = ""
            
            end_loadingToGpu_time = time.clock()
            myLogger.print3("TIMING: Loading sharedVariables for Training in epoch|subepoch="+str(epoch)+"|"+str(subepoch)+" took time: "+str(end_loadingToGpu_time-start_loadingToGpu_time)+"(s)")
            
            
            #------------------------SUBMIT PARALLEL JOB TO GET VALIDATION/TRAINING DATA (if val is/not performed) FOR NEXT SUBEPOCH-----------------
            if performValidationOnSamplesDuringTrainingProcessBool :
                #submit the parallel job
                myLogger.print3("PARALLEL: Before Training in subepoch #" +str(subepoch) + ", submitting the parallel job for extracting Segments for the next Validation.")
                parallelJobToGetDataForNextValidation = job_server.submit(getSampledDataAndLabelsForSubepoch, #local function to call and execute in parallel.
                                                                            tupleWithParametersForValidation, #tuple with the arguments required
                                                                            tupleWithLocalFunctionsThatWillBeCalledByTheMainJob, #tuple of local functions that I need to call
                                                                            tupleWithModulesToImportWhichAreUsedByTheJobFunctions) #tuple of the external modules that I need, of which I am calling functions (not the mods of the ext-functions).
            else : #extract in parallel the samples for the next subepoch's training.
                myLogger.print3("PARALLEL: Before Training in subepoch #" +str(subepoch) + ", submitting the parallel job for extracting Segments for the next Training.")
                parallelJobToGetDataForNextTraining = job_server.submit(getSampledDataAndLabelsForSubepoch, #local function to call and execute in parallel.
                                                                            tupleWithParametersForTraining, #tuple with the arguments required
                                                                            tupleWithLocalFunctionsThatWillBeCalledByTheMainJob, #tuple of local functions that I need to call
                                                                            tupleWithModulesToImportWhichAreUsedByTheJobFunctions) #tuple of the external modules that I need, of which I am calling
                
            #-------------------------------START TRAINING IN BATCHES------------------------------
            myLogger.print3("-T-T-T-T-T- Now Training for this subepoch... This may take a few minutes... -T-T-T-T-T-")
            start_trainingForSubepoch_time = time.clock()
            
            train0orValidation1 = 0 #training
            doTrainOrValidationOnBatchesAndReturnMeanAccuraciesOfSubepoch(myLogger,
                                                                        train0orValidation1,
                                                                        numberOfBatchesTraining,
                                                                        cnn3dInstance,
                                                                        vectorWithWeightsOfTheClassesForCostFunctionOfTraining,
                                                                        subepoch,
                                                                        trainingAccuracyMonitorForEpoch)
            cnn3dInstance.freeGpuTrainingData()
            
            end_trainingForSubepoch_time = time.clock()
            myLogger.print3("TIMING: Training on the batches of this subepoch #" + str(subepoch) + " took time: "+str(end_trainingForSubepoch_time-start_trainingForSubepoch_time)+"(s)")
            
        myLogger.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" )
        myLogger.print3("~~~~~~~~~~~~~~~~~~ Epoch #" + str(epoch) + " finished. Reporting Accuracy over whole epoch. ~~~~~~~~~~~~~~~~~~" )
        myLogger.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" )
        
        if performValidationOnSamplesDuringTrainingProcessBool :
            validationAccuracyMonitorForEpoch.reportMeanAccyracyOfEpoch()
        trainingAccuracyMonitorForEpoch.reportMeanAccyracyOfEpoch()
        
        del trainingAccuracyMonitorForEpoch; del validationAccuracyMonitorForEpoch;
        
        #=======================Learning Rate Schedule.=========================
        if (lowerLrByStable0orAuto1orPredefined2orExponential3Schedule == 0) and (numEpochsToWaitBeforeLowerLR > 0) and (cnn3dInstance.numberOfEpochsTrained % numEpochsToWaitBeforeLowerLR)==0 :
            # STABLE LR SCHEDULE"
            myLogger.print3("DEBUG: Going to lower Learning Rate because of STABLE schedule! The CNN has now been trained for: " + str(cnn3dInstance.numberOfEpochsTrained) + " epochs. I need to decrease LR every: " + str(numEpochsToWaitBeforeLowerLR) + " epochs.")
            cnn3dInstance.divide_learning_rate_of_a_cnn_by(divideLrBy, myLogger)
        elif (lowerLrByStable0orAuto1orPredefined2orExponential3Schedule == 1) and (numEpochsToWaitBeforeLowerLR > 0) :
            # AUTO LR SCHEDULE!
            if not performValidationOnSamplesDuringTrainingProcessBool : #This flag should have been set True from the start if training should do Auto-schedule. If we get in here, this is a bug.
                myLogger.print3("ERROR: For Auto-schedule I need to be performing validation-on-samples during the training-process. The flag performValidationOnSamplesDuringTrainingProcessBool should have been set to True. Instead it seems it was False and no validation was performed. This is a bug. Contact the developer, this should not have happened. Try another Learning Rate schedule for now! Exiting.")
                exit(1)
            if (cnn3dInstance.numberOfEpochsTrained >= cnn3dInstance.topMeanValidationAccuracyAchievedInEpoch[1] + numEpochsToWaitBeforeLowerLR) and \
                    (cnn3dInstance.numberOfEpochsTrained >= cnn3dInstance.lastEpochAtTheEndOfWhichLrWasLowered + numEpochsToWaitBeforeLowerLR) :
                myLogger.print3("DEBUG: Going to lower Learning Rate because of AUTO schedule! The CNN has now been trained for: " + str(cnn3dInstance.numberOfEpochsTrained) + " epochs. Epoch with last highest achieved validation accuracy: " + str(cnn3dInstance.topMeanValidationAccuracyAchievedInEpoch[1]) + ", and epoch that Learning Rate was last lowered: " + str(cnn3dInstance.lastEpochAtTheEndOfWhichLrWasLowered) + ". I waited for increase in accuracy for: " +str(numEpochsToWaitBeforeLowerLR) + " epochs. Going to lower Learning Rate...")
                cnn3dInstance.divide_learning_rate_of_a_cnn_by(divideLrBy, myLogger)
        elif (lowerLrByStable0orAuto1orPredefined2orExponential3Schedule == 2) and (cnn3dInstance.numberOfEpochsTrained in lowerLrAtTheEndOfTheseEpochsPredefinedScheduleList) :
            #Predefined Schedule.
            myLogger.print3("DEBUG: Going to lower Learning Rate because of PREDEFINED schedule! The CNN has now been trained for: " + str(cnn3dInstance.numberOfEpochsTrained) + " epochs. I need to decrease after that many epochs: " + str(lowerLrAtTheEndOfTheseEpochsPredefinedScheduleList))
            cnn3dInstance.divide_learning_rate_of_a_cnn_by(divideLrBy, myLogger)
        elif (lowerLrByStable0orAuto1orPredefined2orExponential3Schedule == 3 and cnn3dInstance.numberOfEpochsTrained >= exponentialScheduleForLrAndMom[0]) :
            myLogger.print3("DEBUG: Going to lower Learning Rate and Increase Momentum because of EXPONENTIAL schedule! The CNN has now been trained for: " + str(cnn3dInstance.numberOfEpochsTrained) + " epochs.")
            minEpochToLowerLr = exponentialScheduleForLrAndMom[0]          
            #newLearningRate = initialLearningRate * gamma^t. gamma = {t-th}root(valueIwantLrToHaveAtTimepointT / initialLearningRate)
            gammaForExpSchedule = pow( ( cnn3dInstance.initialLearningRate*exponentialScheduleForLrAndMom[1] * 1.0) / cnn3dInstance.initialLearningRate, 1.0 / (n_epochs-minEpochToLowerLr))
            newLearningRate = cnn3dInstance.initialLearningRate * pow(gammaForExpSchedule, cnn3dInstance.numberOfEpochsTrained-minEpochToLowerLr + 1.0)
            #Momentum increased linearly.
            newMomentum = ((cnn3dInstance.numberOfEpochsTrained - minEpochToLowerLr + 1) - (n_epochs-minEpochToLowerLr))*1.0 / (n_epochs - minEpochToLowerLr) * (exponentialScheduleForLrAndMom[2] - cnn3dInstance.initialMomentum) + exponentialScheduleForLrAndMom[2]
            print("DEBUG: new learning rate was calculated: ", newLearningRate, " and new Momentum: ", newMomentum)
            cnn3dInstance.change_learning_rate_of_a_cnn(newLearningRate, myLogger)
            cnn3dInstance.change_momentum_of_a_cnn(newMomentum, myLogger)
            
        #================== Everything for epoch has finished. =======================
        #Training finished. Update the number of epochs that the cnn was trained.
        cnn3dInstance.increaseNumberOfEpochsTrained()
        
        myLogger.print3("SAVING: Epoch #"+str(epoch)+" finished. Saving CNN model.")
        dump_cnn_to_gzip_file_dotSave(cnn3dInstance, fileToSaveTrainedCnnModelTo+"."+datetimeNowAsStr(), myLogger)
        end_epoch_time = time.clock()
        myLogger.print3("TIMING: The whole Epoch #"+str(epoch)+" took time: "+str(end_epoch_time-start_epoch_time)+"(s)")
        myLogger.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
        
        if performFullInferenceOnValidationImagesEveryFewEpochsBool and (cnn3dInstance.numberOfEpochsTrained != 0) and (cnn3dInstance.numberOfEpochsTrained % everyThatManyEpochsComputeDiceOnTheFullValidationImages == 0) :
            myLogger.print3("***Starting validation with Full Inference / Segmentation on validation subjects for Epoch #"+str(epoch)+"...***")
            validation0orTesting1 = 0
            #do_validation_or_testing(myLogger,
            performInferenceOnWholeVolumes(myLogger,
                                    validation0orTesting1,
                                    savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation,
                                    cnn3dInstance,
                                    
                                    listOfFilepathsToEachChannelOfEachPatientValidation,
                                    
                                    providedGtForValidationBool,
                                    listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                                    
                                    providedRoiMaskForValidationBool,
                                    listOfFilepathsToRoiMaskOfEachPatientValidation,
                                    
                                    borrowFlag,
                                    listOfNamesToGiveToPredictionsIfSavingResults = "Placeholder" if not savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation else listOfNamesToGiveToPredictionsValidationIfSavingWhenEvalDice,
                                    
                                    #----Preprocessing------
                                    padInputImagesBool=padInputImagesBool,
                                    smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage=smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage,
                                    
                                    #for the cnn extension
                                    useSameSubChannelsAsSingleScale=useSameSubChannelsAsSingleScale,
                                    
                                    listOfFilepathsToEachSubsampledChannelOfEachPatient=listOfFilepathsToEachSubsampledChannelOfEachPatientValidation,
                                    
                                    #--------For FM visualisation---------
                                    saveIndividualFmImagesForVisualisation=saveIndividualFmImagesForVisualisation,
                                    saveMultidimensionalImageWithAllFms=saveMultidimensionalImageWithAllFms,
                                    indicesOfFmsToVisualisePerPathwayTypeAndPerLayer=indicesOfFmsToVisualisePerPathwayTypeAndPerLayer,
                                    listOfNamesToGiveToFmVisualisationsIfSaving=listOfNamesToGiveToFmVisualisationsIfSaving
                                    )
            
    dump_cnn_to_gzip_file_dotSave(cnn3dInstance, fileToSaveTrainedCnnModelTo+".final."+datetimeNowAsStr(), myLogger)
    
    end_training_time = time.clock()
    myLogger.print3("TIMING: Training process took time: "+str(end_training_time-start_training_time)+"(s)")
    myLogger.print3("The whole do_training() function has finished.")
    
    
Exemple #3
0
def deepMedicTestMain(testConfigFilepath, absPathToSavedModelFromCmdLine):
    print("Given configuration file: ", testConfigFilepath)
    #Parse the config file in this naive fashion...
    testConfig = TestConfig()
    #configStruct = testConfig.configStruct
    exec(open(testConfigFilepath).read(), testConfig.configStruct)
    configGet = testConfig.get  #Main interface

    #Do checks.
    #checkIfMainTestConfigIsCorrect(testConfig, testConfigFilepath, absPathToSavedModelFromCmdLine) #Checks REQUIRED fields are complete.
    #checkIfFilesThatListFilesPerCaseAreCorrect(testConfig, testConfigFilepath) #Checks listing-files (whatever given).
    #checkIfOptionalParametersAreGivenCorrectly(testConfig, testConfigFilepath)

    #At this point it was checked that all parameters (that could be checked) and filepaths are correct, pointing to files/dirs and all files/dirs exist.

    #Create Folders and Logger
    mainOutputAbsFolder = getAbsPathEvenIfRelativeIsGiven(
        configGet(testConfig.FOLDER_FOR_OUTPUT), testConfigFilepath)
    sessionName = configGet(testConfig.SESSION_NAME) if configGet(
        testConfig.SESSION_NAME
    ) else TestSessionParameters.getDefaultSessionName()
    [folderForLogs, folderForPredictions, folderForFeatures
     ] = makeFoldersNeededForTestingSession(mainOutputAbsFolder, sessionName)
    loggerFileName = folderForLogs + "/" + sessionName + ".txt"
    sessionLogger = logger.Logger(loggerFileName)

    sessionLogger.print3("CONFIG: Given THEANO_FLAGS: " +
                         str(os.environ['THEANO_FLAGS']))
    sessionLogger.print3(
        "CONFIG: The configuration file for the session was loaded from: " +
        str(testConfigFilepath))

    #Load the CNN Model!
    sessionLogger.print3(
        "=========== Loading the CNN model for testing... ===============")
    #If CNN-Model was specified in command line, completely override the one in the config file.
    filepathToCnnModelToLoad = None
    if absPathToSavedModelFromCmdLine and configGet(
            testConfig.CNN_MODEL_FILEPATH):
        sessionLogger.print3(
            "WARN: A CNN-Model to use was specified both in the command line input and in the test-config-file! The input by the command line will be used: "
            + str(absPathToSavedModelFromCmdLine))
        filepathToCnnModelToLoad = absPathToSavedModelFromCmdLine
    elif absPathToSavedModelFromCmdLine:
        filepathToCnnModelToLoad = absPathToSavedModelFromCmdLine
    else:
        filepathToCnnModelToLoad = getAbsPathEvenIfRelativeIsGiven(
            configGet(testConfig.CNN_MODEL_FILEPATH), testConfigFilepath)
    sessionLogger.print3(
        "...Loading the network can take a few minutes if the model is big...")
    cnn3dInstance = load_object_from_gzip_file(filepathToCnnModelToLoad)
    sessionLogger.print3("The CNN model was loaded successfully from: " +
                         str(filepathToCnnModelToLoad))
    #Do final checks of the parameters. Check the ones that need check in comparison to the model's parameters! Such as: SAVE_PROBMAPS_PER_CLASS, INDICES_OF_FMS_TO_SAVE, Number of Channels!
    #testConfig.checkIfConfigIsCorrectForParticularCnnModel(cnn3dInstance)

    #Fill in the session's parameters.
    #[[case1-ch1, ..., caseN-ch1], [case1-ch2,...,caseN-ch2]]
    listOfAListPerChannelWithFilepathsOfAllCases = [
        parseAbsFileLinesInList(
            getAbsPathEvenIfRelativeIsGiven(channelConfPath,
                                            testConfigFilepath))
        for channelConfPath in configGet(testConfig.CHANNELS)
    ]
    #[[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]]
    listWithAListPerCaseWithFilepathPerChannel = [
        list(item)
        for item in zip(*tuple(listOfAListPerChannelWithFilepathsOfAllCases))
    ]
    gtLabelsFilepaths = parseAbsFileLinesInList(
        getAbsPathEvenIfRelativeIsGiven(configGet(testConfig.GT_LABELS),
                                        testConfigFilepath)) if configGet(
                                            testConfig.GT_LABELS) else None
    roiMasksFilepaths = parseAbsFileLinesInList(
        getAbsPathEvenIfRelativeIsGiven(configGet(testConfig.ROI_MASKS),
                                        testConfigFilepath)) if configGet(
                                            testConfig.ROI_MASKS) else None
    namesToSavePredsAndFeats = parseFileLinesInList(
        getAbsPathEvenIfRelativeIsGiven(
            configGet(testConfig.NAMES_FOR_PRED_PER_CASE),
            testConfigFilepath)) if configGet(
                testConfig.NAMES_FOR_PRED_PER_CASE
            ) else None  #CAREFUL: Here we use a different parsing function!

    testSessionParameters = TestSessionParameters(
                    sessionName = sessionName,
                    sessionLogger = sessionLogger,
                    mainOutputAbsFolder = mainOutputAbsFolder,
                    cnn3dInstance = cnn3dInstance,
                    cnnModelFilepath = filepathToCnnModelToLoad,

                    #Input:
                    listWithAListPerCaseWithFilepathPerChannel = listWithAListPerCaseWithFilepathPerChannel,
                    gtLabelsFilepaths = gtLabelsFilepaths,
                    roiMasksFilepaths = roiMasksFilepaths,

                    #Output
                    namesToSavePredictionsAndFeatures = namesToSavePredsAndFeats,
                    #predictions
                    saveSegmentation = configGet(testConfig.SAVE_SEGM),
                    saveProbMapsBoolPerClass = configGet(testConfig.SAVE_PROBMAPS_PER_CLASS),
                    folderForPredictions = folderForPredictions,

                    #features:
                    saveIndividualFmImages = configGet(testConfig.SAVE_INDIV_FMS),
                    saveMultidimensionalImageWithAllFms = configGet(testConfig.SAVE_4DIM_FMS),
                    indicesOfFmsToVisualisePerPathwayAndLayer = [configGet(testConfig.INDICES_OF_FMS_TO_SAVE_NORMAL)] +\
                                                                [configGet(testConfig.INDICES_OF_FMS_TO_SAVE_SUBSAMPLED)] +\
                                                                [configGet(testConfig.INDICES_OF_FMS_TO_SAVE_FC)
                                                                ],
                    folderForFeatures = folderForFeatures,

                    padInputImagesBool = configGet(testConfig.PAD_INPUT),
                    )

    testSessionParameters.sessionLogger.print3(
        "\n===========       NEW TESTING SESSION         ===============")

    testSessionParameters.printParametersOfThisSession()

    testSessionParameters.sessionLogger.print3(
        "\n=======================================================")
    testSessionParameters.sessionLogger.print3(
        "=========== Compiling the Testing Function ============")
    testSessionParameters.sessionLogger.print3(
        "=======================================================")

    cnn3dInstance.compileTestAndVisualisationFunction(
        *testSessionParameters.getTupleForCompilationOfTestFunc())

    testSessionParameters.sessionLogger.print3(
        "\n======================================================")
    testSessionParameters.sessionLogger.print3(
        "=========== Testing with the CNN model ===============")
    testSessionParameters.sessionLogger.print3(
        "======================================================")

    checkCpuOrGpu(sessionLogger,
                  cnn3dInstance.cnnTestAndVisualiseAllFmsFunction)
    performInferenceOnWholeVolumes(
        *testSessionParameters.getTupleForCnnTesting())

    testSessionParameters.sessionLogger.print3(
        "\n======================================================")
    testSessionParameters.sessionLogger.print3(
        "=========== Testing session finished =================")
    testSessionParameters.sessionLogger.print3(
        "======================================================")
Exemple #4
0
def do_training(
        sessionTf,
        saver_all,
        cnn3d,
        trainer,
        log,
        fileToSaveTrainedCnnModelTo,
        performValidationOnSamplesDuringTrainingProcessBool,
        savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation,
        listOfNamesToGiveToPredictionsValidationIfSavingWhenEvalDice,
        listOfFilepathsToEachChannelOfEachPatientTraining,
        listOfFilepathsToEachChannelOfEachPatientValidation,
        listOfFilepathsToGtLabelsOfEachPatientTraining,
        providedGtForValidationBool,
        listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
        providedWeightMapsToSampleForEachCategoryTraining,
        forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining,
        providedWeightMapsToSampleForEachCategoryValidation,
        forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation,
        providedRoiMaskForTrainingBool,
        listOfFilepathsToRoiMaskOfEachPatientTraining,  # Also needed for normalization-augmentation
        providedRoiMaskForValidationBool,
        listOfFilepathsToRoiMaskOfEachPatientValidation,
        n_epochs,  # Every epoch the CNN model is saved.
        number_of_subepochs,  # per epoch. Every subepoch Accuracy is reported
        maxNumSubjectsLoadedPerSubepoch,  # Max num of cases loaded every subepoch for segments extraction. The more, the longer loading.
        imagePartsLoadedInGpuPerSubepoch,
        imagePartsLoadedInGpuPerSubepochValidation,

        #-------Sampling Type---------
        samplingTypeInstanceTraining,  # Instance of the deepmedic/samplingType.SamplingType class for training and validation
        samplingTypeInstanceValidation,

        #-------Preprocessing-----------
        padInputImagesBool,
        #-------Data Augmentation-------
        doIntAugm_shiftMuStd_multiMuStd,
        reflectImageWithHalfProbDuringTraining,
        useSameSubChannelsAsSingleScale,
        listOfFilepathsToEachSubsampledChannelOfEachPatientTraining,  # deprecated, not supported
        listOfFilepathsToEachSubsampledChannelOfEachPatientValidation,  # deprecated, not supported

        # Validation
    performFullInferenceOnValidationImagesEveryFewEpochsBool,  #Even if not providedGtForValidationBool, inference will be performed if this == True, to save the results, eg for visual.
        everyThatManyEpochsComputeDiceOnTheFullValidationImages,  # Should not be == 0, except if performFullInferenceOnValidationImagesEveryFewEpochsBool == False

        #--------For FM visualisation---------
    saveIndividualFmImagesForVisualisation,
        saveMultidimensionalImageWithAllFms,
        indicesOfFmsToVisualisePerPathwayTypeAndPerLayer,
        listOfNamesToGiveToFmVisualisationsIfSaving,

        #-------- Others --------
        run_input_checks):

    start_training_time = time.time()

    # Used because I cannot pass cnn3d to the sampling function.
    #This is because the parallel process used to load theano again. And created problems in the GPU when cnmem is used. Not sure this is needed with Tensorflow. Probably.
    cnn3dWrapper = CnnWrapperForSampling(cnn3d)

    #---------To run PARALLEL the extraction of parts for the next subepoch---
    ppservers = ()  # tuple of all parallel python servers to connect with
    job_server = pp.Server(
        ncpus=1, ppservers=ppservers
    )  # Creates jobserver with automatically detected number of workers

    tupleWithParametersForTraining = (
        log, "train", run_input_checks, cnn3dWrapper,
        maxNumSubjectsLoadedPerSubepoch, imagePartsLoadedInGpuPerSubepoch,
        samplingTypeInstanceTraining,
        listOfFilepathsToEachChannelOfEachPatientTraining,
        listOfFilepathsToGtLabelsOfEachPatientTraining,
        providedRoiMaskForTrainingBool,
        listOfFilepathsToRoiMaskOfEachPatientTraining,
        providedWeightMapsToSampleForEachCategoryTraining,
        forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining,
        useSameSubChannelsAsSingleScale,
        listOfFilepathsToEachSubsampledChannelOfEachPatientTraining,
        padInputImagesBool, doIntAugm_shiftMuStd_multiMuStd,
        reflectImageWithHalfProbDuringTraining)
    tupleWithParametersForValidation = (
        log,
        "val",
        run_input_checks,
        cnn3dWrapper,
        maxNumSubjectsLoadedPerSubepoch,
        imagePartsLoadedInGpuPerSubepochValidation,
        samplingTypeInstanceValidation,
        listOfFilepathsToEachChannelOfEachPatientValidation,
        listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
        providedRoiMaskForValidationBool,
        listOfFilepathsToRoiMaskOfEachPatientValidation,
        providedWeightMapsToSampleForEachCategoryValidation,
        forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation,
        useSameSubChannelsAsSingleScale,
        listOfFilepathsToEachSubsampledChannelOfEachPatientValidation,
        padInputImagesBool,
        [0, -1, -1,
         -1],  #don't perform intensity-augmentation during validation.
        [0, 0, 0]  #don't perform reflection-augmentation during validation.
    )
    tupleWithLocalFunctionsThatWillBeCalledByTheMainJob = ()
    tupleWithModulesToImportWhichAreUsedByTheJobFunctions = (
        "from __future__ import absolute_import, print_function, division",
        "time", "numpy as np",
        "from deepmedic.dataManagement.sampling import *")
    boolItIsTheVeryFirstSubepochOfThisProcess = True  #to know so that in the very first I sequencially load the data for it.
    #------End for parallel------

    model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(
        session=sessionTf)
    while model_num_epochs_trained < n_epochs:
        epoch = model_num_epochs_trained

        trainingAccuracyMonitorForEpoch = AccuracyOfEpochMonitorSegmentation(
            log, 0, model_num_epochs_trained, cnn3d.num_classes,
            number_of_subepochs)
        validationAccuracyMonitorForEpoch = None if not performValidationOnSamplesDuringTrainingProcessBool else \
                                        AccuracyOfEpochMonitorSegmentation(log, 1, model_num_epochs_trained, cnn3d.num_classes, number_of_subepochs )

        log.print3(
            "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
        )
        log.print3("~~~~~~~~~~~~~~~~~~~~Starting new Epoch! Epoch #" +
                   str(epoch) + "/" + str(n_epochs) +
                   " ~~~~~~~~~~~~~~~~~~~~~~~~~")
        log.print3(
            "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
        )
        start_epoch_time = time.time()

        for subepoch in range(
                number_of_subepochs
        ):  #per subepoch I randomly load some images in the gpu. Random order.
            log.print3(
                "**************************************************************************************************"
            )
            log.print3("************* Starting new Subepoch: #" +
                       str(subepoch) + "/" + str(number_of_subepochs) +
                       " *************")
            log.print3(
                "**************************************************************************************************"
            )

            #-------------------------GET DATA FOR THIS SUBEPOCH's VALIDATION---------------------------------

            if performValidationOnSamplesDuringTrainingProcessBool:
                if boolItIsTheVeryFirstSubepochOfThisProcess:
                    [
                        channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal
                    ] = getSampledDataAndLabelsForSubepoch(
                        log,
                        "val",
                        run_input_checks,
                        cnn3dWrapper,
                        maxNumSubjectsLoadedPerSubepoch,
                        imagePartsLoadedInGpuPerSubepochValidation,
                        samplingTypeInstanceValidation,
                        listOfFilepathsToEachChannelOfEachPatientValidation,
                        listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                        providedRoiMaskForValidationBool,
                        listOfFilepathsToRoiMaskOfEachPatientValidation,
                        providedWeightMapsToSampleForEachCategoryValidation,
                        forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation,
                        useSameSubChannelsAsSingleScale,
                        listOfFilepathsToEachSubsampledChannelOfEachPatientValidation,
                        padInputImagesBool,
                        doIntAugm_shiftMuStd_multiMuStd=[False, [], []],
                        reflectImageWithHalfProbDuringTraining=[0, 0, 0])
                    boolItIsTheVeryFirstSubepochOfThisProcess = False
                else:  #It was done in parallel with the training of the previous epoch, just grab the results...
                    [
                        channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal
                    ] = parallelJobToGetDataForNextValidation(
                    )  #fromParallelProcessing that had started from last loop when it was submitted.

                # Below is computed with number of extracted samples, in case I dont manage to extract as many as I wanted initially.
                numberOfBatchesValidation = len(
                    channsOfSegmentsForSubepPerPathwayVal[0]
                ) // cnn3d.batchSize["val"]

                #------------------------SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING-----------------
                #submit the parallel job
                log.print3(
                    "PARALLEL: Before Validation in subepoch #" +
                    str(subepoch) +
                    ", the parallel job for extracting Segments for the next Training is submitted."
                )
                parallelJobToGetDataForNextTraining = job_server.submit(
                    getSampledDataAndLabelsForSubepoch,  #local function to call and execute in parallel.
                    tupleWithParametersForTraining,  #tuple with the arguments required
                    tupleWithLocalFunctionsThatWillBeCalledByTheMainJob,  #tuple of local functions that I need to call
                    tupleWithModulesToImportWhichAreUsedByTheJobFunctions
                )  #tuple of the external modules that I need, of which I am calling functions (not the mods of the ext-functions).

                #------------------------------------DO VALIDATION--------------------------------
                log.print3(
                    "-V-V-V-V-V- Now Validating for this subepoch before commencing the training iterations... -V-V-V-V-V-"
                )
                start_validationForSubepoch_time = time.time()

                doTrainOrValidationOnBatchesAndReturnMeanAccuraciesOfSubepoch(
                    log,
                    sessionTf,
                    "val",
                    numberOfBatchesValidation,  # Computed by the number of extracted samples. So, adapts.
                    cnn3d,
                    subepoch,
                    validationAccuracyMonitorForEpoch,
                    channsOfSegmentsForSubepPerPathwayVal,
                    labelsForCentralOfSegmentsForSubepVal)

                end_validationForSubepoch_time = time.time()
                log.print3(
                    "TIMING: Validating on the batches of this subepoch #" +
                    str(subepoch) + " took time: " +
                    str(end_validationForSubepoch_time -
                        start_validationForSubepoch_time) + "(s)")

            #-------------------END OF THE VALIDATION-DURING-TRAINING-LOOP-------------------------

            #-------------------------GET DATA FOR THIS SUBEPOCH's TRAINING---------------------------------
            if (not performValidationOnSamplesDuringTrainingProcessBool
                ) and boolItIsTheVeryFirstSubepochOfThisProcess:
                [
                    channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain
                ] = getSampledDataAndLabelsForSubepoch(
                    log, "train", run_input_checks, cnn3dWrapper,
                    maxNumSubjectsLoadedPerSubepoch,
                    imagePartsLoadedInGpuPerSubepoch,
                    samplingTypeInstanceTraining,
                    listOfFilepathsToEachChannelOfEachPatientTraining,
                    listOfFilepathsToGtLabelsOfEachPatientTraining,
                    providedRoiMaskForTrainingBool,
                    listOfFilepathsToRoiMaskOfEachPatientTraining,
                    providedWeightMapsToSampleForEachCategoryTraining,
                    forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining,
                    useSameSubChannelsAsSingleScale,
                    listOfFilepathsToEachSubsampledChannelOfEachPatientTraining,
                    padInputImagesBool, doIntAugm_shiftMuStd_multiMuStd,
                    reflectImageWithHalfProbDuringTraining)
                boolItIsTheVeryFirstSubepochOfThisProcess = False
            else:
                #It was done in parallel with the validation (or with previous training iteration, in case I am not performing validation).
                [
                    channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain
                ] = parallelJobToGetDataForNextTraining(
                )  #fromParallelProcessing that had started from last loop when it was submitted.

            numberOfBatchesTraining = len(
                channsOfSegmentsForSubepPerPathwayTrain[0]
            ) // cnn3d.batchSize[
                "train"]  #Computed with number of extracted samples, in case I dont manage to extract as many as I wanted initially.

            #------------------------SUBMIT PARALLEL JOB TO GET VALIDATION/TRAINING DATA (if val is/not performed) FOR NEXT SUBEPOCH-----------------
            if performValidationOnSamplesDuringTrainingProcessBool:
                #submit the parallel job
                log.print3(
                    "PARALLEL: Before Training in subepoch #" + str(subepoch) +
                    ", submitting the parallel job for extracting Segments for the next Validation."
                )
                parallelJobToGetDataForNextValidation = job_server.submit(
                    getSampledDataAndLabelsForSubepoch,  #local function to call and execute in parallel.
                    tupleWithParametersForValidation,  #tuple with the arguments required
                    tupleWithLocalFunctionsThatWillBeCalledByTheMainJob,  #tuple of local functions that I need to call
                    tupleWithModulesToImportWhichAreUsedByTheJobFunctions
                )  #tuple of the external modules that I need, of which I am calling functions (not the mods of the ext-functions).
            else:  #extract in parallel the samples for the next subepoch's training.
                log.print3(
                    "PARALLEL: Before Training in subepoch #" + str(subepoch) +
                    ", submitting the parallel job for extracting Segments for the next Training."
                )
                parallelJobToGetDataForNextTraining = job_server.submit(
                    getSampledDataAndLabelsForSubepoch,  #local function to call and execute in parallel.
                    tupleWithParametersForTraining,  #tuple with the arguments required
                    tupleWithLocalFunctionsThatWillBeCalledByTheMainJob,  #tuple of local functions that I need to call
                    tupleWithModulesToImportWhichAreUsedByTheJobFunctions
                )  #tuple of the external modules that I need, of which I am calling

            #-------------------------------START TRAINING IN BATCHES------------------------------
            log.print3(
                "-T-T-T-T-T- Now Training for this subepoch... This may take a few minutes... -T-T-T-T-T-"
            )
            start_trainingForSubepoch_time = time.time()

            doTrainOrValidationOnBatchesAndReturnMeanAccuraciesOfSubepoch(
                log, sessionTf, "train", numberOfBatchesTraining, cnn3d,
                subepoch, trainingAccuracyMonitorForEpoch,
                channsOfSegmentsForSubepPerPathwayTrain,
                labelsForCentralOfSegmentsForSubepTrain)

            end_trainingForSubepoch_time = time.time()
            log.print3("TIMING: Training on the batches of this subepoch #" +
                       str(subepoch) + " took time: " +
                       str(end_trainingForSubepoch_time -
                           start_trainingForSubepoch_time) + "(s)")

        log.print3(
            "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
        )
        log.print3(
            "~~~~~~~~~~~~~~~~~~ Epoch #" + str(epoch) +
            " finished. Reporting Accuracy over whole epoch. ~~~~~~~~~~~~~~~~~~"
        )
        log.print3(
            "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
        )

        if performValidationOnSamplesDuringTrainingProcessBool:
            validationAccuracyMonitorForEpoch.reportMeanAccyracyOfEpoch()
        trainingAccuracyMonitorForEpoch.reportMeanAccyracyOfEpoch()

        mean_val_acc_of_ep = validationAccuracyMonitorForEpoch.getMeanEmpiricalAccuracyOfEpoch(
        ) if performValidationOnSamplesDuringTrainingProcessBool else None
        trainer.run_updates_end_of_ep(
            log, sessionTf, mean_val_acc_of_ep
        )  # Updates LR schedule if needed, and increases number of epochs trained.
        model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(
            session=sessionTf)

        del trainingAccuracyMonitorForEpoch
        del validationAccuracyMonitorForEpoch
        #================== Everything for epoch has finished. =======================

        log.print3("SAVING: Epoch #" + str(epoch) +
                   " finished. Saving CNN model.")
        filename_to_save_with = fileToSaveTrainedCnnModelTo + "." + datetimeNowAsStr(
        )
        saver_all.save(sessionTf,
                       filename_to_save_with + ".model.ckpt",
                       write_meta_graph=False)

        end_epoch_time = time.time()
        log.print3("TIMING: The whole Epoch #" + str(epoch) + " took time: " +
                   str(end_epoch_time - start_epoch_time) + "(s)")
        log.print3(
            "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
        )

        if performFullInferenceOnValidationImagesEveryFewEpochsBool and (
                model_num_epochs_trained != 0
        ) and (model_num_epochs_trained %
               everyThatManyEpochsComputeDiceOnTheFullValidationImages == 0):
            log.print3(
                "***Starting validation with Full Inference / Segmentation on validation subjects for Epoch #"
                + str(epoch) + "...***")

            performInferenceOnWholeVolumes(
                sessionTf,
                cnn3d,
                log,
                "val",
                savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation,
                listOfFilepathsToEachChannelOfEachPatientValidation,
                providedGtForValidationBool,
                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                providedRoiMaskForValidationBool,
                listOfFilepathsToRoiMaskOfEachPatientValidation,
                listOfNamesToGiveToPredictionsIfSavingResults="Placeholder" if
                not savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation
                else
                listOfNamesToGiveToPredictionsValidationIfSavingWhenEvalDice,

                #----Preprocessing------
                padInputImagesBool=padInputImagesBool,

                #for the cnn extension
                useSameSubChannelsAsSingleScale=useSameSubChannelsAsSingleScale,
                listOfFilepathsToEachSubsampledChannelOfEachPatient=
                listOfFilepathsToEachSubsampledChannelOfEachPatientValidation,

                #--------For FM visualisation---------
                saveIndividualFmImagesForVisualisation=
                saveIndividualFmImagesForVisualisation,
                saveMultidimensionalImageWithAllFms=
                saveMultidimensionalImageWithAllFms,
                indicesOfFmsToVisualisePerPathwayTypeAndPerLayer=
                indicesOfFmsToVisualisePerPathwayTypeAndPerLayer,
                listOfNamesToGiveToFmVisualisationsIfSaving=
                listOfNamesToGiveToFmVisualisationsIfSaving)

    end_training_time = time.time()
    log.print3("TIMING: Training process took time: " +
               str(end_training_time - start_training_time) + "(s)")
    log.print3("The whole do_training() function has finished.")