Ejemplo n.º 1
0
def do_training(sessionTf,
                saver_all,
                cnn3d,
                trainer,
                tensorboard_loggers,
                
                log,
                fileToSaveTrainedCnnModelTo,

                val_on_samples,
                savePredictedSegmAndProbsDict,

                namesForSavingSegmAndProbs,
                suffixForSegmAndProbsDict,

                paths_per_chan_per_subj_train,
                paths_per_chan_per_subj_val,

                paths_to_lbls_per_subj_train,
                paths_to_lbls_per_subj_val,

                paths_to_wmaps_per_sampl_cat_per_subj_train,
                paths_to_wmaps_per_sampl_cat_per_subj_val,

                paths_to_masks_per_subj_train,
                paths_to_masks_per_subj_val,

                n_epochs,  # Every epoch the CNN model is saved.
                n_subepochs,  # per epoch. Every subepoch Accuracy is reported
                max_n_cases_per_subep_train,  # Max num of subjects loaded every subep for sampling
                n_samples_per_subep_train,
                n_samples_per_subep_val,
                num_parallel_proc_sampling,  # -1: seq. 0: thread for sampling. >0: multiprocess sampling

                # -------Sampling Type---------
                sampling_type_inst_tr,
                # Instance of the deepmedic/samplingType.SamplingType class for training and validation
                sampling_type_inst_val,
                batchsize_train,
                batchsize_val_samples,
                batchsize_val_whole,

                # -------Data Augmentation-------
                augm_img_prms,
                augm_sample_prms,

                # Validation
                val_on_whole_volumes,
                n_epochs_between_val_on_whole_vols,

                # --------For FM visualisation---------
                save_fms_flag,
                idxs_fms_to_save,
                namesForSavingFms,

                # --- Data Compatibility Checks ---
                run_input_checks,

                # -------- Pre-processing ------
                pad_input,
                norm_prms
                ):
    id_str = "[MAIN|PID:" + str(os.getpid()) + "]"
    start_time_train = time.time()

    # I cannot pass cnn3d to the sampling function, because the pp module used to reload theano. 
    # This created problems in the GPU when cnmem is used. Not sure this is needed with Tensorflow. Probably.
    cnn3dWrapper = CnnWrapperForSampling(cnn3d)

    args_for_sampling_tr = (log,
                            "train",
                            num_parallel_proc_sampling,
                            run_input_checks,
                            cnn3dWrapper,
                            max_n_cases_per_subep_train,
                            n_samples_per_subep_train,
                            sampling_type_inst_tr,
                            paths_per_chan_per_subj_train,
                            paths_to_lbls_per_subj_train,
                            paths_to_masks_per_subj_train,
                            paths_to_wmaps_per_sampl_cat_per_subj_train,
                            pad_input,
                            norm_prms,
                            augm_img_prms,
                            augm_sample_prms
                            )
    args_for_sampling_val = (log,
                             "val",
                             num_parallel_proc_sampling,
                             run_input_checks,
                             cnn3dWrapper,
                             max_n_cases_per_subep_train,
                             n_samples_per_subep_val,
                             sampling_type_inst_val,
                             paths_per_chan_per_subj_val,
                             paths_to_lbls_per_subj_val,
                             paths_to_masks_per_subj_val,
                             paths_to_wmaps_per_sampl_cat_per_subj_val,
                             pad_input,
                             norm_prms,
                             None,  # no augmentation in val.
                             None  # no augmentation in val.
                             )

    sampling_job_submitted_train = False
    sampling_job_submitted_val = False
    # For parallel extraction of samples for next train/val while processing previous iteration.
    mp_pool = None
    if num_parallel_proc_sampling > -1:  # Use multiprocessing.
        mp_pool = ThreadPool(processes=1)  # Or multiprocessing.Pool(...), same API.

    try:
        n_eps_trained_model = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf)
        while n_eps_trained_model < n_epochs:
            epoch = n_eps_trained_model

            tb_log_tr = tensorboard_loggers['train'] if tensorboard_loggers is not None else None
            acc_monitor_ep_tr = AccuracyMonitorForEpSegm(log, 0,
                                                         n_eps_trained_model,
                                                         cnn3d.num_classes,
                                                         n_subepochs,
                                                         tb_log_tr)

            tb_log_val = tensorboard_loggers['val'] if tensorboard_loggers is not None else None
            acc_monitor_ep_val = None
            if val_on_samples or val_on_whole_volumes:
                acc_monitor_ep_val = AccuracyMonitorForEpSegm(log, 1,
                                                              n_eps_trained_model,
                                                              cnn3d.num_classes,
                                                              n_subepochs,
                                                              tb_log_val)
            
            val_on_whole_vols_after_this_ep = False
            if val_on_whole_volumes and (n_eps_trained_model + 1) % n_epochs_between_val_on_whole_vols == 0:
                val_on_whole_vols_after_this_ep = True
                
            log.print3("")
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            log.print3("~~\t\t\t Starting new Epoch! Epoch #" + str(epoch) + "/" + str(n_epochs) + "  \t\t\t~~")
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            start_time_ep = time.time()

            for subep in range(n_subepochs):
                log.print3("")
                log.print3("***********************************************************************************")
                log.print3("*\t\t\t Starting new Subepoch: #" + str(subep) + "/" + str(n_subepochs) + " \t\t\t*")
                log.print3("***********************************************************************************")

                # -------------------- GET DATA FOR THIS SUBEPOCH's VALIDATION -----------------------
                if val_on_samples:
                    if mp_pool is None:  # Sequential processing.
                        log.print3(id_str + " NO MULTIPROC: Sampling for subepoch #" + str(subep) +\
                                   " [VALIDATION] will be done by main thread.")
                        (channs_samples_per_path_val,
                         lbls_samples_per_path_val) = get_samples_for_subepoch(*args_for_sampling_val)
                    elif sampling_job_submitted_val:  # done parallel with training of previous epoch.
                        (channs_samples_per_path_val,
                         lbls_samples_per_path_val) = sampling_job_val.get()
                        sampling_job_submitted_val = False
                    else:  # Not previously submitted in case of first epoch or after a full-volumes validation.
                        assert subep == 0
                        log.print3(id_str + " MULTIPROC: Before Validation in subepoch #" + str(subep) +\
                                   ", submitting sampling job for next [VALIDATION].")
                        sampling_job_val = mp_pool.apply_async(get_samples_for_subepoch, args_for_sampling_val)
                        (channs_samples_per_path_val,
                         lbls_samples_per_path_val) = sampling_job_val.get()
                        sampling_job_submitted_val = False

                    # ----------- SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING -----------------
                    if mp_pool is not None:
                        log.print3(id_str + " MULTIPROC: Before Validation in subepoch #" + str(subep) +\
                                   ", submitting sampling job for next [TRAINING].")
                        sampling_job_tr = mp_pool.apply_async(get_samples_for_subepoch, args_for_sampling_tr)
                        sampling_job_submitted_train = True

                    # ------------------------------------DO VALIDATION--------------------------------
                    log.print3("V-V-V-V- Validating for subepoch before starting training iterations -V-V-V-V")
                    start_time_val_subep = time.time()
                    # Calc num of batches from extracted samples, in case not extracted as much as requested.
                    n_batches_val = len(channs_samples_per_path_val[0]) // batchsize_val_samples
                    process_in_batches(log,
                                       sessionTf,
                                       "val",
                                       n_batches_val,
                                       batchsize_val_samples,
                                       cnn3d,
                                       acc_monitor_ep_val,
                                       channs_samples_per_path_val,
                                       lbls_samples_per_path_val)
                    log.print3("TIMING: Validation on batches of subepoch #" + str(subep) +\
                               " lasted: {0:.1f}".format(time.time() - start_time_val_subep) + " secs.")

                # ----------------------- GET DATA FOR THIS SUBEPOCH's TRAINING ------------------------------
                if mp_pool is None:  # Sequential processing.
                    log.print3(id_str + " NO MULTIPROC: Sampling for subepoch #" + str(subep) +\
                               " [TRAINING] will be done by main thread.")
                    (channs_samples_per_path_tr,
                     lbls_samples_per_path_tr) = get_samples_for_subepoch(*args_for_sampling_tr)
                elif sampling_job_submitted_train:  # done parallel with train/val of previous epoch.
                    (channs_samples_per_path_tr,
                     lbls_samples_per_path_tr) = sampling_job_tr.get()
                    sampling_job_submitted_train = False
                else:  # Not previously submitted in case of first epoch or after a full-volumes validation.
                    assert subep == 0
                    log.print3(id_str + " MULTIPROC: Before Training in subepoch #" + str(subep) +\
                               ", submitting sampling job for next [TRAINING].")
                    sampling_job_tr = mp_pool.apply_async(get_samples_for_subepoch, args_for_sampling_tr)
                    (channs_samples_per_path_tr,
                     lbls_samples_per_path_tr) = sampling_job_tr.get()
                    sampling_job_submitted_train = False

                # ----- SUBMIT PARALLEL JOB TO GET VAL / TRAIN (if no val) DATA FOR NEXT SUBEPOCH -----
                if mp_pool is not None and not (val_on_whole_vols_after_this_ep and (subep == n_subepochs - 1)):
                    if val_on_samples:
                        log.print3(id_str + " MULTIPROC: Before Training in subepoch #" + str(subep) +\
                                   ", submitting sampling job for next [VALIDATION].")
                        sampling_job_val = mp_pool.apply_async(get_samples_for_subepoch, args_for_sampling_val)
                        sampling_job_submitted_val = True
                    else:
                        log.print3(id_str + " MULTIPROC: Before Training in subepoch #" + str(subep) +\
                                   ", submitting sampling job for next [TRAINING].")
                        sampling_job_tr = mp_pool.apply_async(get_samples_for_subepoch, args_for_sampling_tr)
                        sampling_job_submitted_train = True

                # ------------------------------ START TRAINING IN BATCHES -----------------------------
                log.print3("-T-T-T-T- Training for this subepoch... May take a few minutes... -T-T-T-T-")
                start_time_train_subep = time.time()
                # Calc num of batches from extracted samples, in case not extracted as much as requested.
                n_batches_train = len(channs_samples_per_path_tr[0]) // batchsize_train
                process_in_batches(log,
                                   sessionTf,
                                   "train",
                                   n_batches_train,
                                   batchsize_train,
                                   cnn3d,
                                   acc_monitor_ep_tr,
                                   channs_samples_per_path_tr,
                                   lbls_samples_per_path_tr)
                log.print3("TIMING: Training on batches of this subepoch #" + str(subep) +\
                           " lasted: {0:.1f}".format(time.time() - start_time_train_subep) + " secs.")

            log.print3("")
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            log.print3("~~~~~~ Epoch #" + str(epoch) + " finished. Reporting Accuracy over whole epoch. ~~~~~~~")
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")

            if val_on_samples:
                acc_monitor_ep_val.report_metrics_samples_ep()
            acc_monitor_ep_tr.report_metrics_samples_ep()

            mean_val_acc_of_ep = acc_monitor_ep_val.get_avg_accuracy_ep() if val_on_samples else None
            # Updates LR schedule if needed, and increases number of epochs trained.
            trainer.run_updates_end_of_ep(log, sessionTf, mean_val_acc_of_ep)
            n_eps_trained_model = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf)

            log.print3("SAVING: Epoch #" + str(epoch) + " finished. Saving CNN model.")
            filename_to_save_with = fileToSaveTrainedCnnModelTo + "." + datetime_now_str()
            saver_all.save(sessionTf, filename_to_save_with + ".model.ckpt", write_meta_graph=False)

            log.print3("TIMING: Whole Epoch #" + str(epoch) +\
                       " lasted: {0:.1f}".format(time.time() - start_time_ep) + " secs.")
            log.print3("~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~")

            if val_on_whole_vols_after_this_ep:
                log.print3("***Start validation by segmenting whole subjects for Epoch #" + str(epoch) + "***")

                mean_metrics_val_whole_vols = inference_on_whole_volumes(sessionTf,
                                                                         cnn3d,
                                                                         log,
                                                                         "val",
                                                                         savePredictedSegmAndProbsDict,
                                                                         paths_per_chan_per_subj_val,
                                                                         paths_to_lbls_per_subj_val,
                                                                         paths_to_masks_per_subj_val,
                                                                         namesForSavingSegmAndProbs,
                                                                         suffixForSegmAndProbsDict,
                                                                         # Hyper parameters
                                                                         batchsize_val_whole,
                                                                         # Data compatibility checks
                                                                         run_input_checks,
                                                                         # Pre-Processing
                                                                         pad_input,
                                                                         norm_prms,
                                                                         # Saving feature maps
                                                                         save_fms_flag,
                                                                         idxs_fms_to_save,
                                                                         namesForSavingFms)
                
                acc_monitor_ep_val.report_metrics_whole_vols(mean_metrics_val_whole_vols)

            del acc_monitor_ep_tr
            del acc_monitor_ep_val

        log.print3("TIMING: Training process lasted: {0:.1f}".format(time.time() - start_time_train) + " secs.")

    except (Exception, KeyboardInterrupt) as e:
        log.print3("\n\n ERROR: Caught exception in do_training(): " + str(e) + "\n")
        log.print3(traceback.format_exc())
        if mp_pool is not None:
            log.print3("Terminating worker pool.")
            mp_pool.terminate()
            mp_pool.join()  # Will wait. A KeybInt will kill this (py3)
        return 1
    else:
        if mp_pool is not None:
            log.print3("Closing worker pool.")
            mp_pool.close()
            mp_pool.join()

    # Save the final trained model.
    filename_to_save_with = fileToSaveTrainedCnnModelTo + ".final." + datetime_now_str()
    log.print3("Saving the final model at:" + str(filename_to_save_with))
    saver_all.save(sessionTf, filename_to_save_with + ".model.ckpt", write_meta_graph=False)

    log.print3("The whole do_training() function has finished.")
    return 0
