Esempio n. 1
0
    def run_session(self, *args):
        (sess_device, model_params, reset_trainer) = args

        graphTf = tf.Graph()

        with graphTf.as_default():
            with graphTf.device(
                    sess_device
            ):  # Explicit device assignment, throws an error if GPU is specified but not available.
                self._log.print3(
                    "=========== Making the CNN graph... ===============")
                cnn3d = Cnn3d()
                with tf.variable_scope("net"):
                    cnn3d.make_cnn_model(*model_params.get_args_for_arch())
                    # I have now created the CNN graph. But not yet the Optimizer's graph.

            # No explicit device assignment for the rest. Because trained has piecewise_constant that is only on cpu, and so is saver.
            with tf.variable_scope("trainer"):
                self._log.print3("=========== Building Trainer ===========\n")
                trainer = Trainer(*(self._params.get_args_for_trainer() +
                                    [cnn3d]))
                trainer.create_optimizer(*self._params.get_args_for_optimizer(
                ))  # Trainer and net connect here.

            # The below should not create any new tf.variables.
            self._log.print3(
                "=========== Compiling the Training Function ===========")
            self._log.print3(
                "=======================================================\n")
            cnn3d.setup_ops_n_feeds_to_train(
                self._log,
                trainer.get_total_cost(),
                trainer.get_param_updates_wrt_total_cost()  # list of ops
            )

            self._log.print3(
                "=========== Compiling the Validation Function =========")
            cnn3d.setup_ops_n_feeds_to_val(self._log)

            self._log.print3(
                "=========== Compiling the Testing Function ============")
            cnn3d.setup_ops_n_feeds_to_test(
                self._log,
                self._params.indices_fms_per_pathtype_per_layer_to_save
            )  # For validation with full segmentation

            # Create the savers
            saver_all = tf.train.Saver(
            )  # Will be used during training for saving everything.
            collection_vars_net = tf.get_collection(
                tf.GraphKeys.GLOBAL_VARIABLES, scope="net"
            )  # Alternative: tf.train.Saver([v for v in tf.all_variables() if v.name.startswith("net"])
            saver_net = tf.train.Saver(var_list=collection_vars_net
                                       )  # Used to load the net's parameters.
            collection_vars_trainer = tf.get_collection(
                tf.GraphKeys.GLOBAL_VARIABLES, scope="trainer")
            saver_trainer = tf.train.Saver(
                var_list=collection_vars_trainer
            )  # Used to load the trainer's parameters.

        # self._print_vars_in_collection(collection_vars_net, "net")
        # self._print_vars_in_collection(collection_vars_trainer, "trainer")

        with tf.Session(graph=graphTf,
                        config=tf.ConfigProto(allow_soft_placement=True,
                                              log_device_placement=False,
                                              device_count={
                                                  'CPU': 999,
                                                  'GPU': 99
                                              })) as sessionTf:
            # with tf.Session( graph=graphTf, config=tf.ConfigProto(log_device_placement=False, device_count={'CPU':999, 'GPU':99})
            # Load or initialize parameters
            file_to_load_params_from = self._params.get_path_to_load_model_from(
            )
            if file_to_load_params_from is not None:  # Load params
                self._log.print3(
                    "=========== Loading parameters from specified saved model ==============="
                )
                chkpt_fname = tf.train.latest_checkpoint(
                    file_to_load_params_from) if os.path.isdir(
                        file_to_load_params_from) else file_to_load_params_from
                self._log.print3("Loading checkpoint file:" + str(chkpt_fname))
                self._log.print3("Loading network parameters...")
                try:
                    saver_net.restore(sessionTf, chkpt_fname)
                    self._log.print3("Network parameters were loaded.")
                except Exception as e:
                    handle_exception_tf_restore(self._log, e)

                if not reset_trainer:
                    self._log.print3("Loading trainer parameters...")
                    saver_trainer.restore(sessionTf, chkpt_fname)
                    self._log.print3("Trainer parameters were loaded.")
                else:
                    self._log.print3(
                        "Reset of trainer parameters was requested. Re-initializing them..."
                    )
                    tf.variables_initializer(
                        var_list=collection_vars_trainer).run()
                    self._log.print3("Trainer parameters re-initialized.")
            else:
                self._log.print3(
                    "=========== Initializing network and trainer variables  ==============="
                )
                # tf.variables_initializer(var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) ).run() # Initializes all.
                # Initialize separate as below, so that in case I miss a variable, I will get an error and I will know.
                tf.variables_initializer(var_list=collection_vars_net).run()
                tf.variables_initializer(
                    var_list=collection_vars_trainer).run()
                self._log.print3("All variables were initialized.")

                filename_to_save_with = self._params.filepath_to_save_models + ".initial." + datetimeNowAsStr(
                )
                self._log.print3("Saving the initial model at:" +
                                 str(filename_to_save_with))
                saver_all.save(sessionTf,
                               filename_to_save_with + ".model.ckpt",
                               write_meta_graph=False)
                # tf.train.write_graph( graph_or_graph_def=sessionTf.graph.as_graph_def(), logdir="", name=filename_to_save_with+".graph.pb", as_text=False)

            self._log.print3("")
            self._log.print3(
                "=======================================================")
            self._log.print3(
                "============== Training the CNN model =================")
            self._log.print3(
                "=======================================================\n")

            res_code = do_training(
                *([sessionTf, saver_all, cnn3d, trainer] +
                  self._params.get_args_for_train_routine()))

        self._log.print3(
            "\n=======================================================")
        self._log.print3(
            "=========== Training session finished =================")
        self._log.print3(
            "=======================================================")
