Example #1
0
def deepMedicTrainMain(trainConfigFilepath, absPathToSavedModelFromCmdLine,
                       cnnInstancePreLoaded,
                       filenameAndPathWherePreLoadedModelWas, resetOptimizer):
    print("Given Training-Configuration File: ", trainConfigFilepath)
    #Parse the config file in this naive fashion...
    trainConfig = TrainConfig()
    exec(open(trainConfigFilepath).read(), trainConfig.configStruct)
    configGet = trainConfig.get  #Main interface
    """
    #Do checks.
    checkIfMainTestConfigIsCorrect(testConfig, testConfigFilepath, absPathToSavedModelFromCmdLine) #Checks REQUIRED fields are complete.
    checkIfFilesThatListFilesPerCaseAreCorrect(testConfig, testConfigFilepath) #Checks listing-files (whatever given).
    checkIfOptionalParametersAreGivenCorrectly(testConfig, testConfigFilepath)
    
    #At this point it was checked that all parameters (that could be checked) and filepaths are correct, pointing to files/dirs and all files/dirs exist.
    """

    #Create Folders and Logger
    mainOutputAbsFolder = getAbsPathEvenIfRelativeIsGiven(
        configGet(trainConfig.FOLDER_FOR_OUTPUT), trainConfigFilepath)
    sessionName = configGet(trainConfig.SESSION_NAME) if configGet(
        trainConfig.SESSION_NAME
    ) else TrainSessionParameters.getDefaultSessionName()
    [
        folderForSessionCnnModels, folderForLogs, folderForPredictions,
        folderForFeatures
    ] = makeFoldersNeededForTrainingSession(mainOutputAbsFolder, sessionName)
    loggerFileName = folderForLogs + "/" + sessionName + ".txt"
    sessionLogger = myLoggerModule.MyLogger(loggerFileName)

    sessionLogger.print3(
        "CONFIG: The configuration file for the training session was loaded from: "
        + str(trainConfigFilepath))

    #Load the cnn Model if not given straight from a createCnnModel-session...
    cnn3dInstance = None
    filepathToCnnModel = None
    if cnnInstancePreLoaded:
        sessionLogger.print3(
            "====== CNN-Instance already loaded (from Create-New-Model procedure) ======="
        )
        if configGet(trainConfig.CNN_MODEL_FILEPATH):
            sessionLogger.print3(
                "WARN: Any paths to cnn-models specified in the train-config file will be ommitted! Working with the pre-loaded model!"
            )
        cnn3dInstance = cnnInstancePreLoaded
        filepathToCnnModel = filenameAndPathWherePreLoadedModelWas
    else:
        sessionLogger.print3(
            "=========== Loading the CNN model for training... ==============="
        )
        #If CNN-Model was specified in command line, completely override the one in the config file.
        if absPathToSavedModelFromCmdLine and configGet(
                trainConfig.CNN_MODEL_FILEPATH):
            sessionLogger.print3(
                "WARN: A CNN-Model to use was specified both in the command line input and in the train-config-file! The input by the command line will be used: "
                + str(absPathToSavedModelFromCmdLine))
            filepathToCnnModel = absPathToSavedModelFromCmdLine
        else:
            filepathToCnnModel = getAbsPathEvenIfRelativeIsGiven(
                configGet(trainConfig.CNN_MODEL_FILEPATH), trainConfigFilepath)
        sessionLogger.print3(
            "...Loading the network can take a few minutes if the model is big..."
        )
        cnn3dInstance = load_object_from_gzip_file(filepathToCnnModel)
        sessionLogger.print3("The CNN model was loaded successfully from: " +
                             str(filepathToCnnModel))
    """
    #Do final checks of the parameters. Check the ones that need check in comparison to the model's parameters! Such as: SAVE_PROBMAPS_PER_CLASS, INDICES_OF_FMS_TO_SAVE, Number of Channels!
    trainConfig.checkIfConfigIsCorrectForParticularCnnModel(cnn3dInstance)
    """

    #Fill in the session's parameters.
    if configGet(trainConfig.CHANNELS_TR):
        #[[case1-ch1, ..., caseN-ch1], [case1-ch2,...,caseN-ch2]]
        listOfAListPerChannelWithFilepathsOfAllCasesTrain = [
            parseAbsFileLinesInList(
                getAbsPathEvenIfRelativeIsGiven(channelConfPath,
                                                trainConfigFilepath))
            for channelConfPath in configGet(trainConfig.CHANNELS_TR)
        ]
        #[[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]]
        listWithAListPerCaseWithFilepathPerChannelTrain = [
            list(item) for item in zip(
                *tuple(listOfAListPerChannelWithFilepathsOfAllCasesTrain))
        ]
    else:
        listWithAListPerCaseWithFilepathPerChannelTrain = None
    gtLabelsFilepathsTrain = parseAbsFileLinesInList(
        getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.GT_LABELS_TR),
                                        trainConfigFilepath)) if configGet(
                                            trainConfig.GT_LABELS_TR) else None
    roiMasksFilepathsTrain = parseAbsFileLinesInList(
        getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.ROI_MASKS_TR),
                                        trainConfigFilepath)) if configGet(
                                            trainConfig.ROI_MASKS_TR) else None
    #~~~~~ Advanced Training Sampling~~~~~~~~~
    if configGet(trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_TR):
        #[[case1-weightMap1, ..., caseN-weightMap1], [case1-weightMap2,...,caseN-weightMap2]]
        listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain = [
            parseAbsFileLinesInList(
                getAbsPathEvenIfRelativeIsGiven(weightMapConfPath,
                                                trainConfigFilepath))
            for weightMapConfPath in configGet(
                trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_TR)
        ]
    else:
        listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain = None

    #=======VALIDATION==========
    if configGet(trainConfig.CHANNELS_VAL):
        listOfAListPerChannelWithFilepathsOfAllCasesVal = [
            parseAbsFileLinesInList(
                getAbsPathEvenIfRelativeIsGiven(channelConfPath,
                                                trainConfigFilepath))
            for channelConfPath in configGet(trainConfig.CHANNELS_VAL)
        ]
        #[[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]]
        listWithAListPerCaseWithFilepathPerChannelVal = [
            list(item) for item in zip(
                *tuple(listOfAListPerChannelWithFilepathsOfAllCasesVal))
        ]
    else:
        listWithAListPerCaseWithFilepathPerChannelVal = None
    gtLabelsFilepathsVal = parseAbsFileLinesInList(
        getAbsPathEvenIfRelativeIsGiven(configGet(
            trainConfig.GT_LABELS_VAL), trainConfigFilepath)) if configGet(
                trainConfig.GT_LABELS_VAL) else None
    roiMasksFilepathsVal = parseAbsFileLinesInList(
        getAbsPathEvenIfRelativeIsGiven(configGet(
            trainConfig.ROI_MASKS_VAL), trainConfigFilepath)) if configGet(
                trainConfig.ROI_MASKS_VAL) else None
    #~~~~~Full Inference~~~~~~
    namesToSavePredsAndFeatsVal = parseFileLinesInList(
        getAbsPathEvenIfRelativeIsGiven(
            configGet(trainConfig.NAMES_FOR_PRED_PER_CASE_VAL),
            trainConfigFilepath)) if configGet(
                trainConfig.NAMES_FOR_PRED_PER_CASE_VAL
            ) else None  #CAREFUL: Here we use a different parsing function!
    #~~~~~Advanced Validation Sampling~~~~~~~~
    if configGet(trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_VAL):
        #[[case1-weightMap1, ..., caseN-weightMap1], [case1-weightMap2,...,caseN-weightMap2]]
        listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal = [
            parseAbsFileLinesInList(
                getAbsPathEvenIfRelativeIsGiven(weightMapConfPath,
                                                trainConfigFilepath))
            for weightMapConfPath in configGet(
                trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_VAL)
        ]
    else:
        listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal = None

    trainSessionParameters = TrainSessionParameters(
                    sessionName = sessionName,
                    sessionLogger = sessionLogger,
                    mainOutputAbsFolder = mainOutputAbsFolder,
                    cnn3dInstance = cnn3dInstance,
                    cnnModelFilepath = filepathToCnnModel,

                    #==================TRAINING====================
                    folderForSessionCnnModels = folderForSessionCnnModels,
                    listWithAListPerCaseWithFilepathPerChannelTrain = listWithAListPerCaseWithFilepathPerChannelTrain,
                    gtLabelsFilepathsTrain = gtLabelsFilepathsTrain,

                    #[Optionals]
                    #~~~~~~~~~Sampling~~~~~~~~~
                    roiMasksFilepathsTrain = roiMasksFilepathsTrain,
                    percentOfSamplesToExtractPositTrain = configGet(trainConfig.PERC_POS_SAMPLES_TR),
                    #~~~~~~~~~Advanced Sampling~~~~~~~
                    useDefaultTrainingSamplingFromGtAndRoi = configGet(trainConfig.DEFAULT_TR_SAMPLING),
                    samplingTypeTraining = configGet(trainConfig.TYPE_OF_SAMPLING_TR),
                    proportionOfSamplesPerCategoryTrain = configGet(trainConfig.PROP_OF_SAMPLES_PER_CAT_TR),
                    listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain = listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain,

                    #~~~~~~~~Training Cycle ~~~~~~~
                    numberOfEpochs = configGet(trainConfig.NUM_EPOCHS),
                    numberOfSubepochs = configGet(trainConfig.NUM_SUBEP),
                    numOfCasesLoadedPerSubepoch = configGet(trainConfig.NUM_CASES_LOADED_PERSUB),
                    segmentsLoadedOnGpuPerSubepochTrain = configGet(trainConfig.NUM_TR_SEGMS_LOADED_PERSUB),

                    #~~~~~~~ Learning Rate Schedule ~~~~~~~
                    #Auto requires performValidationOnSamplesThroughoutTraining and providedGtForValidationBool
                    stable0orAuto1orPredefined2orExponential3LrSchedule = configGet(trainConfig.LR_SCH_0123),

                    #Stable + Auto + Predefined.
                    whenDecreasingDivideLrBy = configGet(trainConfig.DIV_LR_BY),
                    #Stable + Auto
                    numEpochsToWaitBeforeLoweringLr = configGet(trainConfig.NUM_EPOCHS_WAIT),
                    #Auto:
                    minIncreaseInValidationAccuracyThatResetsWaiting = configGet(trainConfig.AUTO_MIN_INCR_VAL_ACC),
                    #Predefined.
                    predefinedSchedule = configGet(trainConfig.PREDEF_SCH),
                    #Exponential
                    exponentialSchedForLrAndMom = configGet(trainConfig.EXPON_SCH),

                    #~~~~~~~ Augmentation~~~~~~~~~~~~
                    reflectImagesPerAxis = configGet(trainConfig.REFL_AUGM_PER_AXIS),
                    performIntAugm = configGet(trainConfig.PERF_INT_AUGM_BOOL),
                    sampleIntAugmShiftWithMuAndStd = configGet(trainConfig.INT_AUGM_SHIF_MUSTD),
                    sampleIntAugmMultiWithMuAndStd = configGet(trainConfig.INT_AUGM_MULT_MUSTD),

                    #==================VALIDATION=====================
                    performValidationOnSamplesThroughoutTraining = configGet(trainConfig.PERFORM_VAL_SAMPLES),
                    performFullInferenceOnValidationImagesEveryFewEpochs = configGet(trainConfig.PERFORM_VAL_INFERENCE),
                    #Required:
                    listWithAListPerCaseWithFilepathPerChannelVal = listWithAListPerCaseWithFilepathPerChannelVal,
                    gtLabelsFilepathsVal = gtLabelsFilepathsVal,

                    segmentsLoadedOnGpuPerSubepochVal = configGet(trainConfig.NUM_VAL_SEGMS_LOADED_PERSUB),

                    #[Optionals]
                    roiMasksFilepathsVal = roiMasksFilepathsVal, #For default sampling and for fast inference. Optional. Otherwise from whole image.

                    #~~~~~~~~Full Inference~~~~~~~~
                    numberOfEpochsBetweenFullInferenceOnValImages = configGet(trainConfig.NUM_EPOCHS_BETWEEN_VAL_INF),
                    #Output
                    namesToSavePredictionsAndFeaturesVal = namesToSavePredsAndFeatsVal,
                    #predictions
                    saveSegmentationVal = configGet(trainConfig.SAVE_SEGM_VAL),
                    saveProbMapsBoolPerClassVal = configGet(trainConfig.SAVE_PROBMAPS_PER_CLASS_VAL),
                    folderForPredictionsVal = folderForPredictions,
                    #features:
                    saveIndividualFmImagesVal = configGet(trainConfig.SAVE_INDIV_FMS_VAL),
                    saveMultidimensionalImageWithAllFmsVal = configGet(trainConfig.SAVE_4DIM_FMS_VAL),
                    indicesOfFmsToVisualisePerPathwayAndLayerVal = [configGet(trainConfig.INDICES_OF_FMS_TO_SAVE_NORMAL_VAL)] +\
                                                                    [configGet(trainConfig.INDICES_OF_FMS_TO_SAVE_SUBSAMPLED_VAL)] +\
                                                                    [configGet(trainConfig.INDICES_OF_FMS_TO_SAVE_FC_VAL)],
                    folderForFeaturesVal = folderForFeatures,

                    #~~~~~~~~ Advanced Validation Sampling ~~~~~~~~~~
                    useDefaultUniformValidationSampling = configGet(trainConfig.DEFAULT_VAL_SAMPLING),
                    samplingTypeValidation = configGet(trainConfig.TYPE_OF_SAMPLING_VAL),
                    proportionOfSamplesPerCategoryVal = configGet(trainConfig.PROP_OF_SAMPLES_PER_CAT_VAL),
                    listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal = listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal,

                    #====Optimization=====
                    learningRate=configGet(trainConfig.LRATE),
                    optimizerSgd0Adam1Rms2=configGet(trainConfig.OPTIMIZER),
                    classicMom0Nesterov1=configGet(trainConfig.MOM_TYPE),
                    momentumValue=configGet(trainConfig.MOM),
                    momNonNormalized0Normalized1=configGet(trainConfig.MOM_NORM_NONNORM),
                    #Adam
                    b1Adam=configGet(trainConfig.B1_ADAM),
                    b2Adam=configGet(trainConfig.B2_ADAM),
                    eAdam=configGet(trainConfig.EPS_ADAM),
                    #Rms
                    rhoRms=configGet(trainConfig.RHO_RMS),
                    eRms=configGet(trainConfig.EPS_RMS),
                    #Regularization
                    l1Reg=configGet(trainConfig.L1_REG),
                    l2Reg=configGet(trainConfig.L2_REG),

                    #~~~~~~~ Freeze Layers ~~~~~~~
                    layersToFreezePerPathwayType = [configGet(trainConfig.LAYERS_TO_FREEZE_NORM),
                                                    configGet(trainConfig.LAYERS_TO_FREEZE_SUBS),
                                                    configGet(trainConfig.LAYERS_TO_FREEZE_FC) ],

                    #==============Generic and Preprocessing===============
                    padInputImagesBool = configGet(trainConfig.PAD_INPUT)
                    )

    trainSessionParameters.sessionLogger.print3(
        "\n===========       NEW TRAINING SESSION         ===============")
    trainSessionParameters.printParametersOfThisSession()

    trainSessionParameters.sessionLogger.print3(
        "\n=======================================================")
    trainSessionParameters.sessionLogger.print3(
        "=========== Compiling the Training Function ===========")
    trainSessionParameters.sessionLogger.print3(
        "=======================================================")
    if not cnn3dInstance.checkTrainingStateAttributesInitialized(
    ) or resetOptimizer:
        trainSessionParameters.sessionLogger.print3("(Re)Initializing parameters for the optimization. " \
                        "Reason: Uninitialized: ["+str(not cnn3dInstance.checkTrainingStateAttributesInitialized())+"], Reset requested: ["+str(resetOptimizer)+"]" )
        cnn3dInstance.initializeTrainingState(
            *trainSessionParameters.getTupleForInitializingTrainingState())
    cnn3dInstance.compileTrainFunction(
        *trainSessionParameters.getTupleForCompilationOfTrainFunc())
    trainSessionParameters.sessionLogger.print3(
        "\n=========== Compiling the Validation Function =========")
    cnn3dInstance.compileValidationFunction(
        *trainSessionParameters.getTupleForCompilationOfValFunc())
    trainSessionParameters.sessionLogger.print3(
        "\n=========== Compiling the Testing Function ============")
    cnn3dInstance.compileTestAndVisualisationFunction(
        *trainSessionParameters.getTupleForCompilationOfTestFunc(
        ))  # For validation with full segmentation

    trainSessionParameters.sessionLogger.print3(
        "\n=======================================================")
    trainSessionParameters.sessionLogger.print3(
        "============== Training the CNN model =================")
    trainSessionParameters.sessionLogger.print3(
        "=======================================================")
    do_training(*trainSessionParameters.getTupleForCnnTraining())
    trainSessionParameters.sessionLogger.print3(
        "\n=======================================================")
    trainSessionParameters.sessionLogger.print3(
        "=========== Training session finished =================")
    trainSessionParameters.sessionLogger.print3(
        "=======================================================")
