示例#1
0
def deepMedicTestMain(testConfigFilepath, absPathToSavedModelFromCmdLine):
    print("Given Test-Configuration File: ", testConfigFilepath)
    #Parse the config file in this naive fashion...
    testConfig = TestConfig()
    #configStruct = testConfig.configStruct
    exec(open(testConfigFilepath).read(), testConfig.configStruct)
    configGet = testConfig.get  #Main interface

    #Do checks.
    #checkIfMainTestConfigIsCorrect(testConfig, testConfigFilepath, absPathToSavedModelFromCmdLine) #Checks REQUIRED fields are complete.
    #checkIfFilesThatListFilesPerCaseAreCorrect(testConfig, testConfigFilepath) #Checks listing-files (whatever given).
    #checkIfOptionalParametersAreGivenCorrectly(testConfig, testConfigFilepath)

    #At this point it was checked that all parameters (that could be checked) and filepaths are correct, pointing to files/dirs and all files/dirs exist.

    #Create Folders and Logger
    mainOutputAbsFolder = getAbsPathEvenIfRelativeIsGiven(
        configGet(testConfig.FOLDER_FOR_OUTPUT), testConfigFilepath)
    sessionName = configGet(testConfig.SESSION_NAME) if configGet(
        testConfig.SESSION_NAME
    ) else TestSessionParameters.getDefaultSessionName()
    [folderForLogs, folderForPredictions, folderForFeatures
     ] = makeFoldersNeededForTestingSession(mainOutputAbsFolder, sessionName)
    loggerFileName = folderForLogs + "/" + sessionName + ".txt"
    sessionLogger = myLoggerModule.MyLogger(loggerFileName)

    sessionLogger.print3(
        "CONFIG: The configuration file for the testing session was loaded from: "
        + str(testConfigFilepath))

    #Load the CNN Model!
    sessionLogger.print3(
        "=========== Loading the CNN model for testing... ===============")
    #If CNN-Model was specified in command line, completely override the one in the config file.
    filepathToCnnModelToLoad = None
    if absPathToSavedModelFromCmdLine and configGet(
            testConfig.CNN_MODEL_FILEPATH):
        sessionLogger.print3(
            "WARN: A CNN-Model to use was specified both in the command line input and in the test-config-file! The input by the command line will be used: "
            + str(absPathToSavedModelFromCmdLine))
        filepathToCnnModelToLoad = absPathToSavedModelFromCmdLine
    elif absPathToSavedModelFromCmdLine:
        filepathToCnnModelToLoad = absPathToSavedModelFromCmdLine
    else:
        filepathToCnnModelToLoad = getAbsPathEvenIfRelativeIsGiven(
            configGet(testConfig.CNN_MODEL_FILEPATH), testConfigFilepath)
    sessionLogger.print3(
        "...Loading the network can take a few minutes if the model is big...")
    cnn3dInstance = load_object_from_gzip_file(filepathToCnnModelToLoad)
    sessionLogger.print3("The CNN model was loaded successfully from: " +
                         str(filepathToCnnModelToLoad))
    #Do final checks of the parameters. Check the ones that need check in comparison to the model's parameters! Such as: SAVE_PROBMAPS_PER_CLASS, INDICES_OF_FMS_TO_SAVE, Number of Channels!
    #testConfig.checkIfConfigIsCorrectForParticularCnnModel(cnn3dInstance)

    #Fill in the session's parameters.
    #[[case1-ch1, ..., caseN-ch1], [case1-ch2,...,caseN-ch2]]
    listOfAListPerChannelWithFilepathsOfAllCases = [
        parseAbsFileLinesInList(
            getAbsPathEvenIfRelativeIsGiven(channelConfPath,
                                            testConfigFilepath))
        for channelConfPath in configGet(testConfig.CHANNELS)
    ]
    #[[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]]
    listWithAListPerCaseWithFilepathPerChannel = [
        list(item)
        for item in zip(*tuple(listOfAListPerChannelWithFilepathsOfAllCases))
    ]
    gtLabelsFilepaths = parseAbsFileLinesInList(
        getAbsPathEvenIfRelativeIsGiven(configGet(testConfig.GT_LABELS),
                                        testConfigFilepath)) if configGet(
                                            testConfig.GT_LABELS) else None
    roiMasksFilepaths = parseAbsFileLinesInList(
        getAbsPathEvenIfRelativeIsGiven(configGet(testConfig.ROI_MASKS),
                                        testConfigFilepath)) if configGet(
                                            testConfig.ROI_MASKS) else None
    namesToSavePredsAndFeats = parseFileLinesInList(
        getAbsPathEvenIfRelativeIsGiven(
            configGet(testConfig.NAMES_FOR_PRED_PER_CASE),
            testConfigFilepath)) if configGet(
                testConfig.NAMES_FOR_PRED_PER_CASE
            ) else None  #CAREFUL: Here we use a different parsing function!

    testSessionParameters = TestSessionParameters(
                    sessionName = sessionName,
                    sessionLogger = sessionLogger,
                    mainOutputAbsFolder = mainOutputAbsFolder,
                    cnn3dInstance = cnn3dInstance,
                    cnnModelFilepath = filepathToCnnModelToLoad,

                    #Input:
                    listWithAListPerCaseWithFilepathPerChannel = listWithAListPerCaseWithFilepathPerChannel,
                    gtLabelsFilepaths = gtLabelsFilepaths,
                    roiMasksFilepaths = roiMasksFilepaths,

                    #Output
                    namesToSavePredictionsAndFeatures = namesToSavePredsAndFeats,
                    #predictions
                    saveSegmentation = configGet(testConfig.SAVE_SEGM),
                    saveProbMapsBoolPerClass = configGet(testConfig.SAVE_PROBMAPS_PER_CLASS),
                    folderForPredictions = folderForPredictions,

                    #features:
                    saveIndividualFmImages = configGet(testConfig.SAVE_INDIV_FMS),
                    saveMultidimensionalImageWithAllFms = configGet(testConfig.SAVE_4DIM_FMS),
                    indicesOfFmsToVisualisePerPathwayAndLayer = [configGet(testConfig.INDICES_OF_FMS_TO_SAVE_NORMAL)] +\
                                                                [configGet(testConfig.INDICES_OF_FMS_TO_SAVE_SUBSAMPLED)] +\
                                                                [configGet(testConfig.INDICES_OF_FMS_TO_SAVE_FC)
                                                                ],
                    folderForFeatures = folderForFeatures,

                    padInputImagesBool = configGet(testConfig.PAD_INPUT),
                    )

    testSessionParameters.sessionLogger.print3(
        "\n===========       NEW TESTING SESSION         ===============")
    testSessionParameters.printParametersOfThisSession()

    testSessionParameters.sessionLogger.print3(
        "\n=======================================================")
    testSessionParameters.sessionLogger.print3(
        "=========== Compiling the Testing Function ============")
    testSessionParameters.sessionLogger.print3(
        "=======================================================")
    cnn3dInstance.compileTestAndVisualisationFunction(
        *testSessionParameters.getTupleForCompilationOfTestFunc())

    testSessionParameters.sessionLogger.print3(
        "\n======================================================")
    testSessionParameters.sessionLogger.print3(
        "=========== Testing with the CNN model ===============")
    testSessionParameters.sessionLogger.print3(
        "======================================================")
    performInferenceForTestingOnWholeVolumes(
        *testSessionParameters.getTupleForCnnTesting())
    testSessionParameters.sessionLogger.print3(
        "\n======================================================")
    testSessionParameters.sessionLogger.print3(
        "=========== Testing session finished =================")
    testSessionParameters.sessionLogger.print3(
        "======================================================")