Esempio n. 2
0
def do_training(sessionTf,
                saver_all,
                cnn3d,
                trainer,
                log,
                
                fileToSaveTrainedCnnModelTo,
                
                val_on_samples_during_train,
                savePredictedSegmAndProbsDict,
                
                namesForSavingSegmAndProbs,
                suffixForSegmAndProbsDict,
                
                listOfFilepathsToEachChannelOfEachPatientTraining,
                listOfFilepathsToEachChannelOfEachPatientValidation,
                
                listOfFilepathsToGtLabelsOfEachPatientTraining,
                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                
                paths_to_wmaps_per_sampl_cat_per_subj_train,
                paths_to_wmaps_per_sampl_cat_per_subj_val,
                
                listOfFilepathsToRoiMaskOfEachPatientTraining,
                listOfFilepathsToRoiMaskOfEachPatientValidation,
                
                n_epochs, # Every epoch the CNN model is saved.
                num_subepochs, # per epoch. Every subepoch Accuracy is reported
                max_n_cases_per_subep_train,  # Max num of subjects loaded every subepoch for segments extraction.
                n_samples_per_subep_train,
                n_samples_per_subep_val,
                num_parallel_proc_sampling, # -1: seq. 0: thread for sampling. >0: multiprocess sampling
                
                #-------Sampling Type---------
                samplingTypeInstanceTraining, # Instance of the deepmedic/samplingType.SamplingType class for training and validation
                samplingTypeInstanceValidation,
                batchsize_train,
                batchsize_val_samples,
                batchsize_val_whole,
                
                #-------Preprocessing-----------
                pad_input_imgs,
                #-------Data Augmentation-------
                augm_img_prms,
                augm_sample_prms,
                
                # Validation
                val_on_whole_volumes,
                num_epochs_between_val_on_whole_volumes,
                
                #--------For FM visualisation---------
                saveIndividualFmImagesForVisualisation,
                saveMultidimensionalImageWithAllFms,
                indicesOfFmsToVisualisePerPathwayTypeAndPerLayer,
                namesForSavingFms,
                
                #-------- Others --------
                run_input_checks
                ):
    
    id_str = "[MAIN|PID:"+str(os.getpid())+"]"
    start_time_train = time.time()
    
    # I cannot pass cnn3d to the sampling function, because the pp module used to reload theano. 
    # This created problems in the GPU when cnmem is used. Not sure this is needed with Tensorflow. Probably.
    cnn3dWrapper = CnnWrapperForSampling(cnn3d)
    
    args_for_sampling_train = ( log,
                                "train",
                                num_parallel_proc_sampling,
                                run_input_checks,
                                cnn3dWrapper,
                                max_n_cases_per_subep_train,
                                n_samples_per_subep_train,
                                samplingTypeInstanceTraining,
                                listOfFilepathsToEachChannelOfEachPatientTraining,
                                listOfFilepathsToGtLabelsOfEachPatientTraining,
                                listOfFilepathsToRoiMaskOfEachPatientTraining,
                                paths_to_wmaps_per_sampl_cat_per_subj_train,
                                pad_input_imgs,
                                augm_img_prms,
                                augm_sample_prms )
    args_for_sampling_val = (   log,
                                "val",
                                num_parallel_proc_sampling,
                                run_input_checks,
                                cnn3dWrapper,
                                max_n_cases_per_subep_train,
                                n_samples_per_subep_val,
                                samplingTypeInstanceValidation,
                                listOfFilepathsToEachChannelOfEachPatientValidation,
                                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                                listOfFilepathsToRoiMaskOfEachPatientValidation,
                                paths_to_wmaps_per_sampl_cat_per_subj_val,
                                pad_input_imgs,
                                None, # no augmentation in val.
                                None ) # no augmentation in val.
    
    sampling_job_submitted_train = False
    sampling_job_submitted_val = False
    # For parallel extraction of samples for next train/val while processing previous iteration.
    worker_pool = None
    if num_parallel_proc_sampling > -1 : # Use multiprocessing.
        worker_pool = ThreadPool(processes=1) # Or multiprocessing.Pool(...), same API.
    
    try:
        model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf)
        while model_num_epochs_trained < n_epochs :
            epoch = model_num_epochs_trained
            
            acc_monitor_for_ep_train = AccuracyOfEpochMonitorSegmentation(log, 0, model_num_epochs_trained, cnn3d.num_classes, num_subepochs)
            acc_monitor_for_ep_val = None if not val_on_samples_during_train else \
                                            AccuracyOfEpochMonitorSegmentation(log, 1, model_num_epochs_trained, cnn3d.num_classes, num_subepochs )
                                            
            val_on_whole_volumes_after_ep = False
            if val_on_whole_volumes and (model_num_epochs_trained+1) % num_epochs_between_val_on_whole_volumes == 0:
                val_on_whole_volumes_after_ep = True
                
                
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            log.print3("~~~~~~~~~~~~~\t Starting new Epoch! Epoch #"+str(epoch)+"/"+str(n_epochs)+" \t~~~~~~~~~~~~~")
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            start_time_ep = time.time()
            
            for subepoch in range(num_subepochs):
                log.print3("***************************************************************************************")
                log.print3("*******\t\t Starting new Subepoch: #"+str(subepoch)+"/"+str(num_subepochs)+" \t\t********")
                log.print3("***************************************************************************************")
                
                #-------------------------GET DATA FOR THIS SUBEPOCH's VALIDATION---------------------------------
                if val_on_samples_during_train :
                    if worker_pool is None: # Sequential processing.
                        log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [VALIDATION] will be done by main thread.")
                        (channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal) = getSampledDataAndLabelsForSubepoch( *args_for_sampling_val )
                    elif sampling_job_submitted_val : #It was done in parallel with the training of the previous epoch, just grab results.
                        (channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal) = parallelJobToGetDataForNextValidation.get()
                        sampling_job_submitted_val = False
                    else : # Not previously submitted in case of first epoch or after a full-volumes validation.
                        assert subepoch == 0
                        log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].")
                        parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val)
                        (channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal) = parallelJobToGetDataForNextValidation.get()
                        sampling_job_submitted_val = False

                    
                    #------------------------SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING-----------------
                    if worker_pool is not None:
                        log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].")
                        parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train)
                        sampling_job_submitted_train = True
                    
                    #------------------------------------DO VALIDATION--------------------------------
                    log.print3("-V-V-V-V-V- Now Validating for this subepoch before commencing the training iterations... -V-V-V-V-V-")
                    start_time_val_subep = time.time()
                    # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested.
                    num_batches_val = len(channsOfSegmentsForSubepPerPathwayVal[0]) // batchsize_val_samples
                    trainOrValidateForSubepoch( log,
                                                sessionTf,
                                                "val",
                                                num_batches_val,
                                                batchsize_val_samples,
                                                cnn3d,
                                                subepoch,
                                                acc_monitor_for_ep_val,
                                                channsOfSegmentsForSubepPerPathwayVal,
                                                labelsForCentralOfSegmentsForSubepVal )
                    end_time_val_subep = time.time()
                    log.print3("TIMING: Validation on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_val_subep-start_time_val_subep)+" secs.")
                
                #-------------------------GET DATA FOR THIS SUBEPOCH's TRAINING---------------------------------
                if worker_pool is None: # Sequential processing.
                    log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [TRAINING] will be done by main thread.")
                    (channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain) = getSampledDataAndLabelsForSubepoch( *args_for_sampling_train )
                elif sampling_job_submitted_train: # Sampling job should have been done in parallel with previous train/val. Just grab results.
                    (channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain) = parallelJobToGetDataForNextTraining.get()
                    sampling_job_submitted_train = False
                else:  # Not previously submitted in case of first epoch or after a full-volumes validation.
                    assert subepoch == 0
                    log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].")
                    parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train)
                    (channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain) = parallelJobToGetDataForNextTraining.get()
                    sampling_job_submitted_train = False

                        
                #------------------------SUBMIT PARALLEL JOB TO GET VALIDATION/TRAINING DATA (if val is/not performed) FOR NEXT SUBEPOCH-----------------
                if worker_pool is not None and not (val_on_whole_volumes_after_ep and (subepoch == num_subepochs-1)):
                    if val_on_samples_during_train :
                        log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].")
                        parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val)
                        sampling_job_submitted_val = True
                    else :
                        log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].")
                        parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train)
                        sampling_job_submitted_train = True
                
                #-------------------------------START TRAINING IN BATCHES------------------------------
                log.print3("-T-T-T-T-T- Now Training for this subepoch... This may take a few minutes... -T-T-T-T-T-")
                start_time_train_subep = time.time()
                # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested.
                num_batches_train = len(channsOfSegmentsForSubepPerPathwayTrain[0]) // batchsize_train
                trainOrValidateForSubepoch( log,
                                            sessionTf,
                                            "train",
                                            num_batches_train,
                                            batchsize_train,
                                            cnn3d,
                                            subepoch,
                                            acc_monitor_for_ep_train,
                                            channsOfSegmentsForSubepPerPathwayTrain,
                                            labelsForCentralOfSegmentsForSubepTrain )
                end_time_train_subep = time.time()
                log.print3("TIMING: Training on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_train_subep-start_time_train_subep)+" secs.")
                
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            log.print3("~~~~~~ Epoch #" + str(epoch) + " finished. Reporting Accuracy over whole epoch. ~~~~~~~" )
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            
            if val_on_samples_during_train:
                acc_monitor_for_ep_val.reportMeanAccyracyOfEpoch()
            acc_monitor_for_ep_train.reportMeanAccyracyOfEpoch()
            
            mean_val_acc_of_ep = acc_monitor_for_ep_val.getMeanEmpiricalAccuracyOfEpoch() if val_on_samples_during_train else None
            trainer.run_updates_end_of_ep(log, sessionTf, mean_val_acc_of_ep) # Updates LR schedule if needed, and increases number of epochs trained.
            model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf)
            
            del acc_monitor_for_ep_train; del acc_monitor_for_ep_val;
            
            log.print3("SAVING: Epoch #"+str(epoch)+" finished. Saving CNN model.")
            filename_to_save_with = fileToSaveTrainedCnnModelTo + "." + datetimeNowAsStr()
            saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False )
            
            end_time_ep = time.time()
            log.print3("TIMING: The whole Epoch #"+str(epoch)+" lasted: {0:.1f}".format(end_time_ep-start_time_ep)+" secs.")
            log.print3("~~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~~~")
            
            
            if val_on_whole_volumes_after_ep:
                log.print3("***Starting validation with Full Inference / Segmentation on validation subjects for Epoch #"+str(epoch)+"...***")
                
                res_code = inferenceWholeVolumes(
                                sessionTf,
                                cnn3d,
                                log,
                                "val",
                                savePredictedSegmAndProbsDict,
                                listOfFilepathsToEachChannelOfEachPatientValidation,
                                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                                listOfFilepathsToRoiMaskOfEachPatientValidation,
                                namesForSavingSegmAndProbs = namesForSavingSegmAndProbs,
                                suffixForSegmAndProbsDict = suffixForSegmAndProbsDict,
                                # Hyper parameters
                                batchsize = batchsize_val_whole,
                                
                                #----Preprocessing------
                                pad_input_imgs = pad_input_imgs,
                                
                                #--------For FM visualisation---------
                                saveIndividualFmImagesForVisualisation = saveIndividualFmImagesForVisualisation,
                                saveMultidimensionalImageWithAllFms = saveMultidimensionalImageWithAllFms,
                                indicesOfFmsToVisualisePerPathwayTypeAndPerLayer = indicesOfFmsToVisualisePerPathwayTypeAndPerLayer,
                                namesForSavingFms = namesForSavingFms
                                )
        end_time_train = time.time()
        log.print3("TIMING: Training process lasted: {0:.1f}".format(end_time_train-start_time_train)+" secs.")
        
    except (Exception, KeyboardInterrupt) as e:
        log.print3("\n\n ERROR: Caught exception in do_training(): " + str(e) + "\n")
        log.print3( traceback.format_exc() )
        if worker_pool is not None:
            log.print3("Terminating worker pool.")
            worker_pool.terminate()
            worker_pool.join() # Will wait. A KeybInt will kill this (py3)
        return 1
    else:
        if worker_pool is not None:
            log.print3("Closing worker pool.")
            worker_pool.close()
            worker_pool.join()
    
    # Save the final trained model.
    filename_to_save_with = fileToSaveTrainedCnnModelTo + ".final." + datetimeNowAsStr()
    log.print3("Saving the final model at:" + str(filename_to_save_with))
    saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False )
            
    log.print3("The whole do_training() function has finished.")
    return 0