Ejemplo n.º 2
0
def do_training(sessionTf,
                saver_all,
                cnn3d,
                trainer,
                log,
                
                fileToSaveTrainedCnnModelTo,
                
                val_on_samples_during_train,
                savePredictedSegmAndProbsDict,
                
                namesForSavingSegmAndProbs,
                suffixForSegmAndProbsDict,
                
                listOfFilepathsToEachChannelOfEachPatientTraining,
                listOfFilepathsToEachChannelOfEachPatientValidation,
                
                listOfFilepathsToGtLabelsOfEachPatientTraining,
                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                
                paths_to_wmaps_per_sampl_cat_per_subj_train,
                paths_to_wmaps_per_sampl_cat_per_subj_val,
                
                listOfFilepathsToRoiMaskOfEachPatientTraining,
                listOfFilepathsToRoiMaskOfEachPatientValidation,
                
                n_epochs, # Every epoch the CNN model is saved.
                num_subepochs, # per epoch. Every subepoch Accuracy is reported
                max_n_cases_per_subep_train,  # Max num of subjects loaded every subepoch for segments extraction.
                n_samples_per_subep_train,
                n_samples_per_subep_val,
                num_parallel_proc_sampling, # -1: seq. 0: thread for sampling. >0: multiprocess sampling
                
                #-------Sampling Type---------
                samplingTypeInstanceTraining, # Instance of the deepmedic/samplingType.SamplingType class for training and validation
                samplingTypeInstanceValidation,
                batchsize_train,
                batchsize_val_samples,
                batchsize_val_whole,
                
                #-------Preprocessing-----------
                pad_input_imgs,
                #-------Data Augmentation-------
                augm_img_prms,
                augm_sample_prms,
                
                # Validation
                val_on_whole_volumes,
                num_epochs_between_val_on_whole_volumes,
                
                #--------For FM visualisation---------
                saveIndividualFmImagesForVisualisation,
                saveMultidimensionalImageWithAllFms,
                indicesOfFmsToVisualisePerPathwayTypeAndPerLayer,
                namesForSavingFms,
                
                #-------- Others --------
                run_input_checks
                ):
    
    id_str = "[MAIN|PID:"+str(os.getpid())+"]"
    start_time_train = time.time()
    
    # I cannot pass cnn3d to the sampling function, because the pp module used to reload theano. 
    # This created problems in the GPU when cnmem is used. Not sure this is needed with Tensorflow. Probably.
    cnn3dWrapper = CnnWrapperForSampling(cnn3d)
    
    args_for_sampling_train = ( log,
                                "train",
                                num_parallel_proc_sampling,
                                run_input_checks,
                                cnn3dWrapper,
                                max_n_cases_per_subep_train,
                                n_samples_per_subep_train,
                                samplingTypeInstanceTraining,
                                listOfFilepathsToEachChannelOfEachPatientTraining,
                                listOfFilepathsToGtLabelsOfEachPatientTraining,
                                listOfFilepathsToRoiMaskOfEachPatientTraining,
                                paths_to_wmaps_per_sampl_cat_per_subj_train,
                                pad_input_imgs,
                                augm_img_prms,
                                augm_sample_prms )
    args_for_sampling_val = (   log,
                                "val",
                                num_parallel_proc_sampling,
                                run_input_checks,
                                cnn3dWrapper,
                                max_n_cases_per_subep_train,
                                n_samples_per_subep_val,
                                samplingTypeInstanceValidation,
                                listOfFilepathsToEachChannelOfEachPatientValidation,
                                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                                listOfFilepathsToRoiMaskOfEachPatientValidation,
                                paths_to_wmaps_per_sampl_cat_per_subj_val,
                                pad_input_imgs,
                                None, # no augmentation in val.
                                None ) # no augmentation in val.
    
    sampling_job_submitted_train = False
    sampling_job_submitted_val = False
    # For parallel extraction of samples for next train/val while processing previous iteration.
    worker_pool = None
    if num_parallel_proc_sampling > -1 : # Use multiprocessing.
        worker_pool = ThreadPool(processes=1) # Or multiprocessing.Pool(...), same API.
    
    try:
        model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf)
        while model_num_epochs_trained < n_epochs :
            epoch = model_num_epochs_trained
            
            acc_monitor_for_ep_train = AccuracyOfEpochMonitorSegmentation(log, 0, model_num_epochs_trained, cnn3d.num_classes, num_subepochs)
            acc_monitor_for_ep_val = None if not val_on_samples_during_train else \
                                            AccuracyOfEpochMonitorSegmentation(log, 1, model_num_epochs_trained, cnn3d.num_classes, num_subepochs )
                                            
            val_on_whole_volumes_after_ep = False
            if val_on_whole_volumes and (model_num_epochs_trained+1) % num_epochs_between_val_on_whole_volumes == 0:
                val_on_whole_volumes_after_ep = True
                
                
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            log.print3("~~~~~~~~~~~~~\t Starting new Epoch! Epoch #"+str(epoch)+"/"+str(n_epochs)+" \t~~~~~~~~~~~~~")
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            start_time_ep = time.time()
            
            for subepoch in range(num_subepochs):
                log.print3("***************************************************************************************")
                log.print3("*******\t\t Starting new Subepoch: #"+str(subepoch)+"/"+str(num_subepochs)+" \t\t********")
                log.print3("***************************************************************************************")
                
                #-------------------------GET DATA FOR THIS SUBEPOCH's VALIDATION---------------------------------
                if val_on_samples_during_train :
                    if worker_pool is None: # Sequential processing.
                        log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [VALIDATION] will be done by main thread.")
                        (channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal) = getSampledDataAndLabelsForSubepoch( *args_for_sampling_val )
                    elif sampling_job_submitted_val : #It was done in parallel with the training of the previous epoch, just grab results.
                        (channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal) = parallelJobToGetDataForNextValidation.get()
                        sampling_job_submitted_val = False
                    else : # Not previously submitted in case of first epoch or after a full-volumes validation.
                        assert subepoch == 0
                        log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].")
                        parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val)
                        (channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal) = parallelJobToGetDataForNextValidation.get()
                        sampling_job_submitted_val = False

                    
                    #------------------------SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING-----------------
                    if worker_pool is not None:
                        log.print3(id_str+" MULTIPROC: Before Validation in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].")
                        parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train)
                        sampling_job_submitted_train = True
                    
                    #------------------------------------DO VALIDATION--------------------------------
                    log.print3("-V-V-V-V-V- Now Validating for this subepoch before commencing the training iterations... -V-V-V-V-V-")
                    start_time_val_subep = time.time()
                    # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested.
                    num_batches_val = len(channsOfSegmentsForSubepPerPathwayVal[0]) // batchsize_val_samples
                    trainOrValidateForSubepoch( log,
                                                sessionTf,
                                                "val",
                                                num_batches_val,
                                                batchsize_val_samples,
                                                cnn3d,
                                                subepoch,
                                                acc_monitor_for_ep_val,
                                                channsOfSegmentsForSubepPerPathwayVal,
                                                labelsForCentralOfSegmentsForSubepVal )
                    end_time_val_subep = time.time()
                    log.print3("TIMING: Validation on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_val_subep-start_time_val_subep)+" secs.")
                
                #-------------------------GET DATA FOR THIS SUBEPOCH's TRAINING---------------------------------
                if worker_pool is None: # Sequential processing.
                    log.print3(id_str+" NO MULTIPROC: Sampling for subepoch #"+str(subepoch)+" [TRAINING] will be done by main thread.")
                    (channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain) = getSampledDataAndLabelsForSubepoch( *args_for_sampling_train )
                elif sampling_job_submitted_train: # Sampling job should have been done in parallel with previous train/val. Just grab results.
                    (channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain) = parallelJobToGetDataForNextTraining.get()
                    sampling_job_submitted_train = False
                else:  # Not previously submitted in case of first epoch or after a full-volumes validation.
                    assert subepoch == 0
                    log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].")
                    parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train)
                    (channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain) = parallelJobToGetDataForNextTraining.get()
                    sampling_job_submitted_train = False

                        
                #------------------------SUBMIT PARALLEL JOB TO GET VALIDATION/TRAINING DATA (if val is/not performed) FOR NEXT SUBEPOCH-----------------
                if worker_pool is not None and not (val_on_whole_volumes_after_ep and (subepoch == num_subepochs-1)):
                    if val_on_samples_during_train :
                        log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [VALIDATION].")
                        parallelJobToGetDataForNextValidation = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_val)
                        sampling_job_submitted_val = True
                    else :
                        log.print3(id_str+" MULTIPROC: Before Training in subepoch #"+str(subepoch)+", submitting sampling job for next [TRAINING].")
                        parallelJobToGetDataForNextTraining = worker_pool.apply_async(getSampledDataAndLabelsForSubepoch, args_for_sampling_train)
                        sampling_job_submitted_train = True
                
                #-------------------------------START TRAINING IN BATCHES------------------------------
                log.print3("-T-T-T-T-T- Now Training for this subepoch... This may take a few minutes... -T-T-T-T-T-")
                start_time_train_subep = time.time()
                # Compute num of batches from num of extracted samples, in case we did not extract as many as initially requested.
                num_batches_train = len(channsOfSegmentsForSubepPerPathwayTrain[0]) // batchsize_train
                trainOrValidateForSubepoch( log,
                                            sessionTf,
                                            "train",
                                            num_batches_train,
                                            batchsize_train,
                                            cnn3d,
                                            subepoch,
                                            acc_monitor_for_ep_train,
                                            channsOfSegmentsForSubepPerPathwayTrain,
                                            labelsForCentralOfSegmentsForSubepTrain )
                end_time_train_subep = time.time()
                log.print3("TIMING: Training on batches of this subepoch #"+str(subepoch)+" lasted: {0:.1f}".format(end_time_train_subep-start_time_train_subep)+" secs.")
                
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            log.print3("~~~~~~ Epoch #" + str(epoch) + " finished. Reporting Accuracy over whole epoch. ~~~~~~~" )
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            
            if val_on_samples_during_train:
                acc_monitor_for_ep_val.reportMeanAccyracyOfEpoch()
            acc_monitor_for_ep_train.reportMeanAccyracyOfEpoch()
            
            mean_val_acc_of_ep = acc_monitor_for_ep_val.getMeanEmpiricalAccuracyOfEpoch() if val_on_samples_during_train else None
            trainer.run_updates_end_of_ep(log, sessionTf, mean_val_acc_of_ep) # Updates LR schedule if needed, and increases number of epochs trained.
            model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf)
            
            del acc_monitor_for_ep_train; del acc_monitor_for_ep_val;
            
            log.print3("SAVING: Epoch #"+str(epoch)+" finished. Saving CNN model.")
            filename_to_save_with = fileToSaveTrainedCnnModelTo + "." + datetimeNowAsStr()
            saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False )
            
            end_time_ep = time.time()
            log.print3("TIMING: The whole Epoch #"+str(epoch)+" lasted: {0:.1f}".format(end_time_ep-start_time_ep)+" secs.")
            log.print3("~~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~~~")
            
            
            if val_on_whole_volumes_after_ep:
                log.print3("***Starting validation with Full Inference / Segmentation on validation subjects for Epoch #"+str(epoch)+"...***")
                
                res_code = inferenceWholeVolumes(
                                sessionTf,
                                cnn3d,
                                log,
                                "val",
                                savePredictedSegmAndProbsDict,
                                listOfFilepathsToEachChannelOfEachPatientValidation,
                                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                                listOfFilepathsToRoiMaskOfEachPatientValidation,
                                namesForSavingSegmAndProbs = namesForSavingSegmAndProbs,
                                suffixForSegmAndProbsDict = suffixForSegmAndProbsDict,
                                # Hyper parameters
                                batchsize = batchsize_val_whole,
                                
                                #----Preprocessing------
                                pad_input_imgs = pad_input_imgs,
                                
                                #--------For FM visualisation---------
                                saveIndividualFmImagesForVisualisation = saveIndividualFmImagesForVisualisation,
                                saveMultidimensionalImageWithAllFms = saveMultidimensionalImageWithAllFms,
                                indicesOfFmsToVisualisePerPathwayTypeAndPerLayer = indicesOfFmsToVisualisePerPathwayTypeAndPerLayer,
                                namesForSavingFms = namesForSavingFms
                                )
        end_time_train = time.time()
        log.print3("TIMING: Training process lasted: {0:.1f}".format(end_time_train-start_time_train)+" secs.")
        
    except (Exception, KeyboardInterrupt) as e:
        log.print3("\n\n ERROR: Caught exception in do_training(): " + str(e) + "\n")
        log.print3( traceback.format_exc() )
        if worker_pool is not None:
            log.print3("Terminating worker pool.")
            worker_pool.terminate()
            worker_pool.join() # Will wait. A KeybInt will kill this (py3)
        return 1
    else:
        if worker_pool is not None:
            log.print3("Closing worker pool.")
            worker_pool.close()
            worker_pool.join()
    
    # Save the final trained model.
    filename_to_save_with = fileToSaveTrainedCnnModelTo + ".final." + datetimeNowAsStr()
    log.print3("Saving the final model at:" + str(filename_to_save_with))
    saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False )
            
    log.print3("The whole do_training() function has finished.")
    return 0
