Exemple #1
0
 def run_session(self, *args):
     (sess_device,
      model_params,) = args
     
     graphTf = tf.Graph()
     
     with graphTf.as_default():
         with graphTf.device(sess_device): # Throws an error if GPU is specified but not available.
             self._log.print3("=========== Making the CNN graph... ===============")
             cnn3d = Cnn3d()
             with tf.compat.v1.variable_scope("net"):
                 cnn3d.make_cnn_model(*model_params.get_args_for_arch())  # Creates network's graph (no optimizer)
                 inp_plchldrs, inp_shapes_per_path = cnn3d.create_inp_plchldrs(model_params.get_inp_dims_hr_path('test'), 'test')
                 p_y_given_x = cnn3d.apply(inp_plchldrs, 'infer', 'test', verbose=True, log=self._log)
                 
         self._log.print3("=========== Compiling the Testing Function ============")
         self._log.print3("=======================================================\n")
         
         cnn3d.setup_ops_n_feeds_to_test(self._log, inp_plchldrs, p_y_given_x, self._params.inds_fms_per_pathtype_per_layer_to_save)
         # Create the saver
         coll_vars_net = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope="net")
         saver_net = tf.compat.v1.train.Saver(var_list=coll_vars_net)  # saver_net would suffice
         # TF2: dict_vars_net = {'net_var'+str(i): v for i, v in enumerate(coll_vars_net)}
         # TF2: ckpt_net = tf.train.Checkpoint(**dict_vars_net)
         
     with tf.compat.v1.Session(graph=graphTf, config=tf.compat.v1.ConfigProto(log_device_placement=False, device_count={'CPU':999, 'GPU':99})) as sessionTf:
         file_to_load_params_from = self._params.get_path_to_load_model_from()
         if file_to_load_params_from is not None: # Load params
             self._log.print3("=========== Loading parameters from specified saved model ===============")
             chkpt_fname = tf.train.latest_checkpoint(file_to_load_params_from) if os.path.isdir(file_to_load_params_from) else file_to_load_params_from
             self._log.print3("Loading parameters from:" + str(chkpt_fname))
             try:
                 saver_net.restore(sessionTf, chkpt_fname)
                 # TF2: ckpt_net.restore(chkpt_fname)
                 self._log.print3("Parameters were loaded.")
             except Exception as e: handle_exception_tf_restore(self._log, e)
             
         else:
             self._ask_user_if_test_with_random()  # Asks user whether to continue with randomly initialized model.
             self._log.print3("")
             self._log.print3("=========== Initializing network variables  ===============")
             tf.compat.v1.variables_initializer(var_list=coll_vars_net).run()
             self._log.print3("Model variables were initialized.")
             
             
         self._log.print3("")
         self._log.print3("======================================================")
         self._log.print3("=========== Testing with the CNN model ===============")
         self._log.print3("======================================================")
         
         res_code = inference_on_whole_volumes(*([sessionTf, cnn3d] +
                                                 self._params.get_args_for_testing() +
                                                 [inp_shapes_per_path]))
     
     self._log.print3("")
     self._log.print3("======================================================")
     self._log.print3("=========== Testing session finished =================")
     self._log.print3("======================================================")