Esempio n. 3
0
def do_training(
        sessionTf,
        saver_all,
        cnn3d,
        trainer,
        log,
        fileToSaveTrainedCnnModelTo,
        performValidationOnSamplesDuringTrainingProcessBool,
        savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation,
        listOfNamesToGiveToPredictionsValidationIfSavingWhenEvalDice,
        listOfFilepathsToEachChannelOfEachPatientTraining,
        listOfFilepathsToEachChannelOfEachPatientValidation,
        listOfFilepathsToGtLabelsOfEachPatientTraining,
        providedGtForValidationBool,
        listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
        providedWeightMapsToSampleForEachCategoryTraining,
        forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining,
        providedWeightMapsToSampleForEachCategoryValidation,
        forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation,
        providedRoiMaskForTrainingBool,
        listOfFilepathsToRoiMaskOfEachPatientTraining,  # Also needed for normalization-augmentation
        providedRoiMaskForValidationBool,
        listOfFilepathsToRoiMaskOfEachPatientValidation,
        n_epochs,  # Every epoch the CNN model is saved.
        number_of_subepochs,  # per epoch. Every subepoch Accuracy is reported
        maxNumSubjectsLoadedPerSubepoch,  # Max num of cases loaded every subepoch for segments extraction. The more, the longer loading.
        imagePartsLoadedInGpuPerSubepoch,
        imagePartsLoadedInGpuPerSubepochValidation,

        #-------Sampling Type---------
        samplingTypeInstanceTraining,  # Instance of the deepmedic/samplingType.SamplingType class for training and validation
        samplingTypeInstanceValidation,

        #-------Preprocessing-----------
        padInputImagesBool,
        #-------Data Augmentation-------
        doIntAugm_shiftMuStd_multiMuStd,
        reflectImageWithHalfProbDuringTraining,
        useSameSubChannelsAsSingleScale,
        listOfFilepathsToEachSubsampledChannelOfEachPatientTraining,  # deprecated, not supported
        listOfFilepathsToEachSubsampledChannelOfEachPatientValidation,  # deprecated, not supported

        # Validation
    performFullInferenceOnValidationImagesEveryFewEpochsBool,  #Even if not providedGtForValidationBool, inference will be performed if this == True, to save the results, eg for visual.
        everyThatManyEpochsComputeDiceOnTheFullValidationImages,  # Should not be == 0, except if performFullInferenceOnValidationImagesEveryFewEpochsBool == False

        #--------For FM visualisation---------
    saveIndividualFmImagesForVisualisation,
        saveMultidimensionalImageWithAllFms,
        indicesOfFmsToVisualisePerPathwayTypeAndPerLayer,
        listOfNamesToGiveToFmVisualisationsIfSaving,

        #-------- Others --------
        run_input_checks):

    start_training_time = time.time()

    # Used because I cannot pass cnn3d to the sampling function.
    #This is because the parallel process used to load theano again. And created problems in the GPU when cnmem is used. Not sure this is needed with Tensorflow. Probably.
    cnn3dWrapper = CnnWrapperForSampling(cnn3d)

    #---------To run PARALLEL the extraction of parts for the next subepoch---
    ppservers = ()  # tuple of all parallel python servers to connect with
    job_server = pp.Server(
        ncpus=1, ppservers=ppservers
    )  # Creates jobserver with automatically detected number of workers

    tupleWithParametersForTraining = (
        log, "train", run_input_checks, cnn3dWrapper,
        maxNumSubjectsLoadedPerSubepoch, imagePartsLoadedInGpuPerSubepoch,
        samplingTypeInstanceTraining,
        listOfFilepathsToEachChannelOfEachPatientTraining,
        listOfFilepathsToGtLabelsOfEachPatientTraining,
        providedRoiMaskForTrainingBool,
        listOfFilepathsToRoiMaskOfEachPatientTraining,
        providedWeightMapsToSampleForEachCategoryTraining,
        forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining,
        useSameSubChannelsAsSingleScale,
        listOfFilepathsToEachSubsampledChannelOfEachPatientTraining,
        padInputImagesBool, doIntAugm_shiftMuStd_multiMuStd,
        reflectImageWithHalfProbDuringTraining)
    tupleWithParametersForValidation = (
        log,
        "val",
        run_input_checks,
        cnn3dWrapper,
        maxNumSubjectsLoadedPerSubepoch,
        imagePartsLoadedInGpuPerSubepochValidation,
        samplingTypeInstanceValidation,
        listOfFilepathsToEachChannelOfEachPatientValidation,
        listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
        providedRoiMaskForValidationBool,
        listOfFilepathsToRoiMaskOfEachPatientValidation,
        providedWeightMapsToSampleForEachCategoryValidation,
        forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation,
        useSameSubChannelsAsSingleScale,
        listOfFilepathsToEachSubsampledChannelOfEachPatientValidation,
        padInputImagesBool,
        [0, -1, -1,
         -1],  #don't perform intensity-augmentation during validation.
        [0, 0, 0]  #don't perform reflection-augmentation during validation.
    )
    tupleWithLocalFunctionsThatWillBeCalledByTheMainJob = ()
    tupleWithModulesToImportWhichAreUsedByTheJobFunctions = (
        "from __future__ import absolute_import, print_function, division",
        "time", "numpy as np",
        "from deepmedic.dataManagement.sampling import *")
    boolItIsTheVeryFirstSubepochOfThisProcess = True  #to know so that in the very first I sequencially load the data for it.
    #------End for parallel------

    model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(
        session=sessionTf)
    while model_num_epochs_trained < n_epochs:
        epoch = model_num_epochs_trained

        trainingAccuracyMonitorForEpoch = AccuracyOfEpochMonitorSegmentation(
            log, 0, model_num_epochs_trained, cnn3d.num_classes,
            number_of_subepochs)
        validationAccuracyMonitorForEpoch = None if not performValidationOnSamplesDuringTrainingProcessBool else \
                                        AccuracyOfEpochMonitorSegmentation(log, 1, model_num_epochs_trained, cnn3d.num_classes, number_of_subepochs )

        log.print3(
            "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
        )
        log.print3("~~~~~~~~~~~~~~~~~~~~Starting new Epoch! Epoch #" +
                   str(epoch) + "/" + str(n_epochs) +
                   " ~~~~~~~~~~~~~~~~~~~~~~~~~")
        log.print3(
            "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
        )
        start_epoch_time = time.time()

        for subepoch in range(
                number_of_subepochs
        ):  #per subepoch I randomly load some images in the gpu. Random order.
            log.print3(
                "**************************************************************************************************"
            )
            log.print3("************* Starting new Subepoch: #" +
                       str(subepoch) + "/" + str(number_of_subepochs) +
                       " *************")
            log.print3(
                "**************************************************************************************************"
            )

            #-------------------------GET DATA FOR THIS SUBEPOCH's VALIDATION---------------------------------

            if performValidationOnSamplesDuringTrainingProcessBool:
                if boolItIsTheVeryFirstSubepochOfThisProcess:
                    [
                        channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal
                    ] = getSampledDataAndLabelsForSubepoch(
                        log,
                        "val",
                        run_input_checks,
                        cnn3dWrapper,
                        maxNumSubjectsLoadedPerSubepoch,
                        imagePartsLoadedInGpuPerSubepochValidation,
                        samplingTypeInstanceValidation,
                        listOfFilepathsToEachChannelOfEachPatientValidation,
                        listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                        providedRoiMaskForValidationBool,
                        listOfFilepathsToRoiMaskOfEachPatientValidation,
                        providedWeightMapsToSampleForEachCategoryValidation,
                        forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation,
                        useSameSubChannelsAsSingleScale,
                        listOfFilepathsToEachSubsampledChannelOfEachPatientValidation,
                        padInputImagesBool,
                        doIntAugm_shiftMuStd_multiMuStd=[False, [], []],
                        reflectImageWithHalfProbDuringTraining=[0, 0, 0])
                    boolItIsTheVeryFirstSubepochOfThisProcess = False
                else:  #It was done in parallel with the training of the previous epoch, just grab the results...
                    [
                        channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal
                    ] = parallelJobToGetDataForNextValidation(
                    )  #fromParallelProcessing that had started from last loop when it was submitted.

                # Below is computed with number of extracted samples, in case I dont manage to extract as many as I wanted initially.
                numberOfBatchesValidation = len(
                    channsOfSegmentsForSubepPerPathwayVal[0]
                ) // cnn3d.batchSize["val"]

                #------------------------SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING-----------------
                #submit the parallel job
                log.print3(
                    "PARALLEL: Before Validation in subepoch #" +
                    str(subepoch) +
                    ", the parallel job for extracting Segments for the next Training is submitted."
                )
                parallelJobToGetDataForNextTraining = job_server.submit(
                    getSampledDataAndLabelsForSubepoch,  #local function to call and execute in parallel.
                    tupleWithParametersForTraining,  #tuple with the arguments required
                    tupleWithLocalFunctionsThatWillBeCalledByTheMainJob,  #tuple of local functions that I need to call
                    tupleWithModulesToImportWhichAreUsedByTheJobFunctions
                )  #tuple of the external modules that I need, of which I am calling functions (not the mods of the ext-functions).

                #------------------------------------DO VALIDATION--------------------------------
                log.print3(
                    "-V-V-V-V-V- Now Validating for this subepoch before commencing the training iterations... -V-V-V-V-V-"
                )
                start_validationForSubepoch_time = time.time()

                doTrainOrValidationOnBatchesAndReturnMeanAccuraciesOfSubepoch(
                    log,
                    sessionTf,
                    "val",
                    numberOfBatchesValidation,  # Computed by the number of extracted samples. So, adapts.
                    cnn3d,
                    subepoch,
                    validationAccuracyMonitorForEpoch,
                    channsOfSegmentsForSubepPerPathwayVal,
                    labelsForCentralOfSegmentsForSubepVal)

                end_validationForSubepoch_time = time.time()
                log.print3(
                    "TIMING: Validating on the batches of this subepoch #" +
                    str(subepoch) + " took time: " +
                    str(end_validationForSubepoch_time -
                        start_validationForSubepoch_time) + "(s)")

            #-------------------END OF THE VALIDATION-DURING-TRAINING-LOOP-------------------------

            #-------------------------GET DATA FOR THIS SUBEPOCH's TRAINING---------------------------------
            if (not performValidationOnSamplesDuringTrainingProcessBool
                ) and boolItIsTheVeryFirstSubepochOfThisProcess:
                [
                    channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain
                ] = getSampledDataAndLabelsForSubepoch(
                    log, "train", run_input_checks, cnn3dWrapper,
                    maxNumSubjectsLoadedPerSubepoch,
                    imagePartsLoadedInGpuPerSubepoch,
                    samplingTypeInstanceTraining,
                    listOfFilepathsToEachChannelOfEachPatientTraining,
                    listOfFilepathsToGtLabelsOfEachPatientTraining,
                    providedRoiMaskForTrainingBool,
                    listOfFilepathsToRoiMaskOfEachPatientTraining,
                    providedWeightMapsToSampleForEachCategoryTraining,
                    forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining,
                    useSameSubChannelsAsSingleScale,
                    listOfFilepathsToEachSubsampledChannelOfEachPatientTraining,
                    padInputImagesBool, doIntAugm_shiftMuStd_multiMuStd,
                    reflectImageWithHalfProbDuringTraining)
                boolItIsTheVeryFirstSubepochOfThisProcess = False
            else:
                #It was done in parallel with the validation (or with previous training iteration, in case I am not performing validation).
                [
                    channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain
                ] = parallelJobToGetDataForNextTraining(
                )  #fromParallelProcessing that had started from last loop when it was submitted.

            numberOfBatchesTraining = len(
                channsOfSegmentsForSubepPerPathwayTrain[0]
            ) // cnn3d.batchSize[
                "train"]  #Computed with number of extracted samples, in case I dont manage to extract as many as I wanted initially.

            #------------------------SUBMIT PARALLEL JOB TO GET VALIDATION/TRAINING DATA (if val is/not performed) FOR NEXT SUBEPOCH-----------------
            if performValidationOnSamplesDuringTrainingProcessBool:
                #submit the parallel job
                log.print3(
                    "PARALLEL: Before Training in subepoch #" + str(subepoch) +
                    ", submitting the parallel job for extracting Segments for the next Validation."
                )
                parallelJobToGetDataForNextValidation = job_server.submit(
                    getSampledDataAndLabelsForSubepoch,  #local function to call and execute in parallel.
                    tupleWithParametersForValidation,  #tuple with the arguments required
                    tupleWithLocalFunctionsThatWillBeCalledByTheMainJob,  #tuple of local functions that I need to call
                    tupleWithModulesToImportWhichAreUsedByTheJobFunctions
                )  #tuple of the external modules that I need, of which I am calling functions (not the mods of the ext-functions).
            else:  #extract in parallel the samples for the next subepoch's training.
                log.print3(
                    "PARALLEL: Before Training in subepoch #" + str(subepoch) +
                    ", submitting the parallel job for extracting Segments for the next Training."
                )
                parallelJobToGetDataForNextTraining = job_server.submit(
                    getSampledDataAndLabelsForSubepoch,  #local function to call and execute in parallel.
                    tupleWithParametersForTraining,  #tuple with the arguments required
                    tupleWithLocalFunctionsThatWillBeCalledByTheMainJob,  #tuple of local functions that I need to call
                    tupleWithModulesToImportWhichAreUsedByTheJobFunctions
                )  #tuple of the external modules that I need, of which I am calling

            #-------------------------------START TRAINING IN BATCHES------------------------------
            log.print3(
                "-T-T-T-T-T- Now Training for this subepoch... This may take a few minutes... -T-T-T-T-T-"
            )
            start_trainingForSubepoch_time = time.time()

            doTrainOrValidationOnBatchesAndReturnMeanAccuraciesOfSubepoch(
                log, sessionTf, "train", numberOfBatchesTraining, cnn3d,
                subepoch, trainingAccuracyMonitorForEpoch,
                channsOfSegmentsForSubepPerPathwayTrain,
                labelsForCentralOfSegmentsForSubepTrain)

            end_trainingForSubepoch_time = time.time()
            log.print3("TIMING: Training on the batches of this subepoch #" +
                       str(subepoch) + " took time: " +
                       str(end_trainingForSubepoch_time -
                           start_trainingForSubepoch_time) + "(s)")

        log.print3(
            "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
        )
        log.print3(
            "~~~~~~~~~~~~~~~~~~ Epoch #" + str(epoch) +
            " finished. Reporting Accuracy over whole epoch. ~~~~~~~~~~~~~~~~~~"
        )
        log.print3(
            "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
        )

        if performValidationOnSamplesDuringTrainingProcessBool:
            validationAccuracyMonitorForEpoch.reportMeanAccyracyOfEpoch()
        trainingAccuracyMonitorForEpoch.reportMeanAccyracyOfEpoch()

        mean_val_acc_of_ep = validationAccuracyMonitorForEpoch.getMeanEmpiricalAccuracyOfEpoch(
        ) if performValidationOnSamplesDuringTrainingProcessBool else None
        trainer.run_updates_end_of_ep(
            log, sessionTf, mean_val_acc_of_ep
        )  # Updates LR schedule if needed, and increases number of epochs trained.
        model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(
            session=sessionTf)

        del trainingAccuracyMonitorForEpoch
        del validationAccuracyMonitorForEpoch
        #================== Everything for epoch has finished. =======================

        log.print3("SAVING: Epoch #" + str(epoch) +
                   " finished. Saving CNN model.")
        filename_to_save_with = fileToSaveTrainedCnnModelTo + "." + datetimeNowAsStr(
        )
        saver_all.save(sessionTf,
                       filename_to_save_with + ".model.ckpt",
                       write_meta_graph=False)

        end_epoch_time = time.time()
        log.print3("TIMING: The whole Epoch #" + str(epoch) + " took time: " +
                   str(end_epoch_time - start_epoch_time) + "(s)")
        log.print3(
            "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
        )

        if performFullInferenceOnValidationImagesEveryFewEpochsBool and (
                model_num_epochs_trained != 0
        ) and (model_num_epochs_trained %
               everyThatManyEpochsComputeDiceOnTheFullValidationImages == 0):
            log.print3(
                "***Starting validation with Full Inference / Segmentation on validation subjects for Epoch #"
                + str(epoch) + "...***")

            performInferenceOnWholeVolumes(
                sessionTf,
                cnn3d,
                log,
                "val",
                savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation,
                listOfFilepathsToEachChannelOfEachPatientValidation,
                providedGtForValidationBool,
                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                providedRoiMaskForValidationBool,
                listOfFilepathsToRoiMaskOfEachPatientValidation,
                listOfNamesToGiveToPredictionsIfSavingResults="Placeholder" if
                not savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation
                else
                listOfNamesToGiveToPredictionsValidationIfSavingWhenEvalDice,

                #----Preprocessing------
                padInputImagesBool=padInputImagesBool,

                #for the cnn extension
                useSameSubChannelsAsSingleScale=useSameSubChannelsAsSingleScale,
                listOfFilepathsToEachSubsampledChannelOfEachPatient=
                listOfFilepathsToEachSubsampledChannelOfEachPatientValidation,

                #--------For FM visualisation---------
                saveIndividualFmImagesForVisualisation=
                saveIndividualFmImagesForVisualisation,
                saveMultidimensionalImageWithAllFms=
                saveMultidimensionalImageWithAllFms,
                indicesOfFmsToVisualisePerPathwayTypeAndPerLayer=
                indicesOfFmsToVisualisePerPathwayTypeAndPerLayer,
                listOfNamesToGiveToFmVisualisationsIfSaving=
                listOfNamesToGiveToFmVisualisationsIfSaving)

    end_training_time = time.time()
    log.print3("TIMING: Training process took time: " +
               str(end_training_time - start_training_time) + "(s)")
    log.print3("The whole do_training() function has finished.")