Ejemplo n.º 3
0
def do_training(myLogger,
                fileToSaveTrainedCnnModelTo,
                cnn3dInstance,
                
                performValidationOnSamplesDuringTrainingProcessBool, #REQUIRED FOR AUTO SCHEDULE.
                savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation,
                
                listOfNamesToGiveToPredictionsValidationIfSavingWhenEvalDice,
                
                listOfFilepathsToEachChannelOfEachPatientTraining,
                listOfFilepathsToEachChannelOfEachPatientValidation,
                
                listOfFilepathsToGtLabelsOfEachPatientTraining,
                providedGtForValidationBool,
                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                
                providedWeightMapsToSampleForEachCategoryTraining,
                forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining,
                providedWeightMapsToSampleForEachCategoryValidation,
                forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation,
                
                providedRoiMaskForTrainingBool,
                listOfFilepathsToRoiMaskOfEachPatientTraining, # Also needed for normalization-augmentation
                providedRoiMaskForValidationBool,
                listOfFilepathsToRoiMaskOfEachPatientValidation,
                
                borrowFlag,
                n_epochs, # Every epoch the CNN model is saved.
                number_of_subepochs, # per epoch. Every subepoch Accuracy is reported
                maxNumSubjectsLoadedPerSubepoch,  # Max num of cases loaded every subepoch for segments extraction. The more, the longer loading.
                imagePartsLoadedInGpuPerSubepoch,
                imagePartsLoadedInGpuPerSubepochValidation,
                
                #-------Sampling Type---------
                samplingTypeInstanceTraining, # Instance of the deepmedic/samplingType.SamplingType class for training and validation
                samplingTypeInstanceValidation,
                
                #-------Preprocessing-----------
                padInputImagesBool,
                smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage,
                #-------Data Augmentation-------
                normAugmNone0OnImages1OrSegments2AlreadyNormalized1SubtrUpToPropOfStdAndDivideWithUpToPerc,
                reflectImageWithHalfProbDuringTraining,
                
                useSameSubChannelsAsSingleScale,
                
                listOfFilepathsToEachSubsampledChannelOfEachPatientTraining, # deprecated, not supported
                listOfFilepathsToEachSubsampledChannelOfEachPatientValidation, # deprecated, not supported
                
                #Learning Rate Schedule:
                lowerLrByStable0orAuto1orPredefined2orExponential3Schedule,
                minIncreaseInValidationAccuracyConsideredForLrSchedule,
                numEpochsToWaitBeforeLowerLR,
                divideLrBy,
                lowerLrAtTheEndOfTheseEpochsPredefinedScheduleList,
                exponentialScheduleForLrAndMom,
                
                #Weighting Classes differently in the CNN's cost function during training:
                numberOfEpochsToWeightTheClassesInTheCostFunction,
                
                performFullInferenceOnValidationImagesEveryFewEpochsBool, #Even if not providedGtForValidationBool, inference will be performed if this == True, to save the results, eg for visual.
                everyThatManyEpochsComputeDiceOnTheFullValidationImages=1, # Should not be == 0, except if performFullInferenceOnValidationImagesEveryFewEpochsBool == False
                
                #--------For FM visualisation---------
                saveIndividualFmImagesForVisualisation=False,
                saveMultidimensionalImageWithAllFms=False,
                indicesOfFmsToVisualisePerPathwayTypeAndPerLayer="placeholder",
                listOfNamesToGiveToFmVisualisationsIfSaving="placeholder"
                ):
    
    start_training_time = time.clock()
    
    # Used because I cannot pass cnn3dInstance to the sampling function.
    #This is because the parallel process then loads theano again. And creates problems in the GPU when cnmem is used.
    cnn3dWrapper = CnnWrapperForSampling(cnn3dInstance) 
    
    #---------To run PARALLEL the extraction of parts for the next subepoch---
    ppservers = () # tuple of all parallel python servers to connect with
    job_server = pp.Server(ncpus=1, ppservers=ppservers) # Creates jobserver with automatically detected number of workers
    
    tupleWithParametersForTraining = (myLogger,
                                    0,
                                    cnn3dWrapper,
                                    maxNumSubjectsLoadedPerSubepoch,
                                    
                                    imagePartsLoadedInGpuPerSubepoch,
                                    samplingTypeInstanceTraining,
                                    
                                    listOfFilepathsToEachChannelOfEachPatientTraining,
                                    
                                    listOfFilepathsToGtLabelsOfEachPatientTraining,
                                    
                                    providedRoiMaskForTrainingBool,
                                    listOfFilepathsToRoiMaskOfEachPatientTraining,
                                    
                                    providedWeightMapsToSampleForEachCategoryTraining,
                                    forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining,
                                    
                                    useSameSubChannelsAsSingleScale,
                                    
                                    listOfFilepathsToEachSubsampledChannelOfEachPatientTraining,
                                    
                                    padInputImagesBool,
                                    smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage,
                                    normAugmNone0OnImages1OrSegments2AlreadyNormalized1SubtrUpToPropOfStdAndDivideWithUpToPerc,
                                    reflectImageWithHalfProbDuringTraining
                                    )
    tupleWithParametersForValidation = (myLogger,
                                    1,
                                    cnn3dWrapper,
                                    maxNumSubjectsLoadedPerSubepoch,
                                    
                                    imagePartsLoadedInGpuPerSubepochValidation,
                                    samplingTypeInstanceValidation,
                                    
                                    listOfFilepathsToEachChannelOfEachPatientValidation,
                                    
                                    listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                                    
                                    providedRoiMaskForValidationBool,
                                    listOfFilepathsToRoiMaskOfEachPatientValidation,
                                    
                                    providedWeightMapsToSampleForEachCategoryValidation,
                                    forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation,
                                    
                                    useSameSubChannelsAsSingleScale,
                                    
                                    listOfFilepathsToEachSubsampledChannelOfEachPatientValidation,
                                    
                                    padInputImagesBool,
                                    smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage,
                                    [0, -1,-1,-1], #don't perform intensity-augmentation during validation.
                                    [0,0,0] #don't perform reflection-augmentation during validation.
                                    )
    tupleWithLocalFunctionsThatWillBeCalledByTheMainJob = ( )
    tupleWithModulesToImportWhichAreUsedByTheJobFunctions = ( "from __future__ import absolute_import, print_function, division", "from six.moves import xrange",
                "time", "numpy as np", "from deepmedic.dataManagement.sampling import *" )
    boolItIsTheVeryFirstSubepochOfThisProcess = True #to know so that in the very first I sequencially load the data for it.
    #------End for parallel------
    
    while cnn3dInstance.numberOfEpochsTrained < n_epochs :
        epoch = cnn3dInstance.numberOfEpochsTrained
        
        trainingAccuracyMonitorForEpoch = AccuracyOfEpochMonitorSegmentation(myLogger, 0, cnn3dInstance.numberOfEpochsTrained, cnn3dInstance.numberOfOutputClasses, number_of_subepochs)
        validationAccuracyMonitorForEpoch = None if not performValidationOnSamplesDuringTrainingProcessBool else \
                                        AccuracyOfEpochMonitorSegmentation(myLogger, 1, cnn3dInstance.numberOfEpochsTrained, cnn3dInstance.numberOfOutputClasses, number_of_subepochs ) 
                                        
        myLogger.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
        myLogger.print3("~~~~~~~~~~~~~~~~~~~~Starting new Epoch! Epoch #"+str(epoch)+"/"+str(n_epochs)+" ~~~~~~~~~~~~~~~~~~~~~~~~~")
        myLogger.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
        start_epoch_time = time.clock()
        
        for subepoch in xrange(number_of_subepochs): #per subepoch I randomly load some images in the gpu. Random order.
            myLogger.print3("**************************************************************************************************")
            myLogger.print3("************* Starting new Subepoch: #"+str(subepoch)+"/"+str(number_of_subepochs)+" *************")
            myLogger.print3("**************************************************************************************************")
            
            #-------------------------GET DATA FOR THIS SUBEPOCH's VALIDATION---------------------------------
            
            if performValidationOnSamplesDuringTrainingProcessBool :
                if boolItIsTheVeryFirstSubepochOfThisProcess :
                    [channsOfSegmentsForSubepPerPathwayVal,
                    labelsForCentralOfSegmentsForSubepVal] = getSampledDataAndLabelsForSubepoch(myLogger,
                                                                        1,
                                                                        cnn3dWrapper,
                                                                        maxNumSubjectsLoadedPerSubepoch,
                                                                        imagePartsLoadedInGpuPerSubepochValidation,
                                                                        samplingTypeInstanceValidation,
                                                                        
                                                                        listOfFilepathsToEachChannelOfEachPatientValidation,
                                                                        
                                                                        listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                                                                        
                                                                        providedRoiMaskForValidationBool,
                                                                        listOfFilepathsToRoiMaskOfEachPatientValidation,
                                                                        
                                                                        providedWeightMapsToSampleForEachCategoryValidation,
                                                                        forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation,
                                                                        
                                                                        useSameSubChannelsAsSingleScale,
                                                                        
                                                                        listOfFilepathsToEachSubsampledChannelOfEachPatientValidation,
                                                                        
                                                                        padInputImagesBool,
                                                                        smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage,
                                                                        normAugmNone0OnImages1OrSegments2AlreadyNormalized1SubtrUpToPropOfStdAndDivideWithUpToPerc=[0,-1,-1,-1],
                                                                        reflectImageWithHalfProbDuringTraining = [0,0,0]
                                                                        )
                    boolItIsTheVeryFirstSubepochOfThisProcess = False
                else : #It was done in parallel with the training of the previous epoch, just grab the results...
                    [channsOfSegmentsForSubepPerPathwayVal,
                    labelsForCentralOfSegmentsForSubepVal] = parallelJobToGetDataForNextValidation() #fromParallelProcessing that had started from last loop when it was submitted.
                    
                #------------------------------LOAD DATA FOR VALIDATION----------------------
                myLogger.print3("Loading Validation data for subepoch #"+str(subepoch)+" on shared variable...")
                start_loadingToGpu_time = time.clock()
                
                numberOfBatchesValidation = len(channsOfSegmentsForSubepPerPathwayVal[0]) // cnn3dInstance.batchSizeValidation #Computed with number of extracted samples, in case I dont manage to extract as many as I wanted initially.
                
                myLogger.print3("DEBUG: For Validation, loading to shared variable that many Segments: " + str(len(channsOfSegmentsForSubepPerPathwayVal[0])))
                
                cnn3dInstance.sharedInpXVal.set_value(channsOfSegmentsForSubepPerPathwayVal[0], borrow=borrowFlag) # Primary pathway
                for index in xrange(len(channsOfSegmentsForSubepPerPathwayVal[1:])) :
                    cnn3dInstance.sharedInpXPerSubsListVal[index].set_value(channsOfSegmentsForSubepPerPathwayVal[1+index], borrow=borrowFlag)
                cnn3dInstance.sharedLabelsYVal.set_value(labelsForCentralOfSegmentsForSubepVal, borrow=borrowFlag)
                channsOfSegmentsForSubepPerPathwayVal = ""
                labelsForCentralOfSegmentsForSubepVal = ""
                
                end_loadingToGpu_time = time.clock()
                myLogger.print3("TIMING: Loading sharedVariables for Validation in epoch|subepoch="+str(epoch)+"|"+str(subepoch)+" took time: "+str(end_loadingToGpu_time-start_loadingToGpu_time)+"(s)")
                
                
                #------------------------SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING-----------------
                #submit the parallel job
                myLogger.print3("PARALLEL: Before Validation in subepoch #" +str(subepoch) + ", the parallel job for extracting Segments for the next Training is submitted.")
                parallelJobToGetDataForNextTraining = job_server.submit(getSampledDataAndLabelsForSubepoch, #local function to call and execute in parallel.
                                                                        tupleWithParametersForTraining, #tuple with the arguments required
                                                                        tupleWithLocalFunctionsThatWillBeCalledByTheMainJob, #tuple of local functions that I need to call
                                                                        tupleWithModulesToImportWhichAreUsedByTheJobFunctions) #tuple of the external modules that I need, of which I am calling functions (not the mods of the ext-functions).
                
                #------------------------------------DO VALIDATION--------------------------------
                myLogger.print3("-V-V-V-V-V- Now Validating for this subepoch before commencing the training iterations... -V-V-V-V-V-")
                start_validationForSubepoch_time = time.clock()
                
                train0orValidation1 = 1 #validation
                vectorWithWeightsOfTheClassesForCostFunctionOfTraining = 'placeholder' #only used in training
                
                doTrainOrValidationOnBatchesAndReturnMeanAccuraciesOfSubepoch(myLogger,
                                                                            train0orValidation1,
                                                                            numberOfBatchesValidation, # Computed by the number of extracted samples. So, adapts.
                                                                            cnn3dInstance,
                                                                            vectorWithWeightsOfTheClassesForCostFunctionOfTraining,
                                                                            subepoch,
                                                                            validationAccuracyMonitorForEpoch)
                cnn3dInstance.freeGpuValidationData()
                
                end_validationForSubepoch_time = time.clock()
                myLogger.print3("TIMING: Validating on the batches of this subepoch #" + str(subepoch) + " took time: "+str(end_validationForSubepoch_time-start_validationForSubepoch_time)+"(s)")
                
                #Update cnn's top achieved validation accuracy if needed: (for the autoReduction of Learning Rate.)
                cnn3dInstance.checkMeanValidationAccOfLastEpochAndUpdateCnnsTopAccAchievedIfNeeded(myLogger,
                                                                                    validationAccuracyMonitorForEpoch.getMeanEmpiricalAccuracyOfEpoch(),
                                                                                    minIncreaseInValidationAccuracyConsideredForLrSchedule)
            #-------------------END OF THE VALIDATION-DURING-TRAINING-LOOP-------------------------
            
            
            #-------------------------GET DATA FOR THIS SUBEPOCH's TRAINING---------------------------------
            if (not performValidationOnSamplesDuringTrainingProcessBool) and boolItIsTheVeryFirstSubepochOfThisProcess :                    
                [channsOfSegmentsForSubepPerPathwayTrain,
                labelsForCentralOfSegmentsForSubepTrain] = getSampledDataAndLabelsForSubepoch(myLogger,
                                                                        0,
                                                                        cnn3dWrapper,
                                                                        maxNumSubjectsLoadedPerSubepoch,
                                                                        imagePartsLoadedInGpuPerSubepoch,
                                                                        samplingTypeInstanceTraining,
                                                                        
                                                                        listOfFilepathsToEachChannelOfEachPatientTraining,
                                                                        
                                                                        listOfFilepathsToGtLabelsOfEachPatientTraining,
                                                                        
                                                                        providedRoiMaskForTrainingBool,
                                                                        listOfFilepathsToRoiMaskOfEachPatientTraining,
                                                                        
                                                                        providedWeightMapsToSampleForEachCategoryTraining,
                                                                        forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining,
                                                                        
                                                                        useSameSubChannelsAsSingleScale,
                                                                        
                                                                        listOfFilepathsToEachSubsampledChannelOfEachPatientTraining,
                                                                        
                                                                        padInputImagesBool,
                                                                        smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage,
                                                                        normAugmNone0OnImages1OrSegments2AlreadyNormalized1SubtrUpToPropOfStdAndDivideWithUpToPerc,
                                                                        reflectImageWithHalfProbDuringTraining
                                                                        )
                boolItIsTheVeryFirstSubepochOfThisProcess = False
            else :
                #It was done in parallel with the validation (or with previous training iteration, in case I am not performing validation).
                [channsOfSegmentsForSubepPerPathwayTrain,
                labelsForCentralOfSegmentsForSubepTrain] = parallelJobToGetDataForNextTraining() #fromParallelProcessing that had started from last loop when it was submitted.
                
            #-------------------------COMPUTE CLASS-WEIGHTS, TO WEIGHT COST FUNCTION AND COUNTER CLASS IMBALANCE----------------------
            #Do it for only few epochs, until I get to an ok local minima neighbourhood.
            if cnn3dInstance.numberOfEpochsTrained < numberOfEpochsToWeightTheClassesInTheCostFunction :
                numOfPatchesInTheSubepoch_notParts = np.prod(labelsForCentralOfSegmentsForSubepTrain.shape)
                actualNumOfPatchesPerClassInTheSubepoch_notParts = np.bincount(np.ravel(labelsForCentralOfSegmentsForSubepTrain).astype(int))
                # yx - y1 = (x - x1) * (y2 - y1)/(x2 - x1)
                # yx = the multiplier I currently want, y1 = the multiplier at the begining, y2 = the multiplier at the end
                # x = current epoch, x1 = epoch where linear decrease starts, x2 = epoch where linear decrease ends
                y1 = (1./(actualNumOfPatchesPerClassInTheSubepoch_notParts+TINY_FLOAT)) * (numOfPatchesInTheSubepoch_notParts*1.0/cnn3dInstance.numberOfOutputClasses)
                y2 = 1.
                x1 = 0. * number_of_subepochs # linear decrease starts from epoch=0
                x2 = numberOfEpochsToWeightTheClassesInTheCostFunction * number_of_subepochs
                x = cnn3dInstance.numberOfEpochsTrained * number_of_subepochs + subepoch
                yx = (x - x1) * (y2 - y1)/(x2 - x1) + y1
                vectorWithWeightsOfTheClassesForCostFunctionOfTraining = np.asarray(yx, dtype="float32")
                myLogger.print3("UPDATE: [Weight of Classes] Setting the weights of the classes in the cost function to: " +str(vectorWithWeightsOfTheClassesForCostFunctionOfTraining))
            else :
                vectorWithWeightsOfTheClassesForCostFunctionOfTraining = np.ones(cnn3dInstance.numberOfOutputClasses, dtype='float32')
                
            #------------------- Learning Rate Schedule ------------------------
            # I must make a learning-rate-manager to encapsulate all these... Very ugly currently... All othere LR schedules are at the outer loop, per epoch.
            if (lowerLrByStable0orAuto1orPredefined2orExponential3Schedule == 4) :
                myLogger.print3("DEBUG: Going to change Learning Rate according to POLY schedule:")
                #newLearningRate = initLr * ( 1 - iter/max_iter) ^ power. Power = 0.9 in parsenet, which we validated to behave ok.
                currentIteration = cnn3dInstance.numberOfEpochsTrained * number_of_subepochs + subepoch
                max_iterations = n_epochs * number_of_subepochs
                newLearningRate = cnn3dInstance.initialLearningRate * pow( 1.0 - 1.0*currentIteration/max_iterations , 0.9)
                myLogger.print3("DEBUG: new learning rate was calculated: " +str(newLearningRate))
                cnn3dInstance.change_learning_rate_of_a_cnn(newLearningRate, myLogger)
                
            #----------------------------------LOAD TRAINING DATA ON GPU-------------------------------
            myLogger.print3("Loading Training data for subepoch #"+str(subepoch)+" on shared variable...")
            start_loadingToGpu_time = time.clock()
            
            numberOfBatchesTraining = len(channsOfSegmentsForSubepPerPathwayTrain[0]) // cnn3dInstance.batchSize #Computed with number of extracted samples, in case I dont manage to extract as many as I wanted initially.
            
            cnn3dInstance.sharedInpXTrain.set_value(channsOfSegmentsForSubepPerPathwayTrain[0], borrow=borrowFlag) # Primary pathway
            for index in xrange(len(channsOfSegmentsForSubepPerPathwayTrain[1:])) :
                cnn3dInstance.sharedInpXPerSubsListTrain[index].set_value(channsOfSegmentsForSubepPerPathwayTrain[1+index], borrow=borrowFlag)
            cnn3dInstance.sharedLabelsYTrain.set_value(labelsForCentralOfSegmentsForSubepTrain, borrow=borrowFlag)
            channsOfSegmentsForSubepPerPathwayTrain = ""
            labelsForCentralOfSegmentsForSubepTrain = ""
            
            end_loadingToGpu_time = time.clock()
            myLogger.print3("TIMING: Loading sharedVariables for Training in epoch|subepoch="+str(epoch)+"|"+str(subepoch)+" took time: "+str(end_loadingToGpu_time-start_loadingToGpu_time)+"(s)")
            
            
            #------------------------SUBMIT PARALLEL JOB TO GET VALIDATION/TRAINING DATA (if val is/not performed) FOR NEXT SUBEPOCH-----------------
            if performValidationOnSamplesDuringTrainingProcessBool :
                #submit the parallel job
                myLogger.print3("PARALLEL: Before Training in subepoch #" +str(subepoch) + ", submitting the parallel job for extracting Segments for the next Validation.")
                parallelJobToGetDataForNextValidation = job_server.submit(getSampledDataAndLabelsForSubepoch, #local function to call and execute in parallel.
                                                                            tupleWithParametersForValidation, #tuple with the arguments required
                                                                            tupleWithLocalFunctionsThatWillBeCalledByTheMainJob, #tuple of local functions that I need to call
                                                                            tupleWithModulesToImportWhichAreUsedByTheJobFunctions) #tuple of the external modules that I need, of which I am calling functions (not the mods of the ext-functions).
            else : #extract in parallel the samples for the next subepoch's training.
                myLogger.print3("PARALLEL: Before Training in subepoch #" +str(subepoch) + ", submitting the parallel job for extracting Segments for the next Training.")
                parallelJobToGetDataForNextTraining = job_server.submit(getSampledDataAndLabelsForSubepoch, #local function to call and execute in parallel.
                                                                            tupleWithParametersForTraining, #tuple with the arguments required
                                                                            tupleWithLocalFunctionsThatWillBeCalledByTheMainJob, #tuple of local functions that I need to call
                                                                            tupleWithModulesToImportWhichAreUsedByTheJobFunctions) #tuple of the external modules that I need, of which I am calling
                
            #-------------------------------START TRAINING IN BATCHES------------------------------
            myLogger.print3("-T-T-T-T-T- Now Training for this subepoch... This may take a few minutes... -T-T-T-T-T-")
            start_trainingForSubepoch_time = time.clock()
            
            train0orValidation1 = 0 #training
            doTrainOrValidationOnBatchesAndReturnMeanAccuraciesOfSubepoch(myLogger,
                                                                        train0orValidation1,
                                                                        numberOfBatchesTraining,
                                                                        cnn3dInstance,
                                                                        vectorWithWeightsOfTheClassesForCostFunctionOfTraining,
                                                                        subepoch,
                                                                        trainingAccuracyMonitorForEpoch)
            cnn3dInstance.freeGpuTrainingData()
            
            end_trainingForSubepoch_time = time.clock()
            myLogger.print3("TIMING: Training on the batches of this subepoch #" + str(subepoch) + " took time: "+str(end_trainingForSubepoch_time-start_trainingForSubepoch_time)+"(s)")
            
        myLogger.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" )
        myLogger.print3("~~~~~~~~~~~~~~~~~~ Epoch #" + str(epoch) + " finished. Reporting Accuracy over whole epoch. ~~~~~~~~~~~~~~~~~~" )
        myLogger.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" )
        
        if performValidationOnSamplesDuringTrainingProcessBool :
            validationAccuracyMonitorForEpoch.reportMeanAccyracyOfEpoch()
        trainingAccuracyMonitorForEpoch.reportMeanAccyracyOfEpoch()
        
        del trainingAccuracyMonitorForEpoch; del validationAccuracyMonitorForEpoch;
        
        #=======================Learning Rate Schedule.=========================
        if (lowerLrByStable0orAuto1orPredefined2orExponential3Schedule == 0) and (numEpochsToWaitBeforeLowerLR > 0) and (cnn3dInstance.numberOfEpochsTrained % numEpochsToWaitBeforeLowerLR)==0 :
            # STABLE LR SCHEDULE"
            myLogger.print3("DEBUG: Going to lower Learning Rate because of STABLE schedule! The CNN has now been trained for: " + str(cnn3dInstance.numberOfEpochsTrained) + " epochs. I need to decrease LR every: " + str(numEpochsToWaitBeforeLowerLR) + " epochs.")
            cnn3dInstance.divide_learning_rate_of_a_cnn_by(divideLrBy, myLogger)
        elif (lowerLrByStable0orAuto1orPredefined2orExponential3Schedule == 1) and (numEpochsToWaitBeforeLowerLR > 0) :
            # AUTO LR SCHEDULE!
            if not performValidationOnSamplesDuringTrainingProcessBool : #This flag should have been set True from the start if training should do Auto-schedule. If we get in here, this is a bug.
                myLogger.print3("ERROR: For Auto-schedule I need to be performing validation-on-samples during the training-process. The flag performValidationOnSamplesDuringTrainingProcessBool should have been set to True. Instead it seems it was False and no validation was performed. This is a bug. Contact the developer, this should not have happened. Try another Learning Rate schedule for now! Exiting.")
                exit(1)
            if (cnn3dInstance.numberOfEpochsTrained >= cnn3dInstance.topMeanValidationAccuracyAchievedInEpoch[1] + numEpochsToWaitBeforeLowerLR) and \
                    (cnn3dInstance.numberOfEpochsTrained >= cnn3dInstance.lastEpochAtTheEndOfWhichLrWasLowered + numEpochsToWaitBeforeLowerLR) :
                myLogger.print3("DEBUG: Going to lower Learning Rate because of AUTO schedule! The CNN has now been trained for: " + str(cnn3dInstance.numberOfEpochsTrained) + " epochs. Epoch with last highest achieved validation accuracy: " + str(cnn3dInstance.topMeanValidationAccuracyAchievedInEpoch[1]) + ", and epoch that Learning Rate was last lowered: " + str(cnn3dInstance.lastEpochAtTheEndOfWhichLrWasLowered) + ". I waited for increase in accuracy for: " +str(numEpochsToWaitBeforeLowerLR) + " epochs. Going to lower Learning Rate...")
                cnn3dInstance.divide_learning_rate_of_a_cnn_by(divideLrBy, myLogger)
        elif (lowerLrByStable0orAuto1orPredefined2orExponential3Schedule == 2) and (cnn3dInstance.numberOfEpochsTrained in lowerLrAtTheEndOfTheseEpochsPredefinedScheduleList) :
            #Predefined Schedule.
            myLogger.print3("DEBUG: Going to lower Learning Rate because of PREDEFINED schedule! The CNN has now been trained for: " + str(cnn3dInstance.numberOfEpochsTrained) + " epochs. I need to decrease after that many epochs: " + str(lowerLrAtTheEndOfTheseEpochsPredefinedScheduleList))
            cnn3dInstance.divide_learning_rate_of_a_cnn_by(divideLrBy, myLogger)
        elif (lowerLrByStable0orAuto1orPredefined2orExponential3Schedule == 3 and cnn3dInstance.numberOfEpochsTrained >= exponentialScheduleForLrAndMom[0]) :
            myLogger.print3("DEBUG: Going to lower Learning Rate and Increase Momentum because of EXPONENTIAL schedule! The CNN has now been trained for: " + str(cnn3dInstance.numberOfEpochsTrained) + " epochs.")
            minEpochToLowerLr = exponentialScheduleForLrAndMom[0]          
            #newLearningRate = initialLearningRate * gamma^t. gamma = {t-th}root(valueIwantLrToHaveAtTimepointT / initialLearningRate)
            gammaForExpSchedule = pow( ( cnn3dInstance.initialLearningRate*exponentialScheduleForLrAndMom[1] * 1.0) / cnn3dInstance.initialLearningRate, 1.0 / (n_epochs-minEpochToLowerLr))
            newLearningRate = cnn3dInstance.initialLearningRate * pow(gammaForExpSchedule, cnn3dInstance.numberOfEpochsTrained-minEpochToLowerLr + 1.0)
            #Momentum increased linearly.
            newMomentum = ((cnn3dInstance.numberOfEpochsTrained - minEpochToLowerLr + 1) - (n_epochs-minEpochToLowerLr))*1.0 / (n_epochs - minEpochToLowerLr) * (exponentialScheduleForLrAndMom[2] - cnn3dInstance.initialMomentum) + exponentialScheduleForLrAndMom[2]
            print("DEBUG: new learning rate was calculated: ", newLearningRate, " and new Momentum: ", newMomentum)
            cnn3dInstance.change_learning_rate_of_a_cnn(newLearningRate, myLogger)
            cnn3dInstance.change_momentum_of_a_cnn(newMomentum, myLogger)
            
        #================== Everything for epoch has finished. =======================
        #Training finished. Update the number of epochs that the cnn was trained.
        cnn3dInstance.increaseNumberOfEpochsTrained()
        
        myLogger.print3("SAVING: Epoch #"+str(epoch)+" finished. Saving CNN model.")
        dump_cnn_to_gzip_file_dotSave(cnn3dInstance, fileToSaveTrainedCnnModelTo+"."+datetimeNowAsStr(), myLogger)
        end_epoch_time = time.clock()
        myLogger.print3("TIMING: The whole Epoch #"+str(epoch)+" took time: "+str(end_epoch_time-start_epoch_time)+"(s)")
        myLogger.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
        
        if performFullInferenceOnValidationImagesEveryFewEpochsBool and (cnn3dInstance.numberOfEpochsTrained != 0) and (cnn3dInstance.numberOfEpochsTrained % everyThatManyEpochsComputeDiceOnTheFullValidationImages == 0) :
            myLogger.print3("***Starting validation with Full Inference / Segmentation on validation subjects for Epoch #"+str(epoch)+"...***")
            validation0orTesting1 = 0
            #do_validation_or_testing(myLogger,
            performInferenceOnWholeVolumes(myLogger,
                                    validation0orTesting1,
                                    savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation,
                                    cnn3dInstance,
                                    
                                    listOfFilepathsToEachChannelOfEachPatientValidation,
                                    
                                    providedGtForValidationBool,
                                    listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                                    
                                    providedRoiMaskForValidationBool,
                                    listOfFilepathsToRoiMaskOfEachPatientValidation,
                                    
                                    borrowFlag,
                                    listOfNamesToGiveToPredictionsIfSavingResults = "Placeholder" if not savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation else listOfNamesToGiveToPredictionsValidationIfSavingWhenEvalDice,
                                    
                                    #----Preprocessing------
                                    padInputImagesBool=padInputImagesBool,
                                    smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage=smoothChannelsWithGaussFilteringStdsForNormalAndSubsampledImage,
                                    
                                    #for the cnn extension
                                    useSameSubChannelsAsSingleScale=useSameSubChannelsAsSingleScale,
                                    
                                    listOfFilepathsToEachSubsampledChannelOfEachPatient=listOfFilepathsToEachSubsampledChannelOfEachPatientValidation,
                                    
                                    #--------For FM visualisation---------
                                    saveIndividualFmImagesForVisualisation=saveIndividualFmImagesForVisualisation,
                                    saveMultidimensionalImageWithAllFms=saveMultidimensionalImageWithAllFms,
                                    indicesOfFmsToVisualisePerPathwayTypeAndPerLayer=indicesOfFmsToVisualisePerPathwayTypeAndPerLayer,
                                    listOfNamesToGiveToFmVisualisationsIfSaving=listOfNamesToGiveToFmVisualisationsIfSaving
                                    )
            
    dump_cnn_to_gzip_file_dotSave(cnn3dInstance, fileToSaveTrainedCnnModelTo+".final."+datetimeNowAsStr(), myLogger)
    
    end_training_time = time.clock()
    myLogger.print3("TIMING: Training process took time: "+str(end_training_time-start_training_time)+"(s)")
    myLogger.print3("The whole do_training() function has finished.")
    
    