示例#2
0
def deepMedicTrainMain(trainConfigFilepath, absPathToSavedModelFromCmdLine,
                       cnnInstancePreLoaded,
                       filenameAndPathWherePreLoadedModelWas, resetOptimizer):
    print("Given Training-Configuration File: ", trainConfigFilepath)
    #Parse the config file in this naive fashion...
    trainConfig = TrainConfig()
    exec(open(trainConfigFilepath).read(), trainConfig.configStruct)
    configGet = trainConfig.get  #Main interface
    """
    #Do checks.
    checkIfMainTestConfigIsCorrect(testConfig, testConfigFilepath, absPathToSavedModelFromCmdLine) #Checks REQUIRED fields are complete.
    checkIfFilesThatListFilesPerCaseAreCorrect(testConfig, testConfigFilepath) #Checks listing-files (whatever given).
    checkIfOptionalParametersAreGivenCorrectly(testConfig, testConfigFilepath)
    
    #At this point it was checked that all parameters (that could be checked) and filepaths are correct, pointing to files/dirs and all files/dirs exist.
    """

    #Create Folders and Logger
    mainOutputAbsFolder = getAbsPathEvenIfRelativeIsGiven(
        configGet(trainConfig.FOLDER_FOR_OUTPUT), trainConfigFilepath)
    sessionName = configGet(trainConfig.SESSION_NAME) if configGet(
        trainConfig.SESSION_NAME
    ) else TrainSessionParameters.getDefaultSessionName()
    [
        folderForSessionCnnModels, folderForLogs, folderForPredictions,
        folderForFeatures
    ] = makeFoldersNeededForTrainingSession(mainOutputAbsFolder, sessionName)
    loggerFileName = folderForLogs + "/" + sessionName + ".txt"
    sessionLogger = myLoggerModule.MyLogger(loggerFileName)

    sessionLogger.print3(
        "CONFIG: The configuration file for the training session was loaded from: "
        + str(trainConfigFilepath))

    #Load the cnn Model if not given straight from a createCnnModel-session...
    cnn3dInstance = None
    filepathToCnnModel = None
    if cnnInstancePreLoaded:
        sessionLogger.print3(
            "====== CNN-Instance already loaded (from Create-New-Model procedure) ======="
        )
        if configGet(trainConfig.CNN_MODEL_FILEPATH):
            sessionLogger.print3(
                "WARN: Any paths to cnn-models specified in the train-config file will be ommitted! Working with the pre-loaded model!"
            )
        cnn3dInstance = cnnInstancePreLoaded
        filepathToCnnModel = filenameAndPathWherePreLoadedModelWas
    else:
        sessionLogger.print3(
            "=========== Loading the CNN model for training... ==============="
        )
        #If CNN-Model was specified in command line, completely override the one in the config file.
        if absPathToSavedModelFromCmdLine and configGet(
                trainConfig.CNN_MODEL_FILEPATH):
            sessionLogger.print3(
                "WARN: A CNN-Model to use was specified both in the command line input and in the train-config-file! The input by the command line will be used: "
                + str(absPathToSavedModelFromCmdLine))
            filepathToCnnModel = absPathToSavedModelFromCmdLine
        else:
            filepathToCnnModel = getAbsPathEvenIfRelativeIsGiven(
                configGet(trainConfig.CNN_MODEL_FILEPATH), trainConfigFilepath)
        sessionLogger.print3(
            "...Loading the network can take a few minutes if the model is big..."
        )
        cnn3dInstance = load_object_from_gzip_file(filepathToCnnModel)
        sessionLogger.print3("The CNN model was loaded successfully from: " +
                             str(filepathToCnnModel))
    """
    #Do final checks of the parameters. Check the ones that need check in comparison to the model's parameters! Such as: SAVE_PROBMAPS_PER_CLASS, INDICES_OF_FMS_TO_SAVE, Number of Channels!
    trainConfig.checkIfConfigIsCorrectForParticularCnnModel(cnn3dInstance)
    """

    #Fill in the session's parameters.
    if configGet(trainConfig.CHANNELS_TR):
        #[[case1-ch1, ..., caseN-ch1], [case1-ch2,...,caseN-ch2]]
        listOfAListPerChannelWithFilepathsOfAllCasesTrain = [
            parseAbsFileLinesInList(
                getAbsPathEvenIfRelativeIsGiven(channelConfPath,
                                                trainConfigFilepath))
            for channelConfPath in configGet(trainConfig.CHANNELS_TR)
        ]
        #[[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]]
        listWithAListPerCaseWithFilepathPerChannelTrain = [
            list(item) for item in zip(
                *tuple(listOfAListPerChannelWithFilepathsOfAllCasesTrain))
        ]
    else:
        listWithAListPerCaseWithFilepathPerChannelTrain = None
    gtLabelsFilepathsTrain = parseAbsFileLinesInList(
        getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.GT_LABELS_TR),
                                        trainConfigFilepath)) if configGet(
                                            trainConfig.GT_LABELS_TR) else None
    roiMasksFilepathsTrain = parseAbsFileLinesInList(
        getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.ROI_MASKS_TR),
                                        trainConfigFilepath)) if configGet(
                                            trainConfig.ROI_MASKS_TR) else None
    #~~~~~ Advanced Training Sampling~~~~~~~~~
    if configGet(trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_TR):
        #[[case1-weightMap1, ..., caseN-weightMap1], [case1-weightMap2,...,caseN-weightMap2]]
        listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain = [
            parseAbsFileLinesInList(
                getAbsPathEvenIfRelativeIsGiven(weightMapConfPath,
                                                trainConfigFilepath))
            for weightMapConfPath in configGet(
                trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_TR)
        ]
    else:
        listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain = None

    #=======VALIDATION==========
    if configGet(trainConfig.CHANNELS_VAL):
        listOfAListPerChannelWithFilepathsOfAllCasesVal = [
            parseAbsFileLinesInList(
                getAbsPathEvenIfRelativeIsGiven(channelConfPath,
                                                trainConfigFilepath))
            for channelConfPath in configGet(trainConfig.CHANNELS_VAL)
        ]
        #[[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]]
        listWithAListPerCaseWithFilepathPerChannelVal = [
            list(item) for item in zip(
                *tuple(listOfAListPerChannelWithFilepathsOfAllCasesVal))
        ]
    else:
        listWithAListPerCaseWithFilepathPerChannelVal = None
    gtLabelsFilepathsVal = parseAbsFileLinesInList(
        getAbsPathEvenIfRelativeIsGiven(configGet(
            trainConfig.GT_LABELS_VAL), trainConfigFilepath)) if configGet(
                trainConfig.GT_LABELS_VAL) else None
    roiMasksFilepathsVal = parseAbsFileLinesInList(
        getAbsPathEvenIfRelativeIsGiven(configGet(
            trainConfig.ROI_MASKS_VAL), trainConfigFilepath)) if configGet(
                trainConfig.ROI_MASKS_VAL) else None
    #~~~~~Full Inference~~~~~~
    namesToSavePredsAndFeatsVal = parseFileLinesInList(
        getAbsPathEvenIfRelativeIsGiven(
            configGet(trainConfig.NAMES_FOR_PRED_PER_CASE_VAL),
            trainConfigFilepath)) if configGet(
                trainConfig.NAMES_FOR_PRED_PER_CASE_VAL
            ) else None  #CAREFUL: Here we use a different parsing function!
    #~~~~~Advanced Validation Sampling~~~~~~~~
    if configGet(trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_VAL):
        #[[case1-weightMap1, ..., caseN-weightMap1], [case1-weightMap2,...,caseN-weightMap2]]
        listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal = [
            parseAbsFileLinesInList(
                getAbsPathEvenIfRelativeIsGiven(weightMapConfPath,
                                                trainConfigFilepath))
            for weightMapConfPath in configGet(
                trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_VAL)
        ]
    else:
        listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal = None

    trainSessionParameters = TrainSessionParameters(
                    sessionName = sessionName,
                    sessionLogger = sessionLogger,
                    mainOutputAbsFolder = mainOutputAbsFolder,
                    cnn3dInstance = cnn3dInstance,
                    cnnModelFilepath = filepathToCnnModel,

                    #==================TRAINING====================
                    folderForSessionCnnModels = folderForSessionCnnModels,
                    listWithAListPerCaseWithFilepathPerChannelTrain = listWithAListPerCaseWithFilepathPerChannelTrain,
                    gtLabelsFilepathsTrain = gtLabelsFilepathsTrain,

                    #[Optionals]
                    #~~~~~~~~~Sampling~~~~~~~~~
                    roiMasksFilepathsTrain = roiMasksFilepathsTrain,
                    percentOfSamplesToExtractPositTrain = configGet(trainConfig.PERC_POS_SAMPLES_TR),
                    #~~~~~~~~~Advanced Sampling~~~~~~~
                    useDefaultTrainingSamplingFromGtAndRoi = configGet(trainConfig.DEFAULT_TR_SAMPLING),
                    samplingTypeTraining = configGet(trainConfig.TYPE_OF_SAMPLING_TR),
                    proportionOfSamplesPerCategoryTrain = configGet(trainConfig.PROP_OF_SAMPLES_PER_CAT_TR),
                    listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain = listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain,

                    #~~~~~~~~Training Cycle ~~~~~~~
                    numberOfEpochs = configGet(trainConfig.NUM_EPOCHS),
                    numberOfSubepochs = configGet(trainConfig.NUM_SUBEP),
                    numOfCasesLoadedPerSubepoch = configGet(trainConfig.NUM_CASES_LOADED_PERSUB),
                    segmentsLoadedOnGpuPerSubepochTrain = configGet(trainConfig.NUM_TR_SEGMS_LOADED_PERSUB),

                    #~~~~~~~ Learning Rate Schedule ~~~~~~~
                    #Auto requires performValidationOnSamplesThroughoutTraining and providedGtForValidationBool
                    stable0orAuto1orPredefined2orExponential3LrSchedule = configGet(trainConfig.LR_SCH_0123),

                    #Stable + Auto + Predefined.
                    whenDecreasingDivideLrBy = configGet(trainConfig.DIV_LR_BY),
                    #Stable + Auto
                    numEpochsToWaitBeforeLoweringLr = configGet(trainConfig.NUM_EPOCHS_WAIT),
                    #Auto:
                    minIncreaseInValidationAccuracyThatResetsWaiting = configGet(trainConfig.AUTO_MIN_INCR_VAL_ACC),
                    #Predefined.
                    predefinedSchedule = configGet(trainConfig.PREDEF_SCH),
                    #Exponential
                    exponentialSchedForLrAndMom = configGet(trainConfig.EXPON_SCH),

                    #~~~~~~~ Augmentation~~~~~~~~~~~~
                    reflectImagesPerAxis = configGet(trainConfig.REFL_AUGM_PER_AXIS),
                    performIntAugm = configGet(trainConfig.PERF_INT_AUGM_BOOL),
                    sampleIntAugmShiftWithMuAndStd = configGet(trainConfig.INT_AUGM_SHIF_MUSTD),
                    sampleIntAugmMultiWithMuAndStd = configGet(trainConfig.INT_AUGM_MULT_MUSTD),

                    #==================VALIDATION=====================
                    performValidationOnSamplesThroughoutTraining = configGet(trainConfig.PERFORM_VAL_SAMPLES),
                    performFullInferenceOnValidationImagesEveryFewEpochs = configGet(trainConfig.PERFORM_VAL_INFERENCE),
                    #Required:
                    listWithAListPerCaseWithFilepathPerChannelVal = listWithAListPerCaseWithFilepathPerChannelVal,
                    gtLabelsFilepathsVal = gtLabelsFilepathsVal,

                    segmentsLoadedOnGpuPerSubepochVal = configGet(trainConfig.NUM_VAL_SEGMS_LOADED_PERSUB),

                    #[Optionals]
                    roiMasksFilepathsVal = roiMasksFilepathsVal, #For default sampling and for fast inference. Optional. Otherwise from whole image.

                    #~~~~~~~~Full Inference~~~~~~~~
                    numberOfEpochsBetweenFullInferenceOnValImages = configGet(trainConfig.NUM_EPOCHS_BETWEEN_VAL_INF),
                    #Output
                    namesToSavePredictionsAndFeaturesVal = namesToSavePredsAndFeatsVal,
                    #predictions
                    saveSegmentationVal = configGet(trainConfig.SAVE_SEGM_VAL),
                    saveProbMapsBoolPerClassVal = configGet(trainConfig.SAVE_PROBMAPS_PER_CLASS_VAL),
                    folderForPredictionsVal = folderForPredictions,
                    #features:
                    saveIndividualFmImagesVal = configGet(trainConfig.SAVE_INDIV_FMS_VAL),
                    saveMultidimensionalImageWithAllFmsVal = configGet(trainConfig.SAVE_4DIM_FMS_VAL),
                    indicesOfFmsToVisualisePerPathwayAndLayerVal = [configGet(trainConfig.INDICES_OF_FMS_TO_SAVE_NORMAL_VAL)] +\
                                                                    [configGet(trainConfig.INDICES_OF_FMS_TO_SAVE_SUBSAMPLED_VAL)] +\
                                                                    [configGet(trainConfig.INDICES_OF_FMS_TO_SAVE_FC_VAL)],
                    folderForFeaturesVal = folderForFeatures,

                    #~~~~~~~~ Advanced Validation Sampling ~~~~~~~~~~
                    useDefaultUniformValidationSampling = configGet(trainConfig.DEFAULT_VAL_SAMPLING),
                    samplingTypeValidation = configGet(trainConfig.TYPE_OF_SAMPLING_VAL),
                    proportionOfSamplesPerCategoryVal = configGet(trainConfig.PROP_OF_SAMPLES_PER_CAT_VAL),
                    listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal = listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal,

                    #====Optimization=====
                    learningRate=configGet(trainConfig.LRATE),
                    optimizerSgd0Adam1Rms2=configGet(trainConfig.OPTIMIZER),
                    classicMom0Nesterov1=configGet(trainConfig.MOM_TYPE),
                    momentumValue=configGet(trainConfig.MOM),
                    momNonNormalized0Normalized1=configGet(trainConfig.MOM_NORM_NONNORM),
                    #Adam
                    b1Adam=configGet(trainConfig.B1_ADAM),
                    b2Adam=configGet(trainConfig.B2_ADAM),
                    eAdam=configGet(trainConfig.EPS_ADAM),
                    #Rms
                    rhoRms=configGet(trainConfig.RHO_RMS),
                    eRms=configGet(trainConfig.EPS_RMS),
                    #Regularization
                    l1Reg=configGet(trainConfig.L1_REG),
                    l2Reg=configGet(trainConfig.L2_REG),

                    #~~~~~~~ Freeze Layers ~~~~~~~
                    layersToFreezePerPathwayType = [configGet(trainConfig.LAYERS_TO_FREEZE_NORM),
                                                    configGet(trainConfig.LAYERS_TO_FREEZE_SUBS),
                                                    configGet(trainConfig.LAYERS_TO_FREEZE_FC) ],

                    #==============Generic and Preprocessing===============
                    padInputImagesBool = configGet(trainConfig.PAD_INPUT)
                    )

    trainSessionParameters.sessionLogger.print3(
        "\n===========       NEW TRAINING SESSION         ===============")
    trainSessionParameters.printParametersOfThisSession()

    trainSessionParameters.sessionLogger.print3(
        "\n=======================================================")
    trainSessionParameters.sessionLogger.print3(
        "=========== Compiling the Training Function ===========")
    trainSessionParameters.sessionLogger.print3(
        "=======================================================")
    if not cnn3dInstance.checkTrainingStateAttributesInitialized(
    ) or resetOptimizer:
        trainSessionParameters.sessionLogger.print3("(Re)Initializing parameters for the optimization. " \
                        "Reason: Uninitialized: ["+str(not cnn3dInstance.checkTrainingStateAttributesInitialized())+"], Reset requested: ["+str(resetOptimizer)+"]" )
        cnn3dInstance.initializeTrainingState(
            *trainSessionParameters.getTupleForInitializingTrainingState())
    cnn3dInstance.compileTrainFunction(
        *trainSessionParameters.getTupleForCompilationOfTrainFunc())
    trainSessionParameters.sessionLogger.print3(
        "\n=========== Compiling the Validation Function =========")
    cnn3dInstance.compileValidationFunction(
        *trainSessionParameters.getTupleForCompilationOfValFunc())
    trainSessionParameters.sessionLogger.print3(
        "\n=========== Compiling the Testing Function ============")
    cnn3dInstance.compileTestAndVisualisationFunction(
        *trainSessionParameters.getTupleForCompilationOfTestFunc(
        ))  # For validation with full segmentation

    trainSessionParameters.sessionLogger.print3(
        "\n=======================================================")
    trainSessionParameters.sessionLogger.print3(
        "============== Training the CNN model =================")
    trainSessionParameters.sessionLogger.print3(
        "=======================================================")
    do_training(*trainSessionParameters.getTupleForCnnTraining())
    trainSessionParameters.sessionLogger.print3(
        "\n=======================================================")
    trainSessionParameters.sessionLogger.print3(
        "=========== Training session finished =================")
    trainSessionParameters.sessionLogger.print3(
        "=======================================================")