Esempio n. 4
0
def do_training(sessionTf,
                saver_all,
                cnn3d,
                trainer,
                log,
                
                fileToSaveTrainedCnnModelTo,
                
                val_on_samples_during_train,
                savePredictedSegmAndProbsDict,
                
                namesForSavingSegmAndProbs,
                suffixForSegmAndProbsDict,
                
                listOfFilepathsToEachChannelOfEachPatientTraining,
                listOfFilepathsToEachChannelOfEachPatientValidation,
                
                listOfFilepathsToGtLabelsOfEachPatientTraining,
                providedGtForValidationBool,
                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                
                providedWeightMapsToSampleForEachCategoryTraining,
                paths_to_wmaps_per_sampl_cat_per_subj_train,
                providedWeightMapsToSampleForEachCategoryValidation,
                paths_to_wmaps_per_sampl_cat_per_subj_val,
                
                providedRoiMaskForTrainingBool,
                listOfFilepathsToRoiMaskOfEachPatientTraining, # Also needed for normalization-augmentation
                providedRoiMaskForValidationBool,
                listOfFilepathsToRoiMaskOfEachPatientValidation,
                
                n_epochs, # Every epoch the CNN model is saved.
                num_subepochs, # per epoch. Every subepoch Accuracy is reported
                max_n_cases_per_subep_train,  # Max num of subjects loaded every subepoch for segments extraction.
                n_samples_per_subep_train,
                n_samples_per_subep_val,
                num_parallel_proc_sampling, # -1: seq. 0: thread for sampling. >0: multiprocess sampling
                
                #-------Sampling Type---------
                samplingTypeInstanceTraining, # Instance of the deepmedic/samplingType.SamplingType class for training and validation
                samplingTypeInstanceValidation,
                batchsize_train,
                batchsize_val_samples,
                batchsize_val_whole,
                
                #-------Preprocessing-----------
                padInputImagesBool,
                #-------Data Augmentation-------
                augm_params,
                
                # Validation
                val_on_whole_volumes,
                num_epochs_between_val_on_whole_volumes,
                
                #--------For FM visualisation---------
                saveIndividualFmImagesForVisualisation,
                saveMultidimensionalImageWithAllFms,
                indicesOfFmsToVisualisePerPathwayTypeAndPerLayer,
                namesForSavingFms,
                
                #-------- Others --------
                run_input_checks
                ):
    
    id_str = "[MAIN|PID:"+str(os.getpid())+"]"
    start_time_train = time.time()
    
    # I cannot pass cnn3d to the sampling function, because the pp module used to reload theano. 
    # This created problems in the GPU when cnmem is used. Not sure this is needed with Tensorflow. Probably.
    cnn3dWrapper = CnnWrapperForSampling(cnn3d)
    
    args_for_sampling_train = ( log,
                                "train",
                                num_parallel_proc_sampling,
                                run_input_checks,
                                cnn3dWrapper,
                                max_n_cases_per_subep_train,
                                n_samples_per_subep_train,
                                samplingTypeInstanceTraining,
                                listOfFilepathsToEachChannelOfEachPatientTraining,
                                listOfFilepathsToGtLabelsOfEachPatientTraining,
                                providedRoiMaskForTrainingBool,
                                listOfFilepathsToRoiMaskOfEachPatientTraining,
                                providedWeightMapsToSampleForEachCategoryTraining,
                                paths_to_wmaps_per_sampl_cat_per_subj_train,
                                padInputImagesBool,
                                augm_params )
    args_for_sampling_val = (   log,
                                "val",
                                num_parallel_proc_sampling,
                                run_input_checks,
                                cnn3dWrapper,
                                max_n_cases_per_subep_train,
                                n_samples_per_subep_val,
                                samplingTypeInstanceValidation,
                                listOfFilepathsToEachChannelOfEachPatientValidation,
                                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                                providedRoiMaskForValidationBool,
                                listOfFilepathsToRoiMaskOfEachPatientValidation,
                                providedWeightMapsToSampleForEachCategoryValidation,
                                paths_to_wmaps_per_sampl_cat_per_subj_val,
                                padInputImagesBool,
                                None ) # no augmentation in validation.
    
    sampling_job_submitted_train = False
    sampling_job_submitted_val = False
    # For parallel extraction of samples for next train/val while processing previous iteration.
    worker_pool = None
    if num_parallel_proc_sampling > -1 : # Use multiprocessing.
        worker_pool = ThreadPool(processes=1) # Or multiprocessing.Pool(...), same API.
    
    try:
        model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf)
        while model_num_epochs_trained < n_epochs :
            epoch = model_num_epochs_trained
            
            acc_monitor_for_ep_train = AccuracyOfEpochMonitorSegmentation(log, 0, model_num_epochs_trained, cnn3d.num_classes, num_subepochs)
            acc_monitor_for_ep_val = None if not val_on_samples_during_train else \
                                            AccuracyOfEpochMonitorSegmentation(log, 1, model_num_epochs_trained, cnn3d.num_classes, num_subepochs )
                                            
            val_on_whole_volumes_after_ep = False
            if val_on_whole_volumes and (model_num_epochs_trained+1) % num_epochs_between_val_on_whole_volumes == 0:
                val_on_whole_volumes_after_ep = True
                
                
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            log.print3("~~~~~~~~~~~~~\t Starting new Epoch! Epoch #"+str(epoch)+"/"+str(n_epochs)+" \t~~~~~~~~~~~~~")
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            start_time_ep = time.time()
            
            for subepoch in range(num_subepochs):
                log.print3("***************************************************************************************")
                log.print3("*******\t\t Starting new Subepoch: #"+str(subepoch)+"/"+str(num_subepochs)+" \t\t********")
                log.print3("***************************************************************************************")
                
                #-------------------------GET DATA FOR THIS SUBEPOCH's VALIDATION---------------------------------
                if val_on_samples_during_train :
                    if worker_pool is None: # Sequential processing.
                        log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [VALIDATION] will be done by main thread.")
                        [channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal] = getSampledDataAndLabelsForSubepoch( *args_for_sampling_val )
                    elif sampling_job_submitted_val : #It was done in parallel with the training of the previous epoch, just grab results.
                        [channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal] = parallelJobToGetDataForNextValidation.get()
                        sampling_job_submitted_val = False
                    else : # Not previously submitted in case of first epoch or after a full-volumes validation.
                        assert subepoch == 0
                        log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].")
                        parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val)
                        [channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal] = parallelJobToGetDataForNextValidation.get()
                        sampling_job_submitted_val = False

                    
                    #------------------------SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING-----------------
                    if worker_pool is not None:
                        log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].")
                        parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train)
                        sampling_job_submitted_train = True
                    
                    #------------------------------------DO VALIDATION--------------------------------
                    log.print3("-V-V-V-V-V- Now Validating for this subepoch before commencing the training iterations... -V-V-V-V-V-")
                    start_time_val_subep = time.time()
                    # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested.
                    num_batches_val = len(channsOfSegmentsForSubepPerPathwayVal[0]) // batchsize_val_samples
                    trainOrValidateForSubepoch( log,
                                                sessionTf,
                                                "val",
                                                num_batches_val,
                                                batchsize_val_samples,
                                                cnn3d,
                                                subepoch,
                                                acc_monitor_for_ep_val,
                                                channsOfSegmentsForSubepPerPathwayVal,
                                                labelsForCentralOfSegmentsForSubepVal )
                    end_time_val_subep = time.time()
                    log.print3("TIMING: Validation on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_val_subep-start_time_val_subep)+" secs.")
                
                #-------------------------GET DATA FOR THIS SUBEPOCH's TRAINING---------------------------------
                if worker_pool is None: # Sequential processing.
                    log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [TRAINING] will be done by main thread.")
                    [channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain] = getSampledDataAndLabelsForSubepoch( *args_for_sampling_train )
                elif sampling_job_submitted_train: # Sampling job should have been done in parallel with previous train/val. Just grab results.
                    [channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain] = parallelJobToGetDataForNextTraining.get()
                    sampling_job_submitted_train = False
                else:  # Not previously submitted in case of first epoch or after a full-volumes validation.
                    assert subepoch == 0
                    log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].")
                    parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train)
                    [channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain] = parallelJobToGetDataForNextTraining.get()
                    sampling_job_submitted_train = False

                        
                #------------------------SUBMIT PARALLEL JOB TO GET VALIDATION/TRAINING DATA (if val is/not performed) FOR NEXT SUBEPOCH-----------------
                if worker_pool is not None and not (val_on_whole_volumes_after_ep and (subepoch == num_subepochs-1)):
                    if val_on_samples_during_train :
                        log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].")
                        parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val)
                        sampling_job_submitted_val = True
                    else :
                        log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].")
                        parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train)
                        sampling_job_submitted_train = True
                
                #-------------------------------START TRAINING IN BATCHES------------------------------
                log.print3("-T-T-T-T-T- Now Training for this subepoch... This may take a few minutes... -T-T-T-T-T-")
                start_time_train_subep = time.time()
                # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested.
                num_batches_train = len(channsOfSegmentsForSubepPerPathwayTrain[0]) // batchsize_train
                trainOrValidateForSubepoch( log,
                                            sessionTf,
                                            "train",
                                            num_batches_train,
                                            batchsize_train,
                                            cnn3d,
                                            subepoch,
                                            acc_monitor_for_ep_train,
                                            channsOfSegmentsForSubepPerPathwayTrain,
                                            labelsForCentralOfSegmentsForSubepTrain )
                end_time_train_subep = time.time()
                log.print3("TIMING: Training on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_train_subep-start_time_train_subep)+" secs.")
                
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            log.print3("~~~~~~ Epoch #" + str(epoch) + " finished. Reporting Accuracy over whole epoch. ~~~~~~~" )
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            
            if val_on_samples_during_train:
                acc_monitor_for_ep_val.reportMeanAccyracyOfEpoch()
            acc_monitor_for_ep_train.reportMeanAccyracyOfEpoch()
            
            mean_val_acc_of_ep = acc_monitor_for_ep_val.getMeanEmpiricalAccuracyOfEpoch() if val_on_samples_during_train else None
            trainer.run_updates_end_of_ep(log, sessionTf, mean_val_acc_of_ep) # Updates LR schedule if needed, and increases number of epochs trained.
            model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf)
            
            del acc_monitor_for_ep_train; del acc_monitor_for_ep_val;
            
            log.print3("SAVING: Epoch #"+str(epoch)+" finished. Saving CNN model.")
            filename_to_save_with = fileToSaveTrainedCnnModelTo + "." + datetimeNowAsStr()
            saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False )
            
            end_time_ep = time.time()
            log.print3("TIMING: The whole Epoch #"+str(epoch)+" lasted: {0:.1f}".format(end_time_ep-start_time_ep)+" secs.")
            log.print3("~~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~~~")
            
            
            if val_on_whole_volumes_after_ep:
                log.print3("***Starting validation with Full Inference / Segmentation on validation subjects for Epoch #"+str(epoch)+"...***")
                
                res_code = inferenceWholeVolumes(
                                sessionTf,
                                cnn3d,
                                log,
                                "val",
                                savePredictedSegmAndProbsDict,
                                listOfFilepathsToEachChannelOfEachPatientValidation,
                                providedGtForValidationBool,
                                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                                providedRoiMaskForValidationBool,
                                listOfFilepathsToRoiMaskOfEachPatientValidation,
                                namesForSavingSegmAndProbs = namesForSavingSegmAndProbs,
                                suffixForSegmAndProbsDict = suffixForSegmAndProbsDict,
                                # Hyper parameters
                                batchsize = batchsize_val_whole,
                                
                                #----Preprocessing------
                                padInputImagesBool=padInputImagesBool,
                                
                                #--------For FM visualisation---------
                                saveIndividualFmImagesForVisualisation=saveIndividualFmImagesForVisualisation,
                                saveMultidimensionalImageWithAllFms=saveMultidimensionalImageWithAllFms,
                                indicesOfFmsToVisualisePerPathwayTypeAndPerLayer=indicesOfFmsToVisualisePerPathwayTypeAndPerLayer,
                                namesForSavingFms=namesForSavingFms
                                )
        end_time_train = time.time()
        log.print3("TIMING: Training process lasted: {0:.1f}".format(end_time_train-start_time_train)+" secs.")
        
    except (Exception, KeyboardInterrupt) as e:
        log.print3("\n\n ERROR: Caught exception in do_training(): " + str(e) + "\n")
        log.print3( traceback.format_exc() )
        if worker_pool is not None:
            log.print3("Terminating worker pool.")
            worker_pool.terminate()
            worker_pool.join() # Will wait. A KeybInt will kill this (py3)
        return 1
    else:
        if worker_pool is not None:
            log.print3("Closing worker pool.")
            worker_pool.close()
            worker_pool.join()
    
    # Save the final trained model.
    filename_to_save_with = fileToSaveTrainedCnnModelTo + ".final." + datetimeNowAsStr()
    log.print3("Saving the final model at:" + str(filename_to_save_with))
    saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False )
            
    log.print3("The whole do_training() function has finished.")
    return 0
