コード例 #1
0
ファイル: testSession.py プロジェクト: Kamnitsask/deepmedic
 def run_session(self, *args):
     (sess_device,
      model_params,) = args
     
     graphTf = tf.Graph()
     
     with graphTf.as_default():
         with graphTf.device(sess_device): # Throws an error if GPU is specified but not available.
             self._log.print3("=========== Making the CNN graph... ===============")
             cnn3d = Cnn3d()
             with tf.variable_scope("net"):
                 cnn3d.make_cnn_model( *model_params.get_args_for_arch() ) # Creates the network's graph (without optimizer).
                 
         self._log.print3("=========== Compiling the Testing Function ============")
         self._log.print3("=======================================================\n")
         
         cnn3d.setup_ops_n_feeds_to_test( self._log,
                                          self._params.indices_fms_per_pathtype_per_layer_to_save )
         # Create the saver
         saver_all = tf.train.Saver() # saver_net would suffice
         
     with tf.Session( graph=graphTf, config=tf.ConfigProto(log_device_placement=False, device_count={'CPU':999, 'GPU':99}) ) as sessionTf:
         file_to_load_params_from = self._params.get_path_to_load_model_from()
         if file_to_load_params_from is not None: # Load params
             self._log.print3("=========== Loading parameters from specified saved model ===============")
             chkpt_fname = tf.train.latest_checkpoint( file_to_load_params_from ) if os.path.isdir( file_to_load_params_from ) else file_to_load_params_from
             self._log.print3("Loading parameters from:" + str(chkpt_fname))
             try:
                 saver_all.restore(sessionTf, chkpt_fname)
                 self._log.print3("Parameters were loaded.")
             except Exception as e: handle_exception_tf_restore(self._log, e)
             
         else:
             self._ask_user_if_test_with_random() # Asks user whether to continue with randomly initialized model. It exits if no is given.
             self._log.print3("")
             self._log.print3("=========== Initializing network variables  ===============")
             tf.variables_initializer( var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="net") ).run()
             self._log.print3("Model variables were initialized.")
             
             
         self._log.print3("")
         self._log.print3("======================================================")
         self._log.print3("=========== Testing with the CNN model ===============")
         self._log.print3("======================================================\n")
         
         res_code = inferenceWholeVolumes( *( [sessionTf, cnn3d] + self._params.get_args_for_testing() ) )
     
     self._log.print3("")
     self._log.print3("======================================================")
     self._log.print3("=========== Testing session finished =================")
     self._log.print3("======================================================")
