示例#1
0
    def run_session(self, *args):
        (sess_device, model_params, reset_trainer) = args

        graphTf = tf.Graph()

        with graphTf.as_default():
            # Explicit device assignment, throws an error if GPU is specified but not available.
            with tf.device(sess_device):
                self._log.print3(
                    "=========== Making the CNN graph... ===============")
                cnn3d = Cnn3d()
                with tf.compat.v1.variable_scope("net"):
                    cnn3d.make_cnn_model(*model_params.get_args_for_arch())
                    # I have now created the CNN graph. But not yet the Optimizer's graph.
                    inp_plchldrs_train, inp_shapes_per_path_train = cnn3d.create_inp_plchldrs(
                        model_params.get_inp_dims_hr_path('train'), 'train')
                    inp_plchldrs_val, inp_shapes_per_path_val = cnn3d.create_inp_plchldrs(
                        model_params.get_inp_dims_hr_path('val'), 'val')
                    inp_plchldrs_test, inp_shapes_per_path_test = cnn3d.create_inp_plchldrs(
                        model_params.get_inp_dims_hr_path('test'), 'test')
                    p_y_given_x_train = cnn3d.apply(inp_plchldrs_train,
                                                    'train',
                                                    'train',
                                                    verbose=True,
                                                    log=self._log)
                    p_y_given_x_val = cnn3d.apply(inp_plchldrs_val,
                                                  'infer',
                                                  'val',
                                                  verbose=True,
                                                  log=self._log)
                    p_y_given_x_test = cnn3d.apply(inp_plchldrs_test,
                                                   'infer',
                                                   'test',
                                                   verbose=True,
                                                   log=self._log)

            # No explicit device assignment for the rest.
            # Because trained has piecewise_constant that is only on cpu, and so is saver.
            with tf.compat.v1.variable_scope("trainer"):
                self._log.print3("=========== Building Trainer ===========\n")
                trainer = Trainer(*(self._params.get_args_for_trainer() +
                                    [cnn3d]))
                trainer.compute_costs(self._log, p_y_given_x_train)
                trainer.create_optimizer(*self._params.get_args_for_optimizer(
                ))  # Trainer and net connect here.

            tensorboard_loggers = self.create_tensorboard_loggers(
                ['train', 'val'],
                graphTf,
                create_log=self._params.get_tensorboard_bool())

            # The below should not create any new tf.variables.
            self._log.print3(
                "=========== Compiling the Training Function ===========")
            self._log.print3(
                "=======================================================\n")
            cnn3d.setup_ops_n_feeds_to_train(
                self._log,
                inp_plchldrs_train,
                p_y_given_x_train,
                trainer.get_total_cost(),
                trainer.get_param_updates_wrt_total_cost()  # list of ops
            )

            self._log.print3(
                "=========== Compiling the Validation Function =========")
            cnn3d.setup_ops_n_feeds_to_val(self._log, inp_plchldrs_val,
                                           p_y_given_x_val)

            self._log.print3(
                "=========== Compiling the Testing Function ============")
            # For validation with full segmentation
            cnn3d.setup_ops_n_feeds_to_test(
                self._log, inp_plchldrs_test, p_y_given_x_test,
                self._params.indices_fms_per_pathtype_per_layer_to_save)

            # Create the savers
            saver_all = tf.compat.v1.train.Saver(
            )  # Will be used during training for saving everything.
            # Alternative: tf.train.Saver([v for v in tf.all_variables() if v.name.startswith("net"])
            collection_vars_net = tf.compat.v1.get_collection(
                tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope="net")
            saver_net = tf.compat.v1.train.Saver(
                var_list=collection_vars_net
            )  # Used to load the net's parameters.
            collection_vars_trainer = tf.compat.v1.get_collection(
                tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope="trainer")
            saver_trainer = tf.compat.v1.train.Saver(
                var_list=collection_vars_trainer
            )  # Used to load the trainer's parameters.

            # TF2: dict_vars_net = {'net_var'+str(i): v for i, v in enumerate(collection_vars_net)}
            # TF2: dict_vars_trainer = {'trainer_var'+str(i): v for i, v in enumerate(collection_vars_trainer)}
            # TF2: dict_vars_all = dict_vars_net.copy()
            # TF2: for key in dict_vars_trainer:
            # TF2:     dict_vars_all[key] = dict_vars_trainer[key]
            # TF2: ckpt_all = tf.train.Checkpoint(**dict_vars_all)
            # TF2: ckpt_net = tf.train.Checkpoint(**dict_vars_net)
            # TF2: ckpt_trainer = tf.train.Checkpoint(**dict_vars_trainer)

        # self._print_vars_in_collection(collection_vars_net, "net")
        # self._print_vars_in_collection(collection_vars_trainer, "trainer")

        with tf.compat.v1.Session(graph=graphTf,
                                  config=tf.compat.v1.ConfigProto(
                                      log_device_placement=False,
                                      device_count={
                                          'CPU': 999,
                                          'GPU': 99
                                      })) as sessionTf:
            # Load or initialize parameters
            file_to_load_params_from = self._params.get_path_to_load_model_from(
            )
            if file_to_load_params_from is not None:  # Load params
                self._log.print3(
                    "=========== Loading parameters from specified saved model ==============="
                )
                chkpt_fname = tf.train.latest_checkpoint(file_to_load_params_from) \
                    if os.path.isdir(file_to_load_params_from) else file_to_load_params_from
                self._log.print3("Loading checkpoint file:" + str(chkpt_fname))
                self._log.print3("Loading network parameters...")
                try:
                    saver_net.restore(sessionTf, chkpt_fname)
                    # TF2: status = ckpt_net.restore(chkpt_fname); #status.assert_consumed() # Passes if ckpt and program vars match exactly.

                    self._log.print3("Network parameters were loaded.")
                except Exception as e:
                    handle_exception_tf_restore(self._log, e)

                if not reset_trainer:
                    self._log.print3("Loading trainer parameters...")
                    saver_trainer.restore(sessionTf, chkpt_fname)
                    # TF2: status = ckpt_trainer.restore(chkpt_fname); #status.assert_consumed() # Passes if ckpt and program vars match exactly.
                    self._log.print3("Trainer parameters were loaded.")
                else:
                    self._log.print3(
                        "Reset of trainer parameters was requested. Re-initializing them..."
                    )
                    tf.compat.v1.variables_initializer(
                        var_list=collection_vars_trainer).run()
                    self._log.print3("Trainer parameters re-initialized.")
            else:
                self._log.print3(
                    "=========== Initializing network and trainer variables  ==============="
                )
                # Initializes all.
                # tf.variables_initializer(var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) ).run()
                # Initialize separate as below, so that in case I miss a variable, I will get an error and I will know.
                tf.compat.v1.variables_initializer(
                    var_list=collection_vars_net).run()
                tf.compat.v1.variables_initializer(
                    var_list=collection_vars_trainer).run()
                self._log.print3("All variables were initialized.")

                filename_to_save_with = self._params.filepath_to_save_models + ".initial." + datetime_now_str(
                )
                self._log.print3("Saving the initial model at:" +
                                 str(filename_to_save_with))
                saver_all.save(sessionTf,
                               filename_to_save_with + ".model.ckpt",
                               write_meta_graph=False)
                # TF2: ckpt_all.save(file_prefix = filename_to_save_with+".all.ckpt2")
                # TF2: ckpt_net.save(file_prefix = filename_to_save_with+".net.ckpt2")
                # TF2: ckpt_trainer.save(file_prefix = filename_to_save_with+".trainer.ckpt2")

                # tf.train.write_graph(graph_or_graph_def=sessionTf.graph.as_graph_def(),
                #                      logdir="", name=filename_to_save_with+".graph.pb", as_text=False)

            self._log.print3("")
            self._log.print3(
                "=======================================================")
            self._log.print3(
                "============== Training the CNN model =================")
            self._log.print3(
                "=======================================================")

            do_training(*([sessionTf, saver_all, cnn3d, trainer, tensorboard_loggers] +\
                          self._params.get_args_for_train_routine() +\
                          [inp_shapes_per_path_train, inp_shapes_per_path_val, inp_shapes_per_path_test]))

            # TF2: ckpt_all.save(file_prefix = filename_to_save_with+".all.FINAL.ckpt2")
            # TF2: ckpt_net.save(file_prefix = filename_to_save_with+".net.FINAL.ckpt2")
            # TF2: ckpt_trainer.save(file_prefix = filename_to_save_with+".trainer.FINAL.ckpt2")

        self._log.print3(
            "\n=======================================================")
        self._log.print3(
            "=========== Training session finished =================")
        self._log.print3(
            "=======================================================")
