def run_session(self, *args): (sess_device, model_params, reset_trainer) = args graphTf = tf.Graph() with graphTf.as_default(): with graphTf.device( sess_device ): # Explicit device assignment, throws an error if GPU is specified but not available. self._log.print3( "=========== Making the CNN graph... ===============") cnn3d = Cnn3d() with tf.variable_scope("net"): cnn3d.make_cnn_model(*model_params.get_args_for_arch()) # I have now created the CNN graph. But not yet the Optimizer's graph. # No explicit device assignment for the rest. Because trained has piecewise_constant that is only on cpu, and so is saver. with tf.variable_scope("trainer"): self._log.print3("=========== Building Trainer ===========\n") trainer = Trainer(*(self._params.get_args_for_trainer() + [cnn3d])) trainer.create_optimizer(*self._params.get_args_for_optimizer( )) # Trainer and net connect here. # The below should not create any new tf.variables. self._log.print3( "=========== Compiling the Training Function ===========") self._log.print3( "=======================================================\n") cnn3d.setup_ops_n_feeds_to_train( self._log, trainer.get_total_cost(), trainer.get_param_updates_wrt_total_cost() # list of ops ) self._log.print3( "=========== Compiling the Validation Function =========") cnn3d.setup_ops_n_feeds_to_val(self._log) self._log.print3( "=========== Compiling the Testing Function ============") cnn3d.setup_ops_n_feeds_to_test( self._log, self._params.indices_fms_per_pathtype_per_layer_to_save ) # For validation with full segmentation # Create the savers saver_all = tf.train.Saver( ) # Will be used during training for saving everything. collection_vars_net = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope="net" ) # Alternative: tf.train.Saver([v for v in tf.all_variables() if v.name.startswith("net"]) saver_net = tf.train.Saver(var_list=collection_vars_net ) # Used to load the net's parameters. collection_vars_trainer = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope="trainer") saver_trainer = tf.train.Saver( var_list=collection_vars_trainer ) # Used to load the trainer's parameters. # self._print_vars_in_collection(collection_vars_net, "net") # self._print_vars_in_collection(collection_vars_trainer, "trainer") with tf.Session(graph=graphTf, config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, device_count={ 'CPU': 999, 'GPU': 99 })) as sessionTf: # with tf.Session( graph=graphTf, config=tf.ConfigProto(log_device_placement=False, device_count={'CPU':999, 'GPU':99}) # Load or initialize parameters file_to_load_params_from = self._params.get_path_to_load_model_from( ) if file_to_load_params_from is not None: # Load params self._log.print3( "=========== Loading parameters from specified saved model ===============" ) chkpt_fname = tf.train.latest_checkpoint( file_to_load_params_from) if os.path.isdir( file_to_load_params_from) else file_to_load_params_from self._log.print3("Loading checkpoint file:" + str(chkpt_fname)) self._log.print3("Loading network parameters...") try: saver_net.restore(sessionTf, chkpt_fname) self._log.print3("Network parameters were loaded.") except Exception as e: handle_exception_tf_restore(self._log, e) if not reset_trainer: self._log.print3("Loading trainer parameters...") saver_trainer.restore(sessionTf, chkpt_fname) self._log.print3("Trainer parameters were loaded.") else: self._log.print3( "Reset of trainer parameters was requested. Re-initializing them..." ) tf.variables_initializer( var_list=collection_vars_trainer).run() self._log.print3("Trainer parameters re-initialized.") else: self._log.print3( "=========== Initializing network and trainer variables ===============" ) # tf.variables_initializer(var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) ).run() # Initializes all. # Initialize separate as below, so that in case I miss a variable, I will get an error and I will know. tf.variables_initializer(var_list=collection_vars_net).run() tf.variables_initializer( var_list=collection_vars_trainer).run() self._log.print3("All variables were initialized.") filename_to_save_with = self._params.filepath_to_save_models + ".initial." + datetimeNowAsStr( ) self._log.print3("Saving the initial model at:" + str(filename_to_save_with)) saver_all.save(sessionTf, filename_to_save_with + ".model.ckpt", write_meta_graph=False) # tf.train.write_graph( graph_or_graph_def=sessionTf.graph.as_graph_def(), logdir="", name=filename_to_save_with+".graph.pb", as_text=False) self._log.print3("") self._log.print3( "=======================================================") self._log.print3( "============== Training the CNN model =================") self._log.print3( "=======================================================\n") res_code = do_training( *([sessionTf, saver_all, cnn3d, trainer] + self._params.get_args_for_train_routine())) self._log.print3( "\n=======================================================") self._log.print3( "=========== Training session finished =================") self._log.print3( "=======================================================")
def do_training(sessionTf, saver_all, cnn3d, trainer, log, fileToSaveTrainedCnnModelTo, val_on_samples_during_train, savePredictedSegmAndProbsDict, namesForSavingSegmAndProbs, suffixForSegmAndProbsDict, listOfFilepathsToEachChannelOfEachPatientTraining, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientTraining, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, paths_to_wmaps_per_sampl_cat_per_subj_train, paths_to_wmaps_per_sampl_cat_per_subj_val, listOfFilepathsToRoiMaskOfEachPatientTraining, listOfFilepathsToRoiMaskOfEachPatientValidation, n_epochs, # Every epoch the CNN model is saved. num_subepochs, # per epoch. Every subepoch Accuracy is reported max_n_cases_per_subep_train, # Max num of subjects loaded every subepoch for segments extraction. n_samples_per_subep_train, n_samples_per_subep_val, num_parallel_proc_sampling, # -1: seq. 0: thread for sampling. >0: multiprocess sampling #-------Sampling Type--------- samplingTypeInstanceTraining, # Instance of the deepmedic/samplingType.SamplingType class for training and validation samplingTypeInstanceValidation, batchsize_train, batchsize_val_samples, batchsize_val_whole, #-------Preprocessing----------- pad_input_imgs, #-------Data Augmentation------- augm_img_prms, augm_sample_prms, # Validation val_on_whole_volumes, num_epochs_between_val_on_whole_volumes, #--------For FM visualisation--------- saveIndividualFmImagesForVisualisation, saveMultidimensionalImageWithAllFms, indicesOfFmsToVisualisePerPathwayTypeAndPerLayer, namesForSavingFms, #-------- Others -------- run_input_checks ): id_str = "[MAIN|PID:"+str(os.getpid())+"]" start_time_train = time.time() # I cannot pass cnn3d to the sampling function, because the pp module used to reload theano. # This created problems in the GPU when cnmem is used. Not sure this is needed with Tensorflow. Probably. cnn3dWrapper = CnnWrapperForSampling(cnn3d) args_for_sampling_train = ( log, "train", num_parallel_proc_sampling, run_input_checks, cnn3dWrapper, max_n_cases_per_subep_train, n_samples_per_subep_train, samplingTypeInstanceTraining, listOfFilepathsToEachChannelOfEachPatientTraining, listOfFilepathsToGtLabelsOfEachPatientTraining, listOfFilepathsToRoiMaskOfEachPatientTraining, paths_to_wmaps_per_sampl_cat_per_subj_train, pad_input_imgs, augm_img_prms, augm_sample_prms ) args_for_sampling_val = ( log, "val", num_parallel_proc_sampling, run_input_checks, cnn3dWrapper, max_n_cases_per_subep_train, n_samples_per_subep_val, samplingTypeInstanceValidation, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, listOfFilepathsToRoiMaskOfEachPatientValidation, paths_to_wmaps_per_sampl_cat_per_subj_val, pad_input_imgs, None, # no augmentation in val. None ) # no augmentation in val. sampling_job_submitted_train = False sampling_job_submitted_val = False # For parallel extraction of samples for next train/val while processing previous iteration. worker_pool = None if num_parallel_proc_sampling > -1 : # Use multiprocessing. worker_pool = ThreadPool(processes=1) # Or multiprocessing.Pool(...), same API. try: model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf) while model_num_epochs_trained < n_epochs : epoch = model_num_epochs_trained acc_monitor_for_ep_train = AccuracyOfEpochMonitorSegmentation(log, 0, model_num_epochs_trained, cnn3d.num_classes, num_subepochs) acc_monitor_for_ep_val = None if not val_on_samples_during_train else \ AccuracyOfEpochMonitorSegmentation(log, 1, model_num_epochs_trained, cnn3d.num_classes, num_subepochs ) val_on_whole_volumes_after_ep = False if val_on_whole_volumes and (model_num_epochs_trained+1) % num_epochs_between_val_on_whole_volumes == 0: val_on_whole_volumes_after_ep = True log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") log.print3("~~~~~~~~~~~~~\t Starting new Epoch! Epoch #"+str(epoch)+"/"+str(n_epochs)+" \t~~~~~~~~~~~~~") log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") start_time_ep = time.time() for subepoch in range(num_subepochs): log.print3("***************************************************************************************") log.print3("*******\t\t Starting new Subepoch: #"+str(subepoch)+"/"+str(num_subepochs)+" \t\t********") log.print3("***************************************************************************************") #-------------------------GET DATA FOR THIS SUBEPOCH's VALIDATION--------------------------------- if val_on_samples_during_train : if worker_pool is None: # Sequential processing. log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [VALIDATION] will be done by main thread.") (channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal) = getSampledDataAndLabelsForSubepoch( *args_for_sampling_val ) elif sampling_job_submitted_val : #It was done in parallel with the training of the previous epoch, just grab results. (channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal) = parallelJobToGetDataForNextValidation.get() sampling_job_submitted_val = False else : # Not previously submitted in case of first epoch or after a full-volumes validation. assert subepoch == 0 log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].") parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val) (channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal) = parallelJobToGetDataForNextValidation.get() sampling_job_submitted_val = False #------------------------SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING----------------- if worker_pool is not None: log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].") parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train) sampling_job_submitted_train = True #------------------------------------DO VALIDATION-------------------------------- log.print3("-V-V-V-V-V- Now Validating for this subepoch before commencing the training iterations... -V-V-V-V-V-") start_time_val_subep = time.time() # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested. num_batches_val = len(channsOfSegmentsForSubepPerPathwayVal[0]) // batchsize_val_samples trainOrValidateForSubepoch( log, sessionTf, "val", num_batches_val, batchsize_val_samples, cnn3d, subepoch, acc_monitor_for_ep_val, channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal ) end_time_val_subep = time.time() log.print3("TIMING: Validation on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_val_subep-start_time_val_subep)+" secs.") #-------------------------GET DATA FOR THIS SUBEPOCH's TRAINING--------------------------------- if worker_pool is None: # Sequential processing. log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [TRAINING] will be done by main thread.") (channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain) = getSampledDataAndLabelsForSubepoch( *args_for_sampling_train ) elif sampling_job_submitted_train: # Sampling job should have been done in parallel with previous train/val. Just grab results. (channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain) = parallelJobToGetDataForNextTraining.get() sampling_job_submitted_train = False else: # Not previously submitted in case of first epoch or after a full-volumes validation. assert subepoch == 0 log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].") parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train) (channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain) = parallelJobToGetDataForNextTraining.get() sampling_job_submitted_train = False #------------------------SUBMIT PARALLEL JOB TO GET VALIDATION/TRAINING DATA (if val is/not performed) FOR NEXT SUBEPOCH----------------- if worker_pool is not None and not (val_on_whole_volumes_after_ep and (subepoch == num_subepochs-1)): if val_on_samples_during_train : log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].") parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val) sampling_job_submitted_val = True else : log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].") parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train) sampling_job_submitted_train = True #-------------------------------START TRAINING IN BATCHES------------------------------ log.print3("-T-T-T-T-T- Now Training for this subepoch... This may take a few minutes... -T-T-T-T-T-") start_time_train_subep = time.time() # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested. num_batches_train = len(channsOfSegmentsForSubepPerPathwayTrain[0]) // batchsize_train trainOrValidateForSubepoch( log, sessionTf, "train", num_batches_train, batchsize_train, cnn3d, subepoch, acc_monitor_for_ep_train, channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain ) end_time_train_subep = time.time() log.print3("TIMING: Training on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_train_subep-start_time_train_subep)+" secs.") log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") log.print3("~~~~~~ Epoch #" + str(epoch) + " finished. Reporting Accuracy over whole epoch. ~~~~~~~" ) log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") if val_on_samples_during_train: acc_monitor_for_ep_val.reportMeanAccyracyOfEpoch() acc_monitor_for_ep_train.reportMeanAccyracyOfEpoch() mean_val_acc_of_ep = acc_monitor_for_ep_val.getMeanEmpiricalAccuracyOfEpoch() if val_on_samples_during_train else None trainer.run_updates_end_of_ep(log, sessionTf, mean_val_acc_of_ep) # Updates LR schedule if needed, and increases number of epochs trained. model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf) del acc_monitor_for_ep_train; del acc_monitor_for_ep_val; log.print3("SAVING: Epoch #"+str(epoch)+" finished. Saving CNN model.") filename_to_save_with = fileToSaveTrainedCnnModelTo + "." + datetimeNowAsStr() saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False ) end_time_ep = time.time() log.print3("TIMING: The whole Epoch #"+str(epoch)+" lasted: {0:.1f}".format(end_time_ep-start_time_ep)+" secs.") log.print3("~~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~~~") if val_on_whole_volumes_after_ep: log.print3("***Starting validation with Full Inference / Segmentation on validation subjects for Epoch #"+str(epoch)+"...***") res_code = inferenceWholeVolumes( sessionTf, cnn3d, log, "val", savePredictedSegmAndProbsDict, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, listOfFilepathsToRoiMaskOfEachPatientValidation, namesForSavingSegmAndProbs = namesForSavingSegmAndProbs, suffixForSegmAndProbsDict = suffixForSegmAndProbsDict, # Hyper parameters batchsize = batchsize_val_whole, #----Preprocessing------ pad_input_imgs = pad_input_imgs, #--------For FM visualisation--------- saveIndividualFmImagesForVisualisation = saveIndividualFmImagesForVisualisation, saveMultidimensionalImageWithAllFms = saveMultidimensionalImageWithAllFms, indicesOfFmsToVisualisePerPathwayTypeAndPerLayer = indicesOfFmsToVisualisePerPathwayTypeAndPerLayer, namesForSavingFms = namesForSavingFms ) end_time_train = time.time() log.print3("TIMING: Training process lasted: {0:.1f}".format(end_time_train-start_time_train)+" secs.") except (Exception, KeyboardInterrupt) as e: log.print3("\n\n ERROR: Caught exception in do_training(): " + str(e) + "\n") log.print3( traceback.format_exc() ) if worker_pool is not None: log.print3("Terminating worker pool.") worker_pool.terminate() worker_pool.join() # Will wait. A KeybInt will kill this (py3) return 1 else: if worker_pool is not None: log.print3("Closing worker pool.") worker_pool.close() worker_pool.join() # Save the final trained model. filename_to_save_with = fileToSaveTrainedCnnModelTo + ".final." + datetimeNowAsStr() log.print3("Saving the final model at:" + str(filename_to_save_with)) saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False ) log.print3("The whole do_training() function has finished.") return 0
def do_training( sessionTf, saver_all, cnn3d, trainer, log, fileToSaveTrainedCnnModelTo, performValidationOnSamplesDuringTrainingProcessBool, savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation, listOfNamesToGiveToPredictionsValidationIfSavingWhenEvalDice, listOfFilepathsToEachChannelOfEachPatientTraining, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientTraining, providedGtForValidationBool, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, providedWeightMapsToSampleForEachCategoryTraining, forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining, providedWeightMapsToSampleForEachCategoryValidation, forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation, providedRoiMaskForTrainingBool, listOfFilepathsToRoiMaskOfEachPatientTraining, # Also needed for normalization-augmentation providedRoiMaskForValidationBool, listOfFilepathsToRoiMaskOfEachPatientValidation, n_epochs, # Every epoch the CNN model is saved. number_of_subepochs, # per epoch. Every subepoch Accuracy is reported maxNumSubjectsLoadedPerSubepoch, # Max num of cases loaded every subepoch for segments extraction. The more, the longer loading. imagePartsLoadedInGpuPerSubepoch, imagePartsLoadedInGpuPerSubepochValidation, #-------Sampling Type--------- samplingTypeInstanceTraining, # Instance of the deepmedic/samplingType.SamplingType class for training and validation samplingTypeInstanceValidation, #-------Preprocessing----------- padInputImagesBool, #-------Data Augmentation------- doIntAugm_shiftMuStd_multiMuStd, reflectImageWithHalfProbDuringTraining, useSameSubChannelsAsSingleScale, listOfFilepathsToEachSubsampledChannelOfEachPatientTraining, # deprecated, not supported listOfFilepathsToEachSubsampledChannelOfEachPatientValidation, # deprecated, not supported # Validation performFullInferenceOnValidationImagesEveryFewEpochsBool, #Even if not providedGtForValidationBool, inference will be performed if this == True, to save the results, eg for visual. everyThatManyEpochsComputeDiceOnTheFullValidationImages, # Should not be == 0, except if performFullInferenceOnValidationImagesEveryFewEpochsBool == False #--------For FM visualisation--------- saveIndividualFmImagesForVisualisation, saveMultidimensionalImageWithAllFms, indicesOfFmsToVisualisePerPathwayTypeAndPerLayer, listOfNamesToGiveToFmVisualisationsIfSaving, #-------- Others -------- run_input_checks): start_training_time = time.time() # Used because I cannot pass cnn3d to the sampling function. #This is because the parallel process used to load theano again. And created problems in the GPU when cnmem is used. Not sure this is needed with Tensorflow. Probably. cnn3dWrapper = CnnWrapperForSampling(cnn3d) #---------To run PARALLEL the extraction of parts for the next subepoch--- ppservers = () # tuple of all parallel python servers to connect with job_server = pp.Server( ncpus=1, ppservers=ppservers ) # Creates jobserver with automatically detected number of workers tupleWithParametersForTraining = ( log, "train", run_input_checks, cnn3dWrapper, maxNumSubjectsLoadedPerSubepoch, imagePartsLoadedInGpuPerSubepoch, samplingTypeInstanceTraining, listOfFilepathsToEachChannelOfEachPatientTraining, listOfFilepathsToGtLabelsOfEachPatientTraining, providedRoiMaskForTrainingBool, listOfFilepathsToRoiMaskOfEachPatientTraining, providedWeightMapsToSampleForEachCategoryTraining, forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining, useSameSubChannelsAsSingleScale, listOfFilepathsToEachSubsampledChannelOfEachPatientTraining, padInputImagesBool, doIntAugm_shiftMuStd_multiMuStd, reflectImageWithHalfProbDuringTraining) tupleWithParametersForValidation = ( log, "val", run_input_checks, cnn3dWrapper, maxNumSubjectsLoadedPerSubepoch, imagePartsLoadedInGpuPerSubepochValidation, samplingTypeInstanceValidation, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, providedRoiMaskForValidationBool, listOfFilepathsToRoiMaskOfEachPatientValidation, providedWeightMapsToSampleForEachCategoryValidation, forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation, useSameSubChannelsAsSingleScale, listOfFilepathsToEachSubsampledChannelOfEachPatientValidation, padInputImagesBool, [0, -1, -1, -1], #don't perform intensity-augmentation during validation. [0, 0, 0] #don't perform reflection-augmentation during validation. ) tupleWithLocalFunctionsThatWillBeCalledByTheMainJob = () tupleWithModulesToImportWhichAreUsedByTheJobFunctions = ( "from __future__ import absolute_import, print_function, division", "time", "numpy as np", "from deepmedic.dataManagement.sampling import *") boolItIsTheVeryFirstSubepochOfThisProcess = True #to know so that in the very first I sequencially load the data for it. #------End for parallel------ model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval( session=sessionTf) while model_num_epochs_trained < n_epochs: epoch = model_num_epochs_trained trainingAccuracyMonitorForEpoch = AccuracyOfEpochMonitorSegmentation( log, 0, model_num_epochs_trained, cnn3d.num_classes, number_of_subepochs) validationAccuracyMonitorForEpoch = None if not performValidationOnSamplesDuringTrainingProcessBool else \ AccuracyOfEpochMonitorSegmentation(log, 1, model_num_epochs_trained, cnn3d.num_classes, number_of_subepochs ) log.print3( "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" ) log.print3("~~~~~~~~~~~~~~~~~~~~Starting new Epoch! Epoch #" + str(epoch) + "/" + str(n_epochs) + " ~~~~~~~~~~~~~~~~~~~~~~~~~") log.print3( "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" ) start_epoch_time = time.time() for subepoch in range( number_of_subepochs ): #per subepoch I randomly load some images in the gpu. Random order. log.print3( "**************************************************************************************************" ) log.print3("************* Starting new Subepoch: #" + str(subepoch) + "/" + str(number_of_subepochs) + " *************") log.print3( "**************************************************************************************************" ) #-------------------------GET DATA FOR THIS SUBEPOCH's VALIDATION--------------------------------- if performValidationOnSamplesDuringTrainingProcessBool: if boolItIsTheVeryFirstSubepochOfThisProcess: [ channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal ] = getSampledDataAndLabelsForSubepoch( log, "val", run_input_checks, cnn3dWrapper, maxNumSubjectsLoadedPerSubepoch, imagePartsLoadedInGpuPerSubepochValidation, samplingTypeInstanceValidation, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, providedRoiMaskForValidationBool, listOfFilepathsToRoiMaskOfEachPatientValidation, providedWeightMapsToSampleForEachCategoryValidation, forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation, useSameSubChannelsAsSingleScale, listOfFilepathsToEachSubsampledChannelOfEachPatientValidation, padInputImagesBool, doIntAugm_shiftMuStd_multiMuStd=[False, [], []], reflectImageWithHalfProbDuringTraining=[0, 0, 0]) boolItIsTheVeryFirstSubepochOfThisProcess = False else: #It was done in parallel with the training of the previous epoch, just grab the results... [ channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal ] = parallelJobToGetDataForNextValidation( ) #fromParallelProcessing that had started from last loop when it was submitted. # Below is computed with number of extracted samples, in case I dont manage to extract as many as I wanted initially. numberOfBatchesValidation = len( channsOfSegmentsForSubepPerPathwayVal[0] ) // cnn3d.batchSize["val"] #------------------------SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING----------------- #submit the parallel job log.print3( "PARALLEL: Before Validation in subepoch #" + str(subepoch) + ", the parallel job for extracting Segments for the next Training is submitted." ) parallelJobToGetDataForNextTraining = job_server.submit( getSampledDataAndLabelsForSubepoch, #local function to call and execute in parallel. tupleWithParametersForTraining, #tuple with the arguments required tupleWithLocalFunctionsThatWillBeCalledByTheMainJob, #tuple of local functions that I need to call tupleWithModulesToImportWhichAreUsedByTheJobFunctions ) #tuple of the external modules that I need, of which I am calling functions (not the mods of the ext-functions). #------------------------------------DO VALIDATION-------------------------------- log.print3( "-V-V-V-V-V- Now Validating for this subepoch before commencing the training iterations... -V-V-V-V-V-" ) start_validationForSubepoch_time = time.time() doTrainOrValidationOnBatchesAndReturnMeanAccuraciesOfSubepoch( log, sessionTf, "val", numberOfBatchesValidation, # Computed by the number of extracted samples. So, adapts. cnn3d, subepoch, validationAccuracyMonitorForEpoch, channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal) end_validationForSubepoch_time = time.time() log.print3( "TIMING: Validating on the batches of this subepoch #" + str(subepoch) + " took time: " + str(end_validationForSubepoch_time - start_validationForSubepoch_time) + "(s)") #-------------------END OF THE VALIDATION-DURING-TRAINING-LOOP------------------------- #-------------------------GET DATA FOR THIS SUBEPOCH's TRAINING--------------------------------- if (not performValidationOnSamplesDuringTrainingProcessBool ) and boolItIsTheVeryFirstSubepochOfThisProcess: [ channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain ] = getSampledDataAndLabelsForSubepoch( log, "train", run_input_checks, cnn3dWrapper, maxNumSubjectsLoadedPerSubepoch, imagePartsLoadedInGpuPerSubepoch, samplingTypeInstanceTraining, listOfFilepathsToEachChannelOfEachPatientTraining, listOfFilepathsToGtLabelsOfEachPatientTraining, providedRoiMaskForTrainingBool, listOfFilepathsToRoiMaskOfEachPatientTraining, providedWeightMapsToSampleForEachCategoryTraining, forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining, useSameSubChannelsAsSingleScale, listOfFilepathsToEachSubsampledChannelOfEachPatientTraining, padInputImagesBool, doIntAugm_shiftMuStd_multiMuStd, reflectImageWithHalfProbDuringTraining) boolItIsTheVeryFirstSubepochOfThisProcess = False else: #It was done in parallel with the validation (or with previous training iteration, in case I am not performing validation). [ channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain ] = parallelJobToGetDataForNextTraining( ) #fromParallelProcessing that had started from last loop when it was submitted. numberOfBatchesTraining = len( channsOfSegmentsForSubepPerPathwayTrain[0] ) // cnn3d.batchSize[ "train"] #Computed with number of extracted samples, in case I dont manage to extract as many as I wanted initially. #------------------------SUBMIT PARALLEL JOB TO GET VALIDATION/TRAINING DATA (if val is/not performed) FOR NEXT SUBEPOCH----------------- if performValidationOnSamplesDuringTrainingProcessBool: #submit the parallel job log.print3( "PARALLEL: Before Training in subepoch #" + str(subepoch) + ", submitting the parallel job for extracting Segments for the next Validation." ) parallelJobToGetDataForNextValidation = job_server.submit( getSampledDataAndLabelsForSubepoch, #local function to call and execute in parallel. tupleWithParametersForValidation, #tuple with the arguments required tupleWithLocalFunctionsThatWillBeCalledByTheMainJob, #tuple of local functions that I need to call tupleWithModulesToImportWhichAreUsedByTheJobFunctions ) #tuple of the external modules that I need, of which I am calling functions (not the mods of the ext-functions). else: #extract in parallel the samples for the next subepoch's training. log.print3( "PARALLEL: Before Training in subepoch #" + str(subepoch) + ", submitting the parallel job for extracting Segments for the next Training." ) parallelJobToGetDataForNextTraining = job_server.submit( getSampledDataAndLabelsForSubepoch, #local function to call and execute in parallel. tupleWithParametersForTraining, #tuple with the arguments required tupleWithLocalFunctionsThatWillBeCalledByTheMainJob, #tuple of local functions that I need to call tupleWithModulesToImportWhichAreUsedByTheJobFunctions ) #tuple of the external modules that I need, of which I am calling #-------------------------------START TRAINING IN BATCHES------------------------------ log.print3( "-T-T-T-T-T- Now Training for this subepoch... This may take a few minutes... -T-T-T-T-T-" ) start_trainingForSubepoch_time = time.time() doTrainOrValidationOnBatchesAndReturnMeanAccuraciesOfSubepoch( log, sessionTf, "train", numberOfBatchesTraining, cnn3d, subepoch, trainingAccuracyMonitorForEpoch, channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain) end_trainingForSubepoch_time = time.time() log.print3("TIMING: Training on the batches of this subepoch #" + str(subepoch) + " took time: " + str(end_trainingForSubepoch_time - start_trainingForSubepoch_time) + "(s)") log.print3( "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" ) log.print3( "~~~~~~~~~~~~~~~~~~ Epoch #" + str(epoch) + " finished. Reporting Accuracy over whole epoch. ~~~~~~~~~~~~~~~~~~" ) log.print3( "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" ) if performValidationOnSamplesDuringTrainingProcessBool: validationAccuracyMonitorForEpoch.reportMeanAccyracyOfEpoch() trainingAccuracyMonitorForEpoch.reportMeanAccyracyOfEpoch() mean_val_acc_of_ep = validationAccuracyMonitorForEpoch.getMeanEmpiricalAccuracyOfEpoch( ) if performValidationOnSamplesDuringTrainingProcessBool else None trainer.run_updates_end_of_ep( log, sessionTf, mean_val_acc_of_ep ) # Updates LR schedule if needed, and increases number of epochs trained. model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval( session=sessionTf) del trainingAccuracyMonitorForEpoch del validationAccuracyMonitorForEpoch #================== Everything for epoch has finished. ======================= log.print3("SAVING: Epoch #" + str(epoch) + " finished. Saving CNN model.") filename_to_save_with = fileToSaveTrainedCnnModelTo + "." + datetimeNowAsStr( ) saver_all.save(sessionTf, filename_to_save_with + ".model.ckpt", write_meta_graph=False) end_epoch_time = time.time() log.print3("TIMING: The whole Epoch #" + str(epoch) + " took time: " + str(end_epoch_time - start_epoch_time) + "(s)") log.print3( "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" ) if performFullInferenceOnValidationImagesEveryFewEpochsBool and ( model_num_epochs_trained != 0 ) and (model_num_epochs_trained % everyThatManyEpochsComputeDiceOnTheFullValidationImages == 0): log.print3( "***Starting validation with Full Inference / Segmentation on validation subjects for Epoch #" + str(epoch) + "...***") performInferenceOnWholeVolumes( sessionTf, cnn3d, log, "val", savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation, listOfFilepathsToEachChannelOfEachPatientValidation, providedGtForValidationBool, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, providedRoiMaskForValidationBool, listOfFilepathsToRoiMaskOfEachPatientValidation, listOfNamesToGiveToPredictionsIfSavingResults="Placeholder" if not savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation else listOfNamesToGiveToPredictionsValidationIfSavingWhenEvalDice, #----Preprocessing------ padInputImagesBool=padInputImagesBool, #for the cnn extension useSameSubChannelsAsSingleScale=useSameSubChannelsAsSingleScale, listOfFilepathsToEachSubsampledChannelOfEachPatient= listOfFilepathsToEachSubsampledChannelOfEachPatientValidation, #--------For FM visualisation--------- saveIndividualFmImagesForVisualisation= saveIndividualFmImagesForVisualisation, saveMultidimensionalImageWithAllFms= saveMultidimensionalImageWithAllFms, indicesOfFmsToVisualisePerPathwayTypeAndPerLayer= indicesOfFmsToVisualisePerPathwayTypeAndPerLayer, listOfNamesToGiveToFmVisualisationsIfSaving= listOfNamesToGiveToFmVisualisationsIfSaving) end_training_time = time.time() log.print3("TIMING: Training process took time: " + str(end_training_time - start_training_time) + "(s)") log.print3("The whole do_training() function has finished.")
def do_training(sessionTf, saver_all, cnn3d, trainer, log, fileToSaveTrainedCnnModelTo, val_on_samples_during_train, savePredictedSegmAndProbsDict, namesForSavingSegmAndProbs, suffixForSegmAndProbsDict, listOfFilepathsToEachChannelOfEachPatientTraining, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientTraining, providedGtForValidationBool, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, providedWeightMapsToSampleForEachCategoryTraining, paths_to_wmaps_per_sampl_cat_per_subj_train, providedWeightMapsToSampleForEachCategoryValidation, paths_to_wmaps_per_sampl_cat_per_subj_val, providedRoiMaskForTrainingBool, listOfFilepathsToRoiMaskOfEachPatientTraining, # Also needed for normalization-augmentation providedRoiMaskForValidationBool, listOfFilepathsToRoiMaskOfEachPatientValidation, n_epochs, # Every epoch the CNN model is saved. num_subepochs, # per epoch. Every subepoch Accuracy is reported max_n_cases_per_subep_train, # Max num of subjects loaded every subepoch for segments extraction. n_samples_per_subep_train, n_samples_per_subep_val, num_parallel_proc_sampling, # -1: seq. 0: thread for sampling. >0: multiprocess sampling #-------Sampling Type--------- samplingTypeInstanceTraining, # Instance of the deepmedic/samplingType.SamplingType class for training and validation samplingTypeInstanceValidation, batchsize_train, batchsize_val_samples, batchsize_val_whole, #-------Preprocessing----------- padInputImagesBool, #-------Data Augmentation------- augm_params, # Validation val_on_whole_volumes, num_epochs_between_val_on_whole_volumes, #--------For FM visualisation--------- saveIndividualFmImagesForVisualisation, saveMultidimensionalImageWithAllFms, indicesOfFmsToVisualisePerPathwayTypeAndPerLayer, namesForSavingFms, #-------- Others -------- run_input_checks ): id_str = "[MAIN|PID:"+str(os.getpid())+"]" start_time_train = time.time() # I cannot pass cnn3d to the sampling function, because the pp module used to reload theano. # This created problems in the GPU when cnmem is used. Not sure this is needed with Tensorflow. Probably. cnn3dWrapper = CnnWrapperForSampling(cnn3d) args_for_sampling_train = ( log, "train", num_parallel_proc_sampling, run_input_checks, cnn3dWrapper, max_n_cases_per_subep_train, n_samples_per_subep_train, samplingTypeInstanceTraining, listOfFilepathsToEachChannelOfEachPatientTraining, listOfFilepathsToGtLabelsOfEachPatientTraining, providedRoiMaskForTrainingBool, listOfFilepathsToRoiMaskOfEachPatientTraining, providedWeightMapsToSampleForEachCategoryTraining, paths_to_wmaps_per_sampl_cat_per_subj_train, padInputImagesBool, augm_params ) args_for_sampling_val = ( log, "val", num_parallel_proc_sampling, run_input_checks, cnn3dWrapper, max_n_cases_per_subep_train, n_samples_per_subep_val, samplingTypeInstanceValidation, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, providedRoiMaskForValidationBool, listOfFilepathsToRoiMaskOfEachPatientValidation, providedWeightMapsToSampleForEachCategoryValidation, paths_to_wmaps_per_sampl_cat_per_subj_val, padInputImagesBool, None ) # no augmentation in validation. sampling_job_submitted_train = False sampling_job_submitted_val = False # For parallel extraction of samples for next train/val while processing previous iteration. worker_pool = None if num_parallel_proc_sampling > -1 : # Use multiprocessing. worker_pool = ThreadPool(processes=1) # Or multiprocessing.Pool(...), same API. try: model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf) while model_num_epochs_trained < n_epochs : epoch = model_num_epochs_trained acc_monitor_for_ep_train = AccuracyOfEpochMonitorSegmentation(log, 0, model_num_epochs_trained, cnn3d.num_classes, num_subepochs) acc_monitor_for_ep_val = None if not val_on_samples_during_train else \ AccuracyOfEpochMonitorSegmentation(log, 1, model_num_epochs_trained, cnn3d.num_classes, num_subepochs ) val_on_whole_volumes_after_ep = False if val_on_whole_volumes and (model_num_epochs_trained+1) % num_epochs_between_val_on_whole_volumes == 0: val_on_whole_volumes_after_ep = True log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") log.print3("~~~~~~~~~~~~~\t Starting new Epoch! Epoch #"+str(epoch)+"/"+str(n_epochs)+" \t~~~~~~~~~~~~~") log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") start_time_ep = time.time() for subepoch in range(num_subepochs): log.print3("***************************************************************************************") log.print3("*******\t\t Starting new Subepoch: #"+str(subepoch)+"/"+str(num_subepochs)+" \t\t********") log.print3("***************************************************************************************") #-------------------------GET DATA FOR THIS SUBEPOCH's VALIDATION--------------------------------- if val_on_samples_during_train : if worker_pool is None: # Sequential processing. log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [VALIDATION] will be done by main thread.") [channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal] = getSampledDataAndLabelsForSubepoch( *args_for_sampling_val ) elif sampling_job_submitted_val : #It was done in parallel with the training of the previous epoch, just grab results. [channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal] = parallelJobToGetDataForNextValidation.get() sampling_job_submitted_val = False else : # Not previously submitted in case of first epoch or after a full-volumes validation. assert subepoch == 0 log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].") parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val) [channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal] = parallelJobToGetDataForNextValidation.get() sampling_job_submitted_val = False #------------------------SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING----------------- if worker_pool is not None: log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].") parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train) sampling_job_submitted_train = True #------------------------------------DO VALIDATION-------------------------------- log.print3("-V-V-V-V-V- Now Validating for this subepoch before commencing the training iterations... -V-V-V-V-V-") start_time_val_subep = time.time() # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested. num_batches_val = len(channsOfSegmentsForSubepPerPathwayVal[0]) // batchsize_val_samples trainOrValidateForSubepoch( log, sessionTf, "val", num_batches_val, batchsize_val_samples, cnn3d, subepoch, acc_monitor_for_ep_val, channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal ) end_time_val_subep = time.time() log.print3("TIMING: Validation on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_val_subep-start_time_val_subep)+" secs.") #-------------------------GET DATA FOR THIS SUBEPOCH's TRAINING--------------------------------- if worker_pool is None: # Sequential processing. log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [TRAINING] will be done by main thread.") [channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain] = getSampledDataAndLabelsForSubepoch( *args_for_sampling_train ) elif sampling_job_submitted_train: # Sampling job should have been done in parallel with previous train/val. Just grab results. [channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain] = parallelJobToGetDataForNextTraining.get() sampling_job_submitted_train = False else: # Not previously submitted in case of first epoch or after a full-volumes validation. assert subepoch == 0 log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].") parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train) [channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain] = parallelJobToGetDataForNextTraining.get() sampling_job_submitted_train = False #------------------------SUBMIT PARALLEL JOB TO GET VALIDATION/TRAINING DATA (if val is/not performed) FOR NEXT SUBEPOCH----------------- if worker_pool is not None and not (val_on_whole_volumes_after_ep and (subepoch == num_subepochs-1)): if val_on_samples_during_train : log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].") parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val) sampling_job_submitted_val = True else : log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].") parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train) sampling_job_submitted_train = True #-------------------------------START TRAINING IN BATCHES------------------------------ log.print3("-T-T-T-T-T- Now Training for this subepoch... This may take a few minutes... -T-T-T-T-T-") start_time_train_subep = time.time() # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested. num_batches_train = len(channsOfSegmentsForSubepPerPathwayTrain[0]) // batchsize_train trainOrValidateForSubepoch( log, sessionTf, "train", num_batches_train, batchsize_train, cnn3d, subepoch, acc_monitor_for_ep_train, channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain ) end_time_train_subep = time.time() log.print3("TIMING: Training on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_train_subep-start_time_train_subep)+" secs.") log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") log.print3("~~~~~~ Epoch #" + str(epoch) + " finished. Reporting Accuracy over whole epoch. ~~~~~~~" ) log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") if val_on_samples_during_train: acc_monitor_for_ep_val.reportMeanAccyracyOfEpoch() acc_monitor_for_ep_train.reportMeanAccyracyOfEpoch() mean_val_acc_of_ep = acc_monitor_for_ep_val.getMeanEmpiricalAccuracyOfEpoch() if val_on_samples_during_train else None trainer.run_updates_end_of_ep(log, sessionTf, mean_val_acc_of_ep) # Updates LR schedule if needed, and increases number of epochs trained. model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf) del acc_monitor_for_ep_train; del acc_monitor_for_ep_val; log.print3("SAVING: Epoch #"+str(epoch)+" finished. Saving CNN model.") filename_to_save_with = fileToSaveTrainedCnnModelTo + "." + datetimeNowAsStr() saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False ) end_time_ep = time.time() log.print3("TIMING: The whole Epoch #"+str(epoch)+" lasted: {0:.1f}".format(end_time_ep-start_time_ep)+" secs.") log.print3("~~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~~~") if val_on_whole_volumes_after_ep: log.print3("***Starting validation with Full Inference / Segmentation on validation subjects for Epoch #"+str(epoch)+"...***") res_code = inferenceWholeVolumes( sessionTf, cnn3d, log, "val", savePredictedSegmAndProbsDict, listOfFilepathsToEachChannelOfEachPatientValidation, providedGtForValidationBool, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, providedRoiMaskForValidationBool, listOfFilepathsToRoiMaskOfEachPatientValidation, namesForSavingSegmAndProbs = namesForSavingSegmAndProbs, suffixForSegmAndProbsDict = suffixForSegmAndProbsDict, # Hyper parameters batchsize = batchsize_val_whole, #----Preprocessing------ padInputImagesBool=padInputImagesBool, #--------For FM visualisation--------- saveIndividualFmImagesForVisualisation=saveIndividualFmImagesForVisualisation, saveMultidimensionalImageWithAllFms=saveMultidimensionalImageWithAllFms, indicesOfFmsToVisualisePerPathwayTypeAndPerLayer=indicesOfFmsToVisualisePerPathwayTypeAndPerLayer, namesForSavingFms=namesForSavingFms ) end_time_train = time.time() log.print3("TIMING: Training process lasted: {0:.1f}".format(end_time_train-start_time_train)+" secs.") except (Exception, KeyboardInterrupt) as e: log.print3("\n\n ERROR: Caught exception in do_training(): " + str(e) + "\n") log.print3( traceback.format_exc() ) if worker_pool is not None: log.print3("Terminating worker pool.") worker_pool.terminate() worker_pool.join() # Will wait. A KeybInt will kill this (py3) return 1 else: if worker_pool is not None: log.print3("Closing worker pool.") worker_pool.close() worker_pool.join() # Save the final trained model. filename_to_save_with = fileToSaveTrainedCnnModelTo + ".final." + datetimeNowAsStr() log.print3("Saving the final model at:" + str(filename_to_save_with)) saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False ) log.print3("The whole do_training() function has finished.") return 0
def run_session(self, *args): (sess_device, model_params, reset_trainer) = args graphTf = tf.Graph() with graphTf.as_default(): with graphTf.device(sess_device): # Explicit device assignment, throws an error if GPU is specified but not available. self._log.print3("=========== Making the CNN graph... ===============") cnn3d = Cnn3d() with tf.variable_scope("net"): cnn3d.make_cnn_model( *model_params.get_args_for_arch() ) # I have now created the CNN graph. But not yet the Optimizer's graph. # No explicit device assignment for the rest. Because trained has piecewise_constant that is only on cpu, and so is saver. with tf.variable_scope("trainer"): self._log.print3("=========== Building Trainer ===========\n") trainer = Trainer( *( self._params.get_args_for_trainer() + [cnn3d] ) ) trainer.create_optimizer( *self._params.get_args_for_optimizer() ) # Trainer and net connect here. # The below should not create any new tf.variables. self._log.print3("=========== Compiling the Training Function ===========") self._log.print3("=======================================================\n") cnn3d.setup_ops_n_feeds_to_train( self._log, trainer.get_total_cost(), trainer.get_param_updates_wrt_total_cost() # list of ops ) self._log.print3("=========== Compiling the Validation Function =========") cnn3d.setup_ops_n_feeds_to_val( self._log ) self._log.print3("=========== Compiling the Testing Function ============") cnn3d.setup_ops_n_feeds_to_test( self._log, self._params.indices_fms_per_pathtype_per_layer_to_save ) # For validation with full segmentation # Create the savers saver_all = tf.train.Saver() # Will be used during training for saving everything. collection_vars_net = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="net") # Alternative: tf.train.Saver([v for v in tf.all_variables() if v.name.startswith("net"]) saver_net = tf.train.Saver( var_list = collection_vars_net ) # Used to load the net's parameters. collection_vars_trainer = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="trainer") saver_trainer = tf.train.Saver( var_list = collection_vars_trainer ) # Used to load the trainer's parameters. # self._print_vars_in_collection(collection_vars_net, "net") # self._print_vars_in_collection(collection_vars_trainer, "trainer") with tf.Session( graph=graphTf, config=tf.ConfigProto(log_device_placement=False, device_count={'CPU':999, 'GPU':99}) ) as sessionTf: # Load or initialize parameters file_to_load_params_from = self._params.get_path_to_load_model_from() if file_to_load_params_from is not None: # Load params self._log.print3("=========== Loading parameters from specified saved model ===============") chkpt_fname = tf.train.latest_checkpoint( file_to_load_params_from ) if os.path.isdir( file_to_load_params_from ) else file_to_load_params_from self._log.print3("Loading checkpoint file:" + str(chkpt_fname)) self._log.print3("Loading network parameters...") try: saver_net.restore(sessionTf, chkpt_fname) self._log.print3("Network parameters were loaded.") except Exception as e: handle_exception_tf_restore(self._log, e) if not reset_trainer: self._log.print3("Loading trainer parameters...") saver_trainer.restore(sessionTf, chkpt_fname) self._log.print3("Trainer parameters were loaded.") else: self._log.print3("Reset of trainer parameters was requested. Re-initializing them...") tf.variables_initializer(var_list = collection_vars_trainer).run() self._log.print3("Trainer parameters re-initialized.") else : self._log.print3("=========== Initializing network and trainer variables ===============") # tf.variables_initializer(var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) ).run() # Initializes all. # Initialize separate as below, so that in case I miss a variable, I will get an error and I will know. tf.variables_initializer(var_list = collection_vars_net).run() tf.variables_initializer(var_list = collection_vars_trainer).run() self._log.print3("All variables were initialized.") filename_to_save_with = self._params.filepath_to_save_models + ".initial." + datetimeNowAsStr() self._log.print3("Saving the initial model at:" + str(filename_to_save_with)) saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False ) # tf.train.write_graph( graph_or_graph_def=sessionTf.graph.as_graph_def(), logdir="", name=filename_to_save_with+".graph.pb", as_text=False) self._log.print3("") self._log.print3("=======================================================") self._log.print3("============== Training the CNN model =================") self._log.print3("=======================================================\n") res_code = do_training( *( [sessionTf, saver_all, cnn3d, trainer] + self._params.get_args_for_train_routine() ) ) self._log.print3("\n=======================================================") self._log.print3("=========== Training session finished =================") self._log.print3("=======================================================")