コード例 #2
0
def do_training(sessionTf,
                saver_all,
                cnn3d,
                trainer,
                log,
                
                fileToSaveTrainedCnnModelTo,
                
                val_on_samples_during_train,
                savePredictedSegmAndProbsDict,
                
                namesForSavingSegmAndProbs,
                suffixForSegmAndProbsDict,
                
                listOfFilepathsToEachChannelOfEachPatientTraining,
                listOfFilepathsToEachChannelOfEachPatientValidation,
                
                listOfFilepathsToGtLabelsOfEachPatientTraining,
                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                
                paths_to_wmaps_per_sampl_cat_per_subj_train,
                paths_to_wmaps_per_sampl_cat_per_subj_val,
                
                listOfFilepathsToRoiMaskOfEachPatientTraining,
                listOfFilepathsToRoiMaskOfEachPatientValidation,
                
                n_epochs, # Every epoch the CNN model is saved.
                num_subepochs, # per epoch. Every subepoch Accuracy is reported
                max_n_cases_per_subep_train,  # Max num of subjects loaded every subepoch for segments extraction.
                n_samples_per_subep_train,
                n_samples_per_subep_val,
                num_parallel_proc_sampling, # -1: seq. 0: thread for sampling. >0: multiprocess sampling
                
                #-------Sampling Type---------
                samplingTypeInstanceTraining, # Instance of the deepmedic/samplingType.SamplingType class for training and validation
                samplingTypeInstanceValidation,
                batchsize_train,
                batchsize_val_samples,
                batchsize_val_whole,
                
                #-------Preprocessing-----------
                pad_input_imgs,
                #-------Data Augmentation-------
                augm_img_prms,
                augm_sample_prms,
                
                # Validation
                val_on_whole_volumes,
                num_epochs_between_val_on_whole_volumes,
                
                #--------For FM visualisation---------
                saveIndividualFmImagesForVisualisation,
                saveMultidimensionalImageWithAllFms,
                indicesOfFmsToVisualisePerPathwayTypeAndPerLayer,
                namesForSavingFms,
                
                #-------- Others --------
                run_input_checks
                ):
    
    id_str = "[MAIN|PID:"+str(os.getpid())+"]"
    start_time_train = time.time()
    
    # I cannot pass cnn3d to the sampling function, because the pp module used to reload theano. 
    # This created problems in the GPU when cnmem is used. Not sure this is needed with Tensorflow. Probably.
    cnn3dWrapper = CnnWrapperForSampling(cnn3d)
    
    args_for_sampling_train = ( log,
                                "train",
                                num_parallel_proc_sampling,
                                run_input_checks,
                                cnn3dWrapper,
                                max_n_cases_per_subep_train,
                                n_samples_per_subep_train,
                                samplingTypeInstanceTraining,
                                listOfFilepathsToEachChannelOfEachPatientTraining,
                                listOfFilepathsToGtLabelsOfEachPatientTraining,
                                listOfFilepathsToRoiMaskOfEachPatientTraining,
                                paths_to_wmaps_per_sampl_cat_per_subj_train,
                                pad_input_imgs,
                                augm_img_prms,
                                augm_sample_prms )
    args_for_sampling_val = (   log,
                                "val",
                                num_parallel_proc_sampling,
                                run_input_checks,
                                cnn3dWrapper,
                                max_n_cases_per_subep_train,
                                n_samples_per_subep_val,
                                samplingTypeInstanceValidation,
                                listOfFilepathsToEachChannelOfEachPatientValidation,
                                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                                listOfFilepathsToRoiMaskOfEachPatientValidation,
                                paths_to_wmaps_per_sampl_cat_per_subj_val,
                                pad_input_imgs,
                                None, # no augmentation in val.
                                None ) # no augmentation in val.
    
    sampling_job_submitted_train = False
    sampling_job_submitted_val = False
    # For parallel extraction of samples for next train/val while processing previous iteration.
    worker_pool = None
    if num_parallel_proc_sampling > -1 : # Use multiprocessing.
        worker_pool = ThreadPool(processes=1) # Or multiprocessing.Pool(...), same API.
    
    try:
        model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf)
        while model_num_epochs_trained < n_epochs :
            epoch = model_num_epochs_trained
            
            acc_monitor_for_ep_train = AccuracyOfEpochMonitorSegmentation(log, 0, model_num_epochs_trained, cnn3d.num_classes, num_subepochs)
            acc_monitor_for_ep_val = None if not val_on_samples_during_train else \
                                            AccuracyOfEpochMonitorSegmentation(log, 1, model_num_epochs_trained, cnn3d.num_classes, num_subepochs )
                                            
            val_on_whole_volumes_after_ep = False
            if val_on_whole_volumes and (model_num_epochs_trained+1) % num_epochs_between_val_on_whole_volumes == 0:
                val_on_whole_volumes_after_ep = True
                
                
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            log.print3("~~~~~~~~~~~~~\t Starting new Epoch! Epoch #"+str(epoch)+"/"+str(n_epochs)+" \t~~~~~~~~~~~~~")
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            start_time_ep = time.time()
            
            for subepoch in range(num_subepochs):
                log.print3("***************************************************************************************")
                log.print3("*******\t\t Starting new Subepoch: #"+str(subepoch)+"/"+str(num_subepochs)+" \t\t********")
                log.print3("***************************************************************************************")
                
                #-------------------------GET DATA FOR THIS SUBEPOCH's VALIDATION---------------------------------
                if val_on_samples_during_train :
                    if worker_pool is None: # Sequential processing.
                        log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [VALIDATION] will be done by main thread.")
                        (channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal) = getSampledDataAndLabelsForSubepoch( *args_for_sampling_val )
                    elif sampling_job_submitted_val : #It was done in parallel with the training of the previous epoch, just grab results.
                        (channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal) = parallelJobToGetDataForNextValidation.get()
                        sampling_job_submitted_val = False
                    else : # Not previously submitted in case of first epoch or after a full-volumes validation.
                        assert subepoch == 0
                        log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].")
                        parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val)
                        (channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal) = parallelJobToGetDataForNextValidation.get()
                        sampling_job_submitted_val = False

                    
                    #------------------------SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING-----------------
                    if worker_pool is not None:
                        log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].")
                        parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train)
                        sampling_job_submitted_train = True
                    
                    #------------------------------------DO VALIDATION--------------------------------
                    log.print3("-V-V-V-V-V- Now Validating for this subepoch before commencing the training iterations... -V-V-V-V-V-")
                    start_time_val_subep = time.time()
                    # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested.
                    num_batches_val = len(channsOfSegmentsForSubepPerPathwayVal[0]) // batchsize_val_samples
                    trainOrValidateForSubepoch( log,
                                                sessionTf,
                                                "val",
                                                num_batches_val,
                                                batchsize_val_samples,
                                                cnn3d,
                                                subepoch,
                                                acc_monitor_for_ep_val,
                                                channsOfSegmentsForSubepPerPathwayVal,
                                                labelsForCentralOfSegmentsForSubepVal )
                    end_time_val_subep = time.time()
                    log.print3("TIMING: Validation on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_val_subep-start_time_val_subep)+" secs.")
                
                #-------------------------GET DATA FOR THIS SUBEPOCH's TRAINING---------------------------------
                if worker_pool is None: # Sequential processing.
                    log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [TRAINING] will be done by main thread.")
                    (channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain) = getSampledDataAndLabelsForSubepoch( *args_for_sampling_train )
                elif sampling_job_submitted_train: # Sampling job should have been done in parallel with previous train/val. Just grab results.
                    (channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain) = parallelJobToGetDataForNextTraining.get()
                    sampling_job_submitted_train = False
                else:  # Not previously submitted in case of first epoch or after a full-volumes validation.
                    assert subepoch == 0
                    log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].")
                    parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train)
                    (channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain) = parallelJobToGetDataForNextTraining.get()
                    sampling_job_submitted_train = False

                        
                #------------------------SUBMIT PARALLEL JOB TO GET VALIDATION/TRAINING DATA (if val is/not performed) FOR NEXT SUBEPOCH-----------------
                if worker_pool is not None and not (val_on_whole_volumes_after_ep and (subepoch == num_subepochs-1)):
                    if val_on_samples_during_train :
                        log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].")
                        parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val)
                        sampling_job_submitted_val = True
                    else :
                        log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].")
                        parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train)
                        sampling_job_submitted_train = True
                
                #-------------------------------START TRAINING IN BATCHES------------------------------
                log.print3("-T-T-T-T-T- Now Training for this subepoch... This may take a few minutes... -T-T-T-T-T-")
                start_time_train_subep = time.time()
                # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested.
                num_batches_train = len(channsOfSegmentsForSubepPerPathwayTrain[0]) // batchsize_train
                trainOrValidateForSubepoch( log,
                                            sessionTf,
                                            "train",
                                            num_batches_train,
                                            batchsize_train,
                                            cnn3d,
                                            subepoch,
                                            acc_monitor_for_ep_train,
                                            channsOfSegmentsForSubepPerPathwayTrain,
                                            labelsForCentralOfSegmentsForSubepTrain )
                end_time_train_subep = time.time()
                log.print3("TIMING: Training on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_train_subep-start_time_train_subep)+" secs.")
                
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            log.print3("~~~~~~ Epoch #" + str(epoch) + " finished. Reporting Accuracy over whole epoch. ~~~~~~~" )
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            
            if val_on_samples_during_train:
                acc_monitor_for_ep_val.reportMeanAccyracyOfEpoch()
            acc_monitor_for_ep_train.reportMeanAccyracyOfEpoch()
            
            mean_val_acc_of_ep = acc_monitor_for_ep_val.getMeanEmpiricalAccuracyOfEpoch() if val_on_samples_during_train else None
            trainer.run_updates_end_of_ep(log, sessionTf, mean_val_acc_of_ep) # Updates LR schedule if needed, and increases number of epochs trained.
            model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf)
            
            del acc_monitor_for_ep_train; del acc_monitor_for_ep_val;
            
            log.print3("SAVING: Epoch #"+str(epoch)+" finished. Saving CNN model.")
            filename_to_save_with = fileToSaveTrainedCnnModelTo + "." + datetimeNowAsStr()
            saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False )
            
            end_time_ep = time.time()
            log.print3("TIMING: The whole Epoch #"+str(epoch)+" lasted: {0:.1f}".format(end_time_ep-start_time_ep)+" secs.")
            log.print3("~~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~~~")
            
            
            if val_on_whole_volumes_after_ep:
                log.print3("***Starting validation with Full Inference / Segmentation on validation subjects for Epoch #"+str(epoch)+"...***")
                
                res_code = inferenceWholeVolumes(
                                sessionTf,
                                cnn3d,
                                log,
                                "val",
                                savePredictedSegmAndProbsDict,
                                listOfFilepathsToEachChannelOfEachPatientValidation,
                                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                                listOfFilepathsToRoiMaskOfEachPatientValidation,
                                namesForSavingSegmAndProbs = namesForSavingSegmAndProbs,
                                suffixForSegmAndProbsDict = suffixForSegmAndProbsDict,
                                # Hyper parameters
                                batchsize = batchsize_val_whole,
                                
                                #----Preprocessing------
                                pad_input_imgs = pad_input_imgs,
                                
                                #--------For FM visualisation---------
                                saveIndividualFmImagesForVisualisation = saveIndividualFmImagesForVisualisation,
                                saveMultidimensionalImageWithAllFms = saveMultidimensionalImageWithAllFms,
                                indicesOfFmsToVisualisePerPathwayTypeAndPerLayer = indicesOfFmsToVisualisePerPathwayTypeAndPerLayer,
                                namesForSavingFms = namesForSavingFms
                                )
        end_time_train = time.time()
        log.print3("TIMING: Training process lasted: {0:.1f}".format(end_time_train-start_time_train)+" secs.")
        
    except (Exception, KeyboardInterrupt) as e:
        log.print3("\n\n ERROR: Caught exception in do_training(): " + str(e) + "\n")
        log.print3( traceback.format_exc() )
        if worker_pool is not None:
            log.print3("Terminating worker pool.")
            worker_pool.terminate()
            worker_pool.join() # Will wait. A KeybInt will kill this (py3)
        return 1
    else:
        if worker_pool is not None:
            log.print3("Closing worker pool.")
            worker_pool.close()
            worker_pool.join()
    
    # Save the final trained model.
    filename_to_save_with = fileToSaveTrainedCnnModelTo + ".final." + datetimeNowAsStr()
    log.print3("Saving the final model at:" + str(filename_to_save_with))
    saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False )
            
    log.print3("The whole do_training() function has finished.")
    return 0