def deepMedicTestMain(testConfigFilepath, absPathToSavedModelFromCmdLine) :


	print "Given Test-Configuration File: ", testConfigFilepath
	#Parse the config file in this naive fashion...
	testConfig = TestConfig()
	#configStruct = testConfig.configStruct
	execfile(testConfigFilepath, testConfig.configStruct)
	configGet = testConfig.get #Main interface
	
	#Do checks.
	#checkIfMainTestConfigIsCorrect(testConfig, testConfigFilepath, absPathToSavedModelFromCmdLine) #Checks REQUIRED fields are complete.
	#checkIfFilesThatListFilesPerCaseAreCorrect(testConfig, testConfigFilepath) #Checks listing-files (whatever given).
	#checkIfOptionalParametersAreGivenCorrectly(testConfig, testConfigFilepath)

	#At this point it was checked that all parameters (that could be checked) and filepaths are correct, pointing to files/dirs and all files/dirs exist.

	
	#Create Folders and Logger
	mainOutputAbsFolder = getAbsPathEvenIfRelativeIsGiven(configGet(testConfig.FOLDER_FOR_OUTPUT), testConfigFilepath)
	sessionName = configGet(testConfig.SESSION_NAME) if configGet(testConfig.SESSION_NAME) else TestSessionParameters.getDefaultSessionName()
	[folderForLogs,
	folderForPredictions,
	folderForFeatures] = makeFoldersNeededForTestingSession(mainOutputAbsFolder, sessionName)
	loggerFileName = folderForLogs + "/" + sessionName + ".txt"
	sessionLogger = myLoggerModule.MyLogger(loggerFileName)

	sessionLogger.print3("CONFIG: The configuration file for the testing session was loaded from: " + str(testConfigFilepath))

	#Load the CNN Model!
	sessionLogger.print3("=========== Loading the CNN model for testing... ===============")
	#If CNN-Model was specified in command line, completely override the one in the config file.
	filepathToCnnModelToLoad = None
	if absPathToSavedModelFromCmdLine and configGet(testConfig.CNN_MODEL_FILEPATH) :
		sessionLogger.print3("WARN: A CNN-Model to use was specified both in the command line input and in the test-config-file! The input by the command line will be used: " + str(absPathToSavedModelFromCmdLine) )
		filepathToCnnModelToLoad = absPathToSavedModelFromCmdLine
	elif absPathToSavedModelFromCmdLine :
		filepathToCnnModelToLoad = absPathToSavedModelFromCmdLine
	else :
		filepathToCnnModelToLoad = getAbsPathEvenIfRelativeIsGiven(configGet(testConfig.CNN_MODEL_FILEPATH), testConfigFilepath)
	sessionLogger.print3("...Loading the network can take a few minutes if the model is big...")
	cnn3dInstance = load_object_from_gzip_file(filepathToCnnModelToLoad)
	sessionLogger.print3("The CNN model was loaded successfully from: " + str(filepathToCnnModelToLoad))
	#Do final checks of the parameters. Check the ones that need check in comparison to the model's parameters! Such as: SAVE_PROBMAPS_PER_CLASS, INDICES_OF_FMS_TO_SAVE, Number of Channels!
	#testConfig.checkIfConfigIsCorrectForParticularCnnModel(cnn3dInstance)


	#Fill in the session's parameters.
	#[[case1-ch1, ..., caseN-ch1], [case1-ch2,...,caseN-ch2]]
	listOfAListPerChannelWithFilepathsOfAllCases = [parseAbsFileLinesInList(getAbsPathEvenIfRelativeIsGiven(channelConfPath, testConfigFilepath)) for channelConfPath in configGet(testConfig.CHANNELS)]
	#[[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]]
	listWithAListPerCaseWithFilepathPerChannel = [ list(item) for item in zip(*tuple(listOfAListPerChannelWithFilepathsOfAllCases)) ]
	gtLabelsFilepaths = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(testConfig.GT_LABELS), testConfigFilepath) ) if configGet(testConfig.GT_LABELS) else None
	roiMasksFilepaths = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(testConfig.ROI_MASKS), testConfigFilepath) ) if configGet(testConfig.ROI_MASKS) else None
	namesToSavePredsAndFeats = parseFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(testConfig.NAMES_FOR_PRED_PER_CASE), testConfigFilepath) ) if configGet(testConfig.NAMES_FOR_PRED_PER_CASE) else None #CAREFUL: Here we use a different parsing function!

	testSessionParameters = TestSessionParameters(
			sessionName = sessionName,
			sessionLogger = sessionLogger,
			mainOutputAbsFolder = mainOutputAbsFolder,
			cnn3dInstance = cnn3dInstance,
			cnnModelFilepath = filepathToCnnModelToLoad,


			#Input:
			listWithAListPerCaseWithFilepathPerChannel = listWithAListPerCaseWithFilepathPerChannel,
			gtLabelsFilepaths = gtLabelsFilepaths,
			roiMasksFilepaths = roiMasksFilepaths,

			#Output				

			
			namesToSavePredictionsAndFeatures = namesToSavePredsAndFeats,
			#predictions
			saveSegmentation = configGet(testConfig.SAVE_SEGM),
			saveProbMapsBoolPerClass = configGet(testConfig.SAVE_PROBMAPS_PER_CLASS),
			folderForPredictions = folderForPredictions,

			#features:
			saveIndividualFmImages = configGet(testConfig.SAVE_INDIV_FMS),
			saveMultidimensionalImageWithAllFms = configGet(testConfig.SAVE_4DIM_FMS),
			indicesOfFmsToVisualisePerPathwayAndLayer = [	configGet(testConfig.INDICES_OF_FMS_TO_SAVE_NORMAL),
									configGet(testConfig.INDICES_OF_FMS_TO_SAVE_SUBSAMPLED),
									configGet(testConfig.INDICES_OF_FMS_TO_SAVE_FC)
									],
			folderForFeatures = folderForFeatures,

			padInputImagesBool = configGet(testConfig.PAD_INPUT),

			)

	testSessionParameters.sessionLogger.print3("===========       NEW TESTING SESSION         ===============")
	testSessionParameters.printParametersOfThisSession()

	testSessionParameters.sessionLogger.print3("======================================================")
	testSessionParameters.sessionLogger.print3("=========== Testing with the CNN model ===============")
	testSessionParameters.sessionLogger.print3("======================================================")
	performInferenceForTestingOnWholeVolumes(*testSessionParameters.getTupleForCnnTesting())
	testSessionParameters.sessionLogger.print3("======================================================")
	testSessionParameters.sessionLogger.print3("=========== Testing session finished =================")
	testSessionParameters.sessionLogger.print3("======================================================")
