def run_session(self, *args): (sess_device, model_params,) = args graphTf = tf.Graph() with graphTf.as_default(): with graphTf.device(sess_device): # Throws an error if GPU is specified but not available. self._log.print3("=========== Making the CNN graph... ===============") cnn3d = Cnn3d() with tf.variable_scope("net"): cnn3d.make_cnn_model( *model_params.get_args_for_arch() ) # Creates the network's graph (without optimizer). self._log.print3("=========== Compiling the Testing Function ============") self._log.print3("=======================================================\n") cnn3d.setup_ops_n_feeds_to_test( self._log, self._params.indices_fms_per_pathtype_per_layer_to_save ) # Create the saver saver_all = tf.train.Saver() # saver_net would suffice with tf.Session( graph=graphTf, config=tf.ConfigProto(log_device_placement=False, device_count={'CPU':999, 'GPU':99}) ) as sessionTf: file_to_load_params_from = self._params.get_path_to_load_model_from() if file_to_load_params_from is not None: # Load params self._log.print3("=========== Loading parameters from specified saved model ===============") chkpt_fname = tf.train.latest_checkpoint( file_to_load_params_from ) if os.path.isdir( file_to_load_params_from ) else file_to_load_params_from self._log.print3("Loading parameters from:" + str(chkpt_fname)) try: saver_all.restore(sessionTf, chkpt_fname) self._log.print3("Parameters were loaded.") except Exception as e: handle_exception_tf_restore(self._log, e) else: self._ask_user_if_test_with_random() # Asks user whether to continue with randomly initialized model. It exits if no is given. self._log.print3("") self._log.print3("=========== Initializing network variables ===============") tf.variables_initializer( var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="net") ).run() self._log.print3("Model variables were initialized.") self._log.print3("") self._log.print3("======================================================") self._log.print3("=========== Testing with the CNN model ===============") self._log.print3("======================================================\n") res_code = inferenceWholeVolumes( *( [sessionTf, cnn3d] + self._params.get_args_for_testing() ) ) self._log.print3("") self._log.print3("======================================================") self._log.print3("=========== Testing session finished =================") self._log.print3("======================================================")
def do_training(sessionTf, saver_all, cnn3d, trainer, log, fileToSaveTrainedCnnModelTo, val_on_samples_during_train, savePredictedSegmAndProbsDict, namesForSavingSegmAndProbs, suffixForSegmAndProbsDict, listOfFilepathsToEachChannelOfEachPatientTraining, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientTraining, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, paths_to_wmaps_per_sampl_cat_per_subj_train, paths_to_wmaps_per_sampl_cat_per_subj_val, listOfFilepathsToRoiMaskOfEachPatientTraining, listOfFilepathsToRoiMaskOfEachPatientValidation, n_epochs, # Every epoch the CNN model is saved. num_subepochs, # per epoch. Every subepoch Accuracy is reported max_n_cases_per_subep_train, # Max num of subjects loaded every subepoch for segments extraction. n_samples_per_subep_train, n_samples_per_subep_val, num_parallel_proc_sampling, # -1: seq. 0: thread for sampling. >0: multiprocess sampling #-------Sampling Type--------- samplingTypeInstanceTraining, # Instance of the deepmedic/samplingType.SamplingType class for training and validation samplingTypeInstanceValidation, batchsize_train, batchsize_val_samples, batchsize_val_whole, #-------Preprocessing----------- pad_input_imgs, #-------Data Augmentation------- augm_img_prms, augm_sample_prms, # Validation val_on_whole_volumes, num_epochs_between_val_on_whole_volumes, #--------For FM visualisation--------- saveIndividualFmImagesForVisualisation, saveMultidimensionalImageWithAllFms, indicesOfFmsToVisualisePerPathwayTypeAndPerLayer, namesForSavingFms, #-------- Others -------- run_input_checks ): id_str = "[MAIN|PID:"+str(os.getpid())+"]" start_time_train = time.time() # I cannot pass cnn3d to the sampling function, because the pp module used to reload theano. # This created problems in the GPU when cnmem is used. Not sure this is needed with Tensorflow. Probably. cnn3dWrapper = CnnWrapperForSampling(cnn3d) args_for_sampling_train = ( log, "train", num_parallel_proc_sampling, run_input_checks, cnn3dWrapper, max_n_cases_per_subep_train, n_samples_per_subep_train, samplingTypeInstanceTraining, listOfFilepathsToEachChannelOfEachPatientTraining, listOfFilepathsToGtLabelsOfEachPatientTraining, listOfFilepathsToRoiMaskOfEachPatientTraining, paths_to_wmaps_per_sampl_cat_per_subj_train, pad_input_imgs, augm_img_prms, augm_sample_prms ) args_for_sampling_val = ( log, "val", num_parallel_proc_sampling, run_input_checks, cnn3dWrapper, max_n_cases_per_subep_train, n_samples_per_subep_val, samplingTypeInstanceValidation, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, listOfFilepathsToRoiMaskOfEachPatientValidation, paths_to_wmaps_per_sampl_cat_per_subj_val, pad_input_imgs, None, # no augmentation in val. None ) # no augmentation in val. sampling_job_submitted_train = False sampling_job_submitted_val = False # For parallel extraction of samples for next train/val while processing previous iteration. worker_pool = None if num_parallel_proc_sampling > -1 : # Use multiprocessing. worker_pool = ThreadPool(processes=1) # Or multiprocessing.Pool(...), same API. try: model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf) while model_num_epochs_trained < n_epochs : epoch = model_num_epochs_trained acc_monitor_for_ep_train = AccuracyOfEpochMonitorSegmentation(log, 0, model_num_epochs_trained, cnn3d.num_classes, num_subepochs) acc_monitor_for_ep_val = None if not val_on_samples_during_train else \ AccuracyOfEpochMonitorSegmentation(log, 1, model_num_epochs_trained, cnn3d.num_classes, num_subepochs ) val_on_whole_volumes_after_ep = False if val_on_whole_volumes and (model_num_epochs_trained+1) % num_epochs_between_val_on_whole_volumes == 0: val_on_whole_volumes_after_ep = True log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") log.print3("~~~~~~~~~~~~~\t Starting new Epoch! Epoch #"+str(epoch)+"/"+str(n_epochs)+" \t~~~~~~~~~~~~~") log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") start_time_ep = time.time() for subepoch in range(num_subepochs): log.print3("***************************************************************************************") log.print3("*******\t\t Starting new Subepoch: #"+str(subepoch)+"/"+str(num_subepochs)+" \t\t********") log.print3("***************************************************************************************") #-------------------------GET DATA FOR THIS SUBEPOCH's VALIDATION--------------------------------- if val_on_samples_during_train : if worker_pool is None: # Sequential processing. log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [VALIDATION] will be done by main thread.") (channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal) = getSampledDataAndLabelsForSubepoch( *args_for_sampling_val ) elif sampling_job_submitted_val : #It was done in parallel with the training of the previous epoch, just grab results. (channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal) = parallelJobToGetDataForNextValidation.get() sampling_job_submitted_val = False else : # Not previously submitted in case of first epoch or after a full-volumes validation. assert subepoch == 0 log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].") parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val) (channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal) = parallelJobToGetDataForNextValidation.get() sampling_job_submitted_val = False #------------------------SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING----------------- if worker_pool is not None: log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].") parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train) sampling_job_submitted_train = True #------------------------------------DO VALIDATION-------------------------------- log.print3("-V-V-V-V-V- Now Validating for this subepoch before commencing the training iterations... -V-V-V-V-V-") start_time_val_subep = time.time() # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested. num_batches_val = len(channsOfSegmentsForSubepPerPathwayVal[0]) // batchsize_val_samples trainOrValidateForSubepoch( log, sessionTf, "val", num_batches_val, batchsize_val_samples, cnn3d, subepoch, acc_monitor_for_ep_val, channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal ) end_time_val_subep = time.time() log.print3("TIMING: Validation on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_val_subep-start_time_val_subep)+" secs.") #-------------------------GET DATA FOR THIS SUBEPOCH's TRAINING--------------------------------- if worker_pool is None: # Sequential processing. log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [TRAINING] will be done by main thread.") (channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain) = getSampledDataAndLabelsForSubepoch( *args_for_sampling_train ) elif sampling_job_submitted_train: # Sampling job should have been done in parallel with previous train/val. Just grab results. (channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain) = parallelJobToGetDataForNextTraining.get() sampling_job_submitted_train = False else: # Not previously submitted in case of first epoch or after a full-volumes validation. assert subepoch == 0 log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].") parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train) (channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain) = parallelJobToGetDataForNextTraining.get() sampling_job_submitted_train = False #------------------------SUBMIT PARALLEL JOB TO GET VALIDATION/TRAINING DATA (if val is/not performed) FOR NEXT SUBEPOCH----------------- if worker_pool is not None and not (val_on_whole_volumes_after_ep and (subepoch == num_subepochs-1)): if val_on_samples_during_train : log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].") parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val) sampling_job_submitted_val = True else : log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].") parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train) sampling_job_submitted_train = True #-------------------------------START TRAINING IN BATCHES------------------------------ log.print3("-T-T-T-T-T- Now Training for this subepoch... This may take a few minutes... -T-T-T-T-T-") start_time_train_subep = time.time() # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested. num_batches_train = len(channsOfSegmentsForSubepPerPathwayTrain[0]) // batchsize_train trainOrValidateForSubepoch( log, sessionTf, "train", num_batches_train, batchsize_train, cnn3d, subepoch, acc_monitor_for_ep_train, channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain ) end_time_train_subep = time.time() log.print3("TIMING: Training on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_train_subep-start_time_train_subep)+" secs.") log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") log.print3("~~~~~~ Epoch #" + str(epoch) + " finished. Reporting Accuracy over whole epoch. ~~~~~~~" ) log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") if val_on_samples_during_train: acc_monitor_for_ep_val.reportMeanAccyracyOfEpoch() acc_monitor_for_ep_train.reportMeanAccyracyOfEpoch() mean_val_acc_of_ep = acc_monitor_for_ep_val.getMeanEmpiricalAccuracyOfEpoch() if val_on_samples_during_train else None trainer.run_updates_end_of_ep(log, sessionTf, mean_val_acc_of_ep) # Updates LR schedule if needed, and increases number of epochs trained. model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf) del acc_monitor_for_ep_train; del acc_monitor_for_ep_val; log.print3("SAVING: Epoch #"+str(epoch)+" finished. Saving CNN model.") filename_to_save_with = fileToSaveTrainedCnnModelTo + "." + datetimeNowAsStr() saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False ) end_time_ep = time.time() log.print3("TIMING: The whole Epoch #"+str(epoch)+" lasted: {0:.1f}".format(end_time_ep-start_time_ep)+" secs.") log.print3("~~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~~~") if val_on_whole_volumes_after_ep: log.print3("***Starting validation with Full Inference / Segmentation on validation subjects for Epoch #"+str(epoch)+"...***") res_code = inferenceWholeVolumes( sessionTf, cnn3d, log, "val", savePredictedSegmAndProbsDict, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, listOfFilepathsToRoiMaskOfEachPatientValidation, namesForSavingSegmAndProbs = namesForSavingSegmAndProbs, suffixForSegmAndProbsDict = suffixForSegmAndProbsDict, # Hyper parameters batchsize = batchsize_val_whole, #----Preprocessing------ pad_input_imgs = pad_input_imgs, #--------For FM visualisation--------- saveIndividualFmImagesForVisualisation = saveIndividualFmImagesForVisualisation, saveMultidimensionalImageWithAllFms = saveMultidimensionalImageWithAllFms, indicesOfFmsToVisualisePerPathwayTypeAndPerLayer = indicesOfFmsToVisualisePerPathwayTypeAndPerLayer, namesForSavingFms = namesForSavingFms ) end_time_train = time.time() log.print3("TIMING: Training process lasted: {0:.1f}".format(end_time_train-start_time_train)+" secs.") except (Exception, KeyboardInterrupt) as e: log.print3("\n\n ERROR: Caught exception in do_training(): " + str(e) + "\n") log.print3( traceback.format_exc() ) if worker_pool is not None: log.print3("Terminating worker pool.") worker_pool.terminate() worker_pool.join() # Will wait. A KeybInt will kill this (py3) return 1 else: if worker_pool is not None: log.print3("Closing worker pool.") worker_pool.close() worker_pool.join() # Save the final trained model. filename_to_save_with = fileToSaveTrainedCnnModelTo + ".final." + datetimeNowAsStr() log.print3("Saving the final model at:" + str(filename_to_save_with)) saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False ) log.print3("The whole do_training() function has finished.") return 0
def do_training(sessionTf, saver_all, cnn3d, trainer, log, fileToSaveTrainedCnnModelTo, val_on_samples_during_train, savePredictedSegmAndProbsDict, namesForSavingSegmAndProbs, suffixForSegmAndProbsDict, listOfFilepathsToEachChannelOfEachPatientTraining, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientTraining, providedGtForValidationBool, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, providedWeightMapsToSampleForEachCategoryTraining, paths_to_wmaps_per_sampl_cat_per_subj_train, providedWeightMapsToSampleForEachCategoryValidation, paths_to_wmaps_per_sampl_cat_per_subj_val, providedRoiMaskForTrainingBool, listOfFilepathsToRoiMaskOfEachPatientTraining, # Also needed for normalization-augmentation providedRoiMaskForValidationBool, listOfFilepathsToRoiMaskOfEachPatientValidation, n_epochs, # Every epoch the CNN model is saved. num_subepochs, # per epoch. Every subepoch Accuracy is reported max_n_cases_per_subep_train, # Max num of subjects loaded every subepoch for segments extraction. n_samples_per_subep_train, n_samples_per_subep_val, num_parallel_proc_sampling, # -1: seq. 0: thread for sampling. >0: multiprocess sampling #-------Sampling Type--------- samplingTypeInstanceTraining, # Instance of the deepmedic/samplingType.SamplingType class for training and validation samplingTypeInstanceValidation, batchsize_train, batchsize_val_samples, batchsize_val_whole, #-------Preprocessing----------- padInputImagesBool, #-------Data Augmentation------- augm_params, # Validation val_on_whole_volumes, num_epochs_between_val_on_whole_volumes, #--------For FM visualisation--------- saveIndividualFmImagesForVisualisation, saveMultidimensionalImageWithAllFms, indicesOfFmsToVisualisePerPathwayTypeAndPerLayer, namesForSavingFms, #-------- Others -------- run_input_checks ): id_str = "[MAIN|PID:"+str(os.getpid())+"]" start_time_train = time.time() # I cannot pass cnn3d to the sampling function, because the pp module used to reload theano. # This created problems in the GPU when cnmem is used. Not sure this is needed with Tensorflow. Probably. cnn3dWrapper = CnnWrapperForSampling(cnn3d) args_for_sampling_train = ( log, "train", num_parallel_proc_sampling, run_input_checks, cnn3dWrapper, max_n_cases_per_subep_train, n_samples_per_subep_train, samplingTypeInstanceTraining, listOfFilepathsToEachChannelOfEachPatientTraining, listOfFilepathsToGtLabelsOfEachPatientTraining, providedRoiMaskForTrainingBool, listOfFilepathsToRoiMaskOfEachPatientTraining, providedWeightMapsToSampleForEachCategoryTraining, paths_to_wmaps_per_sampl_cat_per_subj_train, padInputImagesBool, augm_params ) args_for_sampling_val = ( log, "val", num_parallel_proc_sampling, run_input_checks, cnn3dWrapper, max_n_cases_per_subep_train, n_samples_per_subep_val, samplingTypeInstanceValidation, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, providedRoiMaskForValidationBool, listOfFilepathsToRoiMaskOfEachPatientValidation, providedWeightMapsToSampleForEachCategoryValidation, paths_to_wmaps_per_sampl_cat_per_subj_val, padInputImagesBool, None ) # no augmentation in validation. sampling_job_submitted_train = False sampling_job_submitted_val = False # For parallel extraction of samples for next train/val while processing previous iteration. worker_pool = None if num_parallel_proc_sampling > -1 : # Use multiprocessing. worker_pool = ThreadPool(processes=1) # Or multiprocessing.Pool(...), same API. try: model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf) while model_num_epochs_trained < n_epochs : epoch = model_num_epochs_trained acc_monitor_for_ep_train = AccuracyOfEpochMonitorSegmentation(log, 0, model_num_epochs_trained, cnn3d.num_classes, num_subepochs) acc_monitor_for_ep_val = None if not val_on_samples_during_train else \ AccuracyOfEpochMonitorSegmentation(log, 1, model_num_epochs_trained, cnn3d.num_classes, num_subepochs ) val_on_whole_volumes_after_ep = False if val_on_whole_volumes and (model_num_epochs_trained+1) % num_epochs_between_val_on_whole_volumes == 0: val_on_whole_volumes_after_ep = True log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") log.print3("~~~~~~~~~~~~~\t Starting new Epoch! Epoch #"+str(epoch)+"/"+str(n_epochs)+" \t~~~~~~~~~~~~~") log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") start_time_ep = time.time() for subepoch in range(num_subepochs): log.print3("***************************************************************************************") log.print3("*******\t\t Starting new Subepoch: #"+str(subepoch)+"/"+str(num_subepochs)+" \t\t********") log.print3("***************************************************************************************") #-------------------------GET DATA FOR THIS SUBEPOCH's VALIDATION--------------------------------- if val_on_samples_during_train : if worker_pool is None: # Sequential processing. log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [VALIDATION] will be done by main thread.") [channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal] = getSampledDataAndLabelsForSubepoch( *args_for_sampling_val ) elif sampling_job_submitted_val : #It was done in parallel with the training of the previous epoch, just grab results. [channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal] = parallelJobToGetDataForNextValidation.get() sampling_job_submitted_val = False else : # Not previously submitted in case of first epoch or after a full-volumes validation. assert subepoch == 0 log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].") parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val) [channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal] = parallelJobToGetDataForNextValidation.get() sampling_job_submitted_val = False #------------------------SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING----------------- if worker_pool is not None: log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].") parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train) sampling_job_submitted_train = True #------------------------------------DO VALIDATION-------------------------------- log.print3("-V-V-V-V-V- Now Validating for this subepoch before commencing the training iterations... -V-V-V-V-V-") start_time_val_subep = time.time() # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested. num_batches_val = len(channsOfSegmentsForSubepPerPathwayVal[0]) // batchsize_val_samples trainOrValidateForSubepoch( log, sessionTf, "val", num_batches_val, batchsize_val_samples, cnn3d, subepoch, acc_monitor_for_ep_val, channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal ) end_time_val_subep = time.time() log.print3("TIMING: Validation on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_val_subep-start_time_val_subep)+" secs.") #-------------------------GET DATA FOR THIS SUBEPOCH's TRAINING--------------------------------- if worker_pool is None: # Sequential processing. log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [TRAINING] will be done by main thread.") [channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain] = getSampledDataAndLabelsForSubepoch( *args_for_sampling_train ) elif sampling_job_submitted_train: # Sampling job should have been done in parallel with previous train/val. Just grab results. [channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain] = parallelJobToGetDataForNextTraining.get() sampling_job_submitted_train = False else: # Not previously submitted in case of first epoch or after a full-volumes validation. assert subepoch == 0 log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].") parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train) [channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain] = parallelJobToGetDataForNextTraining.get() sampling_job_submitted_train = False #------------------------SUBMIT PARALLEL JOB TO GET VALIDATION/TRAINING DATA (if val is/not performed) FOR NEXT SUBEPOCH----------------- if worker_pool is not None and not (val_on_whole_volumes_after_ep and (subepoch == num_subepochs-1)): if val_on_samples_during_train : log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].") parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val) sampling_job_submitted_val = True else : log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].") parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train) sampling_job_submitted_train = True #-------------------------------START TRAINING IN BATCHES------------------------------ log.print3("-T-T-T-T-T- Now Training for this subepoch... This may take a few minutes... -T-T-T-T-T-") start_time_train_subep = time.time() # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested. num_batches_train = len(channsOfSegmentsForSubepPerPathwayTrain[0]) // batchsize_train trainOrValidateForSubepoch( log, sessionTf, "train", num_batches_train, batchsize_train, cnn3d, subepoch, acc_monitor_for_ep_train, channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain ) end_time_train_subep = time.time() log.print3("TIMING: Training on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_train_subep-start_time_train_subep)+" secs.") log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") log.print3("~~~~~~ Epoch #" + str(epoch) + " finished. Reporting Accuracy over whole epoch. ~~~~~~~" ) log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") if val_on_samples_during_train: acc_monitor_for_ep_val.reportMeanAccyracyOfEpoch() acc_monitor_for_ep_train.reportMeanAccyracyOfEpoch() mean_val_acc_of_ep = acc_monitor_for_ep_val.getMeanEmpiricalAccuracyOfEpoch() if val_on_samples_during_train else None trainer.run_updates_end_of_ep(log, sessionTf, mean_val_acc_of_ep) # Updates LR schedule if needed, and increases number of epochs trained. model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf) del acc_monitor_for_ep_train; del acc_monitor_for_ep_val; log.print3("SAVING: Epoch #"+str(epoch)+" finished. Saving CNN model.") filename_to_save_with = fileToSaveTrainedCnnModelTo + "." + datetimeNowAsStr() saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False ) end_time_ep = time.time() log.print3("TIMING: The whole Epoch #"+str(epoch)+" lasted: {0:.1f}".format(end_time_ep-start_time_ep)+" secs.") log.print3("~~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~~~") if val_on_whole_volumes_after_ep: log.print3("***Starting validation with Full Inference / Segmentation on validation subjects for Epoch #"+str(epoch)+"...***") res_code = inferenceWholeVolumes( sessionTf, cnn3d, log, "val", savePredictedSegmAndProbsDict, listOfFilepathsToEachChannelOfEachPatientValidation, providedGtForValidationBool, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, providedRoiMaskForValidationBool, listOfFilepathsToRoiMaskOfEachPatientValidation, namesForSavingSegmAndProbs = namesForSavingSegmAndProbs, suffixForSegmAndProbsDict = suffixForSegmAndProbsDict, # Hyper parameters batchsize = batchsize_val_whole, #----Preprocessing------ padInputImagesBool=padInputImagesBool, #--------For FM visualisation--------- saveIndividualFmImagesForVisualisation=saveIndividualFmImagesForVisualisation, saveMultidimensionalImageWithAllFms=saveMultidimensionalImageWithAllFms, indicesOfFmsToVisualisePerPathwayTypeAndPerLayer=indicesOfFmsToVisualisePerPathwayTypeAndPerLayer, namesForSavingFms=namesForSavingFms ) end_time_train = time.time() log.print3("TIMING: Training process lasted: {0:.1f}".format(end_time_train-start_time_train)+" secs.") except (Exception, KeyboardInterrupt) as e: log.print3("\n\n ERROR: Caught exception in do_training(): " + str(e) + "\n") log.print3( traceback.format_exc() ) if worker_pool is not None: log.print3("Terminating worker pool.") worker_pool.terminate() worker_pool.join() # Will wait. A KeybInt will kill this (py3) return 1 else: if worker_pool is not None: log.print3("Closing worker pool.") worker_pool.close() worker_pool.join() # Save the final trained model. filename_to_save_with = fileToSaveTrainedCnnModelTo + ".final." + datetimeNowAsStr() log.print3("Saving the final model at:" + str(filename_to_save_with)) saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False ) log.print3("The whole do_training() function has finished.") return 0
def run_session(self, *args): ( sess_device, model_params, ) = args graphTf = tf.Graph() with graphTf.as_default(): with graphTf.device( sess_device ): # Throws an error if GPU is specified but not available. self._log.print3( "=========== Making the CNN graph... ===============") cnn3d = Cnn3d() with tf.variable_scope("net"): cnn3d.make_cnn_model(*model_params.get_args_for_arch( )) # Creates the network's graph (without optimizer). self._log.print3( "=========== Compiling the Testing Function ============") self._log.print3( "=======================================================\n") cnn3d.setup_ops_n_feeds_to_test( self._log, self._params.indices_fms_per_pathtype_per_layer_to_save) # Create the saver saver_all = tf.train.Saver() # saver_net would suffice with tf.Session(graph=graphTf, config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, device_count={ 'CPU': 999, 'GPU': 99 })) as sessionTf: # ZYwith tf.Session( graph=graphTf, config=tf.ConfigProto(log_device_placement=False, device_count={'CPU':999, 'GPU':99}) ) as sessionTf: file_to_load_params_from = self._params.get_path_to_load_model_from( ) if file_to_load_params_from is not None: # Load params self._log.print3( "=========== Loading parameters from specified saved model ===============" ) chkpt_fname = tf.train.latest_checkpoint( file_to_load_params_from) if os.path.isdir( file_to_load_params_from) else file_to_load_params_from self._log.print3("Loading parameters from:" + str(chkpt_fname)) try: saver_all.restore(sessionTf, chkpt_fname) self._log.print3("Parameters were loaded.") except Exception as e: handle_exception_tf_restore(self._log, e) else: self._ask_user_if_test_with_random( ) # Asks user whether to continue with randomly initialized model. It exits if no is given. self._log.print3("") self._log.print3( "=========== Initializing network variables ===============" ) tf.variables_initializer(var_list=tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope="net")).run() self._log.print3("Model variables were initialized.") self._log.print3("") self._log.print3( "======================================================") self._log.print3( "=========== Testing with the CNN model ===============") self._log.print3( "======================================================\n") res_code = inferenceWholeVolumes( *([sessionTf, cnn3d] + self._params.get_args_for_testing())) self._log.print3("") self._log.print3( "======================================================") self._log.print3( "=========== Testing session finished =================") self._log.print3( "======================================================")