Exemple #2
0
def do_training(sessionTf,
                saver_all,
                cnn3d,
                trainer,
                tensorboard_loggers,
                
                log,
                fileToSaveTrainedCnnModelTo,

                val_on_samples,
                savePredictedSegmAndProbsDict,

                namesForSavingSegmAndProbs,
                suffixForSegmAndProbsDict,

                paths_per_chan_per_subj_train,
                paths_per_chan_per_subj_val,

                paths_to_lbls_per_subj_train,
                paths_to_lbls_per_subj_val,

                paths_to_wmaps_per_sampl_cat_per_subj_train,
                paths_to_wmaps_per_sampl_cat_per_subj_val,

                paths_to_masks_per_subj_train,
                paths_to_masks_per_subj_val,

                n_epochs,  # Every epoch the CNN model is saved.
                n_subepochs,  # per epoch. Every subepoch Accuracy is reported
                max_n_cases_per_subep_train,  # Max num of subjects loaded every subep for sampling
                n_samples_per_subep_train,
                n_samples_per_subep_val,
                num_parallel_proc_sampling,  # -1: seq. 0: thread for sampling. >0: multiprocess sampling

                # -------Sampling Type---------
                sampling_type_inst_tr,
                # Instance of the deepmedic/samplingType.SamplingType class for training and validation
                sampling_type_inst_val,
                batchsize_train,
                batchsize_val_samples,
                batchsize_val_whole,

                # -------Data Augmentation-------
                augm_img_prms,
                augm_sample_prms,

                # Validation
                val_on_whole_volumes,
                n_epochs_between_val_on_whole_vols,

                # --------For FM visualisation---------
                save_fms_flag,
                idxs_fms_to_save,
                namesForSavingFms,

                # --- Data Compatibility Checks ---
                run_input_checks,

                # -------- Pre-processing ------
                pad_input,
                norm_prms
                ):
    id_str = "[MAIN|PID:" + str(os.getpid()) + "]"
    start_time_train = time.time()

    # I cannot pass cnn3d to the sampling function, because the pp module used to reload theano. 
    # This created problems in the GPU when cnmem is used. Not sure this is needed with Tensorflow. Probably.
    cnn3dWrapper = CnnWrapperForSampling(cnn3d)

    args_for_sampling_tr = (log,
                            "train",
                            num_parallel_proc_sampling,
                            run_input_checks,
                            cnn3dWrapper,
                            max_n_cases_per_subep_train,
                            n_samples_per_subep_train,
                            sampling_type_inst_tr,
                            paths_per_chan_per_subj_train,
                            paths_to_lbls_per_subj_train,
                            paths_to_masks_per_subj_train,
                            paths_to_wmaps_per_sampl_cat_per_subj_train,
                            pad_input,
                            norm_prms,
                            augm_img_prms,
                            augm_sample_prms
                            )
    args_for_sampling_val = (log,
                             "val",
                             num_parallel_proc_sampling,
                             run_input_checks,
                             cnn3dWrapper,
                             max_n_cases_per_subep_train,
                             n_samples_per_subep_val,
                             sampling_type_inst_val,
                             paths_per_chan_per_subj_val,
                             paths_to_lbls_per_subj_val,
                             paths_to_masks_per_subj_val,
                             paths_to_wmaps_per_sampl_cat_per_subj_val,
                             pad_input,
                             norm_prms,
                             None,  # no augmentation in val.
                             None  # no augmentation in val.
                             )

    sampling_job_submitted_train = False
    sampling_job_submitted_val = False
    # For parallel extraction of samples for next train/val while processing previous iteration.
    mp_pool = None
    if num_parallel_proc_sampling > -1:  # Use multiprocessing.
        mp_pool = ThreadPool(processes=1)  # Or multiprocessing.Pool(...), same API.

    try:
        n_eps_trained_model = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf)
        while n_eps_trained_model < n_epochs:
            epoch = n_eps_trained_model

            tb_log_tr = tensorboard_loggers['train'] if tensorboard_loggers is not None else None
            acc_monitor_ep_tr = AccuracyMonitorForEpSegm(log, 0,
                                                         n_eps_trained_model,
                                                         cnn3d.num_classes,
                                                         n_subepochs,
                                                         tb_log_tr)

            tb_log_val = tensorboard_loggers['val'] if tensorboard_loggers is not None else None
            acc_monitor_ep_val = None
            if val_on_samples or val_on_whole_volumes:
                acc_monitor_ep_val = AccuracyMonitorForEpSegm(log, 1,
                                                              n_eps_trained_model,
                                                              cnn3d.num_classes,
                                                              n_subepochs,
                                                              tb_log_val)
            
            val_on_whole_vols_after_this_ep = False
            if val_on_whole_volumes and (n_eps_trained_model + 1) % n_epochs_between_val_on_whole_vols == 0:
                val_on_whole_vols_after_this_ep = True
                
            log.print3("")
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            log.print3("~~\t\t\t Starting new Epoch! Epoch #" + str(epoch) + "/" + str(n_epochs) + "  \t\t\t~~")
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            start_time_ep = time.time()

            for subep in range(n_subepochs):
                log.print3("")
                log.print3("***********************************************************************************")
                log.print3("*\t\t\t Starting new Subepoch: #" + str(subep) + "/" + str(n_subepochs) + " \t\t\t*")
                log.print3("***********************************************************************************")

                # -------------------- GET DATA FOR THIS SUBEPOCH's VALIDATION -----------------------
                if val_on_samples:
                    if mp_pool is None:  # Sequential processing.
                        log.print3(id_str + " NO MULTIPROC: Sampling for subepoch #" + str(subep) +\
                                   " [VALIDATION] will be done by main thread.")
                        (channs_samples_per_path_val,
                         lbls_samples_per_path_val) = get_samples_for_subepoch(*args_for_sampling_val)
                    elif sampling_job_submitted_val:  # done parallel with training of previous epoch.
                        (channs_samples_per_path_val,
                         lbls_samples_per_path_val) = sampling_job_val.get()
                        sampling_job_submitted_val = False
                    else:  # Not previously submitted in case of first epoch or after a full-volumes validation.
                        assert subep == 0
                        log.print3(id_str + " MULTIPROC: Before Validation in subepoch #" + str(subep) +\
                                   ", submitting sampling job for next [VALIDATION].")
                        sampling_job_val = mp_pool.apply_async(get_samples_for_subepoch, args_for_sampling_val)
                        (channs_samples_per_path_val,
                         lbls_samples_per_path_val) = sampling_job_val.get()
                        sampling_job_submitted_val = False

                    # ----------- SUBMIT PARALLEL JOB TO GET TRAINING DATA FOR NEXT TRAINING -----------------
                    if mp_pool is not None:
                        log.print3(id_str + " MULTIPROC: Before Validation in subepoch #" + str(subep) +\
                                   ", submitting sampling job for next [TRAINING].")
                        sampling_job_tr = mp_pool.apply_async(get_samples_for_subepoch, args_for_sampling_tr)
                        sampling_job_submitted_train = True

                    # ------------------------------------DO VALIDATION--------------------------------
                    log.print3("V-V-V-V- Validating for subepoch before starting training iterations -V-V-V-V")
                    start_time_val_subep = time.time()
                    # Calc num of batches from extracted samples, in case not extracted as much as requested.
                    n_batches_val = len(channs_samples_per_path_val[0]) // batchsize_val_samples
                    process_in_batches(log,
                                       sessionTf,
                                       "val",
                                       n_batches_val,
                                       batchsize_val_samples,
                                       cnn3d,
                                       acc_monitor_ep_val,
                                       channs_samples_per_path_val,
                                       lbls_samples_per_path_val)
                    log.print3("TIMING: Validation on batches of subepoch #" + str(subep) +\
                               " lasted: {0:.1f}".format(time.time() - start_time_val_subep) + " secs.")

                # ----------------------- GET DATA FOR THIS SUBEPOCH's TRAINING ------------------------------
                if mp_pool is None:  # Sequential processing.
                    log.print3(id_str + " NO MULTIPROC: Sampling for subepoch #" + str(subep) +\
                               " [TRAINING] will be done by main thread.")
                    (channs_samples_per_path_tr,
                     lbls_samples_per_path_tr) = get_samples_for_subepoch(*args_for_sampling_tr)
                elif sampling_job_submitted_train:  # done parallel with train/val of previous epoch.
                    (channs_samples_per_path_tr,
                     lbls_samples_per_path_tr) = sampling_job_tr.get()
                    sampling_job_submitted_train = False
                else:  # Not previously submitted in case of first epoch or after a full-volumes validation.
                    assert subep == 0
                    log.print3(id_str + " MULTIPROC: Before Training in subepoch #" + str(subep) +\
                               ", submitting sampling job for next [TRAINING].")
                    sampling_job_tr = mp_pool.apply_async(get_samples_for_subepoch, args_for_sampling_tr)
                    (channs_samples_per_path_tr,
                     lbls_samples_per_path_tr) = sampling_job_tr.get()
                    sampling_job_submitted_train = False

                # ----- SUBMIT PARALLEL JOB TO GET VAL / TRAIN (if no val) DATA FOR NEXT SUBEPOCH -----
                if mp_pool is not None and not (val_on_whole_vols_after_this_ep and (subep == n_subepochs - 1)):
                    if val_on_samples:
                        log.print3(id_str + " MULTIPROC: Before Training in subepoch #" + str(subep) +\
                                   ", submitting sampling job for next [VALIDATION].")
                        sampling_job_val = mp_pool.apply_async(get_samples_for_subepoch, args_for_sampling_val)
                        sampling_job_submitted_val = True
                    else:
                        log.print3(id_str + " MULTIPROC: Before Training in subepoch #" + str(subep) +\
                                   ", submitting sampling job for next [TRAINING].")
                        sampling_job_tr = mp_pool.apply_async(get_samples_for_subepoch, args_for_sampling_tr)
                        sampling_job_submitted_train = True

                # ------------------------------ START TRAINING IN BATCHES -----------------------------
                log.print3("-T-T-T-T- Training for this subepoch... May take a few minutes... -T-T-T-T-")
                start_time_train_subep = time.time()
                # Calc num of batches from extracted samples, in case not extracted as much as requested.
                n_batches_train = len(channs_samples_per_path_tr[0]) // batchsize_train
                process_in_batches(log,
                                   sessionTf,
                                   "train",
                                   n_batches_train,
                                   batchsize_train,
                                   cnn3d,
                                   acc_monitor_ep_tr,
                                   channs_samples_per_path_tr,
                                   lbls_samples_per_path_tr)
                log.print3("TIMING: Training on batches of this subepoch #" + str(subep) +\
                           " lasted: {0:.1f}".format(time.time() - start_time_train_subep) + " secs.")

            log.print3("")
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            log.print3("~~~~~~ Epoch #" + str(epoch) + " finished. Reporting Accuracy over whole epoch. ~~~~~~~")
            log.print3("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")

            if val_on_samples:
                acc_monitor_ep_val.report_metrics_samples_ep()
            acc_monitor_ep_tr.report_metrics_samples_ep()

            mean_val_acc_of_ep = acc_monitor_ep_val.get_avg_accuracy_ep() if val_on_samples else None
            # Updates LR schedule if needed, and increases number of epochs trained.
            trainer.run_updates_end_of_ep(log, sessionTf, mean_val_acc_of_ep)
            n_eps_trained_model = trainer.get_num_epochs_trained_tfv().eval(session=sessionTf)

            log.print3("SAVING: Epoch #" + str(epoch) + " finished. Saving CNN model.")
            filename_to_save_with = fileToSaveTrainedCnnModelTo + "." + datetime_now_str()
            saver_all.save(sessionTf, filename_to_save_with + ".model.ckpt", write_meta_graph=False)

            log.print3("TIMING: Whole Epoch #" + str(epoch) +\
                       " lasted: {0:.1f}".format(time.time() - start_time_ep) + " secs.")
            log.print3("~~~~~~~~~~~~~~~~~~~ End of Training Epoch. Model was Saved. ~~~~~~~~~~~~~~~~~~~~~~~~")

            if val_on_whole_vols_after_this_ep:
                log.print3("***Start validation by segmenting whole subjects for Epoch #" + str(epoch) + "***")

                mean_metrics_val_whole_vols = inference_on_whole_volumes(sessionTf,
                                                                         cnn3d,
                                                                         log,
                                                                         "val",
                                                                         savePredictedSegmAndProbsDict,
                                                                         paths_per_chan_per_subj_val,
                                                                         paths_to_lbls_per_subj_val,
                                                                         paths_to_masks_per_subj_val,
                                                                         namesForSavingSegmAndProbs,
                                                                         suffixForSegmAndProbsDict,
                                                                         # Hyper parameters
                                                                         batchsize_val_whole,
                                                                         # Data compatibility checks
                                                                         run_input_checks,
                                                                         # Pre-Processing
                                                                         pad_input,
                                                                         norm_prms,
                                                                         # Saving feature maps
                                                                         save_fms_flag,
                                                                         idxs_fms_to_save,
                                                                         namesForSavingFms)
                
                acc_monitor_ep_val.report_metrics_whole_vols(mean_metrics_val_whole_vols)

            del acc_monitor_ep_tr
            del acc_monitor_ep_val

        log.print3("TIMING: Training process lasted: {0:.1f}".format(time.time() - start_time_train) + " secs.")

    except (Exception, KeyboardInterrupt) as e:
        log.print3("\n\n ERROR: Caught exception in do_training(): " + str(e) + "\n")
        log.print3(traceback.format_exc())
        if mp_pool is not None:
            log.print3("Terminating worker pool.")
            mp_pool.terminate()
            mp_pool.join()  # Will wait. A KeybInt will kill this (py3)
        return 1
    else:
        if mp_pool is not None:
            log.print3("Closing worker pool.")
            mp_pool.close()
            mp_pool.join()

    # Save the final trained model.
    filename_to_save_with = fileToSaveTrainedCnnModelTo + ".final." + datetime_now_str()
    log.print3("Saving the final model at:" + str(filename_to_save_with))
    saver_all.save(sessionTf, filename_to_save_with + ".model.ckpt", write_meta_graph=False)

    log.print3("The whole do_training() function has finished.")
    return 0
