def deepMedicTestMain(testConfigFilepath, absPathToSavedModelFromCmdLine): print("Given Test-Configuration File: ", testConfigFilepath) #Parse the config file in this naive fashion... testConfig = TestConfig() #configStruct = testConfig.configStruct exec(open(testConfigFilepath).read(), testConfig.configStruct) configGet = testConfig.get #Main interface #Do checks. #checkIfMainTestConfigIsCorrect(testConfig, testConfigFilepath, absPathToSavedModelFromCmdLine) #Checks REQUIRED fields are complete. #checkIfFilesThatListFilesPerCaseAreCorrect(testConfig, testConfigFilepath) #Checks listing-files (whatever given). #checkIfOptionalParametersAreGivenCorrectly(testConfig, testConfigFilepath) #At this point it was checked that all parameters (that could be checked) and filepaths are correct, pointing to files/dirs and all files/dirs exist. #Create Folders and Logger mainOutputAbsFolder = getAbsPathEvenIfRelativeIsGiven( configGet(testConfig.FOLDER_FOR_OUTPUT), testConfigFilepath) sessionName = configGet(testConfig.SESSION_NAME) if configGet( testConfig.SESSION_NAME ) else TestSessionParameters.getDefaultSessionName() [folderForLogs, folderForPredictions, folderForFeatures ] = makeFoldersNeededForTestingSession(mainOutputAbsFolder, sessionName) loggerFileName = folderForLogs + "/" + sessionName + ".txt" sessionLogger = myLoggerModule.MyLogger(loggerFileName) sessionLogger.print3( "CONFIG: The configuration file for the testing session was loaded from: " + str(testConfigFilepath)) #Load the CNN Model! sessionLogger.print3( "=========== Loading the CNN model for testing... ===============") #If CNN-Model was specified in command line, completely override the one in the config file. filepathToCnnModelToLoad = None if absPathToSavedModelFromCmdLine and configGet( testConfig.CNN_MODEL_FILEPATH): sessionLogger.print3( "WARN: A CNN-Model to use was specified both in the command line input and in the test-config-file! The input by the command line will be used: " + str(absPathToSavedModelFromCmdLine)) filepathToCnnModelToLoad = absPathToSavedModelFromCmdLine elif absPathToSavedModelFromCmdLine: filepathToCnnModelToLoad = absPathToSavedModelFromCmdLine else: filepathToCnnModelToLoad = getAbsPathEvenIfRelativeIsGiven( configGet(testConfig.CNN_MODEL_FILEPATH), testConfigFilepath) sessionLogger.print3( "...Loading the network can take a few minutes if the model is big...") cnn3dInstance = load_object_from_gzip_file(filepathToCnnModelToLoad) sessionLogger.print3("The CNN model was loaded successfully from: " + str(filepathToCnnModelToLoad)) #Do final checks of the parameters. Check the ones that need check in comparison to the model's parameters! Such as: SAVE_PROBMAPS_PER_CLASS, INDICES_OF_FMS_TO_SAVE, Number of Channels! #testConfig.checkIfConfigIsCorrectForParticularCnnModel(cnn3dInstance) #Fill in the session's parameters. #[[case1-ch1, ..., caseN-ch1], [case1-ch2,...,caseN-ch2]] listOfAListPerChannelWithFilepathsOfAllCases = [ parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(channelConfPath, testConfigFilepath)) for channelConfPath in configGet(testConfig.CHANNELS) ] #[[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]] listWithAListPerCaseWithFilepathPerChannel = [ list(item) for item in zip(*tuple(listOfAListPerChannelWithFilepathsOfAllCases)) ] gtLabelsFilepaths = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(testConfig.GT_LABELS), testConfigFilepath)) if configGet( testConfig.GT_LABELS) else None roiMasksFilepaths = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(testConfig.ROI_MASKS), testConfigFilepath)) if configGet( testConfig.ROI_MASKS) else None namesToSavePredsAndFeats = parseFileLinesInList( getAbsPathEvenIfRelativeIsGiven( configGet(testConfig.NAMES_FOR_PRED_PER_CASE), testConfigFilepath)) if configGet( testConfig.NAMES_FOR_PRED_PER_CASE ) else None #CAREFUL: Here we use a different parsing function! testSessionParameters = TestSessionParameters( sessionName = sessionName, sessionLogger = sessionLogger, mainOutputAbsFolder = mainOutputAbsFolder, cnn3dInstance = cnn3dInstance, cnnModelFilepath = filepathToCnnModelToLoad, #Input: listWithAListPerCaseWithFilepathPerChannel = listWithAListPerCaseWithFilepathPerChannel, gtLabelsFilepaths = gtLabelsFilepaths, roiMasksFilepaths = roiMasksFilepaths, #Output namesToSavePredictionsAndFeatures = namesToSavePredsAndFeats, #predictions saveSegmentation = configGet(testConfig.SAVE_SEGM), saveProbMapsBoolPerClass = configGet(testConfig.SAVE_PROBMAPS_PER_CLASS), folderForPredictions = folderForPredictions, #features: saveIndividualFmImages = configGet(testConfig.SAVE_INDIV_FMS), saveMultidimensionalImageWithAllFms = configGet(testConfig.SAVE_4DIM_FMS), indicesOfFmsToVisualisePerPathwayAndLayer = [configGet(testConfig.INDICES_OF_FMS_TO_SAVE_NORMAL)] +\ [configGet(testConfig.INDICES_OF_FMS_TO_SAVE_SUBSAMPLED)] +\ [configGet(testConfig.INDICES_OF_FMS_TO_SAVE_FC) ], folderForFeatures = folderForFeatures, padInputImagesBool = configGet(testConfig.PAD_INPUT), ) testSessionParameters.sessionLogger.print3( "\n=========== NEW TESTING SESSION ===============") testSessionParameters.printParametersOfThisSession() testSessionParameters.sessionLogger.print3( "\n=======================================================") testSessionParameters.sessionLogger.print3( "=========== Compiling the Testing Function ============") testSessionParameters.sessionLogger.print3( "=======================================================") cnn3dInstance.compileTestAndVisualisationFunction( *testSessionParameters.getTupleForCompilationOfTestFunc()) testSessionParameters.sessionLogger.print3( "\n======================================================") testSessionParameters.sessionLogger.print3( "=========== Testing with the CNN model ===============") testSessionParameters.sessionLogger.print3( "======================================================") performInferenceForTestingOnWholeVolumes( *testSessionParameters.getTupleForCnnTesting()) testSessionParameters.sessionLogger.print3( "\n======================================================") testSessionParameters.sessionLogger.print3( "=========== Testing session finished =================") testSessionParameters.sessionLogger.print3( "======================================================")
def deepMedicTrainMain(trainConfigFilepath, absPathToSavedModelFromCmdLine, cnnInstancePreLoaded, filenameAndPathWherePreLoadedModelWas, resetOptimizer): print("Given Training-Configuration File: ", trainConfigFilepath) #Parse the config file in this naive fashion... trainConfig = TrainConfig() exec(open(trainConfigFilepath).read(), trainConfig.configStruct) configGet = trainConfig.get #Main interface """ #Do checks. checkIfMainTestConfigIsCorrect(testConfig, testConfigFilepath, absPathToSavedModelFromCmdLine) #Checks REQUIRED fields are complete. checkIfFilesThatListFilesPerCaseAreCorrect(testConfig, testConfigFilepath) #Checks listing-files (whatever given). checkIfOptionalParametersAreGivenCorrectly(testConfig, testConfigFilepath) #At this point it was checked that all parameters (that could be checked) and filepaths are correct, pointing to files/dirs and all files/dirs exist. """ #Create Folders and Logger mainOutputAbsFolder = getAbsPathEvenIfRelativeIsGiven( configGet(trainConfig.FOLDER_FOR_OUTPUT), trainConfigFilepath) sessionName = configGet(trainConfig.SESSION_NAME) if configGet( trainConfig.SESSION_NAME ) else TrainSessionParameters.getDefaultSessionName() [ folderForSessionCnnModels, folderForLogs, folderForPredictions, folderForFeatures ] = makeFoldersNeededForTrainingSession(mainOutputAbsFolder, sessionName) loggerFileName = folderForLogs + "/" + sessionName + ".txt" sessionLogger = myLoggerModule.MyLogger(loggerFileName) sessionLogger.print3( "CONFIG: The configuration file for the training session was loaded from: " + str(trainConfigFilepath)) #Load the cnn Model if not given straight from a createCnnModel-session... cnn3dInstance = None filepathToCnnModel = None if cnnInstancePreLoaded: sessionLogger.print3( "====== CNN-Instance already loaded (from Create-New-Model procedure) =======" ) if configGet(trainConfig.CNN_MODEL_FILEPATH): sessionLogger.print3( "WARN: Any paths to cnn-models specified in the train-config file will be ommitted! Working with the pre-loaded model!" ) cnn3dInstance = cnnInstancePreLoaded filepathToCnnModel = filenameAndPathWherePreLoadedModelWas else: sessionLogger.print3( "=========== Loading the CNN model for training... ===============" ) #If CNN-Model was specified in command line, completely override the one in the config file. if absPathToSavedModelFromCmdLine and configGet( trainConfig.CNN_MODEL_FILEPATH): sessionLogger.print3( "WARN: A CNN-Model to use was specified both in the command line input and in the train-config-file! The input by the command line will be used: " + str(absPathToSavedModelFromCmdLine)) filepathToCnnModel = absPathToSavedModelFromCmdLine else: filepathToCnnModel = getAbsPathEvenIfRelativeIsGiven( configGet(trainConfig.CNN_MODEL_FILEPATH), trainConfigFilepath) sessionLogger.print3( "...Loading the network can take a few minutes if the model is big..." ) cnn3dInstance = load_object_from_gzip_file(filepathToCnnModel) sessionLogger.print3("The CNN model was loaded successfully from: " + str(filepathToCnnModel)) """ #Do final checks of the parameters. Check the ones that need check in comparison to the model's parameters! Such as: SAVE_PROBMAPS_PER_CLASS, INDICES_OF_FMS_TO_SAVE, Number of Channels! trainConfig.checkIfConfigIsCorrectForParticularCnnModel(cnn3dInstance) """ #Fill in the session's parameters. if configGet(trainConfig.CHANNELS_TR): #[[case1-ch1, ..., caseN-ch1], [case1-ch2,...,caseN-ch2]] listOfAListPerChannelWithFilepathsOfAllCasesTrain = [ parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(channelConfPath, trainConfigFilepath)) for channelConfPath in configGet(trainConfig.CHANNELS_TR) ] #[[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]] listWithAListPerCaseWithFilepathPerChannelTrain = [ list(item) for item in zip( *tuple(listOfAListPerChannelWithFilepathsOfAllCasesTrain)) ] else: listWithAListPerCaseWithFilepathPerChannelTrain = None gtLabelsFilepathsTrain = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.GT_LABELS_TR), trainConfigFilepath)) if configGet( trainConfig.GT_LABELS_TR) else None roiMasksFilepathsTrain = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.ROI_MASKS_TR), trainConfigFilepath)) if configGet( trainConfig.ROI_MASKS_TR) else None #~~~~~ Advanced Training Sampling~~~~~~~~~ if configGet(trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_TR): #[[case1-weightMap1, ..., caseN-weightMap1], [case1-weightMap2,...,caseN-weightMap2]] listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain = [ parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(weightMapConfPath, trainConfigFilepath)) for weightMapConfPath in configGet( trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_TR) ] else: listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain = None #=======VALIDATION========== if configGet(trainConfig.CHANNELS_VAL): listOfAListPerChannelWithFilepathsOfAllCasesVal = [ parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(channelConfPath, trainConfigFilepath)) for channelConfPath in configGet(trainConfig.CHANNELS_VAL) ] #[[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]] listWithAListPerCaseWithFilepathPerChannelVal = [ list(item) for item in zip( *tuple(listOfAListPerChannelWithFilepathsOfAllCasesVal)) ] else: listWithAListPerCaseWithFilepathPerChannelVal = None gtLabelsFilepathsVal = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet( trainConfig.GT_LABELS_VAL), trainConfigFilepath)) if configGet( trainConfig.GT_LABELS_VAL) else None roiMasksFilepathsVal = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet( trainConfig.ROI_MASKS_VAL), trainConfigFilepath)) if configGet( trainConfig.ROI_MASKS_VAL) else None #~~~~~Full Inference~~~~~~ namesToSavePredsAndFeatsVal = parseFileLinesInList( getAbsPathEvenIfRelativeIsGiven( configGet(trainConfig.NAMES_FOR_PRED_PER_CASE_VAL), trainConfigFilepath)) if configGet( trainConfig.NAMES_FOR_PRED_PER_CASE_VAL ) else None #CAREFUL: Here we use a different parsing function! #~~~~~Advanced Validation Sampling~~~~~~~~ if configGet(trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_VAL): #[[case1-weightMap1, ..., caseN-weightMap1], [case1-weightMap2,...,caseN-weightMap2]] listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal = [ parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(weightMapConfPath, trainConfigFilepath)) for weightMapConfPath in configGet( trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_VAL) ] else: listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal = None trainSessionParameters = TrainSessionParameters( sessionName = sessionName, sessionLogger = sessionLogger, mainOutputAbsFolder = mainOutputAbsFolder, cnn3dInstance = cnn3dInstance, cnnModelFilepath = filepathToCnnModel, #==================TRAINING==================== folderForSessionCnnModels = folderForSessionCnnModels, listWithAListPerCaseWithFilepathPerChannelTrain = listWithAListPerCaseWithFilepathPerChannelTrain, gtLabelsFilepathsTrain = gtLabelsFilepathsTrain, #[Optionals] #~~~~~~~~~Sampling~~~~~~~~~ roiMasksFilepathsTrain = roiMasksFilepathsTrain, percentOfSamplesToExtractPositTrain = configGet(trainConfig.PERC_POS_SAMPLES_TR), #~~~~~~~~~Advanced Sampling~~~~~~~ useDefaultTrainingSamplingFromGtAndRoi = configGet(trainConfig.DEFAULT_TR_SAMPLING), samplingTypeTraining = configGet(trainConfig.TYPE_OF_SAMPLING_TR), proportionOfSamplesPerCategoryTrain = configGet(trainConfig.PROP_OF_SAMPLES_PER_CAT_TR), listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain = listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain, #~~~~~~~~Training Cycle ~~~~~~~ numberOfEpochs = configGet(trainConfig.NUM_EPOCHS), numberOfSubepochs = configGet(trainConfig.NUM_SUBEP), numOfCasesLoadedPerSubepoch = configGet(trainConfig.NUM_CASES_LOADED_PERSUB), segmentsLoadedOnGpuPerSubepochTrain = configGet(trainConfig.NUM_TR_SEGMS_LOADED_PERSUB), #~~~~~~~ Learning Rate Schedule ~~~~~~~ #Auto requires performValidationOnSamplesThroughoutTraining and providedGtForValidationBool stable0orAuto1orPredefined2orExponential3LrSchedule = configGet(trainConfig.LR_SCH_0123), #Stable + Auto + Predefined. whenDecreasingDivideLrBy = configGet(trainConfig.DIV_LR_BY), #Stable + Auto numEpochsToWaitBeforeLoweringLr = configGet(trainConfig.NUM_EPOCHS_WAIT), #Auto: minIncreaseInValidationAccuracyThatResetsWaiting = configGet(trainConfig.AUTO_MIN_INCR_VAL_ACC), #Predefined. predefinedSchedule = configGet(trainConfig.PREDEF_SCH), #Exponential exponentialSchedForLrAndMom = configGet(trainConfig.EXPON_SCH), #~~~~~~~ Augmentation~~~~~~~~~~~~ reflectImagesPerAxis = configGet(trainConfig.REFL_AUGM_PER_AXIS), performIntAugm = configGet(trainConfig.PERF_INT_AUGM_BOOL), sampleIntAugmShiftWithMuAndStd = configGet(trainConfig.INT_AUGM_SHIF_MUSTD), sampleIntAugmMultiWithMuAndStd = configGet(trainConfig.INT_AUGM_MULT_MUSTD), #==================VALIDATION===================== performValidationOnSamplesThroughoutTraining = configGet(trainConfig.PERFORM_VAL_SAMPLES), performFullInferenceOnValidationImagesEveryFewEpochs = configGet(trainConfig.PERFORM_VAL_INFERENCE), #Required: listWithAListPerCaseWithFilepathPerChannelVal = listWithAListPerCaseWithFilepathPerChannelVal, gtLabelsFilepathsVal = gtLabelsFilepathsVal, segmentsLoadedOnGpuPerSubepochVal = configGet(trainConfig.NUM_VAL_SEGMS_LOADED_PERSUB), #[Optionals] roiMasksFilepathsVal = roiMasksFilepathsVal, #For default sampling and for fast inference. Optional. Otherwise from whole image. #~~~~~~~~Full Inference~~~~~~~~ numberOfEpochsBetweenFullInferenceOnValImages = configGet(trainConfig.NUM_EPOCHS_BETWEEN_VAL_INF), #Output namesToSavePredictionsAndFeaturesVal = namesToSavePredsAndFeatsVal, #predictions saveSegmentationVal = configGet(trainConfig.SAVE_SEGM_VAL), saveProbMapsBoolPerClassVal = configGet(trainConfig.SAVE_PROBMAPS_PER_CLASS_VAL), folderForPredictionsVal = folderForPredictions, #features: saveIndividualFmImagesVal = configGet(trainConfig.SAVE_INDIV_FMS_VAL), saveMultidimensionalImageWithAllFmsVal = configGet(trainConfig.SAVE_4DIM_FMS_VAL), indicesOfFmsToVisualisePerPathwayAndLayerVal = [configGet(trainConfig.INDICES_OF_FMS_TO_SAVE_NORMAL_VAL)] +\ [configGet(trainConfig.INDICES_OF_FMS_TO_SAVE_SUBSAMPLED_VAL)] +\ [configGet(trainConfig.INDICES_OF_FMS_TO_SAVE_FC_VAL)], folderForFeaturesVal = folderForFeatures, #~~~~~~~~ Advanced Validation Sampling ~~~~~~~~~~ useDefaultUniformValidationSampling = configGet(trainConfig.DEFAULT_VAL_SAMPLING), samplingTypeValidation = configGet(trainConfig.TYPE_OF_SAMPLING_VAL), proportionOfSamplesPerCategoryVal = configGet(trainConfig.PROP_OF_SAMPLES_PER_CAT_VAL), listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal = listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal, #====Optimization===== learningRate=configGet(trainConfig.LRATE), optimizerSgd0Adam1Rms2=configGet(trainConfig.OPTIMIZER), classicMom0Nesterov1=configGet(trainConfig.MOM_TYPE), momentumValue=configGet(trainConfig.MOM), momNonNormalized0Normalized1=configGet(trainConfig.MOM_NORM_NONNORM), #Adam b1Adam=configGet(trainConfig.B1_ADAM), b2Adam=configGet(trainConfig.B2_ADAM), eAdam=configGet(trainConfig.EPS_ADAM), #Rms rhoRms=configGet(trainConfig.RHO_RMS), eRms=configGet(trainConfig.EPS_RMS), #Regularization l1Reg=configGet(trainConfig.L1_REG), l2Reg=configGet(trainConfig.L2_REG), #~~~~~~~ Freeze Layers ~~~~~~~ layersToFreezePerPathwayType = [configGet(trainConfig.LAYERS_TO_FREEZE_NORM), configGet(trainConfig.LAYERS_TO_FREEZE_SUBS), configGet(trainConfig.LAYERS_TO_FREEZE_FC) ], #==============Generic and Preprocessing=============== padInputImagesBool = configGet(trainConfig.PAD_INPUT) ) trainSessionParameters.sessionLogger.print3( "\n=========== NEW TRAINING SESSION ===============") trainSessionParameters.printParametersOfThisSession() trainSessionParameters.sessionLogger.print3( "\n=======================================================") trainSessionParameters.sessionLogger.print3( "=========== Compiling the Training Function ===========") trainSessionParameters.sessionLogger.print3( "=======================================================") if not cnn3dInstance.checkTrainingStateAttributesInitialized( ) or resetOptimizer: trainSessionParameters.sessionLogger.print3("(Re)Initializing parameters for the optimization. " \ "Reason: Uninitialized: ["+str(not cnn3dInstance.checkTrainingStateAttributesInitialized())+"], Reset requested: ["+str(resetOptimizer)+"]" ) cnn3dInstance.initializeTrainingState( *trainSessionParameters.getTupleForInitializingTrainingState()) cnn3dInstance.compileTrainFunction( *trainSessionParameters.getTupleForCompilationOfTrainFunc()) trainSessionParameters.sessionLogger.print3( "\n=========== Compiling the Validation Function =========") cnn3dInstance.compileValidationFunction( *trainSessionParameters.getTupleForCompilationOfValFunc()) trainSessionParameters.sessionLogger.print3( "\n=========== Compiling the Testing Function ============") cnn3dInstance.compileTestAndVisualisationFunction( *trainSessionParameters.getTupleForCompilationOfTestFunc( )) # For validation with full segmentation trainSessionParameters.sessionLogger.print3( "\n=======================================================") trainSessionParameters.sessionLogger.print3( "============== Training the CNN model =================") trainSessionParameters.sessionLogger.print3( "=======================================================") do_training(*trainSessionParameters.getTupleForCnnTraining()) trainSessionParameters.sessionLogger.print3( "\n=======================================================") trainSessionParameters.sessionLogger.print3( "=========== Training session finished =================") trainSessionParameters.sessionLogger.print3( "=======================================================")
def deepMedicTestMain(testConfigFilepath, absPathToSavedModelFromCmdLine) : print "Given Test-Configuration File: ", testConfigFilepath #Parse the config file in this naive fashion... testConfig = TestConfig() #configStruct = testConfig.configStruct execfile(testConfigFilepath, testConfig.configStruct) configGet = testConfig.get #Main interface #Do checks. #checkIfMainTestConfigIsCorrect(testConfig, testConfigFilepath, absPathToSavedModelFromCmdLine) #Checks REQUIRED fields are complete. #checkIfFilesThatListFilesPerCaseAreCorrect(testConfig, testConfigFilepath) #Checks listing-files (whatever given). #checkIfOptionalParametersAreGivenCorrectly(testConfig, testConfigFilepath) #At this point it was checked that all parameters (that could be checked) and filepaths are correct, pointing to files/dirs and all files/dirs exist. #Create Folders and Logger mainOutputAbsFolder = getAbsPathEvenIfRelativeIsGiven(configGet(testConfig.FOLDER_FOR_OUTPUT), testConfigFilepath) sessionName = configGet(testConfig.SESSION_NAME) if configGet(testConfig.SESSION_NAME) else TestSessionParameters.getDefaultSessionName() [folderForLogs, folderForPredictions, folderForFeatures] = makeFoldersNeededForTestingSession(mainOutputAbsFolder, sessionName) loggerFileName = folderForLogs + "/" + sessionName + ".txt" sessionLogger = myLoggerModule.MyLogger(loggerFileName) sessionLogger.print3("CONFIG: The configuration file for the testing session was loaded from: " + str(testConfigFilepath)) #Load the CNN Model! sessionLogger.print3("=========== Loading the CNN model for testing... ===============") #If CNN-Model was specified in command line, completely override the one in the config file. filepathToCnnModelToLoad = None if absPathToSavedModelFromCmdLine and configGet(testConfig.CNN_MODEL_FILEPATH) : sessionLogger.print3("WARN: A CNN-Model to use was specified both in the command line input and in the test-config-file! The input by the command line will be used: " + str(absPathToSavedModelFromCmdLine) ) filepathToCnnModelToLoad = absPathToSavedModelFromCmdLine elif absPathToSavedModelFromCmdLine : filepathToCnnModelToLoad = absPathToSavedModelFromCmdLine else : filepathToCnnModelToLoad = getAbsPathEvenIfRelativeIsGiven(configGet(testConfig.CNN_MODEL_FILEPATH), testConfigFilepath) sessionLogger.print3("...Loading the network can take a few minutes if the model is big...") cnn3dInstance = load_object_from_gzip_file(filepathToCnnModelToLoad) sessionLogger.print3("The CNN model was loaded successfully from: " + str(filepathToCnnModelToLoad)) #Do final checks of the parameters. Check the ones that need check in comparison to the model's parameters! Such as: SAVE_PROBMAPS_PER_CLASS, INDICES_OF_FMS_TO_SAVE, Number of Channels! #testConfig.checkIfConfigIsCorrectForParticularCnnModel(cnn3dInstance) #Fill in the session's parameters. #[[case1-ch1, ..., caseN-ch1], [case1-ch2,...,caseN-ch2]] listOfAListPerChannelWithFilepathsOfAllCases = [parseAbsFileLinesInList(getAbsPathEvenIfRelativeIsGiven(channelConfPath, testConfigFilepath)) for channelConfPath in configGet(testConfig.CHANNELS)] #[[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]] listWithAListPerCaseWithFilepathPerChannel = [ list(item) for item in zip(*tuple(listOfAListPerChannelWithFilepathsOfAllCases)) ] gtLabelsFilepaths = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(testConfig.GT_LABELS), testConfigFilepath) ) if configGet(testConfig.GT_LABELS) else None roiMasksFilepaths = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(testConfig.ROI_MASKS), testConfigFilepath) ) if configGet(testConfig.ROI_MASKS) else None namesToSavePredsAndFeats = parseFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(testConfig.NAMES_FOR_PRED_PER_CASE), testConfigFilepath) ) if configGet(testConfig.NAMES_FOR_PRED_PER_CASE) else None #CAREFUL: Here we use a different parsing function! testSessionParameters = TestSessionParameters( sessionName = sessionName, sessionLogger = sessionLogger, mainOutputAbsFolder = mainOutputAbsFolder, cnn3dInstance = cnn3dInstance, cnnModelFilepath = filepathToCnnModelToLoad, #Input: listWithAListPerCaseWithFilepathPerChannel = listWithAListPerCaseWithFilepathPerChannel, gtLabelsFilepaths = gtLabelsFilepaths, roiMasksFilepaths = roiMasksFilepaths, #Output namesToSavePredictionsAndFeatures = namesToSavePredsAndFeats, #predictions saveSegmentation = configGet(testConfig.SAVE_SEGM), saveProbMapsBoolPerClass = configGet(testConfig.SAVE_PROBMAPS_PER_CLASS), folderForPredictions = folderForPredictions, #features: saveIndividualFmImages = configGet(testConfig.SAVE_INDIV_FMS), saveMultidimensionalImageWithAllFms = configGet(testConfig.SAVE_4DIM_FMS), indicesOfFmsToVisualisePerPathwayAndLayer = [ configGet(testConfig.INDICES_OF_FMS_TO_SAVE_NORMAL), configGet(testConfig.INDICES_OF_FMS_TO_SAVE_SUBSAMPLED), configGet(testConfig.INDICES_OF_FMS_TO_SAVE_FC) ], folderForFeatures = folderForFeatures, padInputImagesBool = configGet(testConfig.PAD_INPUT), ) testSessionParameters.sessionLogger.print3("=========== NEW TESTING SESSION ===============") testSessionParameters.printParametersOfThisSession() testSessionParameters.sessionLogger.print3("======================================================") testSessionParameters.sessionLogger.print3("=========== Testing with the CNN model ===============") testSessionParameters.sessionLogger.print3("======================================================") performInferenceForTestingOnWholeVolumes(*testSessionParameters.getTupleForCnnTesting()) testSessionParameters.sessionLogger.print3("======================================================") testSessionParameters.sessionLogger.print3("=========== Testing session finished =================") testSessionParameters.sessionLogger.print3("======================================================")
def deepMedicNewModelMain(modelConfigFilepath, absPathToPreTrainedModelGivenInCmdLine, listOfLayersToTransfer): print("Given Model-Configuration File: ", modelConfigFilepath) #Parse the config file in this naive fashion... modelConfig = ModelConfig() exec(open(modelConfigFilepath).read(), modelConfig.configStruct) configGet = modelConfig.get #Main interface #Create Folders and Logger mainOutputAbsFolder = getAbsPathEvenIfRelativeIsGiven( configGet(modelConfig.FOLDER_FOR_OUTPUT), modelConfigFilepath) modelName = configGet(modelConfig.MODEL_NAME) if configGet( modelConfig.MODEL_NAME ) else CreateModelSessionParameters.getDefaultModelName() [folderForCnnModels, folderForLogs ] = makeFoldersNeededForCreateModelSession(mainOutputAbsFolder, modelName) loggerFileName = folderForLogs + "/" + modelName + ".txt" sessionLogger = myLoggerModule.MyLogger(loggerFileName) sessionLogger.print3( "CONFIG: The configuration file for the model-creation session was loaded from: " + str(modelConfigFilepath)) #Fill in the session's parameters. createModelSessionParameters = CreateModelSessionParameters( cnnModelName=modelName, sessionLogger=sessionLogger, mainOutputAbsFolder=mainOutputAbsFolder, folderForSessionCnnModels=folderForCnnModels, #===MODEL PARAMETERS=== numberClasses=configGet(modelConfig.NUMB_CLASSES), numberOfInputChannelsNormal=configGet( modelConfig.NUMB_INPUT_CHANNELS_NORMAL), #===Normal pathway=== numFMsNormal=configGet(modelConfig.N_FMS_NORM), kernDimNormal=configGet(modelConfig.KERN_DIM_NORM), residConnAtLayersNormal=configGet(ModelConfig.RESID_CONN_LAYERS_NORM), lowerRankLayersNormal=configGet(ModelConfig.LOWER_RANK_LAYERS_NORM), #==Subsampled pathway== useSubsampledBool=configGet(modelConfig.USE_SUBSAMPLED), numFMsSubsampled=configGet(modelConfig.N_FMS_SUBS), kernDimSubsampled=configGet(modelConfig.KERN_DIM_SUBS), subsampleFactor=configGet(modelConfig.SUBS_FACTOR), residConnAtLayersSubsampled=configGet( ModelConfig.RESID_CONN_LAYERS_SUBS), lowerRankLayersSubsampled=configGet( ModelConfig.LOWER_RANK_LAYERS_SUBS), #==FC Layers==== numFMsFc=configGet(modelConfig.N_FMS_FC), kernelDimensionsFirstFcLayer=configGet(modelConfig.KERN_DIM_1ST_FC), residConnAtLayersFc=configGet(ModelConfig.RESID_CONN_LAYERS_FC), #==Size of Image Segments == segmDimTrain=configGet(modelConfig.SEG_DIM_TRAIN), segmDimVal=configGet(modelConfig.SEG_DIM_VAL), segmDimInfer=configGet(modelConfig.SEG_DIM_INFERENCE), #== Batch Sizes == batchSizeTrain=configGet(modelConfig.BATCH_SIZE_TR), batchSizeVal=configGet(modelConfig.BATCH_SIZE_VAL), batchSizeInfer=configGet(modelConfig.BATCH_SIZE_INFER), #===Other Architectural Parameters === activationFunction=configGet(modelConfig.ACTIV_FUNCTION), #==Dropout Rates== dropNormal=configGet(modelConfig.DROP_R_NORM), dropSubsampled=configGet(modelConfig.DROP_R_SUBS), dropFc=configGet(modelConfig.DROP_R_FC), #== Weight Initialization== initialMethod=configGet(modelConfig.INITIAL_METHOD), #== Batch Normalization == bnRollingAverOverThatManyBatches=configGet( modelConfig.BN_ROLL_AV_BATCHES), ) createModelSessionParameters.sessionLogger.print3( "\n=========== NEW CREATE-MODEL SESSION ============") createModelSessionParameters.printParametersOfThisSession() createModelSessionParameters.sessionLogger.print3( "\n=========== Creating the CNN model ===============") cnn3dInstance = Cnn3d() cnn3dInstance.make_cnn_model( *createModelSessionParameters.getTupleForCnnCreation()) if absPathToPreTrainedModelGivenInCmdLine != None: # Transfer parameters from a previously trained model to the new one. createModelSessionParameters.sessionLogger.print3( "\n=========== Pre-training the new model ===============") sessionLogger.print3( "...Loading the pre-trained network. This can take a few minutes if the model is big..." ) cnnPretrainedInstance = load_object_from_gzip_file( absPathToPreTrainedModelGivenInCmdLine) sessionLogger.print3( "The pre-trained model was loaded successfully from: " + str(absPathToPreTrainedModelGivenInCmdLine)) from deepmedic import cnnTransferParameters cnn3dInstance = cnnTransferParameters.transferParametersBetweenModels( sessionLogger, cnn3dInstance, cnnPretrainedInstance, listOfLayersToTransfer) createModelSessionParameters.sessionLogger.print3( "\n=========== Saving the model ===============") if absPathToPreTrainedModelGivenInCmdLine != None: filenameAndPathToSaveModel = createModelSessionParameters.getPathAndFilenameToSaveModel( ) + ".initial.pretrained." + datetimeNowAsStr() else: filenameAndPathToSaveModel = createModelSessionParameters.getPathAndFilenameToSaveModel( ) + ".initial." + datetimeNowAsStr() filenameAndPathWhereModelWasSaved = dump_cnn_to_gzip_file_dotSave( cnn3dInstance, filenameAndPathToSaveModel, sessionLogger) createModelSessionParameters.sessionLogger.print3( "=========== Creation of the model: \"" + str(createModelSessionParameters.cnnModelName) + "\" finished =================") return (cnn3dInstance, filenameAndPathWhereModelWasSaved)
def deepMedicTrainMain(trainConfigFilepath, absPathToSavedModelFromCmdLine, cnnInstancePreLoaded, filenameAndPathWherePreLoadedModelWas, resetOptimizer) : print "Given Training-Configuration File: ", trainConfigFilepath #Parse the config file in this naive fashion... trainConfig = TrainConfig() execfile(trainConfigFilepath, trainConfig.configStruct) configGet = trainConfig.get #Main interface """ #Do checks. checkIfMainTestConfigIsCorrect(testConfig, testConfigFilepath, absPathToSavedModelFromCmdLine) #Checks REQUIRED fields are complete. checkIfFilesThatListFilesPerCaseAreCorrect(testConfig, testConfigFilepath) #Checks listing-files (whatever given). checkIfOptionalParametersAreGivenCorrectly(testConfig, testConfigFilepath) #At this point it was checked that all parameters (that could be checked) and filepaths are correct, pointing to files/dirs and all files/dirs exist. """ #Create Folders and Logger mainOutputAbsFolder = getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.FOLDER_FOR_OUTPUT), trainConfigFilepath) sessionName = configGet(trainConfig.SESSION_NAME) if configGet(trainConfig.SESSION_NAME) else TrainSessionParameters.getDefaultSessionName() [folderForSessionCnnModels, folderForLogs, folderForPredictions, folderForFeatures] = makeFoldersNeededForTrainingSession(mainOutputAbsFolder, sessionName) loggerFileName = folderForLogs + "/" + sessionName + ".txt" sessionLogger = myLoggerModule.MyLogger(loggerFileName) sessionLogger.print3("CONFIG: The configuration file for the training session was loaded from: " + str(trainConfigFilepath)) #Load the cnn Model if not given straight from a createCnnModel-session... cnn3dInstance = None filepathToCnnModel = None if cnnInstancePreLoaded : sessionLogger.print3("====== CNN-Instance already loaded (from Create-New-Model procedure) =======") if configGet(trainConfig.CNN_MODEL_FILEPATH) : sessionLogger.print3("WARN: Any paths to cnn-models specified in the train-config file will be ommitted! Working with the pre-loaded model!") cnn3dInstance = cnnInstancePreLoaded filepathToCnnModel = filenameAndPathWherePreLoadedModelWas else : sessionLogger.print3("=========== Loading the CNN model for training... ===============") #If CNN-Model was specified in command line, completely override the one in the config file. if absPathToSavedModelFromCmdLine and configGet(trainConfig.CNN_MODEL_FILEPATH) : sessionLogger.print3("WARN: A CNN-Model to use was specified both in the command line input and in the train-config-file! The input by the command line will be used: " + str(absPathToSavedModelFromCmdLine) ) filepathToCnnModel = absPathToSavedModelFromCmdLine else : filepathToCnnModel = getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.CNN_MODEL_FILEPATH), trainConfigFilepath) sessionLogger.print3("...Loading the network can take a few minutes if the model is big...") cnn3dInstance = load_object_from_gzip_file(filepathToCnnModel) sessionLogger.print3("The CNN model was loaded successfully from: " + str(filepathToCnnModel)) """ #Do final checks of the parameters. Check the ones that need check in comparison to the model's parameters! Such as: SAVE_PROBMAPS_PER_CLASS, INDICES_OF_FMS_TO_SAVE, Number of Channels! trainConfig.checkIfConfigIsCorrectForParticularCnnModel(cnn3dInstance) """ #Fill in the session's parameters. if configGet(trainConfig.CHANNELS_TR) : #[[case1-ch1, ..., caseN-ch1], [case1-ch2,...,caseN-ch2]] listOfAListPerChannelWithFilepathsOfAllCasesTrain = [parseAbsFileLinesInList(getAbsPathEvenIfRelativeIsGiven(channelConfPath, trainConfigFilepath)) for channelConfPath in configGet(trainConfig.CHANNELS_TR)] #[[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]] listWithAListPerCaseWithFilepathPerChannelTrain = [ list(item) for item in zip(*tuple(listOfAListPerChannelWithFilepathsOfAllCasesTrain)) ] else : listWithAListPerCaseWithFilepathPerChannelTrain = None gtLabelsFilepathsTrain = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.GT_LABELS_TR), trainConfigFilepath) ) if configGet(trainConfig.GT_LABELS_TR) else None roiMasksFilepathsTrain = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.ROI_MASKS_TR), trainConfigFilepath) ) if configGet(trainConfig.ROI_MASKS_TR) else None #~~~~~ Advanced Training Sampling~~~~~~~~~ if configGet(trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_TR) : #[[case1-weightMap1, ..., caseN-weightMap1], [case1-weightMap2,...,caseN-weightMap2]] listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain = [parseAbsFileLinesInList(getAbsPathEvenIfRelativeIsGiven(weightMapConfPath, trainConfigFilepath)) for weightMapConfPath in configGet(trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_TR)] else : listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain = None #=======VALIDATION========== if configGet(trainConfig.CHANNELS_VAL) : listOfAListPerChannelWithFilepathsOfAllCasesVal = [parseAbsFileLinesInList(getAbsPathEvenIfRelativeIsGiven(channelConfPath, trainConfigFilepath)) for channelConfPath in configGet(trainConfig.CHANNELS_VAL)] #[[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]] listWithAListPerCaseWithFilepathPerChannelVal = [ list(item) for item in zip(*tuple(listOfAListPerChannelWithFilepathsOfAllCasesVal)) ] else : listWithAListPerCaseWithFilepathPerChannelVal = None gtLabelsFilepathsVal = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.GT_LABELS_VAL), trainConfigFilepath) ) if configGet(trainConfig.GT_LABELS_VAL) else None roiMasksFilepathsVal = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.ROI_MASKS_VAL), trainConfigFilepath) ) if configGet(trainConfig.ROI_MASKS_VAL) else None #~~~~~Full Inference~~~~~~ namesToSavePredsAndFeatsVal = parseFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.NAMES_FOR_PRED_PER_CASE_VAL), trainConfigFilepath) ) if configGet(trainConfig.NAMES_FOR_PRED_PER_CASE_VAL) else None #CAREFUL: Here we use a different parsing function! #~~~~~Advanced Validation Sampling~~~~~~~~ if configGet(trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_VAL) : #[[case1-weightMap1, ..., caseN-weightMap1], [case1-weightMap2,...,caseN-weightMap2]] listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal = [parseAbsFileLinesInList(getAbsPathEvenIfRelativeIsGiven(weightMapConfPath, trainConfigFilepath)) for weightMapConfPath in configGet(trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_VAL)] else : listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal = None trainSessionParameters = TrainSessionParameters( sessionName = sessionName, sessionLogger = sessionLogger, mainOutputAbsFolder = mainOutputAbsFolder, cnn3dInstance = cnn3dInstance, cnnModelFilepath = filepathToCnnModel, #==================TRAINING==================== folderForSessionCnnModels = folderForSessionCnnModels, listWithAListPerCaseWithFilepathPerChannelTrain = listWithAListPerCaseWithFilepathPerChannelTrain, gtLabelsFilepathsTrain = gtLabelsFilepathsTrain, #[Optionals] #~~~~~~~~~Sampling~~~~~~~~~ roiMasksFilepathsTrain = roiMasksFilepathsTrain, percentOfSamplesToExtractPositTrain = configGet(trainConfig.PERC_POS_SAMPLES_TR), #~~~~~~~~~Advanced Sampling~~~~~~~ useDefaultTrainingSamplingFromGtAndRoi = configGet(trainConfig.DEFAULT_TR_SAMPLING), samplingTypeTraining = configGet(trainConfig.TYPE_OF_SAMPLING_TR), proportionOfSamplesPerCategoryTrain = configGet(trainConfig.PROP_OF_SAMPLES_PER_CAT_TR), listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain = listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain, #~~~~~~~~Training Cycle ~~~~~~~ numberOfEpochs = configGet(trainConfig.NUM_EPOCHS), numberOfSubepochs = configGet(trainConfig.NUM_SUBEP), numOfCasesLoadedPerSubepoch = configGet(trainConfig.NUM_CASES_LOADED_PERSUB), segmentsLoadedOnGpuPerSubepochTrain = configGet(trainConfig.NUM_TR_SEGMS_LOADED_PERSUB), #~~~~~~~ Learning Rate Schedule ~~~~~~~ #Auto requires performValidationOnSamplesThroughoutTraining and providedGtForValidationBool stable0orAuto1orPredefined2orExponential3LrSchedule = configGet(trainConfig.LR_SCH_0123), #Stable + Auto + Predefined. whenDecreasingDivideLrBy = configGet(trainConfig.DIV_LR_BY), #Stable + Auto numEpochsToWaitBeforeLoweringLr = configGet(trainConfig.NUM_EPOCHS_WAIT), #Auto: minIncreaseInValidationAccuracyThatResetsWaiting = configGet(trainConfig.AUTO_MIN_INCR_VAL_ACC), #Predefined. predefinedSchedule = configGet(trainConfig.PREDEF_SCH), #Exponential exponentialSchedForLrAndMom = configGet(trainConfig.EXPON_SCH), #~~~~~~~ Augmentation~~~~~~~~~~~~ reflectImagesPerAxis = configGet(trainConfig.REFL_AUGM_PER_AXIS), performIntAugm = configGet(trainConfig.PERF_INT_AUGM_BOOL), sampleIntAugmShiftWithMuAndStd = configGet(trainConfig.INT_AUGM_SHIF_MUSTD), sampleIntAugmMultiWithMuAndStd = configGet(trainConfig.INT_AUGM_MULT_MUSTD), #==================VALIDATION===================== performValidationOnSamplesThroughoutTraining = configGet(trainConfig.PERFORM_VAL_SAMPLES), performFullInferenceOnValidationImagesEveryFewEpochs = configGet(trainConfig.PERFORM_VAL_INFERENCE), #Required: listWithAListPerCaseWithFilepathPerChannelVal = listWithAListPerCaseWithFilepathPerChannelVal, gtLabelsFilepathsVal = gtLabelsFilepathsVal, segmentsLoadedOnGpuPerSubepochVal = configGet(trainConfig.NUM_VAL_SEGMS_LOADED_PERSUB), #[Optionals] roiMasksFilepathsVal = roiMasksFilepathsVal, #For default sampling and for fast inference. Optional. Otherwise from whole image. #~~~~~~~~Full Inference~~~~~~~~ numberOfEpochsBetweenFullInferenceOnValImages = configGet(trainConfig.NUM_EPOCHS_BETWEEN_VAL_INF), #Output namesToSavePredictionsAndFeaturesVal = namesToSavePredsAndFeatsVal, #predictions saveSegmentationVal = configGet(trainConfig.SAVE_SEGM_VAL), saveProbMapsBoolPerClassVal = configGet(trainConfig.SAVE_PROBMAPS_PER_CLASS_VAL), folderForPredictionsVal = folderForPredictions, #features: saveIndividualFmImagesVal = configGet(trainConfig.SAVE_INDIV_FMS_VAL), saveMultidimensionalImageWithAllFmsVal = configGet(trainConfig.SAVE_4DIM_FMS_VAL), indicesOfFmsToVisualisePerPathwayAndLayerVal = [configGet(trainConfig.INDICES_OF_FMS_TO_SAVE_NORMAL_VAL)] +\ [configGet(trainConfig.INDICES_OF_FMS_TO_SAVE_SUBSAMPLED_VAL)] +\ [configGet(trainConfig.INDICES_OF_FMS_TO_SAVE_FC_VAL)], folderForFeaturesVal = folderForFeatures, #~~~~~~~~ Advanced Validation Sampling ~~~~~~~~~~ useDefaultUniformValidationSampling = configGet(trainConfig.DEFAULT_VAL_SAMPLING), samplingTypeValidation = configGet(trainConfig.TYPE_OF_SAMPLING_VAL), proportionOfSamplesPerCategoryVal = configGet(trainConfig.PROP_OF_SAMPLES_PER_CAT_VAL), listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal = listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal, #====Optimization===== learningRate=configGet(trainConfig.LRATE), optimizerSgd0Adam1Rms2=configGet(trainConfig.OPTIMIZER), classicMom0Nesterov1=configGet(trainConfig.MOM_TYPE), momentumValue=configGet(trainConfig.MOM), momNonNormalized0Normalized1=configGet(trainConfig.MOM_NORM_NONNORM), #Adam b1Adam=configGet(trainConfig.B1_ADAM), b2Adam=configGet(trainConfig.B2_ADAM), eAdam=configGet(trainConfig.EPS_ADAM), #Rms rhoRms=configGet(trainConfig.RHO_RMS), eRms=configGet(trainConfig.EPS_RMS), #Regularization l1Reg=configGet(trainConfig.L1_REG), l2Reg=configGet(trainConfig.L2_REG), #~~~~~~~ Freeze Layers ~~~~~~~ layersToFreezePerPathwayType = [configGet(trainConfig.LAYERS_TO_FREEZE_NORM), configGet(trainConfig.LAYERS_TO_FREEZE_SUBS), configGet(trainConfig.LAYERS_TO_FREEZE_FC) ], #==============Generic and Preprocessing=============== padInputImagesBool = configGet(trainConfig.PAD_INPUT) ) trainSessionParameters.sessionLogger.print3("\n=========== NEW TRAINING SESSION ===============") trainSessionParameters.printParametersOfThisSession() trainSessionParameters.sessionLogger.print3("\n=======================================================") trainSessionParameters.sessionLogger.print3("=========== Compiling the Training Function ===========") trainSessionParameters.sessionLogger.print3("=======================================================") if not cnn3dInstance.checkTrainingStateAttributesInitialized() or resetOptimizer : trainSessionParameters.sessionLogger.print3("(Re)Initializing parameters for the optimization. " \ "Reason: Uninitialized: ["+str(not cnn3dInstance.checkTrainingStateAttributesInitialized())+"], Reset requested: ["+str(resetOptimizer)+"]" ) cnn3dInstance.initializeTrainingState(*trainSessionParameters.getTupleForInitializingTrainingState()) cnn3dInstance.compileTrainFunction(*trainSessionParameters.getTupleForCompilationOfTrainFunc()) trainSessionParameters.sessionLogger.print3("\n=========== Compiling the Validation Function =========") cnn3dInstance.compileValidationFunction(*trainSessionParameters.getTupleForCompilationOfValFunc()) trainSessionParameters.sessionLogger.print3("\n=========== Compiling the Testing Function ============") cnn3dInstance.compileTestAndVisualisationFunction(*trainSessionParameters.getTupleForCompilationOfTestFunc()) # For validation with full segmentation trainSessionParameters.sessionLogger.print3("\n=======================================================") trainSessionParameters.sessionLogger.print3("============== Training the CNN model =================") trainSessionParameters.sessionLogger.print3("=======================================================") do_training(*trainSessionParameters.getTupleForCnnTraining()) trainSessionParameters.sessionLogger.print3("\n=======================================================") trainSessionParameters.sessionLogger.print3("=========== Training session finished =================") trainSessionParameters.sessionLogger.print3("=======================================================")