示例#4
0
def deepMedicNewModelMain(modelConfigFilepath,
                          absPathToPreTrainedModelGivenInCmdLine,
                          listOfLayersToTransfer):
    print("Given Model-Configuration File: ", modelConfigFilepath)
    #Parse the config file in this naive fashion...
    modelConfig = ModelConfig()
    exec(open(modelConfigFilepath).read(), modelConfig.configStruct)
    configGet = modelConfig.get  #Main interface

    #Create Folders and Logger
    mainOutputAbsFolder = getAbsPathEvenIfRelativeIsGiven(
        configGet(modelConfig.FOLDER_FOR_OUTPUT), modelConfigFilepath)
    modelName = configGet(modelConfig.MODEL_NAME) if configGet(
        modelConfig.MODEL_NAME
    ) else CreateModelSessionParameters.getDefaultModelName()
    [folderForCnnModels, folderForLogs
     ] = makeFoldersNeededForCreateModelSession(mainOutputAbsFolder, modelName)
    loggerFileName = folderForLogs + "/" + modelName + ".txt"
    sessionLogger = myLoggerModule.MyLogger(loggerFileName)

    sessionLogger.print3(
        "CONFIG: The configuration file for the model-creation session was loaded from: "
        + str(modelConfigFilepath))

    #Fill in the session's parameters.
    createModelSessionParameters = CreateModelSessionParameters(
        cnnModelName=modelName,
        sessionLogger=sessionLogger,
        mainOutputAbsFolder=mainOutputAbsFolder,
        folderForSessionCnnModels=folderForCnnModels,
        #===MODEL PARAMETERS===
        numberClasses=configGet(modelConfig.NUMB_CLASSES),
        numberOfInputChannelsNormal=configGet(
            modelConfig.NUMB_INPUT_CHANNELS_NORMAL),
        #===Normal pathway===
        numFMsNormal=configGet(modelConfig.N_FMS_NORM),
        kernDimNormal=configGet(modelConfig.KERN_DIM_NORM),
        residConnAtLayersNormal=configGet(ModelConfig.RESID_CONN_LAYERS_NORM),
        lowerRankLayersNormal=configGet(ModelConfig.LOWER_RANK_LAYERS_NORM),
        #==Subsampled pathway==
        useSubsampledBool=configGet(modelConfig.USE_SUBSAMPLED),
        numFMsSubsampled=configGet(modelConfig.N_FMS_SUBS),
        kernDimSubsampled=configGet(modelConfig.KERN_DIM_SUBS),
        subsampleFactor=configGet(modelConfig.SUBS_FACTOR),
        residConnAtLayersSubsampled=configGet(
            ModelConfig.RESID_CONN_LAYERS_SUBS),
        lowerRankLayersSubsampled=configGet(
            ModelConfig.LOWER_RANK_LAYERS_SUBS),
        #==FC Layers====
        numFMsFc=configGet(modelConfig.N_FMS_FC),
        kernelDimensionsFirstFcLayer=configGet(modelConfig.KERN_DIM_1ST_FC),
        residConnAtLayersFc=configGet(ModelConfig.RESID_CONN_LAYERS_FC),
        #==Size of Image Segments ==
        segmDimTrain=configGet(modelConfig.SEG_DIM_TRAIN),
        segmDimVal=configGet(modelConfig.SEG_DIM_VAL),
        segmDimInfer=configGet(modelConfig.SEG_DIM_INFERENCE),
        #== Batch Sizes ==
        batchSizeTrain=configGet(modelConfig.BATCH_SIZE_TR),
        batchSizeVal=configGet(modelConfig.BATCH_SIZE_VAL),
        batchSizeInfer=configGet(modelConfig.BATCH_SIZE_INFER),
        #===Other Architectural Parameters ===
        activationFunction=configGet(modelConfig.ACTIV_FUNCTION),
        #==Dropout Rates==
        dropNormal=configGet(modelConfig.DROP_R_NORM),
        dropSubsampled=configGet(modelConfig.DROP_R_SUBS),
        dropFc=configGet(modelConfig.DROP_R_FC),
        #== Weight Initialization==
        initialMethod=configGet(modelConfig.INITIAL_METHOD),
        #== Batch Normalization ==
        bnRollingAverOverThatManyBatches=configGet(
            modelConfig.BN_ROLL_AV_BATCHES),
    )

    createModelSessionParameters.sessionLogger.print3(
        "\n===========    NEW CREATE-MODEL SESSION    ============")
    createModelSessionParameters.printParametersOfThisSession()

    createModelSessionParameters.sessionLogger.print3(
        "\n=========== Creating the CNN model ===============")
    cnn3dInstance = Cnn3d()
    cnn3dInstance.make_cnn_model(
        *createModelSessionParameters.getTupleForCnnCreation())

    if absPathToPreTrainedModelGivenInCmdLine != None:  # Transfer parameters from a previously trained model to the new one.
        createModelSessionParameters.sessionLogger.print3(
            "\n=========== Pre-training the new model ===============")
        sessionLogger.print3(
            "...Loading the pre-trained network. This can take a few minutes if the model is big..."
        )
        cnnPretrainedInstance = load_object_from_gzip_file(
            absPathToPreTrainedModelGivenInCmdLine)
        sessionLogger.print3(
            "The pre-trained model was loaded successfully from: " +
            str(absPathToPreTrainedModelGivenInCmdLine))
        from deepmedic import cnnTransferParameters
        cnn3dInstance = cnnTransferParameters.transferParametersBetweenModels(
            sessionLogger, cnn3dInstance, cnnPretrainedInstance,
            listOfLayersToTransfer)

    createModelSessionParameters.sessionLogger.print3(
        "\n=========== Saving the model ===============")
    if absPathToPreTrainedModelGivenInCmdLine != None:
        filenameAndPathToSaveModel = createModelSessionParameters.getPathAndFilenameToSaveModel(
        ) + ".initial.pretrained." + datetimeNowAsStr()
    else:
        filenameAndPathToSaveModel = createModelSessionParameters.getPathAndFilenameToSaveModel(
        ) + ".initial." + datetimeNowAsStr()
    filenameAndPathWhereModelWasSaved = dump_cnn_to_gzip_file_dotSave(
        cnn3dInstance, filenameAndPathToSaveModel, sessionLogger)
    createModelSessionParameters.sessionLogger.print3(
        "=========== Creation of the model: \"" +
        str(createModelSessionParameters.cnnModelName) +
        "\" finished =================")

    return (cnn3dInstance, filenameAndPathWhereModelWasSaved)