示例#2
0
    def run_session(self, *args):
        (sess_device, model_params, reset_trainer) = args

        graphTf = tf.Graph()

        with graphTf.as_default():
            with graphTf.device(
                    sess_device
            ):  # Explicit device assignment, throws an error if GPU is specified but not available.
                self._log.print3(
                    "=========== Making the CNN graph... ===============")
                cnn3d = Cnn3d()
                with tf.variable_scope("net"):
                    cnn3d.make_cnn_model(*model_params.get_args_for_arch())
                    # I have now created the CNN graph. But not yet the Optimizer's graph.

            # No explicit device assignment for the rest. Because trained has piecewise_constant that is only on cpu, and so is saver.
            with tf.variable_scope("trainer"):
                self._log.print3("=========== Building Trainer ===========\n")
                trainer = Trainer(*(self._params.get_args_for_trainer() +
                                    [cnn3d]))
                trainer.create_optimizer(*self._params.get_args_for_optimizer(
                ))  # Trainer and net connect here.

            # The below should not create any new tf.variables.
            self._log.print3(
                "=========== Compiling the Training Function ===========")
            self._log.print3(
                "=======================================================\n")
            cnn3d.setup_ops_n_feeds_to_train(
                self._log,
                trainer.get_total_cost(),
                trainer.get_param_updates_wrt_total_cost()  # list of ops
            )

            self._log.print3(
                "=========== Compiling the Validation Function =========")
            cnn3d.setup_ops_n_feeds_to_val(self._log)

            self._log.print3(
                "=========== Compiling the Testing Function ============")
            cnn3d.setup_ops_n_feeds_to_test(
                self._log,
                self._params.indices_fms_per_pathtype_per_layer_to_save
            )  # For validation with full segmentation

            # Create the savers
            saver_all = tf.train.Saver(
            )  # Will be used during training for saving everything.
            collection_vars_net = tf.get_collection(
                tf.GraphKeys.GLOBAL_VARIABLES, scope="net"
            )  # Alternative: tf.train.Saver([v for v in tf.all_variables() if v.name.startswith("net"])
            saver_net = tf.train.Saver(var_list=collection_vars_net
                                       )  # Used to load the net's parameters.
            collection_vars_trainer = tf.get_collection(
                tf.GraphKeys.GLOBAL_VARIABLES, scope="trainer")
            saver_trainer = tf.train.Saver(
                var_list=collection_vars_trainer
            )  # Used to load the trainer's parameters.

        # self._print_vars_in_collection(collection_vars_net, "net")
        # self._print_vars_in_collection(collection_vars_trainer, "trainer")

        with tf.Session(graph=graphTf,
                        config=tf.ConfigProto(allow_soft_placement=True,
                                              log_device_placement=False,
                                              device_count={
                                                  'CPU': 999,
                                                  'GPU': 99
                                              })) as sessionTf:
            # with tf.Session( graph=graphTf, config=tf.ConfigProto(log_device_placement=False, device_count={'CPU':999, 'GPU':99})
            # Load or initialize parameters
            file_to_load_params_from = self._params.get_path_to_load_model_from(
            )
            if file_to_load_params_from is not None:  # Load params
                self._log.print3(
                    "=========== Loading parameters from specified saved model ==============="
                )
                chkpt_fname = tf.train.latest_checkpoint(
                    file_to_load_params_from) if os.path.isdir(
                        file_to_load_params_from) else file_to_load_params_from
                self._log.print3("Loading checkpoint file:" + str(chkpt_fname))
                self._log.print3("Loading network parameters...")
                try:
                    saver_net.restore(sessionTf, chkpt_fname)
                    self._log.print3("Network parameters were loaded.")
                except Exception as e:
                    handle_exception_tf_restore(self._log, e)

                if not reset_trainer:
                    self._log.print3("Loading trainer parameters...")
                    saver_trainer.restore(sessionTf, chkpt_fname)
                    self._log.print3("Trainer parameters were loaded.")
                else:
                    self._log.print3(
                        "Reset of trainer parameters was requested. Re-initializing them..."
                    )
                    tf.variables_initializer(
                        var_list=collection_vars_trainer).run()
                    self._log.print3("Trainer parameters re-initialized.")
            else:
                self._log.print3(
                    "=========== Initializing network and trainer variables  ==============="
                )
                # tf.variables_initializer(var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) ).run() # Initializes all.
                # Initialize separate as below, so that in case I miss a variable, I will get an error and I will know.
                tf.variables_initializer(var_list=collection_vars_net).run()
                tf.variables_initializer(
                    var_list=collection_vars_trainer).run()
                self._log.print3("All variables were initialized.")

                filename_to_save_with = self._params.filepath_to_save_models + ".initial." + datetimeNowAsStr(
                )
                self._log.print3("Saving the initial model at:" +
                                 str(filename_to_save_with))
                saver_all.save(sessionTf,
                               filename_to_save_with + ".model.ckpt",
                               write_meta_graph=False)
                # tf.train.write_graph( graph_or_graph_def=sessionTf.graph.as_graph_def(), logdir="", name=filename_to_save_with+".graph.pb", as_text=False)

            self._log.print3("")
            self._log.print3(
                "=======================================================")
            self._log.print3(
                "============== Training the CNN model =================")
            self._log.print3(
                "=======================================================\n")

            res_code = do_training(
                *([sessionTf, saver_all, cnn3d, trainer] +
                  self._params.get_args_for_train_routine()))

        self._log.print3(
            "\n=======================================================")
        self._log.print3(
            "=========== Training session finished =================")
        self._log.print3(
            "=======================================================")