コード例 #3
0
ファイル: training.py プロジェクト: Kamnitsask/deepmedic
def do_training(sessionTf,
                saver_all,
                cnn3d,
                trainer,
                log,
                
                fileToSaveTrainedCnnModelTo,
                
                val_on_samples_during_train,
                savePredictedSegmAndProbsDict,
                
                namesForSavingSegmAndProbs,
                suffixForSegmAndProbsDict,
                
                listOfFilepathsToEachChannelOfEachPatientTraining,
                listOfFilepathsToEachChannelOfEachPatientValidation,
                
                listOfFilepathsToGtLabelsOfEachPatientTraining,
                providedGtForValidationBool,
                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                
                providedWeightMapsToSampleForEachCategoryTraining,
                paths_to_wmaps_per_sampl_cat_per_subj_train,
                providedWeightMapsToSampleForEachCategoryValidation,
                paths_to_wmaps_per_sampl_cat_per_subj_val,
                
                providedRoiMaskForTrainingBool,
                listOfFilepathsToRoiMaskOfEachPatientTraining, # Also needed for normalization-augmentation
                providedRoiMaskForValidationBool,
                listOfFilepathsToRoiMaskOfEachPatientValidation,
                
                n_epochs, # Every epoch the CNN model is saved.
                num_subepochs, # per epoch. Every subepoch Accuracy is reported
                max_n_cases_per_subep_train,  # Max num of subjects loaded every subepoch for segments extraction.
                n_samples_per_subep_train,
                n_samples_per_subep_val,
                num_parallel_proc_sampling, # -1: seq. 0: thread for sampling. >0: multiprocess sampling
                
                #-------Sampling Type---------
                samplingTypeInstanceTraining, # Instance of the deepmedic/samplingType.SamplingType class for training and validation
                samplingTypeInstanceValidation,
                batchsize_train,
                batchsize_val_samples,
                batchsize_val_whole,
                
                #-------Preprocessing-----------
                padInputImagesBool,
                #-------Data Augmentation-------
                augm_params,
                
                # Validation
                val_on_whole_volumes,
                num_epochs_between_val_on_whole_volumes,
                
                #--------For FM visualisation---------
                saveIndividualFmImagesForVisualisation,
                saveMultidimensionalImageWithAllFms,
                indicesOfFmsToVisualisePerPathwayTypeAndPerLayer,
                namesForSavingFms,
                
                #-------- Others --------
                run_input_checks
                ):
    
    id_str = "[MAIN|PID:"+str(os.getpid())+"]"
    start_time_train = time.time()
    
    # I cannot pass cnn3d to the sampling function, because the pp module used to reload theano. 
    # This created problems in the GPU when cnmem is used. Not sure this is needed with Tensorflow. Probably.
    cnn3dWrapper = CnnWrapperForSampling(cnn3d)
    
    args_for_sampling_train = ( log,
                                "train",
                                num_parallel_proc_sampling,
                                run_input_checks,
                                cnn3dWrapper,
                                max_n_cases_per_subep_train,
                                n_samples_per_subep_train,
                                samplingTypeInstanceTraining,
                                listOfFilepathsToEachChannelOfEachPatientTraining,
                                listOfFilepathsToGtLabelsOfEachPatientTraining,
                                providedRoiMaskForTrainingBool,
                                listOfFilepathsToRoiMaskOfEachPatientTraining,
                                providedWeightMapsToSampleForEachCategoryTraining,
                                paths_to_wmaps_per_sampl_cat_per_subj_train,
                                padInputImagesBool,
                                augm_params )
    args_for_sampling_val = (   log,
                                "val",
                                num_parallel_proc_sampling,
                                run_input_checks,
                                cnn3dWrapper,
                                max_n_cases_per_subep_train,
                                n_samples_per_subep_val,
                                samplingTypeInstanceValidation,
                                listOfFilepathsToEachChannelOfEachPatientValidation,
                                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                                providedRoiMaskForValidationBool,
                                listOfFilepathsToRoiMaskOfEachPatientValidation,
                                providedWeightMapsToSampleForEachCategoryValidation,
                                paths_to_wmaps_per_sampl_cat_per_subj_val,
                                padInputImagesBool,
                                None ) # no augmentation in validation.
    
    sampling_job_submitted_train = False
    sampling_job_submitted_val = False
    # For parallel extraction of samples for next train/val while processing previous iteration.
    worker_pool = None
    if num_parallel_proc_sampling > -1 : # Use multiprocessing.
        worker_pool = ThreadPool(processes=1) # Or multiprocessing.Pool(...), same API.
    
    try:
        model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf)
        while model_num_epochs_trained < n_epochs :
            epoch = model_num_epochs_trained
            
            acc_monitor_for_ep_train = AccuracyOfEpochMonitorSegmentation(log, 0, model_num_epochs_trained, cnn3d.num_classes, num_subepochs)
            acc_monitor_for_ep_val = None if not val_on_samples_during_train else \
                                            AccuracyOfEpochMonitorSegmentation(log, 1, model_num_epochs_trained, cnn3d.num_classes, num_subepochs )
                                            
            val_on_whole_volumes_after_ep = False
            if val_on_whole_volumes and (model_num_epochs_trained+1) % num_epochs_between_val_on_whole_volumes == 0:
                val_on_whole_volumes_after_ep = True
                
                
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            log.print3("~~~~~~~~~~~~~\t Starting new Epoch! Epoch #"+str(epoch)+"/"+str(n_epochs)+" \t~~~~~~~~~~~~~")
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            start_time_ep = time.time()
            
            for subepoch in range(num_subepochs):
                log.print3("***************************************************************************************")
                log.print3("*******\t\t Starting new Subepoch: #"+str(subepoch)+"/"+str(num_subepochs)+" \t\t********")
                log.print3("***************************************************************************************")
                
                #-------------------------GET DATA FOR THIS SUBEPOCH's VALIDATION---------------------------------
                if val_on_samples_during_train :
                    if worker_pool is None: # Sequential processing.
                        log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [VALIDATION] will be done by main thread.")
                        [channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal] = getSampledDataAndLabelsForSubepoch( *args_for_sampling_val )
                    elif sampling_job_submitted_val : #It was done in parallel with the training of the previous epoch, just grab results.
                        [channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal] = parallelJobToGetDataForNextValidation.get()
                        sampling_job_submitted_val = False
                    else : # Not previously submitted in case of first epoch or after a full-volumes validation.
                        assert subepoch == 0
                        log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].")
                        parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val)
                        [channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal] = parallelJobToGetDataForNextValidation.get()
                        sampling_job_submitted_val = False

                    
                    #------------------------SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING-----------------
                    if worker_pool is not None:
                        log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].")
                        parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train)
                        sampling_job_submitted_train = True
                    
                    #------------------------------------DO VALIDATION--------------------------------
                    log.print3("-V-V-V-V-V- Now Validating for this subepoch before commencing the training iterations... -V-V-V-V-V-")
                    start_time_val_subep = time.time()
                    # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested.
                    num_batches_val = len(channsOfSegmentsForSubepPerPathwayVal[0]) // batchsize_val_samples
                    trainOrValidateForSubepoch( log,
                                                sessionTf,
                                                "val",
                                                num_batches_val,
                                                batchsize_val_samples,
                                                cnn3d,
                                                subepoch,
                                                acc_monitor_for_ep_val,
                                                channsOfSegmentsForSubepPerPathwayVal,
                                                labelsForCentralOfSegmentsForSubepVal )
                    end_time_val_subep = time.time()
                    log.print3("TIMING: Validation on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_val_subep-start_time_val_subep)+" secs.")
                
                #-------------------------GET DATA FOR THIS SUBEPOCH's TRAINING---------------------------------
                if worker_pool is None: # Sequential processing.
                    log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [TRAINING] will be done by main thread.")
                    [channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain] = getSampledDataAndLabelsForSubepoch( *args_for_sampling_train )
                elif sampling_job_submitted_train: # Sampling job should have been done in parallel with previous train/val. Just grab results.
                    [channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain] = parallelJobToGetDataForNextTraining.get()
                    sampling_job_submitted_train = False
                else:  # Not previously submitted in case of first epoch or after a full-volumes validation.
                    assert subepoch == 0
                    log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].")
                    parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train)
                    [channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain] = parallelJobToGetDataForNextTraining.get()
                    sampling_job_submitted_train = False

                        
                #------------------------SUBMIT PARALLEL JOB TO GET VALIDATION/TRAINING DATA (if val is/not performed) FOR NEXT SUBEPOCH-----------------
                if worker_pool is not None and not (val_on_whole_volumes_after_ep and (subepoch == num_subepochs-1)):
                    if val_on_samples_during_train :
                        log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].")
                        parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val)
                        sampling_job_submitted_val = True
                    else :
                        log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].")
                        parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train)
                        sampling_job_submitted_train = True
                
                #-------------------------------START TRAINING IN BATCHES------------------------------
                log.print3("-T-T-T-T-T- Now Training for this subepoch... This may take a few minutes... -T-T-T-T-T-")
                start_time_train_subep = time.time()
                # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested.
                num_batches_train = len(channsOfSegmentsForSubepPerPathwayTrain[0]) // batchsize_train
                trainOrValidateForSubepoch( log,
                                            sessionTf,
                                            "train",
                                            num_batches_train,
                                            batchsize_train,
                                            cnn3d,
                                            subepoch,
                                            acc_monitor_for_ep_train,
                                            channsOfSegmentsForSubepPerPathwayTrain,
                                            labelsForCentralOfSegmentsForSubepTrain )
                end_time_train_subep = time.time()
                log.print3("TIMING: Training on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_train_subep-start_time_train_subep)+" secs.")
                
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            log.print3("~~~~~~ Epoch #" + str(epoch) + " finished. Reporting Accuracy over whole epoch. ~~~~~~~" )
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            
            if val_on_samples_during_train:
                acc_monitor_for_ep_val.reportMeanAccyracyOfEpoch()
            acc_monitor_for_ep_train.reportMeanAccyracyOfEpoch()
            
            mean_val_acc_of_ep = acc_monitor_for_ep_val.getMeanEmpiricalAccuracyOfEpoch() if val_on_samples_during_train else None
            trainer.run_updates_end_of_ep(log, sessionTf, mean_val_acc_of_ep) # Updates LR schedule if needed, and increases number of epochs trained.
            model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf)
            
            del acc_monitor_for_ep_train; del acc_monitor_for_ep_val;
            
            log.print3("SAVING: Epoch #"+str(epoch)+" finished. Saving CNN model.")
            filename_to_save_with = fileToSaveTrainedCnnModelTo + "." + datetimeNowAsStr()
            saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False )
            
            end_time_ep = time.time()
            log.print3("TIMING: The whole Epoch #"+str(epoch)+" lasted: {0:.1f}".format(end_time_ep-start_time_ep)+" secs.")
            log.print3("~~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~~~")
            
            
            if val_on_whole_volumes_after_ep:
                log.print3("***Starting validation with Full Inference / Segmentation on validation subjects for Epoch #"+str(epoch)+"...***")
                
                res_code = inferenceWholeVolumes(
                                sessionTf,
                                cnn3d,
                                log,
                                "val",
                                savePredictedSegmAndProbsDict,
                                listOfFilepathsToEachChannelOfEachPatientValidation,
                                providedGtForValidationBool,
                                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                                providedRoiMaskForValidationBool,
                                listOfFilepathsToRoiMaskOfEachPatientValidation,
                                namesForSavingSegmAndProbs = namesForSavingSegmAndProbs,
                                suffixForSegmAndProbsDict = suffixForSegmAndProbsDict,
                                # Hyper parameters
                                batchsize = batchsize_val_whole,
                                
                                #----Preprocessing------
                                padInputImagesBool=padInputImagesBool,
                                
                                #--------For FM visualisation---------
                                saveIndividualFmImagesForVisualisation=saveIndividualFmImagesForVisualisation,
                                saveMultidimensionalImageWithAllFms=saveMultidimensionalImageWithAllFms,
                                indicesOfFmsToVisualisePerPathwayTypeAndPerLayer=indicesOfFmsToVisualisePerPathwayTypeAndPerLayer,
                                namesForSavingFms=namesForSavingFms
                                )
        end_time_train = time.time()
        log.print3("TIMING: Training process lasted: {0:.1f}".format(end_time_train-start_time_train)+" secs.")
        
    except (Exception, KeyboardInterrupt) as e:
        log.print3("\n\n ERROR: Caught exception in do_training(): " + str(e) + "\n")
        log.print3( traceback.format_exc() )
        if worker_pool is not None:
            log.print3("Terminating worker pool.")
            worker_pool.terminate()
            worker_pool.join() # Will wait. A KeybInt will kill this (py3)
        return 1
    else:
        if worker_pool is not None:
            log.print3("Closing worker pool.")
            worker_pool.close()
            worker_pool.join()
    
    # Save the final trained model.
    filename_to_save_with = fileToSaveTrainedCnnModelTo + ".final." + datetimeNowAsStr()
    log.print3("Saving the final model at:" + str(filename_to_save_with))
    saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False )
            
    log.print3("The whole do_training() function has finished.")
    return 0