def deepMedicTrainMain(trainConfigFilepath, absPathToSavedModelFromCmdLine, cnnInstancePreLoaded, filenameAndPathWherePreLoadedModelWas, resetOptimizer) :
    print "Given Training-Configuration File: ", trainConfigFilepath
    #Parse the config file in this naive fashion...
    trainConfig = TrainConfig()
    execfile(trainConfigFilepath, trainConfig.configStruct)
    configGet = trainConfig.get #Main interface
    
    """
    #Do checks.
    checkIfMainTestConfigIsCorrect(testConfig, testConfigFilepath, absPathToSavedModelFromCmdLine) #Checks REQUIRED fields are complete.
    checkIfFilesThatListFilesPerCaseAreCorrect(testConfig, testConfigFilepath) #Checks listing-files (whatever given).
    checkIfOptionalParametersAreGivenCorrectly(testConfig, testConfigFilepath)
    
    #At this point it was checked that all parameters (that could be checked) and filepaths are correct, pointing to files/dirs and all files/dirs exist.
    """
    
    #Create Folders and Logger
    mainOutputAbsFolder = getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.FOLDER_FOR_OUTPUT), trainConfigFilepath)
    sessionName = configGet(trainConfig.SESSION_NAME) if configGet(trainConfig.SESSION_NAME) else TrainSessionParameters.getDefaultSessionName()
    [folderForSessionCnnModels,
    folderForLogs,
    folderForPredictions,
    folderForFeatures] = makeFoldersNeededForTrainingSession(mainOutputAbsFolder, sessionName)
    loggerFileName = folderForLogs + "/" + sessionName + ".txt"
    sessionLogger = myLoggerModule.MyLogger(loggerFileName)
    
    sessionLogger.print3("CONFIG: The configuration file for the training session was loaded from: " + str(trainConfigFilepath))
    
    #Load the cnn Model if not given straight from a createCnnModel-session...
    cnn3dInstance = None
    filepathToCnnModel = None
    if cnnInstancePreLoaded :
        sessionLogger.print3("====== CNN-Instance already loaded (from Create-New-Model procedure) =======")
        if configGet(trainConfig.CNN_MODEL_FILEPATH) :
            sessionLogger.print3("WARN: Any paths to cnn-models specified in the train-config file will be ommitted! Working with the pre-loaded model!")
        cnn3dInstance = cnnInstancePreLoaded
        filepathToCnnModel = filenameAndPathWherePreLoadedModelWas
    else :
        sessionLogger.print3("=========== Loading the CNN model for training... ===============")
        #If CNN-Model was specified in command line, completely override the one in the config file.
        if absPathToSavedModelFromCmdLine and configGet(trainConfig.CNN_MODEL_FILEPATH) :
            sessionLogger.print3("WARN: A CNN-Model to use was specified both in the command line input and in the train-config-file! The input by the command line will be used: " + str(absPathToSavedModelFromCmdLine) )
            filepathToCnnModel = absPathToSavedModelFromCmdLine
        else :
            filepathToCnnModel = getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.CNN_MODEL_FILEPATH), trainConfigFilepath)
        sessionLogger.print3("...Loading the network can take a few minutes if the model is big...")
        cnn3dInstance = load_object_from_gzip_file(filepathToCnnModel)
        sessionLogger.print3("The CNN model was loaded successfully from: " + str(filepathToCnnModel))
        
    """
    #Do final checks of the parameters. Check the ones that need check in comparison to the model's parameters! Such as: SAVE_PROBMAPS_PER_CLASS, INDICES_OF_FMS_TO_SAVE, Number of Channels!
    trainConfig.checkIfConfigIsCorrectForParticularCnnModel(cnn3dInstance)
    """
    
    #Fill in the session's parameters.
    if configGet(trainConfig.CHANNELS_TR) :
        #[[case1-ch1, ..., caseN-ch1], [case1-ch2,...,caseN-ch2]]
        listOfAListPerChannelWithFilepathsOfAllCasesTrain = [parseAbsFileLinesInList(getAbsPathEvenIfRelativeIsGiven(channelConfPath, trainConfigFilepath)) for channelConfPath in configGet(trainConfig.CHANNELS_TR)]
        #[[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]]
        listWithAListPerCaseWithFilepathPerChannelTrain = [ list(item) for item in zip(*tuple(listOfAListPerChannelWithFilepathsOfAllCasesTrain)) ]
    else :
        listWithAListPerCaseWithFilepathPerChannelTrain = None
    gtLabelsFilepathsTrain = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.GT_LABELS_TR), trainConfigFilepath) ) if configGet(trainConfig.GT_LABELS_TR) else None
    roiMasksFilepathsTrain = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.ROI_MASKS_TR), trainConfigFilepath) ) if configGet(trainConfig.ROI_MASKS_TR) else None
    #~~~~~ Advanced Training Sampling~~~~~~~~~
    if configGet(trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_TR) :
        #[[case1-weightMap1, ..., caseN-weightMap1], [case1-weightMap2,...,caseN-weightMap2]]
        listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain = [parseAbsFileLinesInList(getAbsPathEvenIfRelativeIsGiven(weightMapConfPath, trainConfigFilepath)) for weightMapConfPath in configGet(trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_TR)]
    else :
        listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain = None
        
    #=======VALIDATION==========
    if configGet(trainConfig.CHANNELS_VAL) :
        listOfAListPerChannelWithFilepathsOfAllCasesVal = [parseAbsFileLinesInList(getAbsPathEvenIfRelativeIsGiven(channelConfPath, trainConfigFilepath)) for channelConfPath in configGet(trainConfig.CHANNELS_VAL)]
        #[[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]]
        listWithAListPerCaseWithFilepathPerChannelVal = [ list(item) for item in zip(*tuple(listOfAListPerChannelWithFilepathsOfAllCasesVal)) ]
    else :
        listWithAListPerCaseWithFilepathPerChannelVal = None
    gtLabelsFilepathsVal = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.GT_LABELS_VAL), trainConfigFilepath) ) if configGet(trainConfig.GT_LABELS_VAL) else None
    roiMasksFilepathsVal = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.ROI_MASKS_VAL), trainConfigFilepath) ) if configGet(trainConfig.ROI_MASKS_VAL) else None
    #~~~~~Full Inference~~~~~~
    namesToSavePredsAndFeatsVal = parseFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.NAMES_FOR_PRED_PER_CASE_VAL), trainConfigFilepath) ) if configGet(trainConfig.NAMES_FOR_PRED_PER_CASE_VAL) else None #CAREFUL: Here we use a different parsing function!
    #~~~~~Advanced Validation Sampling~~~~~~~~
    if configGet(trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_VAL) :
        #[[case1-weightMap1, ..., caseN-weightMap1], [case1-weightMap2,...,caseN-weightMap2]]
        listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal = [parseAbsFileLinesInList(getAbsPathEvenIfRelativeIsGiven(weightMapConfPath, trainConfigFilepath)) for weightMapConfPath in configGet(trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_VAL)]
    else :
        listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal = None
        
    trainSessionParameters = TrainSessionParameters(
                    sessionName = sessionName,
                    sessionLogger = sessionLogger,
                    mainOutputAbsFolder = mainOutputAbsFolder,
                    cnn3dInstance = cnn3dInstance,
                    cnnModelFilepath = filepathToCnnModel,
                    
                    #==================TRAINING====================
                    folderForSessionCnnModels = folderForSessionCnnModels,
                    listWithAListPerCaseWithFilepathPerChannelTrain = listWithAListPerCaseWithFilepathPerChannelTrain,
                    gtLabelsFilepathsTrain = gtLabelsFilepathsTrain,
                    
                    #[Optionals]
                    #~~~~~~~~~Sampling~~~~~~~~~
                    roiMasksFilepathsTrain = roiMasksFilepathsTrain,
                    percentOfSamplesToExtractPositTrain = configGet(trainConfig.PERC_POS_SAMPLES_TR),
                    #~~~~~~~~~Advanced Sampling~~~~~~~
                    useDefaultTrainingSamplingFromGtAndRoi = configGet(trainConfig.DEFAULT_TR_SAMPLING),
                    samplingTypeTraining = configGet(trainConfig.TYPE_OF_SAMPLING_TR),
                    proportionOfSamplesPerCategoryTrain = configGet(trainConfig.PROP_OF_SAMPLES_PER_CAT_TR),
                    listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain = listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain,
                    
                    #~~~~~~~~Training Cycle ~~~~~~~
                    numberOfEpochs = configGet(trainConfig.NUM_EPOCHS),
                    numberOfSubepochs = configGet(trainConfig.NUM_SUBEP),
                    numOfCasesLoadedPerSubepoch = configGet(trainConfig.NUM_CASES_LOADED_PERSUB),
                    segmentsLoadedOnGpuPerSubepochTrain = configGet(trainConfig.NUM_TR_SEGMS_LOADED_PERSUB),
                    
                    #~~~~~~~ Learning Rate Schedule ~~~~~~~
                    #Auto requires performValidationOnSamplesThroughoutTraining and providedGtForValidationBool
                    stable0orAuto1orPredefined2orExponential3LrSchedule = configGet(trainConfig.LR_SCH_0123),
                    
                    #Stable + Auto + Predefined.
                    whenDecreasingDivideLrBy = configGet(trainConfig.DIV_LR_BY),
                    #Stable + Auto
                    numEpochsToWaitBeforeLoweringLr = configGet(trainConfig.NUM_EPOCHS_WAIT),
                    #Auto:
                    minIncreaseInValidationAccuracyThatResetsWaiting = configGet(trainConfig.AUTO_MIN_INCR_VAL_ACC),
                    #Predefined.
                    predefinedSchedule = configGet(trainConfig.PREDEF_SCH),
                    #Exponential
                    exponentialSchedForLrAndMom = configGet(trainConfig.EXPON_SCH),
                    
                    #~~~~~~~ Augmentation~~~~~~~~~~~~
                    reflectImagesPerAxis = configGet(trainConfig.REFL_AUGM_PER_AXIS),
                    performIntAugm = configGet(trainConfig.PERF_INT_AUGM_BOOL),
                    sampleIntAugmShiftWithMuAndStd = configGet(trainConfig.INT_AUGM_SHIF_MUSTD),
                    sampleIntAugmMultiWithMuAndStd = configGet(trainConfig.INT_AUGM_MULT_MUSTD),
                                                    
                    #==================VALIDATION=====================
                    performValidationOnSamplesThroughoutTraining = configGet(trainConfig.PERFORM_VAL_SAMPLES),
                    performFullInferenceOnValidationImagesEveryFewEpochs = configGet(trainConfig.PERFORM_VAL_INFERENCE),
                    #Required:
                    listWithAListPerCaseWithFilepathPerChannelVal = listWithAListPerCaseWithFilepathPerChannelVal,
                    gtLabelsFilepathsVal = gtLabelsFilepathsVal,
                    
                    segmentsLoadedOnGpuPerSubepochVal = configGet(trainConfig.NUM_VAL_SEGMS_LOADED_PERSUB),
                    
                    #[Optionals]
                    roiMasksFilepathsVal = roiMasksFilepathsVal, #For default sampling and for fast inference. Optional. Otherwise from whole image.
                    
                    #~~~~~~~~Full Inference~~~~~~~~
                    numberOfEpochsBetweenFullInferenceOnValImages = configGet(trainConfig.NUM_EPOCHS_BETWEEN_VAL_INF),
                    #Output                                
                    namesToSavePredictionsAndFeaturesVal = namesToSavePredsAndFeatsVal,
                    #predictions
                    saveSegmentationVal = configGet(trainConfig.SAVE_SEGM_VAL),
                    saveProbMapsBoolPerClassVal = configGet(trainConfig.SAVE_PROBMAPS_PER_CLASS_VAL),
                    folderForPredictionsVal = folderForPredictions,
                    #features:
                    saveIndividualFmImagesVal = configGet(trainConfig.SAVE_INDIV_FMS_VAL),
                    saveMultidimensionalImageWithAllFmsVal = configGet(trainConfig.SAVE_4DIM_FMS_VAL),
                    indicesOfFmsToVisualisePerPathwayAndLayerVal = [configGet(trainConfig.INDICES_OF_FMS_TO_SAVE_NORMAL_VAL)] +\
                                                                    [configGet(trainConfig.INDICES_OF_FMS_TO_SAVE_SUBSAMPLED_VAL)] +\
                                                                    [configGet(trainConfig.INDICES_OF_FMS_TO_SAVE_FC_VAL)],
                    folderForFeaturesVal = folderForFeatures,
                    
                    #~~~~~~~~ Advanced Validation Sampling ~~~~~~~~~~
                    useDefaultUniformValidationSampling = configGet(trainConfig.DEFAULT_VAL_SAMPLING),
                    samplingTypeValidation = configGet(trainConfig.TYPE_OF_SAMPLING_VAL),
                    proportionOfSamplesPerCategoryVal = configGet(trainConfig.PROP_OF_SAMPLES_PER_CAT_VAL),
                    listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal = listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal,
                    
                    #====Optimization=====
                    learningRate=configGet(trainConfig.LRATE),
                    optimizerSgd0Adam1Rms2=configGet(trainConfig.OPTIMIZER),
                    classicMom0Nesterov1=configGet(trainConfig.MOM_TYPE),
                    momentumValue=configGet(trainConfig.MOM),
                    momNonNormalized0Normalized1=configGet(trainConfig.MOM_NORM_NONNORM),
                    #Adam
                    b1Adam=configGet(trainConfig.B1_ADAM),
                    b2Adam=configGet(trainConfig.B2_ADAM),
                    eAdam=configGet(trainConfig.EPS_ADAM),
                    #Rms
                    rhoRms=configGet(trainConfig.RHO_RMS),
                    eRms=configGet(trainConfig.EPS_RMS),
                    #Regularization
                    l1Reg=configGet(trainConfig.L1_REG),
                    l2Reg=configGet(trainConfig.L2_REG),
                    
                    #~~~~~~~ Freeze Layers ~~~~~~~
                    layersToFreezePerPathwayType = [configGet(trainConfig.LAYERS_TO_FREEZE_NORM),
                                                    configGet(trainConfig.LAYERS_TO_FREEZE_SUBS),
                                                    configGet(trainConfig.LAYERS_TO_FREEZE_FC) ],
                                                    
                    #==============Generic and Preprocessing===============
                    padInputImagesBool = configGet(trainConfig.PAD_INPUT)
                    )
    
    trainSessionParameters.sessionLogger.print3("\n===========       NEW TRAINING SESSION         ===============")
    trainSessionParameters.printParametersOfThisSession()
    
    trainSessionParameters.sessionLogger.print3("\n=======================================================")
    trainSessionParameters.sessionLogger.print3("=========== Compiling the Training Function ===========")
    trainSessionParameters.sessionLogger.print3("=======================================================")
    if not cnn3dInstance.checkTrainingStateAttributesInitialized() or resetOptimizer :
        trainSessionParameters.sessionLogger.print3("(Re)Initializing parameters for the optimization. " \
                        "Reason: Uninitialized: ["+str(not cnn3dInstance.checkTrainingStateAttributesInitialized())+"], Reset requested: ["+str(resetOptimizer)+"]" )
        cnn3dInstance.initializeTrainingState(*trainSessionParameters.getTupleForInitializingTrainingState())
    cnn3dInstance.compileTrainFunction(*trainSessionParameters.getTupleForCompilationOfTrainFunc())
    trainSessionParameters.sessionLogger.print3("\n=========== Compiling the Validation Function =========")
    cnn3dInstance.compileValidationFunction(*trainSessionParameters.getTupleForCompilationOfValFunc())
    trainSessionParameters.sessionLogger.print3("\n=========== Compiling the Testing Function ============")
    cnn3dInstance.compileTestAndVisualisationFunction(*trainSessionParameters.getTupleForCompilationOfTestFunc()) # For validation with full segmentation
    
    trainSessionParameters.sessionLogger.print3("\n=======================================================")
    trainSessionParameters.sessionLogger.print3("============== Training the CNN model =================")
    trainSessionParameters.sessionLogger.print3("=======================================================")
    do_training(*trainSessionParameters.getTupleForCnnTraining())
    trainSessionParameters.sessionLogger.print3("\n=======================================================")
    trainSessionParameters.sessionLogger.print3("=========== Training session finished =================")
    trainSessionParameters.sessionLogger.print3("=======================================================")