示例#3
0
 def run_session(self, *args):
     (sess_device,
      model_params,
      reset_trainer) = args
     
     graphTf = tf.Graph()
     
     with graphTf.as_default():
         with graphTf.device(sess_device): # Explicit device assignment, throws an error if GPU is specified but not available.
             self._log.print3("=========== Making the CNN graph... ===============")
             cnn3d = Cnn3d()
             with tf.variable_scope("net"):
                 cnn3d.make_cnn_model( *model_params.get_args_for_arch() )
                 # I have now created the CNN graph. But not yet the Optimizer's graph.
         
         # No explicit device assignment for the rest. Because trained has piecewise_constant that is only on cpu, and so is saver.        
         with tf.variable_scope("trainer"):
             self._log.print3("=========== Building Trainer ===========\n")
             trainer = Trainer( *( self._params.get_args_for_trainer() + [cnn3d] ) )
             trainer.create_optimizer( *self._params.get_args_for_optimizer() ) # Trainer and net connect here.
             
         # The below should not create any new tf.variables.
         self._log.print3("=========== Compiling the Training Function ===========")
         self._log.print3("=======================================================\n")
         cnn3d.setup_ops_n_feeds_to_train( self._log,
                                           trainer.get_total_cost(),
                                           trainer.get_param_updates_wrt_total_cost() # list of ops
                                         )
         
         self._log.print3("=========== Compiling the Validation Function =========")
         cnn3d.setup_ops_n_feeds_to_val( self._log )
         
         self._log.print3("=========== Compiling the Testing Function ============")
         cnn3d.setup_ops_n_feeds_to_test( self._log,
                                          self._params.indices_fms_per_pathtype_per_layer_to_save ) # For validation with full segmentation
         
         # Create the savers
         saver_all = tf.train.Saver() # Will be used during training for saving everything.
         collection_vars_net = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="net") # Alternative: tf.train.Saver([v for v in tf.all_variables() if v.name.startswith("net"])
         saver_net = tf.train.Saver( var_list = collection_vars_net ) # Used to load the net's parameters.
         collection_vars_trainer = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="trainer")
         saver_trainer = tf.train.Saver( var_list = collection_vars_trainer ) # Used to load the trainer's parameters.
         
     # self._print_vars_in_collection(collection_vars_net, "net")
     # self._print_vars_in_collection(collection_vars_trainer, "trainer")
     
     with tf.Session( graph=graphTf, config=tf.ConfigProto(log_device_placement=False, device_count={'CPU':999, 'GPU':99}) ) as sessionTf:
         # Load or initialize parameters
         file_to_load_params_from = self._params.get_path_to_load_model_from()
         if file_to_load_params_from is not None: # Load params
             self._log.print3("=========== Loading parameters from specified saved model ===============")
             chkpt_fname = tf.train.latest_checkpoint( file_to_load_params_from ) if os.path.isdir( file_to_load_params_from ) else file_to_load_params_from
             self._log.print3("Loading checkpoint file:" + str(chkpt_fname))
             self._log.print3("Loading network parameters...")
             try:
                 saver_net.restore(sessionTf, chkpt_fname)
                 self._log.print3("Network parameters were loaded.")
             except Exception as e: handle_exception_tf_restore(self._log, e)
             
             if not reset_trainer:
                 self._log.print3("Loading trainer parameters...")
                 saver_trainer.restore(sessionTf, chkpt_fname)
                 self._log.print3("Trainer parameters were loaded.")
             else:
                 self._log.print3("Reset of trainer parameters was requested. Re-initializing them...")
                 tf.variables_initializer(var_list = collection_vars_trainer).run()
                 self._log.print3("Trainer parameters re-initialized.")
         else :
             self._log.print3("=========== Initializing network and trainer variables  ===============")
             # tf.variables_initializer(var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) ).run() # Initializes all.
             # Initialize separate as below, so that in case I miss a variable, I will get an error and I will know.
             tf.variables_initializer(var_list = collection_vars_net).run()
             tf.variables_initializer(var_list = collection_vars_trainer).run()
             self._log.print3("All variables were initialized.")
             
             filename_to_save_with = self._params.filepath_to_save_models + ".initial." + datetimeNowAsStr()
             self._log.print3("Saving the initial model at:" + str(filename_to_save_with))
             saver_all.save( sessionTf, filename_to_save_with+".model.ckpt", write_meta_graph=False )
             # tf.train.write_graph( graph_or_graph_def=sessionTf.graph.as_graph_def(), logdir="", name=filename_to_save_with+".graph.pb", as_text=False)
          
         self._log.print3("")
         self._log.print3("=======================================================")
         self._log.print3("============== Training the CNN model =================")
         self._log.print3("=======================================================\n")
         
         res_code = do_training( *( [sessionTf, saver_all, cnn3d, trainer] + self._params.get_args_for_train_routine() ) )
         
         
     self._log.print3("\n=======================================================")
     self._log.print3("=========== Training session finished =================")
     self._log.print3("=======================================================")