コード例 #4
0
ファイル: testSession.py プロジェクト: jacketlin/CAST
    def run_session(self, *args):
        (
            sess_device,
            model_params,
        ) = args

        graphTf = tf.Graph()

        with graphTf.as_default():
            with graphTf.device(
                    sess_device
            ):  # Throws an error if GPU is specified but not available.
                self._log.print3(
                    "=========== Making the CNN graph... ===============")
                cnn3d = Cnn3d()
                with tf.variable_scope("net"):
                    cnn3d.make_cnn_model(*model_params.get_args_for_arch(
                    ))  # Creates the network's graph (without optimizer).

            self._log.print3(
                "=========== Compiling the Testing Function ============")
            self._log.print3(
                "=======================================================\n")

            cnn3d.setup_ops_n_feeds_to_test(
                self._log,
                self._params.indices_fms_per_pathtype_per_layer_to_save)
            # Create the saver
            saver_all = tf.train.Saver()  # saver_net would suffice

        with tf.Session(graph=graphTf,
                        config=tf.ConfigProto(allow_soft_placement=True,
                                              log_device_placement=False,
                                              device_count={
                                                  'CPU': 999,
                                                  'GPU': 99
                                              })) as sessionTf:
            # ZYwith tf.Session( graph=graphTf, config=tf.ConfigProto(log_device_placement=False, device_count={'CPU':999, 'GPU':99}) ) as sessionTf:
            file_to_load_params_from = self._params.get_path_to_load_model_from(
            )
            if file_to_load_params_from is not None:  # Load params
                self._log.print3(
                    "=========== Loading parameters from specified saved model ==============="
                )
                chkpt_fname = tf.train.latest_checkpoint(
                    file_to_load_params_from) if os.path.isdir(
                        file_to_load_params_from) else file_to_load_params_from
                self._log.print3("Loading parameters from:" + str(chkpt_fname))
                try:
                    saver_all.restore(sessionTf, chkpt_fname)
                    self._log.print3("Parameters were loaded.")
                except Exception as e:
                    handle_exception_tf_restore(self._log, e)

            else:
                self._ask_user_if_test_with_random(
                )  # Asks user whether to continue with randomly initialized model. It exits if no is given.
                self._log.print3("")
                self._log.print3(
                    "=========== Initializing network variables  ==============="
                )
                tf.variables_initializer(var_list=tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES, scope="net")).run()
                self._log.print3("Model variables were initialized.")

            self._log.print3("")
            self._log.print3(
                "======================================================")
            self._log.print3(
                "=========== Testing with the CNN model ===============")
            self._log.print3(
                "======================================================\n")

            res_code = inferenceWholeVolumes(
                *([sessionTf, cnn3d] + self._params.get_args_for_testing()))

        self._log.print3("")
        self._log.print3(
            "======================================================")
        self._log.print3(
            "=========== Testing session finished =================")
        self._log.print3(
            "======================================================")