Exemple #3
0
    def run_session(self, *args):
        (
            sess_device,
            model_params,
        ) = args

        graphTf = tf.Graph()

        with graphTf.as_default():
            with graphTf.device(
                    sess_device
            ):  # Throws an error if GPU is specified but not available.
                self._log.print3(
                    "=========== Making the CNN graph... ===============")
                cnn3d = Cnn3d()
                with tf.variable_scope("net"):
                    cnn3d.make_cnn_model(*model_params.get_args_for_arch(
                    ))  # Creates the network's graph (without optimizer).

            self._log.print3(
                "=========== Compiling the Testing Function ============")
            self._log.print3(
                "=======================================================\n")

            cnn3d.setup_ops_n_feeds_to_test(
                self._log,
                self._params.indices_fms_per_pathtype_per_layer_to_save)
            # Create the saver
            saver_all = tf.train.Saver()  # saver_net would suffice

        with tf.Session(graph=graphTf,
                        config=tf.ConfigProto(log_device_placement=False,
                                              device_count={
                                                  'CPU': 999,
                                                  'GPU': 99
                                              })) as sessionTf:
            file_to_load_params_from = self._params.get_path_to_load_model_from(
            )
            if file_to_load_params_from is not None:  # Load params
                self._log.print3(
                    "=========== Loading parameters from specified saved model ==============="
                )
                chkpt_fname = tf.train.latest_checkpoint(
                    file_to_load_params_from) if os.path.isdir(
                        file_to_load_params_from) else file_to_load_params_from
                self._log.print3("Loading parameters from:" + str(chkpt_fname))
                try:
                    saver_all.restore(sessionTf, chkpt_fname)
                    self._log.print3("Parameters were loaded.")
                except Exception as e:
                    handle_exception_tf_restore(self._log, e)

            else:
                self._ask_user_if_test_with_random(
                )  # Asks user whether to continue with randomly initialized model. It exits if no is given.
                self._log.print3("")
                self._log.print3(
                    "=========== Initializing network variables  ==============="
                )
                tf.variables_initializer(var_list=tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES, scope="net")).run()
                self._log.print3("Model variables were initialized.")

            self._log.print3("")
            self._log.print3(
                "======================================================")
            self._log.print3(
                "=========== Testing with the CNN model ===============")
            self._log.print3(
                "======================================================")

            res_code = inference_on_whole_volumes(
                *([sessionTf, cnn3d] + self._params.get_args_for_testing()))

        self._log.print3("")
        self._log.print3(
            "======================================================")
        self._log.print3(
            "=========== Testing session finished =================")
        self._log.print3(
            "======================================================")