示例#1
0
 def _finish_process(self, sess, coodinator, threads, model_saver,
                     save_model_path, global_step_val, loss_val,
                     best_loss_val):
     coodinator.request_stop()
     coodinator.join(threads)
     if save_model_path is not None:
         save_best_path, save_steps_path = self._get_save_path(
             save_model_path)
         if not self.arguments["save_best"] or self.extra_save_path:
             logger.info("save model.")
             model_saver.save(sess, save_steps_path, global_step_val)
             with open(
                     osp.join(osp.dirname(save_steps_path),
                              "best_loss_records.txt"), "a+") as recordf:
                 recordf.write("step-loss: %s - %s\n" %
                               (global_step_val, loss_val))
         if self.arguments["save_best"] or self.extra_save_path:
             if global_step_val % self.arguments[
                     "save_best_check_steps"] != 0 and loss_val < best_loss_val:
                 logger.info("save model.")
                 model_saver.save(sess, save_best_path, global_step_val)
                 with open(
                         osp.join(osp.dirname(save_best_path),
                                  "best_loss_records.txt"),
                         "a+") as recordf:
                     recordf.write("step-loss: %s - %s\n" %
                                   (global_step_val, loss_val))
示例#2
0
 def _save_model_step(self, sess, model_saver, save_model_path, loss_val,
                      best_loss_val, global_step_val):
     if save_model_path is not None:
         save_best_path, save_steps_path = self._get_save_path(
             save_model_path)
         # save best
         if (
                 self.arguments["save_best"] or self.extra_save_path
         ) and loss_val < best_loss_val and global_step_val % self.arguments[
                 "save_best_check_steps"] == 0:
             logger.info("save best to {}".format(save_best_path))
             model_saver.save(sess, save_best_path, global_step_val)
             best_loss_val = loss_val
             with open(
                     osp.join(osp.dirname(save_best_path),
                              "best_loss_records.txt"), "a+") as recordf:
                 recordf.write("step-loss: %s - %s\n" %
                               (global_step_val, best_loss_val))
         # save model per save_steps
         if (not self.arguments["save_best"] or self.extra_save_path
             ) and global_step_val % self.arguments["save_steps"] == 0:
             logger.info("save model to {}".format(save_steps_path))
             model_saver.save(sess, save_steps_path, global_step_val)
             if loss_val < best_loss_val:
                 best_loss_val = loss_val
                 with open(
                         osp.join(osp.dirname(save_steps_path),
                                  "best_loss_records.txt"),
                         "a+") as recordf:
                     recordf.write("step-loss: %s - %s\n" %
                                   (global_step_val, best_loss_val))
     return best_loss_val
示例#3
0
 def _restore_pretained_variables(self,
                                  sess,
                                  pretrained_model_path,
                                  variables_to_restore,
                                  save_model_path=None,
                                  saver_for_restore=None):
     if pretrained_model_path:
         if variables_to_restore is None and saver_for_restore:
             variables_to_restore = saver_for_restore._var_list
         logger.info('Will attempt restore from %s: %s',
                     pretrained_model_path, variables_to_restore)
         if saver_for_restore is None:
             assert variables_to_restore
             saver_for_restore = tf.train.Saver(variables_to_restore)
         self._maybe_restore_pretrained_model(
             sess, saver_for_restore, osp.dirname(pretrained_model_path),
             osp.dirname(save_model_path))
示例#4
0
    def _initialize_process(self, sess, save_model_path):
        model_saver = tf.train.Saver(max_to_keep=1)
        summary_writer = None
        if save_model_path is not None:
            summary_writer = tf.summary.FileWriter(
                osp.dirname(save_model_path), sess.graph)
        merged_summary = tf.summary.merge_all()
        sess.run(tf.global_variables_initializer())
        current_steps = sess.run(self.global_step)
        coodinator = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coodinator)

        return model_saver, summary_writer, merged_summary, coodinator, threads, current_steps
示例#5
0
 def _resotre_training_model(self,
                             sess,
                             save_model_path,
                             saver_for_model_restore=None):
     if self.arguments[
             "should_restore_if_could"] and save_model_path is not None:
         model_ckpt = tf.train.get_checkpoint_state(
             osp.dirname(save_model_path))
         model_checkpoint_exists = model_ckpt and model_ckpt.model_checkpoint_path
         if model_checkpoint_exists:
             self.check_best_loss_val()
             logger.info("resotre model from %s" %
                         model_ckpt.model_checkpoint_path)
             if saver_for_model_restore is None:
                 # vars = []
                 # for var in tf.global_variables():
                 #     if "ExponentialMovingAverage" not in var.op.name:
                 #         vars.append(var)
                 #     else:
                 #         print("----------", var)
                 # saver_for_model_restore = tf.train.Saver(vars)
                 saver_for_model_restore = tf.train.Saver()
             saver_for_model_restore.restore(
                 sess, model_ckpt.model_checkpoint_path)