def deepMedicTrainMain(trainConfigFilepath, absPathToSavedModelFromCmdLine, cnnInstancePreLoaded, filenameAndPathWherePreLoadedModelWas, resetOptimizer) :
    print "Given Training-Configuration File: ", trainConfigFilepath
    #Parse the config file in this naive fashion...
    trainConfig = TrainConfig()
    execfile(trainConfigFilepath, trainConfig.configStruct)
    configGet = trainConfig.get #Main interface
    
    """
    #Do checks.
    checkIfMainTestConfigIsCorrect(testConfig, testConfigFilepath, absPathToSavedModelFromCmdLine) #Checks REQUIRED fields are complete.
    checkIfFilesThatListFilesPerCaseAreCorrect(testConfig, testConfigFilepath) #Checks listing-files (whatever given).
    checkIfOptionalParametersAreGivenCorrectly(testConfig, testConfigFilepath)
    
    #At this point it was checked that all parameters (that could be checked) and filepaths are correct, pointing to files/dirs and all files/dirs exist.
    """
    
    #Create Folders and Logger
    mainOutputAbsFolder = getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.FOLDER_FOR_OUTPUT), trainConfigFilepath)
    sessionName = configGet(trainConfig.SESSION_NAME) if configGet(trainConfig.SESSION_NAME) else TrainSessionParameters.getDefaultSessionName()
    [folderForSessionCnnModels,
    folderForLogs,
    folderForPredictions,
    folderForFeatures] = makeFoldersNeededForTrainingSession(mainOutputAbsFolder, sessionName)
    loggerFileName = folderForLogs + "/" + sessionName + ".txt"
    sessionLogger = myLoggerModule.MyLogger(loggerFileName)
    
    sessionLogger.print3("CONFIG: The configuration file for the training session was loaded from: " + str(trainConfigFilepath))
    
    #Load the cnn Model if not given straight from a createCnnModel-session...
    cnn3dInstance = None
    filepathToCnnModel = None
    if cnnInstancePreLoaded :
        sessionLogger.print3("====== CNN-Instance already loaded (from Create-New-Model procedure) =======")
        if configGet(trainConfig.CNN_MODEL_FILEPATH) :
            sessionLogger.print3("WARN: Any paths to cnn-models specified in the train-config file will be ommitted! Working with the pre-loaded model!")
        cnn3dInstance = cnnInstancePreLoaded
        filepathToCnnModel = filenameAndPathWherePreLoadedModelWas
    else :
        sessionLogger.print3("=========== Loading the CNN model for training... ===============")
        #If CNN-Model was specified in command line, completely override the one in the config file.
        if absPathToSavedModelFromCmdLine and configGet(trainConfig.CNN_MODEL_FILEPATH) :
            sessionLogger.print3("WARN: A CNN-Model to use was specified both in the command line input and in the train-config-file! The input by the command line will be used: " + str(absPathToSavedModelFromCmdLine) )
            filepathToCnnModel = absPathToSavedModelFromCmdLine
        else :
            filepathToCnnModel = getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.CNN_MODEL_FILEPATH), trainConfigFilepath)
        sessionLogger.print3("...Loading the network can take a few minutes if the model is big...")
        cnn3dInstance = load_object_from_gzip_file(filepathToCnnModel)
        sessionLogger.print3("The CNN model was loaded successfully from: " + str(filepathToCnnModel))
        
    """
    #Do final checks of the parameters. Check the ones that need check in comparison to the model's parameters! Such as: SAVE_PROBMAPS_PER_CLASS, INDICES_OF_FMS_TO_SAVE, Number of Channels!
    trainConfig.checkIfConfigIsCorrectForParticularCnnModel(cnn3dInstance)
    """
    
    #Fill in the session's parameters.
    if configGet(trainConfig.CHANNELS_TR) :
        #[[case1-ch1, ..., caseN-ch1], [case1-ch2,...,caseN-ch2]]
        listOfAListPerChannelWithFilepathsOfAllCasesTrain = [parseAbsFileLinesInList(getAbsPathEvenIfRelativeIsGiven(channelConfPath, trainConfigFilepath)) for channelConfPath in configGet(trainConfig.CHANNELS_TR)]
        #[[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]]
        listWithAListPerCaseWithFilepathPerChannelTrain = [ list(item) for item in zip(*tuple(listOfAListPerChannelWithFilepathsOfAllCasesTrain)) ]
    else :
        listWithAListPerCaseWithFilepathPerChannelTrain = None
    gtLabelsFilepathsTrain = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.GT_LABELS_TR), trainConfigFilepath) ) if configGet(trainConfig.GT_LABELS_TR) else None
    roiMasksFilepathsTrain = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.ROI_MASKS_TR), trainConfigFilepath) ) if configGet(trainConfig.ROI_MASKS_TR) else None
    #~~~~~ Advanced Training Sampling~~~~~~~~~
    if configGet(trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_TR) :
        #[[case1-weightMap1, ..., caseN-weightMap1], [case1-weightMap2,...,caseN-weightMap2]]
        listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain = [parseAbsFileLinesInList(getAbsPathEvenIfRelativeIsGiven(weightMapConfPath, trainConfigFilepath)) for weightMapConfPath in configGet(trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_TR)]
    else :
        listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain = None
        
    #=======VALIDATION==========
    if configGet(trainConfig.CHANNELS_VAL) :
        listOfAListPerChannelWithFilepathsOfAllCasesVal = [parseAbsFileLinesInList(getAbsPathEvenIfRelativeIsGiven(channelConfPath, trainConfigFilepath)) for channelConfPath in configGet(trainConfig.CHANNELS_VAL)]
        #[[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]]
        listWithAListPerCaseWithFilepathPerChannelVal = [ list(item) for item in zip(*tuple(listOfAListPerChannelWithFilepathsOfAllCasesVal)) ]
    else :
        listWithAListPerCaseWithFilepathPerChannelVal = None
    gtLabelsFilepathsVal = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.GT_LABELS_VAL), trainConfigFilepath) ) if configGet(trainConfig.GT_LABELS_VAL) else None
    roiMasksFilepathsVal = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.ROI_MASKS_VAL), trainConfigFilepath) ) if configGet(trainConfig.ROI_MASKS_VAL) else None
    #~~~~~Full Inference~~~~~~
    namesToSavePredsAndFeatsVal = parseFileLinesInList( getAbsPathEvenIfRelativeIsGiven(configGet(trainConfig.NAMES_FOR_PRED_PER_CASE_VAL), trainConfigFilepath) ) if configGet(trainConfig.NAMES_FOR_PRED_PER_CASE_VAL) else None #CAREFUL: Here we use a different parsing function!
    #~~~~~Advanced Validation Sampling~~~~~~~~
    if configGet(trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_VAL) :
        #[[case1-weightMap1, ..., caseN-weightMap1], [case1-weightMap2,...,caseN-weightMap2]]
        listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal = [parseAbsFileLinesInList(getAbsPathEvenIfRelativeIsGiven(weightMapConfPath, trainConfigFilepath)) for weightMapConfPath in configGet(trainConfig.WEIGHT_MAPS_PER_CAT_FILEPATHS_VAL)]
    else :
        listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal = None
        
    trainSessionParameters = TrainSessionParameters(
                    sessionName = sessionName,
                    sessionLogger = sessionLogger,
                    mainOutputAbsFolder = mainOutputAbsFolder,
                    cnn3dInstance = cnn3dInstance,
                    cnnModelFilepath = filepathToCnnModel,
                    
                    #==================TRAINING====================
                    folderForSessionCnnModels = folderForSessionCnnModels,
                    listWithAListPerCaseWithFilepathPerChannelTrain = listWithAListPerCaseWithFilepathPerChannelTrain,
                    gtLabelsFilepathsTrain = gtLabelsFilepathsTrain,
                    
                    #[Optionals]
                    #~~~~~~~~~Sampling~~~~~~~~~
                    roiMasksFilepathsTrain = roiMasksFilepathsTrain,
                    percentOfSamplesToExtractPositTrain = configGet(trainConfig.PERC_POS_SAMPLES_TR),
                    #~~~~~~~~~Advanced Sampling~~~~~~~
                    useDefaultTrainingSamplingFromGtAndRoi = configGet(trainConfig.DEFAULT_TR_SAMPLING),
                    samplingTypeTraining = configGet(trainConfig.TYPE_OF_SAMPLING_TR),
                    proportionOfSamplesPerCategoryTrain = configGet(trainConfig.PROP_OF_SAMPLES_PER_CAT_TR),
                    listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain = listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesTrain,
                    
                    #~~~~~~~~Training Cycle ~~~~~~~
                    numberOfEpochs = configGet(trainConfig.NUM_EPOCHS),
                    numberOfSubepochs = configGet(trainConfig.NUM_SUBEP),
                    numOfCasesLoadedPerSubepoch = configGet(trainConfig.NUM_CASES_LOADED_PERSUB),
                    segmentsLoadedOnGpuPerSubepochTrain = configGet(trainConfig.NUM_TR_SEGMS_LOADED_PERSUB),
                    
                    #~~~~~~~ Learning Rate Schedule ~~~~~~~
                    #Auto requires performValidationOnSamplesThroughoutTraining and providedGtForValidationBool
                    stable0orAuto1orPredefined2orExponential3LrSchedule = configGet(trainConfig.LR_SCH_0123),
                    
                    #Stable + Auto + Predefined.
                    whenDecreasingDivideLrBy = configGet(trainConfig.DIV_LR_BY),
                    #Stable + Auto
                    numEpochsToWaitBeforeLoweringLr = configGet(trainConfig.NUM_EPOCHS_WAIT),
                    #Auto:
                    minIncreaseInValidationAccuracyThatResetsWaiting = configGet(trainConfig.AUTO_MIN_INCR_VAL_ACC),
                    #Predefined.
                    predefinedSchedule = configGet(trainConfig.PREDEF_SCH),
                    #Exponential
                    exponentialSchedForLrAndMom = configGet(trainConfig.EXPON_SCH),
                    
                    #~~~~~~~ Augmentation~~~~~~~~~~~~
                    reflectImagesPerAxis = configGet(trainConfig.REFL_AUGM_PER_AXIS),
                    performIntAugm = configGet(trainConfig.PERF_INT_AUGM_BOOL),
                    sampleIntAugmShiftWithMuAndStd = configGet(trainConfig.INT_AUGM_SHIF_MUSTD),
                    sampleIntAugmMultiWithMuAndStd = configGet(trainConfig.INT_AUGM_MULT_MUSTD),
                                                    
                    #==================VALIDATION=====================
                    performValidationOnSamplesThroughoutTraining = configGet(trainConfig.PERFORM_VAL_SAMPLES),
                    performFullInferenceOnValidationImagesEveryFewEpochs = configGet(trainConfig.PERFORM_VAL_INFERENCE),
                    #Required:
                    listWithAListPerCaseWithFilepathPerChannelVal = listWithAListPerCaseWithFilepathPerChannelVal,
                    gtLabelsFilepathsVal = gtLabelsFilepathsVal,
                    
                    segmentsLoadedOnGpuPerSubepochVal = configGet(trainConfig.NUM_VAL_SEGMS_LOADED_PERSUB),
                    
                    #[Optionals]
                    roiMasksFilepathsVal = roiMasksFilepathsVal, #For default sampling and for fast inference. Optional. Otherwise from whole image.
                    
                    #~~~~~~~~Full Inference~~~~~~~~
                    numberOfEpochsBetweenFullInferenceOnValImages = configGet(trainConfig.NUM_EPOCHS_BETWEEN_VAL_INF),
                    #Output                                
                    namesToSavePredictionsAndFeaturesVal = namesToSavePredsAndFeatsVal,
                    #predictions
                    saveSegmentationVal = configGet(trainConfig.SAVE_SEGM_VAL),
                    saveProbMapsBoolPerClassVal = configGet(trainConfig.SAVE_PROBMAPS_PER_CLASS_VAL),
                    folderForPredictionsVal = folderForPredictions,
                    #features:
                    saveIndividualFmImagesVal = configGet(trainConfig.SAVE_INDIV_FMS_VAL),
                    saveMultidimensionalImageWithAllFmsVal = configGet(trainConfig.SAVE_4DIM_FMS_VAL),
                    indicesOfFmsToVisualisePerPathwayAndLayerVal = [configGet(trainConfig.INDICES_OF_FMS_TO_SAVE_NORMAL_VAL)] +\
                                                                    [configGet(trainConfig.INDICES_OF_FMS_TO_SAVE_SUBSAMPLED_VAL)] +\
                                                                    [configGet(trainConfig.INDICES_OF_FMS_TO_SAVE_FC_VAL)],
                    folderForFeaturesVal = folderForFeatures,
                    
                    #~~~~~~~~ Advanced Validation Sampling ~~~~~~~~~~
                    useDefaultUniformValidationSampling = configGet(trainConfig.DEFAULT_VAL_SAMPLING),
                    samplingTypeValidation = configGet(trainConfig.TYPE_OF_SAMPLING_VAL),
                    proportionOfSamplesPerCategoryVal = configGet(trainConfig.PROP_OF_SAMPLES_PER_CAT_VAL),
                    listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal = listOfAListPerWeightMapCategoryWithFilepathsOfAllCasesVal,
                    
                    #====Optimization=====
                    learningRate=configGet(trainConfig.LRATE),
                    optimizerSgd0Adam1Rms2=configGet(trainConfig.OPTIMIZER),
                    classicMom0Nesterov1=configGet(trainConfig.MOM_TYPE),
                    momentumValue=configGet(trainConfig.MOM),
                    momNonNormalized0Normalized1=configGet(trainConfig.MOM_NORM_NONNORM),
                    #Adam
                    b1Adam=configGet(trainConfig.B1_ADAM),
                    b2Adam=configGet(trainConfig.B2_ADAM),
                    eAdam=configGet(trainConfig.EPS_ADAM),
                    #Rms
                    rhoRms=configGet(trainConfig.RHO_RMS),
                    eRms=configGet(trainConfig.EPS_RMS),
                    #Regularization
                    l1Reg=configGet(trainConfig.L1_REG),
                    l2Reg=configGet(trainConfig.L2_REG),
                    
                    #~~~~~~~ Freeze Layers ~~~~~~~
                    layersToFreezePerPathwayType = [configGet(trainConfig.LAYERS_TO_FREEZE_NORM),
                                                    configGet(trainConfig.LAYERS_TO_FREEZE_SUBS),
                                                    configGet(trainConfig.LAYERS_TO_FREEZE_FC) ],
                                                    
                    #==============Generic and Preprocessing===============
                    padInputImagesBool = configGet(trainConfig.PAD_INPUT)
                    )
    
    trainSessionParameters.sessionLogger.print3("\n===========       NEW TRAINING SESSION         ===============")
    trainSessionParameters.printParametersOfThisSession()
    
    trainSessionParameters.sessionLogger.print3("\n=======================================================")
    trainSessionParameters.sessionLogger.print3("=========== Compiling the Training Function ===========")
    trainSessionParameters.sessionLogger.print3("=======================================================")
    if not cnn3dInstance.checkTrainingStateAttributesInitialized() or resetOptimizer :
        trainSessionParameters.sessionLogger.print3("(Re)Initializing parameters for the optimization. " \
                        "Reason: Uninitialized: ["+str(not cnn3dInstance.checkTrainingStateAttributesInitialized())+"], Reset requested: ["+str(resetOptimizer)+"]" )
        cnn3dInstance.initializeTrainingState(*trainSessionParameters.getTupleForInitializingTrainingState())
    cnn3dInstance.compileTrainFunction(*trainSessionParameters.getTupleForCompilationOfTrainFunc())
    trainSessionParameters.sessionLogger.print3("\n=========== Compiling the Validation Function =========")
    cnn3dInstance.compileValidationFunction(*trainSessionParameters.getTupleForCompilationOfValFunc())
    trainSessionParameters.sessionLogger.print3("\n=========== Compiling the Testing Function ============")
    cnn3dInstance.compileTestAndVisualisationFunction(*trainSessionParameters.getTupleForCompilationOfTestFunc()) # For validation with full segmentation
    
    trainSessionParameters.sessionLogger.print3("\n=======================================================")
    trainSessionParameters.sessionLogger.print3("============== Training the CNN model =================")
    trainSessionParameters.sessionLogger.print3("=======================================================")
    do_training(*trainSessionParameters.getTupleForCnnTraining())
    trainSessionParameters.sessionLogger.print3("\n=======================================================")
    trainSessionParameters.sessionLogger.print3("=========== Training session finished =================")
    trainSessionParameters.sessionLogger.print3("=======================================================")