def do_training(sessionTf, saver_all, cnn3d, trainer, tensorboard_loggers, log, fileToSaveTrainedCnnModelTo, val_on_samples, savePredictedSegmAndProbsDict, namesForSavingSegmAndProbs, suffixForSegmAndProbsDict, paths_per_chan_per_subj_train, paths_per_chan_per_subj_val, paths_to_lbls_per_subj_train, paths_to_lbls_per_subj_val, paths_to_wmaps_per_sampl_cat_per_subj_train, paths_to_wmaps_per_sampl_cat_per_subj_val, paths_to_masks_per_subj_train, paths_to_masks_per_subj_val, n_epochs, # Every epoch the CNN model is saved. n_subepochs, # per epoch. Every subepoch Accuracy is reported max_n_cases_per_subep_train, # Max num of subjects loaded every subep for sampling n_samples_per_subep_train, n_samples_per_subep_val, num_parallel_proc_sampling, # -1: seq. 0: thread for sampling. >0: multiprocess sampling # -------Sampling Type--------- sampling_type_inst_tr, # Instance of the deepmedic/samplingType.SamplingType class for training and validation sampling_type_inst_val, batchsize_train, batchsize_val_samples, batchsize_val_whole, # -------Data Augmentation------- augm_img_prms, augm_sample_prms, # Validation val_on_whole_volumes, n_epochs_between_val_on_whole_vols, # --------For FM visualisation--------- save_fms_flag, idxs_fms_to_save, namesForSavingFms, # --- Data Compatibility Checks --- run_input_checks, # -------- Pre-processing ------ pad_input, norm_prms ): id_str = "[MAIN|PID:" + str(os.getpid()) + "]" start_time_train = time.time() # I cannot pass cnn3d to the sampling function, because the pp module used to reload theano. # This created problems in the GPU when cnmem is used. Not sure this is needed with Tensorflow. Probably. cnn3dWrapper = CnnWrapperForSampling(cnn3d) args_for_sampling_tr = (log, "train", num_parallel_proc_sampling, run_input_checks, cnn3dWrapper, max_n_cases_per_subep_train, n_samples_per_subep_train, sampling_type_inst_tr, paths_per_chan_per_subj_train, paths_to_lbls_per_subj_train, paths_to_masks_per_subj_train, paths_to_wmaps_per_sampl_cat_per_subj_train, pad_input, norm_prms, augm_img_prms, augm_sample_prms ) args_for_sampling_val = (log, "val", num_parallel_proc_sampling, run_input_checks, cnn3dWrapper, max_n_cases_per_subep_train, n_samples_per_subep_val, sampling_type_inst_val, paths_per_chan_per_subj_val, paths_to_lbls_per_subj_val, paths_to_masks_per_subj_val, paths_to_wmaps_per_sampl_cat_per_subj_val, pad_input, norm_prms, None, # no augmentation in val. None # no augmentation in val. ) sampling_job_submitted_train = False sampling_job_submitted_val = False # For parallel extraction of samples for next train/val while processing previous iteration. mp_pool = None if num_parallel_proc_sampling > -1: # Use multiprocessing. mp_pool = ThreadPool(processes=1) # Or multiprocessing.Pool(...), same API. try: n_eps_trained_model = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf) while n_eps_trained_model < n_epochs: epoch = n_eps_trained_model tb_log_tr = tensorboard_loggers['train'] if tensorboard_loggers is not None else None acc_monitor_ep_tr = AccuracyMonitorForEpSegm(log, 0, n_eps_trained_model, cnn3d.num_classes, n_subepochs, tb_log_tr) tb_log_val = tensorboard_loggers['val'] if tensorboard_loggers is not None else None acc_monitor_ep_val = None if val_on_samples or val_on_whole_volumes: acc_monitor_ep_val = AccuracyMonitorForEpSegm(log, 1, n_eps_trained_model, cnn3d.num_classes, n_subepochs, tb_log_val) val_on_whole_vols_after_this_ep = False if val_on_whole_volumes and (n_eps_trained_model + 1) % n_epochs_between_val_on_whole_vols == 0: val_on_whole_vols_after_this_ep = True log.print3("") log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") log.print3("~~\t\t\t Starting new Epoch! Epoch #" + str(epoch) + "/" + str(n_epochs) + " \t\t\t~~") log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") start_time_ep = time.time() for subep in range(n_subepochs): log.print3("") log.print3("***********************************************************************************") log.print3("*\t\t\t Starting new Subepoch: #" + str(subep) + "/" + str(n_subepochs) + " \t\t\t*") log.print3("***********************************************************************************") # -------------------- GET DATA FOR THIS SUBEPOCH's VALIDATION ----------------------- if val_on_samples: if mp_pool is None: # Sequential processing. log.print3(id_str + " NO MULTIPROC: Sampling for subepoch #" + str(subep) +\ " [VALIDATION] will be done by main thread.") (channs_samples_per_path_val, lbls_samples_per_path_val) = get_samples_for_subepoch(*args_for_sampling_val) elif sampling_job_submitted_val: # done parallel with training of previous epoch. (channs_samples_per_path_val, lbls_samples_per_path_val) = sampling_job_val.get() sampling_job_submitted_val = False else: # Not previously submitted in case of first epoch or after a full-volumes validation. assert subep == 0 log.print3(id_str + " MULTIPROC: Before Validation in subepoch #" + str(subep) +\ ", submitting sampling job for next [VALIDATION].") sampling_job_val = mp_pool.apply_async(get_samples_for_subepoch, args_for_sampling_val) (channs_samples_per_path_val, lbls_samples_per_path_val) = sampling_job_val.get() sampling_job_submitted_val = False # ----------- SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING ----------------- if mp_pool is not None: log.print3(id_str + " MULTIPROC: Before Validation in subepoch #" + str(subep) +\ ", submitting sampling job for next [TRAINING].") sampling_job_tr = mp_pool.apply_async(get_samples_for_subepoch, args_for_sampling_tr) sampling_job_submitted_train = True # ------------------------------------DO VALIDATION-------------------------------- log.print3("V-V-V-V- Validating for subepoch before starting training iterations -V-V-V-V") start_time_val_subep = time.time() # Calc num of batches from extracted samples, in case not extracted as much as requested. n_batches_val = len(channs_samples_per_path_val[0]) // batchsize_val_samples process_in_batches(log, sessionTf, "val", n_batches_val, batchsize_val_samples, cnn3d, acc_monitor_ep_val, channs_samples_per_path_val, lbls_samples_per_path_val) log.print3("TIMING: Validation on batches of subepoch #" + str(subep) +\ " lasted: {0:.1f}".format(time.time() - start_time_val_subep) + " secs.") # ----------------------- GET DATA FOR THIS SUBEPOCH's TRAINING ------------------------------ if mp_pool is None: # Sequential processing. log.print3(id_str + " NO MULTIPROC: Sampling for subepoch #" + str(subep) +\ " [TRAINING] will be done by main thread.") (channs_samples_per_path_tr, lbls_samples_per_path_tr) = get_samples_for_subepoch(*args_for_sampling_tr) elif sampling_job_submitted_train: # done parallel with train/val of previous epoch. (channs_samples_per_path_tr, lbls_samples_per_path_tr) = sampling_job_tr.get() sampling_job_submitted_train = False else: # Not previously submitted in case of first epoch or after a full-volumes validation. assert subep == 0 log.print3(id_str + " MULTIPROC: Before Training in subepoch #" + str(subep) +\ ", submitting sampling job for next [TRAINING].") sampling_job_tr = mp_pool.apply_async(get_samples_for_subepoch, args_for_sampling_tr) (channs_samples_per_path_tr, lbls_samples_per_path_tr) = sampling_job_tr.get() sampling_job_submitted_train = False # ----- SUBMIT PARALLEL JOB TO GET VAL / TRAIN (if no val) DATA FOR NEXT SUBEPOCH ----- if mp_pool is not None and not (val_on_whole_vols_after_this_ep and (subep == n_subepochs - 1)): if val_on_samples: log.print3(id_str + " MULTIPROC: Before Training in subepoch #" + str(subep) +\ ", submitting sampling job for next [VALIDATION].") sampling_job_val = mp_pool.apply_async(get_samples_for_subepoch, args_for_sampling_val) sampling_job_submitted_val = True else: log.print3(id_str + " MULTIPROC: Before Training in subepoch #" + str(subep) +\ ", submitting sampling job for next [TRAINING].") sampling_job_tr = mp_pool.apply_async(get_samples_for_subepoch, args_for_sampling_tr) sampling_job_submitted_train = True # ------------------------------ START TRAINING IN BATCHES ----------------------------- log.print3("-T-T-T-T- Training for this subepoch... May take a few minutes... -T-T-T-T-") start_time_train_subep = time.time() # Calc num of batches from extracted samples, in case not extracted as much as requested. n_batches_train = len(channs_samples_per_path_tr[0]) // batchsize_train process_in_batches(log, sessionTf, "train", n_batches_train, batchsize_train, cnn3d, acc_monitor_ep_tr, channs_samples_per_path_tr, lbls_samples_per_path_tr) log.print3("TIMING: Training on batches of this subepoch #" + str(subep) +\ " lasted: {0:.1f}".format(time.time() - start_time_train_subep) + " secs.") log.print3("") log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") log.print3("~~~~~~ Epoch #" + str(epoch) + " finished. Reporting Accuracy over whole epoch. ~~~~~~~") log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") if val_on_samples: acc_monitor_ep_val.report_metrics_samples_ep() acc_monitor_ep_tr.report_metrics_samples_ep() mean_val_acc_of_ep = acc_monitor_ep_val.get_avg_accuracy_ep() if val_on_samples else None # Updates LR schedule if needed, and increases number of epochs trained. trainer.run_updates_end_of_ep(log, sessionTf, mean_val_acc_of_ep) n_eps_trained_model = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf) log.print3("SAVING: Epoch #" + str(epoch) + " finished. Saving CNN model.") filename_to_save_with = fileToSaveTrainedCnnModelTo + "." + datetime_now_str() saver_all.save(sessionTf, filename_to_save_with + ".model.ckpt", write_meta_graph=False) log.print3("TIMING: Whole Epoch #" + str(epoch) +\ " lasted: {0:.1f}".format(time.time() - start_time_ep) + " secs.") log.print3("~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~") if val_on_whole_vols_after_this_ep: log.print3("***Start validation by segmenting whole subjects for Epoch #" + str(epoch) + "***") mean_metrics_val_whole_vols = inference_on_whole_volumes(sessionTf, cnn3d, log, "val", savePredictedSegmAndProbsDict, paths_per_chan_per_subj_val, paths_to_lbls_per_subj_val, paths_to_masks_per_subj_val, namesForSavingSegmAndProbs, suffixForSegmAndProbsDict, # Hyper parameters batchsize_val_whole, # Data compatibility checks run_input_checks, # Pre-Processing pad_input, norm_prms, # Saving feature maps save_fms_flag, idxs_fms_to_save, namesForSavingFms) acc_monitor_ep_val.report_metrics_whole_vols(mean_metrics_val_whole_vols) del acc_monitor_ep_tr del acc_monitor_ep_val log.print3("TIMING: Training process lasted: {0:.1f}".format(time.time() - start_time_train) + " secs.") except (Exception, KeyboardInterrupt) as e: log.print3("\n\n ERROR: Caught exception in do_training(): " + str(e) + "\n") log.print3(traceback.format_exc()) if mp_pool is not None: log.print3("Terminating worker pool.") mp_pool.terminate() mp_pool.join() # Will wait. A KeybInt will kill this (py3) return 1 else: if mp_pool is not None: log.print3("Closing worker pool.") mp_pool.close() mp_pool.join() # Save the final trained model. filename_to_save_with = fileToSaveTrainedCnnModelTo + ".final." + datetime_now_str() log.print3("Saving the final model at:" + str(filename_to_save_with)) saver_all.save(sessionTf, filename_to_save_with + ".model.ckpt", write_meta_graph=False) log.print3("The whole do_training() function has finished.") return 0
def do_training(sessionTf, saver_all, cnn3d, trainer, log, fileToSaveTrainedCnnModelTo, val_on_samples_during_train, savePredictedSegmAndProbsDict, namesForSavingSegmAndProbs, suffixForSegmAndProbsDict, listOfFilepathsToEachChannelOfEachPatientTraining, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientTraining, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, paths_to_wmaps_per_sampl_cat_per_subj_train, paths_to_wmaps_per_sampl_cat_per_subj_val, listOfFilepathsToRoiMaskOfEachPatientTraining, listOfFilepathsToRoiMaskOfEachPatientValidation, n_epochs, # Every epoch the CNN model is saved. num_subepochs, # per epoch. Every subepoch Accuracy is reported max_n_cases_per_subep_train, # Max num of subjects loaded every subepoch for segments extraction. n_samples_per_subep_train, n_samples_per_subep_val, num_parallel_proc_sampling, # -1: seq. 0: thread for sampling. >0: multiprocess sampling #-------Sampling Type--------- samplingTypeInstanceTraining, # Instance of the deepmedic/samplingType.SamplingType class for training and validation samplingTypeInstanceValidation, batchsize_train, batchsize_val_samples, batchsize_val_whole, #-------Preprocessing----------- pad_input_imgs, #-------Data Augmentation------- augm_img_prms, augm_sample_prms, # Validation val_on_whole_volumes, num_epochs_between_val_on_whole_volumes, #--------For FM visualisation--------- saveIndividualFmImagesForVisualisation, saveMultidimensionalImageWithAllFms, indicesOfFmsToVisualisePerPathwayTypeAndPerLayer, namesForSavingFms, #-------- Others -------- run_input_checks ): id_str = "[MAIN|PID:"+str(os.getpid())+"]" start_time_train = time.time() # I cannot pass cnn3d to the sampling function, because the pp module used to reload theano. # This created problems in the GPU when cnmem is used. Not sure this is needed with Tensorflow. Probably. cnn3dWrapper = CnnWrapperForSampling(cnn3d) args_for_sampling_train = ( log, "train", num_parallel_proc_sampling, run_input_checks, cnn3dWrapper, max_n_cases_per_subep_train, n_samples_per_subep_train, samplingTypeInstanceTraining, listOfFilepathsToEachChannelOfEachPatientTraining, listOfFilepathsToGtLabelsOfEachPatientTraining, listOfFilepathsToRoiMaskOfEachPatientTraining, paths_to_wmaps_per_sampl_cat_per_subj_train, pad_input_imgs, augm_img_prms, augm_sample_prms ) args_for_sampling_val = ( log, "val", num_parallel_proc_sampling, run_input_checks, cnn3dWrapper, max_n_cases_per_subep_train, n_samples_per_subep_val, samplingTypeInstanceValidation, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, listOfFilepathsToRoiMaskOfEachPatientValidation, paths_to_wmaps_per_sampl_cat_per_subj_val, pad_input_imgs, None, # no augmentation in val. None ) # no augmentation in val. sampling_job_submitted_train = False sampling_job_submitted_val = False # For parallel extraction of samples for next train/val while processing previous iteration. worker_pool = None if num_parallel_proc_sampling > -1 : # Use multiprocessing. worker_pool = ThreadPool(processes=1) # Or multiprocessing.Pool(...), same API. try: model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf) while model_num_epochs_trained < n_epochs : epoch = model_num_epochs_trained acc_monitor_for_ep_train = AccuracyOfEpochMonitorSegmentation(log, 0, model_num_epochs_trained, cnn3d.num_classes, num_subepochs) acc_monitor_for_ep_val = None if not val_on_samples_during_train else \ AccuracyOfEpochMonitorSegmentation(log, 1, model_num_epochs_trained, cnn3d.num_classes, num_subepochs ) val_on_whole_volumes_after_ep = False if val_on_whole_volumes and (model_num_epochs_trained+1) % num_epochs_between_val_on_whole_volumes == 0: val_on_whole_volumes_after_ep = True log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") log.print3("~~~~~~~~~~~~~\t Starting new Epoch! Epoch #"+str(epoch)+"/"+str(n_epochs)+" \t~~~~~~~~~~~~~") log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") start_time_ep = time.time() for subepoch in range(num_subepochs): log.print3("***************************************************************************************") log.print3("*******\t\t Starting new Subepoch: #"+str(subepoch)+"/"+str(num_subepochs)+" \t\t********") log.print3("***************************************************************************************") #-------------------------GET DATA FOR THIS SUBEPOCH's VALIDATION--------------------------------- if val_on_samples_during_train : if worker_pool is None: # Sequential processing. log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [VALIDATION] will be done by main thread.") (channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal) = getSampledDataAndLabelsForSubepoch( *args_for_sampling_val ) elif sampling_job_submitted_val : #It was done in parallel with the training of the previous epoch, just grab results. (channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal) = parallelJobToGetDataForNextValidation.get() sampling_job_submitted_val = False else : # Not previously submitted in case of first epoch or after a full-volumes validation. assert subepoch == 0 log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].") parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val) (channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal) = parallelJobToGetDataForNextValidation.get() sampling_job_submitted_val = False #------------------------SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING----------------- if worker_pool is not None: log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].") parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train) sampling_job_submitted_train = True #------------------------------------DO VALIDATION-------------------------------- log.print3("-V-V-V-V-V- Now Validating for this subepoch before commencing the training iterations... -V-V-V-V-V-") start_time_val_subep = time.time() # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested. num_batches_val = len(channsOfSegmentsForSubepPerPathwayVal[0]) // batchsize_val_samples trainOrValidateForSubepoch( log, sessionTf, "val", num_batches_val, batchsize_val_samples, cnn3d, subepoch, acc_monitor_for_ep_val, channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal ) end_time_val_subep = time.time() log.print3("TIMING: Validation on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_val_subep-start_time_val_subep)+" secs.") #-------------------------GET DATA FOR THIS SUBEPOCH's TRAINING--------------------------------- if worker_pool is None: # Sequential processing. log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [TRAINING] will be done by main thread.") (channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain) = getSampledDataAndLabelsForSubepoch( *args_for_sampling_train ) elif sampling_job_submitted_train: # Sampling job should have been done in parallel with previous train/val. Just grab results. (channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain) = parallelJobToGetDataForNextTraining.get() sampling_job_submitted_train = False else: # Not previously submitted in case of first epoch or after a full-volumes validation. assert subepoch == 0 log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].") parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train) (channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain) = parallelJobToGetDataForNextTraining.get() sampling_job_submitted_train = False #------------------------SUBMIT PARALLEL JOB TO GET VALIDATION/TRAINING DATA (if val is/not performed) FOR NEXT SUBEPOCH----------------- if worker_pool is not None and not (val_on_whole_volumes_after_ep and (subepoch == num_subepochs-1)): if val_on_samples_during_train : log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].") parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val) sampling_job_submitted_val = True else : log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].") parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train) sampling_job_submitted_train = True #-------------------------------START TRAINING IN BATCHES------------------------------ log.print3("-T-T-T-T-T- Now Training for this subepoch... This may take a few minutes... -T-T-T-T-T-") start_time_train_subep = time.time() # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested. num_batches_train = len(channsOfSegmentsForSubepPerPathwayTrain[0]) // batchsize_train trainOrValidateForSubepoch( log, sessionTf, "train", num_batches_train, batchsize_train, cnn3d, subepoch, acc_monitor_for_ep_train, channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain ) end_time_train_subep = time.time() log.print3("TIMING: Training on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_train_subep-start_time_train_subep)+" secs.") log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") log.print3("~~~~~~ Epoch #" + str(epoch) + " finished. Reporting Accuracy over whole epoch. ~~~~~~~" ) log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") if val_on_samples_during_train: acc_monitor_for_ep_val.reportMeanAccyracyOfEpoch() acc_monitor_for_ep_train.reportMeanAccyracyOfEpoch() mean_val_acc_of_ep = acc_monitor_for_ep_val.getMeanEmpiricalAccuracyOfEpoch() if val_on_samples_during_train else None trainer.run_updates_end_of_ep(log, sessionTf, mean_val_acc_of_ep) # Updates LR schedule if needed, and increases number of epochs trained. model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf) del acc_monitor_for_ep_train; del acc_monitor_for_ep_val; log.print3("SAVING: Epoch #"+str(epoch)+" finished. Saving CNN model.") filename_to_save_with = fileToSaveTrainedCnnModelTo + "." + datetimeNowAsStr() saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False ) end_time_ep = time.time() log.print3("TIMING: The whole Epoch #"+str(epoch)+" lasted: {0:.1f}".format(end_time_ep-start_time_ep)+" secs.") log.print3("~~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~~~") if val_on_whole_volumes_after_ep: log.print3("***Starting validation with Full Inference / Segmentation on validation subjects for Epoch #"+str(epoch)+"...***") res_code = inferenceWholeVolumes( sessionTf, cnn3d, log, "val", savePredictedSegmAndProbsDict, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, listOfFilepathsToRoiMaskOfEachPatientValidation, namesForSavingSegmAndProbs = namesForSavingSegmAndProbs, suffixForSegmAndProbsDict = suffixForSegmAndProbsDict, # Hyper parameters batchsize = batchsize_val_whole, #----Preprocessing------ pad_input_imgs = pad_input_imgs, #--------For FM visualisation--------- saveIndividualFmImagesForVisualisation = saveIndividualFmImagesForVisualisation, saveMultidimensionalImageWithAllFms = saveMultidimensionalImageWithAllFms, indicesOfFmsToVisualisePerPathwayTypeAndPerLayer = indicesOfFmsToVisualisePerPathwayTypeAndPerLayer, namesForSavingFms = namesForSavingFms ) end_time_train = time.time() log.print3("TIMING: Training process lasted: {0:.1f}".format(end_time_train-start_time_train)+" secs.") except (Exception, KeyboardInterrupt) as e: log.print3("\n\n ERROR: Caught exception in do_training(): " + str(e) + "\n") log.print3( traceback.format_exc() ) if worker_pool is not None: log.print3("Terminating worker pool.") worker_pool.terminate() worker_pool.join() # Will wait. A KeybInt will kill this (py3) return 1 else: if worker_pool is not None: log.print3("Closing worker pool.") worker_pool.close() worker_pool.join() # Save the final trained model. filename_to_save_with = fileToSaveTrainedCnnModelTo + ".final." + datetimeNowAsStr() log.print3("Saving the final model at:" + str(filename_to_save_with)) saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False ) log.print3("The whole do_training() function has finished.") return 0
def do_training(myLogger, fileToSaveTrainedCnnModelTo, cnn3dInstance, performValidationOnSamplesDuringTrainingProcessBool, #REQUIRED FOR AUTO SCHEDULE. savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation, listOfNamesToGiveToPredictionsValidationIfSavingWhenEvalDice, listOfFilepathsToEachChannelOfEachPatientTraining, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientTraining, providedGtForValidationBool, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, providedWeightMapsToSampleForEachCategoryTraining, forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining, providedWeightMapsToSampleForEachCategoryValidation, forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation, providedRoiMaskForTrainingBool, listOfFilepathsToRoiMaskOfEachPatientTraining, # Also needed for normalization-augmentation providedRoiMaskForValidationBool, listOfFilepathsToRoiMaskOfEachPatientValidation, borrowFlag, n_epochs, # Every epoch the CNN model is saved. number_of_subepochs, # per epoch. Every subepoch Accuracy is reported maxNumSubjectsLoadedPerSubepoch, # Max num of cases loaded every subepoch for segments extraction. The more, the longer loading. imagePartsLoadedInGpuPerSubepoch, imagePartsLoadedInGpuPerSubepochValidation, #-------Sampling Type--------- samplingTypeInstanceTraining, # Instance of the deepmedic/samplingType.SamplingType class for training and validation samplingTypeInstanceValidation, #-------Preprocessing----------- padInputImagesBool, smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage, #-------Data Augmentation------- normAugmNone0OnImages1OrSegments2AlreadyNormalized1SubtrUpToPropOfStdAndDivideWithUpToPerc, reflectImageWithHalfProbDuringTraining, useSameSubChannelsAsSingleScale, listOfFilepathsToEachSubsampledChannelOfEachPatientTraining, # deprecated, not supported listOfFilepathsToEachSubsampledChannelOfEachPatientValidation, # deprecated, not supported #Learning Rate Schedule: lowerLrByStable0orAuto1orPredefined2orExponential3Schedule, minIncreaseInValidationAccuracyConsideredForLrSchedule, numEpochsToWaitBeforeLowerLR, divideLrBy, lowerLrAtTheEndOfTheseEpochsPredefinedScheduleList, exponentialScheduleForLrAndMom, #Weighting Classes differently in the CNN's cost function during training: numberOfEpochsToWeightTheClassesInTheCostFunction, performFullInferenceOnValidationImagesEveryFewEpochsBool, #Even if not providedGtForValidationBool, inference will be performed if this == True, to save the results, eg for visual. everyThatManyEpochsComputeDiceOnTheFullValidationImages=1, # Should not be == 0, except if performFullInferenceOnValidationImagesEveryFewEpochsBool == False #--------For FM visualisation--------- saveIndividualFmImagesForVisualisation=False, saveMultidimensionalImageWithAllFms=False, indicesOfFmsToVisualisePerPathwayTypeAndPerLayer="placeholder", listOfNamesToGiveToFmVisualisationsIfSaving="placeholder" ): start_training_time = time.clock() # Used because I cannot pass cnn3dInstance to the sampling function. #This is because the parallel process then loads theano again. And creates problems in the GPU when cnmem is used. cnn3dWrapper = CnnWrapperForSampling(cnn3dInstance) #---------To run PARALLEL the extraction of parts for the next subepoch--- ppservers = () # tuple of all parallel python servers to connect with job_server = pp.Server(ncpus=1, ppservers=ppservers) # Creates jobserver with automatically detected number of workers tupleWithParametersForTraining = (myLogger, 0, cnn3dWrapper, maxNumSubjectsLoadedPerSubepoch, imagePartsLoadedInGpuPerSubepoch, samplingTypeInstanceTraining, listOfFilepathsToEachChannelOfEachPatientTraining, listOfFilepathsToGtLabelsOfEachPatientTraining, providedRoiMaskForTrainingBool, listOfFilepathsToRoiMaskOfEachPatientTraining, providedWeightMapsToSampleForEachCategoryTraining, forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining, useSameSubChannelsAsSingleScale, listOfFilepathsToEachSubsampledChannelOfEachPatientTraining, padInputImagesBool, smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage, normAugmNone0OnImages1OrSegments2AlreadyNormalized1SubtrUpToPropOfStdAndDivideWithUpToPerc, reflectImageWithHalfProbDuringTraining ) tupleWithParametersForValidation = (myLogger, 1, cnn3dWrapper, maxNumSubjectsLoadedPerSubepoch, imagePartsLoadedInGpuPerSubepochValidation, samplingTypeInstanceValidation, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, providedRoiMaskForValidationBool, listOfFilepathsToRoiMaskOfEachPatientValidation, providedWeightMapsToSampleForEachCategoryValidation, forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation, useSameSubChannelsAsSingleScale, listOfFilepathsToEachSubsampledChannelOfEachPatientValidation, padInputImagesBool, smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage, [0, -1,-1,-1], #don't perform intensity-augmentation during validation. [0,0,0] #don't perform reflection-augmentation during validation. ) tupleWithLocalFunctionsThatWillBeCalledByTheMainJob = ( ) tupleWithModulesToImportWhichAreUsedByTheJobFunctions = ( "from __future__ import absolute_import, print_function, division", "from six.moves import xrange", "time", "numpy as np", "from deepmedic.dataManagement.sampling import *" ) boolItIsTheVeryFirstSubepochOfThisProcess = True #to know so that in the very first I sequencially load the data for it. #------End for parallel------ while cnn3dInstance.numberOfEpochsTrained < n_epochs : epoch = cnn3dInstance.numberOfEpochsTrained trainingAccuracyMonitorForEpoch = AccuracyOfEpochMonitorSegmentation(myLogger, 0, cnn3dInstance.numberOfEpochsTrained, cnn3dInstance.numberOfOutputClasses, number_of_subepochs) validationAccuracyMonitorForEpoch = None if not performValidationOnSamplesDuringTrainingProcessBool else \ AccuracyOfEpochMonitorSegmentation(myLogger, 1, cnn3dInstance.numberOfEpochsTrained, cnn3dInstance.numberOfOutputClasses, number_of_subepochs ) myLogger.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") myLogger.print3("~~~~~~~~~~~~~~~~~~~~Starting new Epoch! Epoch #"+str(epoch)+"/"+str(n_epochs)+" ~~~~~~~~~~~~~~~~~~~~~~~~~") myLogger.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") start_epoch_time = time.clock() for subepoch in xrange(number_of_subepochs): #per subepoch I randomly load some images in the gpu. Random order. myLogger.print3("**************************************************************************************************") myLogger.print3("************* Starting new Subepoch: #"+str(subepoch)+"/"+str(number_of_subepochs)+" *************") myLogger.print3("**************************************************************************************************") #-------------------------GET DATA FOR THIS SUBEPOCH's VALIDATION--------------------------------- if performValidationOnSamplesDuringTrainingProcessBool : if boolItIsTheVeryFirstSubepochOfThisProcess : [channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal] = getSampledDataAndLabelsForSubepoch(myLogger, 1, cnn3dWrapper, maxNumSubjectsLoadedPerSubepoch, imagePartsLoadedInGpuPerSubepochValidation, samplingTypeInstanceValidation, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, providedRoiMaskForValidationBool, listOfFilepathsToRoiMaskOfEachPatientValidation, providedWeightMapsToSampleForEachCategoryValidation, forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation, useSameSubChannelsAsSingleScale, listOfFilepathsToEachSubsampledChannelOfEachPatientValidation, padInputImagesBool, smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage, normAugmNone0OnImages1OrSegments2AlreadyNormalized1SubtrUpToPropOfStdAndDivideWithUpToPerc=[0,-1,-1,-1], reflectImageWithHalfProbDuringTraining = [0,0,0] ) boolItIsTheVeryFirstSubepochOfThisProcess = False else : #It was done in parallel with the training of the previous epoch, just grab the results... [channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal] = parallelJobToGetDataForNextValidation() #fromParallelProcessing that had started from last loop when it was submitted. #------------------------------LOAD DATA FOR VALIDATION---------------------- myLogger.print3("Loading Validation data for subepoch #"+str(subepoch)+" on shared variable...") start_loadingToGpu_time = time.clock() numberOfBatchesValidation = len(channsOfSegmentsForSubepPerPathwayVal[0]) // cnn3dInstance.batchSizeValidation #Computed with number of extracted samples, in case I dont manage to extract as many as I wanted initially. myLogger.print3("DEBUG: For Validation, loading to shared variable that many Segments: " + str(len(channsOfSegmentsForSubepPerPathwayVal[0]))) cnn3dInstance.sharedInpXVal.set_value(channsOfSegmentsForSubepPerPathwayVal[0], borrow=borrowFlag) # Primary pathway for index in xrange(len(channsOfSegmentsForSubepPerPathwayVal[1:])) : cnn3dInstance.sharedInpXPerSubsListVal[index].set_value(channsOfSegmentsForSubepPerPathwayVal[1+index], borrow=borrowFlag) cnn3dInstance.sharedLabelsYVal.set_value(labelsForCentralOfSegmentsForSubepVal, borrow=borrowFlag) channsOfSegmentsForSubepPerPathwayVal = "" labelsForCentralOfSegmentsForSubepVal = "" end_loadingToGpu_time = time.clock() myLogger.print3("TIMING: Loading sharedVariables for Validation in epoch|subepoch="+str(epoch)+"|"+str(subepoch)+" took time: "+str(end_loadingToGpu_time-start_loadingToGpu_time)+"(s)") #------------------------SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING----------------- #submit the parallel job myLogger.print3("PARALLEL: Before Validation in subepoch #" +str(subepoch) + ", the parallel job for extracting Segments for the next Training is submitted.") parallelJobToGetDataForNextTraining = job_server.submit(getSampledDataAndLabelsForSubepoch, #local function to call and execute in parallel. tupleWithParametersForTraining, #tuple with the arguments required tupleWithLocalFunctionsThatWillBeCalledByTheMainJob, #tuple of local functions that I need to call tupleWithModulesToImportWhichAreUsedByTheJobFunctions) #tuple of the external modules that I need, of which I am calling functions (not the mods of the ext-functions). #------------------------------------DO VALIDATION-------------------------------- myLogger.print3("-V-V-V-V-V- Now Validating for this subepoch before commencing the training iterations... -V-V-V-V-V-") start_validationForSubepoch_time = time.clock() train0orValidation1 = 1 #validation vectorWithWeightsOfTheClassesForCostFunctionOfTraining = 'placeholder' #only used in training doTrainOrValidationOnBatchesAndReturnMeanAccuraciesOfSubepoch(myLogger, train0orValidation1, numberOfBatchesValidation, # Computed by the number of extracted samples. So, adapts. cnn3dInstance, vectorWithWeightsOfTheClassesForCostFunctionOfTraining, subepoch, validationAccuracyMonitorForEpoch) cnn3dInstance.freeGpuValidationData() end_validationForSubepoch_time = time.clock() myLogger.print3("TIMING: Validating on the batches of this subepoch #" + str(subepoch) + " took time: "+str(end_validationForSubepoch_time-start_validationForSubepoch_time)+"(s)") #Update cnn's top achieved validation accuracy if needed: (for the autoReduction of Learning Rate.) cnn3dInstance.checkMeanValidationAccOfLastEpochAndUpdateCnnsTopAccAchievedIfNeeded(myLogger, validationAccuracyMonitorForEpoch.getMeanEmpiricalAccuracyOfEpoch(), minIncreaseInValidationAccuracyConsideredForLrSchedule) #-------------------END OF THE VALIDATION-DURING-TRAINING-LOOP------------------------- #-------------------------GET DATA FOR THIS SUBEPOCH's TRAINING--------------------------------- if (not performValidationOnSamplesDuringTrainingProcessBool) and boolItIsTheVeryFirstSubepochOfThisProcess : [channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain] = getSampledDataAndLabelsForSubepoch(myLogger, 0, cnn3dWrapper, maxNumSubjectsLoadedPerSubepoch, imagePartsLoadedInGpuPerSubepoch, samplingTypeInstanceTraining, listOfFilepathsToEachChannelOfEachPatientTraining, listOfFilepathsToGtLabelsOfEachPatientTraining, providedRoiMaskForTrainingBool, listOfFilepathsToRoiMaskOfEachPatientTraining, providedWeightMapsToSampleForEachCategoryTraining, forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining, useSameSubChannelsAsSingleScale, listOfFilepathsToEachSubsampledChannelOfEachPatientTraining, padInputImagesBool, smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage, normAugmNone0OnImages1OrSegments2AlreadyNormalized1SubtrUpToPropOfStdAndDivideWithUpToPerc, reflectImageWithHalfProbDuringTraining ) boolItIsTheVeryFirstSubepochOfThisProcess = False else : #It was done in parallel with the validation (or with previous training iteration, in case I am not performing validation). [channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain] = parallelJobToGetDataForNextTraining() #fromParallelProcessing that had started from last loop when it was submitted. #-------------------------COMPUTE CLASS-WEIGHTS, TO WEIGHT COST FUNCTION AND COUNTER CLASS IMBALANCE---------------------- #Do it for only few epochs, until I get to an ok local minima neighbourhood. if cnn3dInstance.numberOfEpochsTrained < numberOfEpochsToWeightTheClassesInTheCostFunction : numOfPatchesInTheSubepoch_notParts = np.prod(labelsForCentralOfSegmentsForSubepTrain.shape) actualNumOfPatchesPerClassInTheSubepoch_notParts = np.bincount(np.ravel(labelsForCentralOfSegmentsForSubepTrain).astype(int)) # yx - y1 = (x - x1) * (y2 - y1)/(x2 - x1) # yx = the multiplier I currently want, y1 = the multiplier at the begining, y2 = the multiplier at the end # x = current epoch, x1 = epoch where linear decrease starts, x2 = epoch where linear decrease ends y1 = (1./(actualNumOfPatchesPerClassInTheSubepoch_notParts+TINY_FLOAT)) * (numOfPatchesInTheSubepoch_notParts*1.0/cnn3dInstance.numberOfOutputClasses) y2 = 1. x1 = 0. * number_of_subepochs # linear decrease starts from epoch=0 x2 = numberOfEpochsToWeightTheClassesInTheCostFunction * number_of_subepochs x = cnn3dInstance.numberOfEpochsTrained * number_of_subepochs + subepoch yx = (x - x1) * (y2 - y1)/(x2 - x1) + y1 vectorWithWeightsOfTheClassesForCostFunctionOfTraining = np.asarray(yx, dtype="float32") myLogger.print3("UPDATE: [Weight of Classes] Setting the weights of the classes in the cost function to: " +str(vectorWithWeightsOfTheClassesForCostFunctionOfTraining)) else : vectorWithWeightsOfTheClassesForCostFunctionOfTraining = np.ones(cnn3dInstance.numberOfOutputClasses, dtype='float32') #------------------- Learning Rate Schedule ------------------------ # I must make a learning-rate-manager to encapsulate all these... Very ugly currently... All othere LR schedules are at the outer loop, per epoch. if (lowerLrByStable0orAuto1orPredefined2orExponential3Schedule == 4) : myLogger.print3("DEBUG: Going to change Learning Rate according to POLY schedule:") #newLearningRate = initLr * ( 1 - iter/max_iter) ^ power. Power = 0.9 in parsenet, which we validated to behave ok. currentIteration = cnn3dInstance.numberOfEpochsTrained * number_of_subepochs + subepoch max_iterations = n_epochs * number_of_subepochs newLearningRate = cnn3dInstance.initialLearningRate * pow( 1.0 - 1.0*currentIteration/max_iterations , 0.9) myLogger.print3("DEBUG: new learning rate was calculated: " +str(newLearningRate)) cnn3dInstance.change_learning_rate_of_a_cnn(newLearningRate, myLogger) #----------------------------------LOAD TRAINING DATA ON GPU------------------------------- myLogger.print3("Loading Training data for subepoch #"+str(subepoch)+" on shared variable...") start_loadingToGpu_time = time.clock() numberOfBatchesTraining = len(channsOfSegmentsForSubepPerPathwayTrain[0]) // cnn3dInstance.batchSize #Computed with number of extracted samples, in case I dont manage to extract as many as I wanted initially. cnn3dInstance.sharedInpXTrain.set_value(channsOfSegmentsForSubepPerPathwayTrain[0], borrow=borrowFlag) # Primary pathway for index in xrange(len(channsOfSegmentsForSubepPerPathwayTrain[1:])) : cnn3dInstance.sharedInpXPerSubsListTrain[index].set_value(channsOfSegmentsForSubepPerPathwayTrain[1+index], borrow=borrowFlag) cnn3dInstance.sharedLabelsYTrain.set_value(labelsForCentralOfSegmentsForSubepTrain, borrow=borrowFlag) channsOfSegmentsForSubepPerPathwayTrain = "" labelsForCentralOfSegmentsForSubepTrain = "" end_loadingToGpu_time = time.clock() myLogger.print3("TIMING: Loading sharedVariables for Training in epoch|subepoch="+str(epoch)+"|"+str(subepoch)+" took time: "+str(end_loadingToGpu_time-start_loadingToGpu_time)+"(s)") #------------------------SUBMIT PARALLEL JOB TO GET VALIDATION/TRAINING DATA (if val is/not performed) FOR NEXT SUBEPOCH----------------- if performValidationOnSamplesDuringTrainingProcessBool : #submit the parallel job myLogger.print3("PARALLEL: Before Training in subepoch #" +str(subepoch) + ", submitting the parallel job for extracting Segments for the next Validation.") parallelJobToGetDataForNextValidation = job_server.submit(getSampledDataAndLabelsForSubepoch, #local function to call and execute in parallel. tupleWithParametersForValidation, #tuple with the arguments required tupleWithLocalFunctionsThatWillBeCalledByTheMainJob, #tuple of local functions that I need to call tupleWithModulesToImportWhichAreUsedByTheJobFunctions) #tuple of the external modules that I need, of which I am calling functions (not the mods of the ext-functions). else : #extract in parallel the samples for the next subepoch's training. myLogger.print3("PARALLEL: Before Training in subepoch #" +str(subepoch) + ", submitting the parallel job for extracting Segments for the next Training.") parallelJobToGetDataForNextTraining = job_server.submit(getSampledDataAndLabelsForSubepoch, #local function to call and execute in parallel. tupleWithParametersForTraining, #tuple with the arguments required tupleWithLocalFunctionsThatWillBeCalledByTheMainJob, #tuple of local functions that I need to call tupleWithModulesToImportWhichAreUsedByTheJobFunctions) #tuple of the external modules that I need, of which I am calling #-------------------------------START TRAINING IN BATCHES------------------------------ myLogger.print3("-T-T-T-T-T- Now Training for this subepoch... This may take a few minutes... -T-T-T-T-T-") start_trainingForSubepoch_time = time.clock() train0orValidation1 = 0 #training doTrainOrValidationOnBatchesAndReturnMeanAccuraciesOfSubepoch(myLogger, train0orValidation1, numberOfBatchesTraining, cnn3dInstance, vectorWithWeightsOfTheClassesForCostFunctionOfTraining, subepoch, trainingAccuracyMonitorForEpoch) cnn3dInstance.freeGpuTrainingData() end_trainingForSubepoch_time = time.clock() myLogger.print3("TIMING: Training on the batches of this subepoch #" + str(subepoch) + " took time: "+str(end_trainingForSubepoch_time-start_trainingForSubepoch_time)+"(s)") myLogger.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" ) myLogger.print3("~~~~~~~~~~~~~~~~~~ Epoch #" + str(epoch) + " finished. Reporting Accuracy over whole epoch. ~~~~~~~~~~~~~~~~~~" ) myLogger.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" ) if performValidationOnSamplesDuringTrainingProcessBool : validationAccuracyMonitorForEpoch.reportMeanAccyracyOfEpoch() trainingAccuracyMonitorForEpoch.reportMeanAccyracyOfEpoch() del trainingAccuracyMonitorForEpoch; del validationAccuracyMonitorForEpoch; #=======================Learning Rate Schedule.========================= if (lowerLrByStable0orAuto1orPredefined2orExponential3Schedule == 0) and (numEpochsToWaitBeforeLowerLR > 0) and (cnn3dInstance.numberOfEpochsTrained % numEpochsToWaitBeforeLowerLR)==0 : # STABLE LR SCHEDULE" myLogger.print3("DEBUG: Going to lower Learning Rate because of STABLE schedule! The CNN has now been trained for: " + str(cnn3dInstance.numberOfEpochsTrained) + " epochs. I need to decrease LR every: " + str(numEpochsToWaitBeforeLowerLR) + " epochs.") cnn3dInstance.divide_learning_rate_of_a_cnn_by(divideLrBy, myLogger) elif (lowerLrByStable0orAuto1orPredefined2orExponential3Schedule == 1) and (numEpochsToWaitBeforeLowerLR > 0) : # AUTO LR SCHEDULE! if not performValidationOnSamplesDuringTrainingProcessBool : #This flag should have been set True from the start if training should do Auto-schedule. If we get in here, this is a bug. myLogger.print3("ERROR: For Auto-schedule I need to be performing validation-on-samples during the training-process. The flag performValidationOnSamplesDuringTrainingProcessBool should have been set to True. Instead it seems it was False and no validation was performed. This is a bug. Contact the developer, this should not have happened. Try another Learning Rate schedule for now! Exiting.") exit(1) if (cnn3dInstance.numberOfEpochsTrained >= cnn3dInstance.topMeanValidationAccuracyAchievedInEpoch[1] + numEpochsToWaitBeforeLowerLR) and \ (cnn3dInstance.numberOfEpochsTrained >= cnn3dInstance.lastEpochAtTheEndOfWhichLrWasLowered + numEpochsToWaitBeforeLowerLR) : myLogger.print3("DEBUG: Going to lower Learning Rate because of AUTO schedule! The CNN has now been trained for: " + str(cnn3dInstance.numberOfEpochsTrained) + " epochs. Epoch with last highest achieved validation accuracy: " + str(cnn3dInstance.topMeanValidationAccuracyAchievedInEpoch[1]) + ", and epoch that Learning Rate was last lowered: " + str(cnn3dInstance.lastEpochAtTheEndOfWhichLrWasLowered) + ". I waited for increase in accuracy for: " +str(numEpochsToWaitBeforeLowerLR) + " epochs. Going to lower Learning Rate...") cnn3dInstance.divide_learning_rate_of_a_cnn_by(divideLrBy, myLogger) elif (lowerLrByStable0orAuto1orPredefined2orExponential3Schedule == 2) and (cnn3dInstance.numberOfEpochsTrained in lowerLrAtTheEndOfTheseEpochsPredefinedScheduleList) : #Predefined Schedule. myLogger.print3("DEBUG: Going to lower Learning Rate because of PREDEFINED schedule! The CNN has now been trained for: " + str(cnn3dInstance.numberOfEpochsTrained) + " epochs. I need to decrease after that many epochs: " + str(lowerLrAtTheEndOfTheseEpochsPredefinedScheduleList)) cnn3dInstance.divide_learning_rate_of_a_cnn_by(divideLrBy, myLogger) elif (lowerLrByStable0orAuto1orPredefined2orExponential3Schedule == 3 and cnn3dInstance.numberOfEpochsTrained >= exponentialScheduleForLrAndMom[0]) : myLogger.print3("DEBUG: Going to lower Learning Rate and Increase Momentum because of EXPONENTIAL schedule! The CNN has now been trained for: " + str(cnn3dInstance.numberOfEpochsTrained) + " epochs.") minEpochToLowerLr = exponentialScheduleForLrAndMom[0] #newLearningRate = initialLearningRate * gamma^t. gamma = {t-th}root(valueIwantLrToHaveAtTimepointT / initialLearningRate) gammaForExpSchedule = pow( ( cnn3dInstance.initialLearningRate*exponentialScheduleForLrAndMom[1] * 1.0) / cnn3dInstance.initialLearningRate, 1.0 / (n_epochs-minEpochToLowerLr)) newLearningRate = cnn3dInstance.initialLearningRate * pow(gammaForExpSchedule, cnn3dInstance.numberOfEpochsTrained-minEpochToLowerLr + 1.0) #Momentum increased linearly. newMomentum = ((cnn3dInstance.numberOfEpochsTrained - minEpochToLowerLr + 1) - (n_epochs-minEpochToLowerLr))*1.0 / (n_epochs - minEpochToLowerLr) * (exponentialScheduleForLrAndMom[2] - cnn3dInstance.initialMomentum) + exponentialScheduleForLrAndMom[2] print("DEBUG: new learning rate was calculated: ", newLearningRate, " and new Momentum: ", newMomentum) cnn3dInstance.change_learning_rate_of_a_cnn(newLearningRate, myLogger) cnn3dInstance.change_momentum_of_a_cnn(newMomentum, myLogger) #================== Everything for epoch has finished. ======================= #Training finished. Update the number of epochs that the cnn was trained. cnn3dInstance.increaseNumberOfEpochsTrained() myLogger.print3("SAVING: Epoch #"+str(epoch)+" finished. Saving CNN model.") dump_cnn_to_gzip_file_dotSave(cnn3dInstance, fileToSaveTrainedCnnModelTo+"."+datetimeNowAsStr(), myLogger) end_epoch_time = time.clock() myLogger.print3("TIMING: The whole Epoch #"+str(epoch)+" took time: "+str(end_epoch_time-start_epoch_time)+"(s)") myLogger.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") if performFullInferenceOnValidationImagesEveryFewEpochsBool and (cnn3dInstance.numberOfEpochsTrained != 0) and (cnn3dInstance.numberOfEpochsTrained % everyThatManyEpochsComputeDiceOnTheFullValidationImages == 0) : myLogger.print3("***Starting validation with Full Inference / Segmentation on validation subjects for Epoch #"+str(epoch)+"...***") validation0orTesting1 = 0 #do_validation_or_testing(myLogger, performInferenceOnWholeVolumes(myLogger, validation0orTesting1, savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation, cnn3dInstance, listOfFilepathsToEachChannelOfEachPatientValidation, providedGtForValidationBool, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, providedRoiMaskForValidationBool, listOfFilepathsToRoiMaskOfEachPatientValidation, borrowFlag, listOfNamesToGiveToPredictionsIfSavingResults = "Placeholder" if not savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation else listOfNamesToGiveToPredictionsValidationIfSavingWhenEvalDice, #----Preprocessing------ padInputImagesBool=padInputImagesBool, smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage=smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage, #for the cnn extension useSameSubChannelsAsSingleScale=useSameSubChannelsAsSingleScale, listOfFilepathsToEachSubsampledChannelOfEachPatient=listOfFilepathsToEachSubsampledChannelOfEachPatientValidation, #--------For FM visualisation--------- saveIndividualFmImagesForVisualisation=saveIndividualFmImagesForVisualisation, saveMultidimensionalImageWithAllFms=saveMultidimensionalImageWithAllFms, indicesOfFmsToVisualisePerPathwayTypeAndPerLayer=indicesOfFmsToVisualisePerPathwayTypeAndPerLayer, listOfNamesToGiveToFmVisualisationsIfSaving=listOfNamesToGiveToFmVisualisationsIfSaving ) dump_cnn_to_gzip_file_dotSave(cnn3dInstance, fileToSaveTrainedCnnModelTo+".final."+datetimeNowAsStr(), myLogger) end_training_time = time.clock() myLogger.print3("TIMING: Training process took time: "+str(end_training_time-start_training_time)+"(s)") myLogger.print3("The whole do_training() function has finished.")
def do_training( sessionTf, saver_all, cnn3d, trainer, log, fileToSaveTrainedCnnModelTo, performValidationOnSamplesDuringTrainingProcessBool, savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation, listOfNamesToGiveToPredictionsValidationIfSavingWhenEvalDice, listOfFilepathsToEachChannelOfEachPatientTraining, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientTraining, providedGtForValidationBool, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, providedWeightMapsToSampleForEachCategoryTraining, forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining, providedWeightMapsToSampleForEachCategoryValidation, forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation, providedRoiMaskForTrainingBool, listOfFilepathsToRoiMaskOfEachPatientTraining, # Also needed for normalization-augmentation providedRoiMaskForValidationBool, listOfFilepathsToRoiMaskOfEachPatientValidation, n_epochs, # Every epoch the CNN model is saved. number_of_subepochs, # per epoch. Every subepoch Accuracy is reported maxNumSubjectsLoadedPerSubepoch, # Max num of cases loaded every subepoch for segments extraction. The more, the longer loading. imagePartsLoadedInGpuPerSubepoch, imagePartsLoadedInGpuPerSubepochValidation, #-------Sampling Type--------- samplingTypeInstanceTraining, # Instance of the deepmedic/samplingType.SamplingType class for training and validation samplingTypeInstanceValidation, #-------Preprocessing----------- padInputImagesBool, #-------Data Augmentation------- doIntAugm_shiftMuStd_multiMuStd, reflectImageWithHalfProbDuringTraining, useSameSubChannelsAsSingleScale, listOfFilepathsToEachSubsampledChannelOfEachPatientTraining, # deprecated, not supported listOfFilepathsToEachSubsampledChannelOfEachPatientValidation, # deprecated, not supported # Validation performFullInferenceOnValidationImagesEveryFewEpochsBool, #Even if not providedGtForValidationBool, inference will be performed if this == True, to save the results, eg for visual. everyThatManyEpochsComputeDiceOnTheFullValidationImages, # Should not be == 0, except if performFullInferenceOnValidationImagesEveryFewEpochsBool == False #--------For FM visualisation--------- saveIndividualFmImagesForVisualisation, saveMultidimensionalImageWithAllFms, indicesOfFmsToVisualisePerPathwayTypeAndPerLayer, listOfNamesToGiveToFmVisualisationsIfSaving, #-------- Others -------- run_input_checks): start_training_time = time.time() # Used because I cannot pass cnn3d to the sampling function. #This is because the parallel process used to load theano again. And created problems in the GPU when cnmem is used. Not sure this is needed with Tensorflow. Probably. cnn3dWrapper = CnnWrapperForSampling(cnn3d) #---------To run PARALLEL the extraction of parts for the next subepoch--- ppservers = () # tuple of all parallel python servers to connect with job_server = pp.Server( ncpus=1, ppservers=ppservers ) # Creates jobserver with automatically detected number of workers tupleWithParametersForTraining = ( log, "train", run_input_checks, cnn3dWrapper, maxNumSubjectsLoadedPerSubepoch, imagePartsLoadedInGpuPerSubepoch, samplingTypeInstanceTraining, listOfFilepathsToEachChannelOfEachPatientTraining, listOfFilepathsToGtLabelsOfEachPatientTraining, providedRoiMaskForTrainingBool, listOfFilepathsToRoiMaskOfEachPatientTraining, providedWeightMapsToSampleForEachCategoryTraining, forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining, useSameSubChannelsAsSingleScale, listOfFilepathsToEachSubsampledChannelOfEachPatientTraining, padInputImagesBool, doIntAugm_shiftMuStd_multiMuStd, reflectImageWithHalfProbDuringTraining) tupleWithParametersForValidation = ( log, "val", run_input_checks, cnn3dWrapper, maxNumSubjectsLoadedPerSubepoch, imagePartsLoadedInGpuPerSubepochValidation, samplingTypeInstanceValidation, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, providedRoiMaskForValidationBool, listOfFilepathsToRoiMaskOfEachPatientValidation, providedWeightMapsToSampleForEachCategoryValidation, forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation, useSameSubChannelsAsSingleScale, listOfFilepathsToEachSubsampledChannelOfEachPatientValidation, padInputImagesBool, [0, -1, -1, -1], #don't perform intensity-augmentation during validation. [0, 0, 0] #don't perform reflection-augmentation during validation. ) tupleWithLocalFunctionsThatWillBeCalledByTheMainJob = () tupleWithModulesToImportWhichAreUsedByTheJobFunctions = ( "from __future__ import absolute_import, print_function, division", "time", "numpy as np", "from deepmedic.dataManagement.sampling import *") boolItIsTheVeryFirstSubepochOfThisProcess = True #to know so that in the very first I sequencially load the data for it. #------End for parallel------ model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval( session=sessionTf) while model_num_epochs_trained < n_epochs: epoch = model_num_epochs_trained trainingAccuracyMonitorForEpoch = AccuracyOfEpochMonitorSegmentation( log, 0, model_num_epochs_trained, cnn3d.num_classes, number_of_subepochs) validationAccuracyMonitorForEpoch = None if not performValidationOnSamplesDuringTrainingProcessBool else \ AccuracyOfEpochMonitorSegmentation(log, 1, model_num_epochs_trained, cnn3d.num_classes, number_of_subepochs ) log.print3( "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" ) log.print3("~~~~~~~~~~~~~~~~~~~~Starting new Epoch! Epoch #" + str(epoch) + "/" + str(n_epochs) + " ~~~~~~~~~~~~~~~~~~~~~~~~~") log.print3( "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" ) start_epoch_time = time.time() for subepoch in range( number_of_subepochs ): #per subepoch I randomly load some images in the gpu. Random order. log.print3( "**************************************************************************************************" ) log.print3("************* Starting new Subepoch: #" + str(subepoch) + "/" + str(number_of_subepochs) + " *************") log.print3( "**************************************************************************************************" ) #-------------------------GET DATA FOR THIS SUBEPOCH's VALIDATION--------------------------------- if performValidationOnSamplesDuringTrainingProcessBool: if boolItIsTheVeryFirstSubepochOfThisProcess: [ channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal ] = getSampledDataAndLabelsForSubepoch( log, "val", run_input_checks, cnn3dWrapper, maxNumSubjectsLoadedPerSubepoch, imagePartsLoadedInGpuPerSubepochValidation, samplingTypeInstanceValidation, listOfFilepathsToEachChannelOfEachPatientValidation, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, providedRoiMaskForValidationBool, listOfFilepathsToRoiMaskOfEachPatientValidation, providedWeightMapsToSampleForEachCategoryValidation, forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation, useSameSubChannelsAsSingleScale, listOfFilepathsToEachSubsampledChannelOfEachPatientValidation, padInputImagesBool, doIntAugm_shiftMuStd_multiMuStd=[False, [], []], reflectImageWithHalfProbDuringTraining=[0, 0, 0]) boolItIsTheVeryFirstSubepochOfThisProcess = False else: #It was done in parallel with the training of the previous epoch, just grab the results... [ channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal ] = parallelJobToGetDataForNextValidation( ) #fromParallelProcessing that had started from last loop when it was submitted. # Below is computed with number of extracted samples, in case I dont manage to extract as many as I wanted initially. numberOfBatchesValidation = len( channsOfSegmentsForSubepPerPathwayVal[0] ) // cnn3d.batchSize["val"] #------------------------SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING----------------- #submit the parallel job log.print3( "PARALLEL: Before Validation in subepoch #" + str(subepoch) + ", the parallel job for extracting Segments for the next Training is submitted." ) parallelJobToGetDataForNextTraining = job_server.submit( getSampledDataAndLabelsForSubepoch, #local function to call and execute in parallel. tupleWithParametersForTraining, #tuple with the arguments required tupleWithLocalFunctionsThatWillBeCalledByTheMainJob, #tuple of local functions that I need to call tupleWithModulesToImportWhichAreUsedByTheJobFunctions ) #tuple of the external modules that I need, of which I am calling functions (not the mods of the ext-functions). #------------------------------------DO VALIDATION-------------------------------- log.print3( "-V-V-V-V-V- Now Validating for this subepoch before commencing the training iterations... -V-V-V-V-V-" ) start_validationForSubepoch_time = time.time() doTrainOrValidationOnBatchesAndReturnMeanAccuraciesOfSubepoch( log, sessionTf, "val", numberOfBatchesValidation, # Computed by the number of extracted samples. So, adapts. cnn3d, subepoch, validationAccuracyMonitorForEpoch, channsOfSegmentsForSubepPerPathwayVal, labelsForCentralOfSegmentsForSubepVal) end_validationForSubepoch_time = time.time() log.print3( "TIMING: Validating on the batches of this subepoch #" + str(subepoch) + " took time: " + str(end_validationForSubepoch_time - start_validationForSubepoch_time) + "(s)") #-------------------END OF THE VALIDATION-DURING-TRAINING-LOOP------------------------- #-------------------------GET DATA FOR THIS SUBEPOCH's TRAINING--------------------------------- if (not performValidationOnSamplesDuringTrainingProcessBool ) and boolItIsTheVeryFirstSubepochOfThisProcess: [ channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain ] = getSampledDataAndLabelsForSubepoch( log, "train", run_input_checks, cnn3dWrapper, maxNumSubjectsLoadedPerSubepoch, imagePartsLoadedInGpuPerSubepoch, samplingTypeInstanceTraining, listOfFilepathsToEachChannelOfEachPatientTraining, listOfFilepathsToGtLabelsOfEachPatientTraining, providedRoiMaskForTrainingBool, listOfFilepathsToRoiMaskOfEachPatientTraining, providedWeightMapsToSampleForEachCategoryTraining, forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining, useSameSubChannelsAsSingleScale, listOfFilepathsToEachSubsampledChannelOfEachPatientTraining, padInputImagesBool, doIntAugm_shiftMuStd_multiMuStd, reflectImageWithHalfProbDuringTraining) boolItIsTheVeryFirstSubepochOfThisProcess = False else: #It was done in parallel with the validation (or with previous training iteration, in case I am not performing validation). [ channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain ] = parallelJobToGetDataForNextTraining( ) #fromParallelProcessing that had started from last loop when it was submitted. numberOfBatchesTraining = len( channsOfSegmentsForSubepPerPathwayTrain[0] ) // cnn3d.batchSize[ "train"] #Computed with number of extracted samples, in case I dont manage to extract as many as I wanted initially. #------------------------SUBMIT PARALLEL JOB TO GET VALIDATION/TRAINING DATA (if val is/not performed) FOR NEXT SUBEPOCH----------------- if performValidationOnSamplesDuringTrainingProcessBool: #submit the parallel job log.print3( "PARALLEL: Before Training in subepoch #" + str(subepoch) + ", submitting the parallel job for extracting Segments for the next Validation." ) parallelJobToGetDataForNextValidation = job_server.submit( getSampledDataAndLabelsForSubepoch, #local function to call and execute in parallel. tupleWithParametersForValidation, #tuple with the arguments required tupleWithLocalFunctionsThatWillBeCalledByTheMainJob, #tuple of local functions that I need to call tupleWithModulesToImportWhichAreUsedByTheJobFunctions ) #tuple of the external modules that I need, of which I am calling functions (not the mods of the ext-functions). else: #extract in parallel the samples for the next subepoch's training. log.print3( "PARALLEL: Before Training in subepoch #" + str(subepoch) + ", submitting the parallel job for extracting Segments for the next Training." ) parallelJobToGetDataForNextTraining = job_server.submit( getSampledDataAndLabelsForSubepoch, #local function to call and execute in parallel. tupleWithParametersForTraining, #tuple with the arguments required tupleWithLocalFunctionsThatWillBeCalledByTheMainJob, #tuple of local functions that I need to call tupleWithModulesToImportWhichAreUsedByTheJobFunctions ) #tuple of the external modules that I need, of which I am calling #-------------------------------START TRAINING IN BATCHES------------------------------ log.print3( "-T-T-T-T-T- Now Training for this subepoch... This may take a few minutes... -T-T-T-T-T-" ) start_trainingForSubepoch_time = time.time() doTrainOrValidationOnBatchesAndReturnMeanAccuraciesOfSubepoch( log, sessionTf, "train", numberOfBatchesTraining, cnn3d, subepoch, trainingAccuracyMonitorForEpoch, channsOfSegmentsForSubepPerPathwayTrain, labelsForCentralOfSegmentsForSubepTrain) end_trainingForSubepoch_time = time.time() log.print3("TIMING: Training on the batches of this subepoch #" + str(subepoch) + " took time: " + str(end_trainingForSubepoch_time - start_trainingForSubepoch_time) + "(s)") log.print3( "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" ) log.print3( "~~~~~~~~~~~~~~~~~~ Epoch #" + str(epoch) + " finished. Reporting Accuracy over whole epoch. ~~~~~~~~~~~~~~~~~~" ) log.print3( "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" ) if performValidationOnSamplesDuringTrainingProcessBool: validationAccuracyMonitorForEpoch.reportMeanAccyracyOfEpoch() trainingAccuracyMonitorForEpoch.reportMeanAccyracyOfEpoch() mean_val_acc_of_ep = validationAccuracyMonitorForEpoch.getMeanEmpiricalAccuracyOfEpoch( ) if performValidationOnSamplesDuringTrainingProcessBool else None trainer.run_updates_end_of_ep( log, sessionTf, mean_val_acc_of_ep ) # Updates LR schedule if needed, and increases number of epochs trained. model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval( session=sessionTf) del trainingAccuracyMonitorForEpoch del validationAccuracyMonitorForEpoch #================== Everything for epoch has finished. ======================= log.print3("SAVING: Epoch #" + str(epoch) + " finished. Saving CNN model.") filename_to_save_with = fileToSaveTrainedCnnModelTo + "." + datetimeNowAsStr( ) saver_all.save(sessionTf, filename_to_save_with + ".model.ckpt", write_meta_graph=False) end_epoch_time = time.time() log.print3("TIMING: The whole Epoch #" + str(epoch) + " took time: " + str(end_epoch_time - start_epoch_time) + "(s)") log.print3( "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" ) if performFullInferenceOnValidationImagesEveryFewEpochsBool and ( model_num_epochs_trained != 0 ) and (model_num_epochs_trained % everyThatManyEpochsComputeDiceOnTheFullValidationImages == 0): log.print3( "***Starting validation with Full Inference / Segmentation on validation subjects for Epoch #" + str(epoch) + "...***") performInferenceOnWholeVolumes( sessionTf, cnn3d, log, "val", savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation, listOfFilepathsToEachChannelOfEachPatientValidation, providedGtForValidationBool, listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc, providedRoiMaskForValidationBool, listOfFilepathsToRoiMaskOfEachPatientValidation, listOfNamesToGiveToPredictionsIfSavingResults="Placeholder" if not savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation else listOfNamesToGiveToPredictionsValidationIfSavingWhenEvalDice, #----Preprocessing------ padInputImagesBool=padInputImagesBool, #for the cnn extension useSameSubChannelsAsSingleScale=useSameSubChannelsAsSingleScale, listOfFilepathsToEachSubsampledChannelOfEachPatient= listOfFilepathsToEachSubsampledChannelOfEachPatientValidation, #--------For FM visualisation--------- saveIndividualFmImagesForVisualisation= saveIndividualFmImagesForVisualisation, saveMultidimensionalImageWithAllFms= saveMultidimensionalImageWithAllFms, indicesOfFmsToVisualisePerPathwayTypeAndPerLayer= indicesOfFmsToVisualisePerPathwayTypeAndPerLayer, listOfNamesToGiveToFmVisualisationsIfSaving= listOfNamesToGiveToFmVisualisationsIfSaving) end_training_time = time.time() log.print3("TIMING: Training process took time: " + str(end_training_time - start_training_time) + "(s)") log.print3("The whole do_training() function has finished.")