Ejemplo n.º 4
0
def do_training(
        sessionTf,
        saver_all,
        cnn3d,
        trainer,
        log,
        fileToSaveTrainedCnnModelTo,
        performValidationOnSamplesDuringTrainingProcessBool,
        savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation,
        listOfNamesToGiveToPredictionsValidationIfSavingWhenEvalDice,
        listOfFilepathsToEachChannelOfEachPatientTraining,
        listOfFilepathsToEachChannelOfEachPatientValidation,
        listOfFilepathsToGtLabelsOfEachPatientTraining,
        providedGtForValidationBool,
        listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
        providedWeightMapsToSampleForEachCategoryTraining,
        forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining,
        providedWeightMapsToSampleForEachCategoryValidation,
        forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation,
        providedRoiMaskForTrainingBool,
        listOfFilepathsToRoiMaskOfEachPatientTraining,  # Also needed for normalization-augmentation
        providedRoiMaskForValidationBool,
        listOfFilepathsToRoiMaskOfEachPatientValidation,
        n_epochs,  # Every epoch the CNN model is saved.
        number_of_subepochs,  # per epoch. Every subepoch Accuracy is reported
        maxNumSubjectsLoadedPerSubepoch,  # Max num of cases loaded every subepoch for segments extraction. The more, the longer loading.
        imagePartsLoadedInGpuPerSubepoch,
        imagePartsLoadedInGpuPerSubepochValidation,

        #-------Sampling Type---------
        samplingTypeInstanceTraining,  # Instance of the deepmedic/samplingType.SamplingType class for training and validation
        samplingTypeInstanceValidation,

        #-------Preprocessing-----------
        padInputImagesBool,
        #-------Data Augmentation-------
        doIntAugm_shiftMuStd_multiMuStd,
        reflectImageWithHalfProbDuringTraining,
        useSameSubChannelsAsSingleScale,
        listOfFilepathsToEachSubsampledChannelOfEachPatientTraining,  # deprecated, not supported
        listOfFilepathsToEachSubsampledChannelOfEachPatientValidation,  # deprecated, not supported

        # Validation
    performFullInferenceOnValidationImagesEveryFewEpochsBool,  #Even if not providedGtForValidationBool, inference will be performed if this == True, to save the results, eg for visual.
        everyThatManyEpochsComputeDiceOnTheFullValidationImages,  # Should not be == 0, except if performFullInferenceOnValidationImagesEveryFewEpochsBool == False

        #--------For FM visualisation---------
    saveIndividualFmImagesForVisualisation,
        saveMultidimensionalImageWithAllFms,
        indicesOfFmsToVisualisePerPathwayTypeAndPerLayer,
        listOfNamesToGiveToFmVisualisationsIfSaving,

        #-------- Others --------
        run_input_checks):

    start_training_time = time.time()

    # Used because I cannot pass cnn3d to the sampling function.
    #This is because the parallel process used to load theano again. And created problems in the GPU when cnmem is used. Not sure this is needed with Tensorflow. Probably.
    cnn3dWrapper = CnnWrapperForSampling(cnn3d)

    #---------To run PARALLEL the extraction of parts for the next subepoch---
    ppservers = ()  # tuple of all parallel python servers to connect with
    job_server = pp.Server(
        ncpus=1, ppservers=ppservers
    )  # Creates jobserver with automatically detected number of workers

    tupleWithParametersForTraining = (
        log, "train", run_input_checks, cnn3dWrapper,
        maxNumSubjectsLoadedPerSubepoch, imagePartsLoadedInGpuPerSubepoch,
        samplingTypeInstanceTraining,
        listOfFilepathsToEachChannelOfEachPatientTraining,
        listOfFilepathsToGtLabelsOfEachPatientTraining,
        providedRoiMaskForTrainingBool,
        listOfFilepathsToRoiMaskOfEachPatientTraining,
        providedWeightMapsToSampleForEachCategoryTraining,
        forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining,
        useSameSubChannelsAsSingleScale,
        listOfFilepathsToEachSubsampledChannelOfEachPatientTraining,
        padInputImagesBool, doIntAugm_shiftMuStd_multiMuStd,
        reflectImageWithHalfProbDuringTraining)
    tupleWithParametersForValidation = (
        log,
        "val",
        run_input_checks,
        cnn3dWrapper,
        maxNumSubjectsLoadedPerSubepoch,
        imagePartsLoadedInGpuPerSubepochValidation,
        samplingTypeInstanceValidation,
        listOfFilepathsToEachChannelOfEachPatientValidation,
        listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
        providedRoiMaskForValidationBool,
        listOfFilepathsToRoiMaskOfEachPatientValidation,
        providedWeightMapsToSampleForEachCategoryValidation,
        forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation,
        useSameSubChannelsAsSingleScale,
        listOfFilepathsToEachSubsampledChannelOfEachPatientValidation,
        padInputImagesBool,
        [0, -1, -1,
         -1],  #don't perform intensity-augmentation during validation.
        [0, 0, 0]  #don't perform reflection-augmentation during validation.
    )
    tupleWithLocalFunctionsThatWillBeCalledByTheMainJob = ()
    tupleWithModulesToImportWhichAreUsedByTheJobFunctions = (
        "from __future__ import absolute_import, print_function, division",
        "time", "numpy as np",
        "from deepmedic.dataManagement.sampling import *")
    boolItIsTheVeryFirstSubepochOfThisProcess = True  #to know so that in the very first I sequencially load the data for it.
    #------End for parallel------

    model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(
        session=sessionTf)
    while model_num_epochs_trained < n_epochs:
        epoch = model_num_epochs_trained

        trainingAccuracyMonitorForEpoch = AccuracyOfEpochMonitorSegmentation(
            log, 0, model_num_epochs_trained, cnn3d.num_classes,
            number_of_subepochs)
        validationAccuracyMonitorForEpoch = None if not performValidationOnSamplesDuringTrainingProcessBool else \
                                        AccuracyOfEpochMonitorSegmentation(log, 1, model_num_epochs_trained, cnn3d.num_classes, number_of_subepochs )

        log.print3(
            "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
        )
        log.print3("~~~~~~~~~~~~~~~~~~~~Starting new Epoch! Epoch #" +
                   str(epoch) + "/" + str(n_epochs) +
                   " ~~~~~~~~~~~~~~~~~~~~~~~~~")
        log.print3(
            "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
        )
        start_epoch_time = time.time()

        for subepoch in range(
                number_of_subepochs
        ):  #per subepoch I randomly load some images in the gpu. Random order.
            log.print3(
                "**************************************************************************************************"
            )
            log.print3("************* Starting new Subepoch: #" +
                       str(subepoch) + "/" + str(number_of_subepochs) +
                       " *************")
            log.print3(
                "**************************************************************************************************"
            )

            #-------------------------GET DATA FOR THIS SUBEPOCH's VALIDATION---------------------------------

            if performValidationOnSamplesDuringTrainingProcessBool:
                if boolItIsTheVeryFirstSubepochOfThisProcess:
                    [
                        channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal
                    ] = getSampledDataAndLabelsForSubepoch(
                        log,
                        "val",
                        run_input_checks,
                        cnn3dWrapper,
                        maxNumSubjectsLoadedPerSubepoch,
                        imagePartsLoadedInGpuPerSubepochValidation,
                        samplingTypeInstanceValidation,
                        listOfFilepathsToEachChannelOfEachPatientValidation,
                        listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                        providedRoiMaskForValidationBool,
                        listOfFilepathsToRoiMaskOfEachPatientValidation,
                        providedWeightMapsToSampleForEachCategoryValidation,
                        forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientValidation,
                        useSameSubChannelsAsSingleScale,
                        listOfFilepathsToEachSubsampledChannelOfEachPatientValidation,
                        padInputImagesBool,
                        doIntAugm_shiftMuStd_multiMuStd=[False, [], []],
                        reflectImageWithHalfProbDuringTraining=[0, 0, 0])
                    boolItIsTheVeryFirstSubepochOfThisProcess = False
                else:  #It was done in parallel with the training of the previous epoch, just grab the results...
                    [
                        channsOfSegmentsForSubepPerPathwayVal,
                        labelsForCentralOfSegmentsForSubepVal
                    ] = parallelJobToGetDataForNextValidation(
                    )  #fromParallelProcessing that had started from last loop when it was submitted.

                # Below is computed with number of extracted samples, in case I dont manage to extract as many as I wanted initially.
                numberOfBatchesValidation = len(
                    channsOfSegmentsForSubepPerPathwayVal[0]
                ) // cnn3d.batchSize["val"]

                #------------------------SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING-----------------
                #submit the parallel job
                log.print3(
                    "PARALLEL: Before Validation in subepoch #" +
                    str(subepoch) +
                    ", the parallel job for extracting Segments for the next Training is submitted."
                )
                parallelJobToGetDataForNextTraining = job_server.submit(
                    getSampledDataAndLabelsForSubepoch,  #local function to call and execute in parallel.
                    tupleWithParametersForTraining,  #tuple with the arguments required
                    tupleWithLocalFunctionsThatWillBeCalledByTheMainJob,  #tuple of local functions that I need to call
                    tupleWithModulesToImportWhichAreUsedByTheJobFunctions
                )  #tuple of the external modules that I need, of which I am calling functions (not the mods of the ext-functions).

                #------------------------------------DO VALIDATION--------------------------------
                log.print3(
                    "-V-V-V-V-V- Now Validating for this subepoch before commencing the training iterations... -V-V-V-V-V-"
                )
                start_validationForSubepoch_time = time.time()

                doTrainOrValidationOnBatchesAndReturnMeanAccuraciesOfSubepoch(
                    log,
                    sessionTf,
                    "val",
                    numberOfBatchesValidation,  # Computed by the number of extracted samples. So, adapts.
                    cnn3d,
                    subepoch,
                    validationAccuracyMonitorForEpoch,
                    channsOfSegmentsForSubepPerPathwayVal,
                    labelsForCentralOfSegmentsForSubepVal)

                end_validationForSubepoch_time = time.time()
                log.print3(
                    "TIMING: Validating on the batches of this subepoch #" +
                    str(subepoch) + " took time: " +
                    str(end_validationForSubepoch_time -
                        start_validationForSubepoch_time) + "(s)")

            #-------------------END OF THE VALIDATION-DURING-TRAINING-LOOP-------------------------

            #-------------------------GET DATA FOR THIS SUBEPOCH's TRAINING---------------------------------
            if (not performValidationOnSamplesDuringTrainingProcessBool
                ) and boolItIsTheVeryFirstSubepochOfThisProcess:
                [
                    channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain
                ] = getSampledDataAndLabelsForSubepoch(
                    log, "train", run_input_checks, cnn3dWrapper,
                    maxNumSubjectsLoadedPerSubepoch,
                    imagePartsLoadedInGpuPerSubepoch,
                    samplingTypeInstanceTraining,
                    listOfFilepathsToEachChannelOfEachPatientTraining,
                    listOfFilepathsToGtLabelsOfEachPatientTraining,
                    providedRoiMaskForTrainingBool,
                    listOfFilepathsToRoiMaskOfEachPatientTraining,
                    providedWeightMapsToSampleForEachCategoryTraining,
                    forEachSamplingCategory_aListOfFilepathsToWeightMapsOfEachPatientTraining,
                    useSameSubChannelsAsSingleScale,
                    listOfFilepathsToEachSubsampledChannelOfEachPatientTraining,
                    padInputImagesBool, doIntAugm_shiftMuStd_multiMuStd,
                    reflectImageWithHalfProbDuringTraining)
                boolItIsTheVeryFirstSubepochOfThisProcess = False
            else:
                #It was done in parallel with the validation (or with previous training iteration, in case I am not performing validation).
                [
                    channsOfSegmentsForSubepPerPathwayTrain,
                    labelsForCentralOfSegmentsForSubepTrain
                ] = parallelJobToGetDataForNextTraining(
                )  #fromParallelProcessing that had started from last loop when it was submitted.

            numberOfBatchesTraining = len(
                channsOfSegmentsForSubepPerPathwayTrain[0]
            ) // cnn3d.batchSize[
                "train"]  #Computed with number of extracted samples, in case I dont manage to extract as many as I wanted initially.

            #------------------------SUBMIT PARALLEL JOB TO GET VALIDATION/TRAINING DATA (if val is/not performed) FOR NEXT SUBEPOCH-----------------
            if performValidationOnSamplesDuringTrainingProcessBool:
                #submit the parallel job
                log.print3(
                    "PARALLEL: Before Training in subepoch #" + str(subepoch) +
                    ", submitting the parallel job for extracting Segments for the next Validation."
                )
                parallelJobToGetDataForNextValidation = job_server.submit(
                    getSampledDataAndLabelsForSubepoch,  #local function to call and execute in parallel.
                    tupleWithParametersForValidation,  #tuple with the arguments required
                    tupleWithLocalFunctionsThatWillBeCalledByTheMainJob,  #tuple of local functions that I need to call
                    tupleWithModulesToImportWhichAreUsedByTheJobFunctions
                )  #tuple of the external modules that I need, of which I am calling functions (not the mods of the ext-functions).
            else:  #extract in parallel the samples for the next subepoch's training.
                log.print3(
                    "PARALLEL: Before Training in subepoch #" + str(subepoch) +
                    ", submitting the parallel job for extracting Segments for the next Training."
                )
                parallelJobToGetDataForNextTraining = job_server.submit(
                    getSampledDataAndLabelsForSubepoch,  #local function to call and execute in parallel.
                    tupleWithParametersForTraining,  #tuple with the arguments required
                    tupleWithLocalFunctionsThatWillBeCalledByTheMainJob,  #tuple of local functions that I need to call
                    tupleWithModulesToImportWhichAreUsedByTheJobFunctions
                )  #tuple of the external modules that I need, of which I am calling

            #-------------------------------START TRAINING IN BATCHES------------------------------
            log.print3(
                "-T-T-T-T-T- Now Training for this subepoch... This may take a few minutes... -T-T-T-T-T-"
            )
            start_trainingForSubepoch_time = time.time()

            doTrainOrValidationOnBatchesAndReturnMeanAccuraciesOfSubepoch(
                log, sessionTf, "train", numberOfBatchesTraining, cnn3d,
                subepoch, trainingAccuracyMonitorForEpoch,
                channsOfSegmentsForSubepPerPathwayTrain,
                labelsForCentralOfSegmentsForSubepTrain)

            end_trainingForSubepoch_time = time.time()
            log.print3("TIMING: Training on the batches of this subepoch #" +
                       str(subepoch) + " took time: " +
                       str(end_trainingForSubepoch_time -
                           start_trainingForSubepoch_time) + "(s)")

        log.print3(
            "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
        )
        log.print3(
            "~~~~~~~~~~~~~~~~~~ Epoch #" + str(epoch) +
            " finished. Reporting Accuracy over whole epoch. ~~~~~~~~~~~~~~~~~~"
        )
        log.print3(
            "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
        )

        if performValidationOnSamplesDuringTrainingProcessBool:
            validationAccuracyMonitorForEpoch.reportMeanAccyracyOfEpoch()
        trainingAccuracyMonitorForEpoch.reportMeanAccyracyOfEpoch()

        mean_val_acc_of_ep = validationAccuracyMonitorForEpoch.getMeanEmpiricalAccuracyOfEpoch(
        ) if performValidationOnSamplesDuringTrainingProcessBool else None
        trainer.run_updates_end_of_ep(
            log, sessionTf, mean_val_acc_of_ep
        )  # Updates LR schedule if needed, and increases number of epochs trained.
        model_num_epochs_trained = trainer.get_num_epochs_trained_tfv().eval(
            session=sessionTf)

        del trainingAccuracyMonitorForEpoch
        del validationAccuracyMonitorForEpoch
        #================== Everything for epoch has finished. =======================

        log.print3("SAVING: Epoch #" + str(epoch) +
                   " finished. Saving CNN model.")
        filename_to_save_with = fileToSaveTrainedCnnModelTo + "." + datetimeNowAsStr(
        )
        saver_all.save(sessionTf,
                       filename_to_save_with + ".model.ckpt",
                       write_meta_graph=False)

        end_epoch_time = time.time()
        log.print3("TIMING: The whole Epoch #" + str(epoch) + " took time: " +
                   str(end_epoch_time - start_epoch_time) + "(s)")
        log.print3(
            "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
        )

        if performFullInferenceOnValidationImagesEveryFewEpochsBool and (
                model_num_epochs_trained != 0
        ) and (model_num_epochs_trained %
               everyThatManyEpochsComputeDiceOnTheFullValidationImages == 0):
            log.print3(
                "***Starting validation with Full Inference / Segmentation on validation subjects for Epoch #"
                + str(epoch) + "...***")

            performInferenceOnWholeVolumes(
                sessionTf,
                cnn3d,
                log,
                "val",
                savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation,
                listOfFilepathsToEachChannelOfEachPatientValidation,
                providedGtForValidationBool,
                listOfFilepathsToGtLabelsOfEachPatientValidationOnSamplesAndDsc,
                providedRoiMaskForValidationBool,
                listOfFilepathsToRoiMaskOfEachPatientValidation,
                listOfNamesToGiveToPredictionsIfSavingResults="Placeholder" if
                not savePredictionImagesSegmentationAndProbMapsListWhenEvaluatingDiceForValidation
                else
                listOfNamesToGiveToPredictionsValidationIfSavingWhenEvalDice,

                #----Preprocessing------
                padInputImagesBool=padInputImagesBool,

                #for the cnn extension
                useSameSubChannelsAsSingleScale=useSameSubChannelsAsSingleScale,
                listOfFilepathsToEachSubsampledChannelOfEachPatient=
                listOfFilepathsToEachSubsampledChannelOfEachPatientValidation,

                #--------For FM visualisation---------
                saveIndividualFmImagesForVisualisation=
                saveIndividualFmImagesForVisualisation,
                saveMultidimensionalImageWithAllFms=
                saveMultidimensionalImageWithAllFms,
                indicesOfFmsToVisualisePerPathwayTypeAndPerLayer=
                indicesOfFmsToVisualisePerPathwayTypeAndPerLayer,
                listOfNamesToGiveToFmVisualisationsIfSaving=
                listOfNamesToGiveToFmVisualisationsIfSaving)

    end_training_time = time.time()
    log.print3("TIMING: Training process took time: " +
               str(end_training_time - start_training_time) + "(s)")
    log.print3("The whole do_training() function has finished.")