Esempio n. 5
0
 def run_session(self, *args):
     (sess_device,
      model_params,
      reset_trainer) = args
     
     graphTf = tf.Graph()
     
     with graphTf.as_default():
         with graphTf.device(sess_device): # Explicit device assignment, throws an error if GPU is specified but not available.
             self._log.print3("=========== Making the CNN graph... ===============")
             cnn3d = Cnn3d()
             with tf.variable_scope("net"):
                 cnn3d.make_cnn_model( *model_params.get_args_for_arch() )
                 # I have now created the CNN graph. But not yet the Optimizer's graph.
         
         # No explicit device assignment for the rest. Because trained has piecewise_constant that is only on cpu, and so is saver.        
         with tf.variable_scope("trainer"):
             self._log.print3("=========== Building Trainer ===========\n")
             trainer = Trainer( *( self._params.get_args_for_trainer() + [cnn3d] ) )
             trainer.create_optimizer( *self._params.get_args_for_optimizer() ) # Trainer and net connect here.
             
         # The below should not create any new tf.variables.
         self._log.print3("=========== Compiling the Training Function ===========")
         self._log.print3("=======================================================\n")
         cnn3d.setup_ops_n_feeds_to_train( self._log,
                                           trainer.get_total_cost(),
                                           trainer.get_param_updates_wrt_total_cost() # list of ops
                                         )
         
         self._log.print3("=========== Compiling the Validation Function =========")
         cnn3d.setup_ops_n_feeds_to_val( self._log )
         
         self._log.print3("=========== Compiling the Testing Function ============")
         cnn3d.setup_ops_n_feeds_to_test( self._log,
                                          self._params.indices_fms_per_pathtype_per_layer_to_save ) # For validation with full segmentation
         
         # Create the savers
         saver_all = tf.train.Saver() # Will be used during training for saving everything.
         collection_vars_net = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="net") # Alternative: tf.train.Saver([v for v in tf.all_variables() if v.name.startswith("net"])
         saver_net = tf.train.Saver( var_list = collection_vars_net ) # Used to load the net's parameters.
         collection_vars_trainer = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="trainer")
         saver_trainer = tf.train.Saver( var_list = collection_vars_trainer ) # Used to load the trainer's parameters.
         
     # self._print_vars_in_collection(collection_vars_net, "net")
     # self._print_vars_in_collection(collection_vars_trainer, "trainer")
     
     with tf.Session( graph=graphTf, config=tf.ConfigProto(log_device_placement=False, device_count={'CPU':999, 'GPU':99}) ) as sessionTf:
         # Load or initialize parameters
         file_to_load_params_from = self._params.get_path_to_load_model_from()
         if file_to_load_params_from is not None: # Load params
             self._log.print3("=========== Loading parameters from specified saved model ===============")
             chkpt_fname = tf.train.latest_checkpoint( file_to_load_params_from ) if os.path.isdir( file_to_load_params_from ) else file_to_load_params_from
             self._log.print3("Loading checkpoint file:" + str(chkpt_fname))
             self._log.print3("Loading network parameters...")
             try:
                 saver_net.restore(sessionTf, chkpt_fname)
                 self._log.print3("Network parameters were loaded.")
             except Exception as e: handle_exception_tf_restore(self._log, e)
             
             if not reset_trainer:
                 self._log.print3("Loading trainer parameters...")
                 saver_trainer.restore(sessionTf, chkpt_fname)
                 self._log.print3("Trainer parameters were loaded.")
             else:
                 self._log.print3("Reset of trainer parameters was requested. Re-initializing them...")
                 tf.variables_initializer(var_list = collection_vars_trainer).run()
                 self._log.print3("Trainer parameters re-initialized.")
         else :
             self._log.print3("=========== Initializing network and trainer variables  ===============")
             # tf.variables_initializer(var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) ).run() # Initializes all.
             # Initialize separate as below, so that in case I miss a variable, I will get an error and I will know.
             tf.variables_initializer(var_list = collection_vars_net).run()
             tf.variables_initializer(var_list = collection_vars_trainer).run()
             self._log.print3("All variables were initialized.")
             
             filename_to_save_with = self._params.filepath_to_save_models + ".initial." + datetimeNowAsStr()
             self._log.print3("Saving the initial model at:" + str(filename_to_save_with))
             saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False )
             # tf.train.write_graph( graph_or_graph_def=sessionTf.graph.as_graph_def(), logdir="", name=filename_to_save_with+".graph.pb", as_text=False)
          
         self._log.print3("")
         self._log.print3("=======================================================")
         self._log.print3("============== Training the CNN model =================")
         self._log.print3("=======================================================\n")
         
         res_code = do_training( *( [sessionTf, saver_all, cnn3d, trainer] + self._params.get_args_for_train_routine() ) )
         
         
     self._log.print3("\n=======================================================")
     self._log.print3("=========== Training session finished =================")
     self._log.print3("=======================================================")