def _finish_process(self, sess, coodinator, threads, model_saver, save_model_path, global_step_val, loss_val, best_loss_val): coodinator.request_stop() coodinator.join(threads) if save_model_path is not None: save_best_path, save_steps_path = self._get_save_path( save_model_path) if not self.arguments["save_best"] or self.extra_save_path: logger.info("save model.") model_saver.save(sess, save_steps_path, global_step_val) with open( osp.join(osp.dirname(save_steps_path), "best_loss_records.txt"), "a+") as recordf: recordf.write("step-loss: %s - %s\n" % (global_step_val, loss_val)) if self.arguments["save_best"] or self.extra_save_path: if global_step_val % self.arguments[ "save_best_check_steps"] != 0 and loss_val < best_loss_val: logger.info("save model.") model_saver.save(sess, save_best_path, global_step_val) with open( osp.join(osp.dirname(save_best_path), "best_loss_records.txt"), "a+") as recordf: recordf.write("step-loss: %s - %s\n" % (global_step_val, loss_val))
def _save_model_step(self, sess, model_saver, save_model_path, loss_val, best_loss_val, global_step_val): if save_model_path is not None: save_best_path, save_steps_path = self._get_save_path( save_model_path) # save best if ( self.arguments["save_best"] or self.extra_save_path ) and loss_val < best_loss_val and global_step_val % self.arguments[ "save_best_check_steps"] == 0: logger.info("save best to {}".format(save_best_path)) model_saver.save(sess, save_best_path, global_step_val) best_loss_val = loss_val with open( osp.join(osp.dirname(save_best_path), "best_loss_records.txt"), "a+") as recordf: recordf.write("step-loss: %s - %s\n" % (global_step_val, best_loss_val)) # save model per save_steps if (not self.arguments["save_best"] or self.extra_save_path ) and global_step_val % self.arguments["save_steps"] == 0: logger.info("save model to {}".format(save_steps_path)) model_saver.save(sess, save_steps_path, global_step_val) if loss_val < best_loss_val: best_loss_val = loss_val with open( osp.join(osp.dirname(save_steps_path), "best_loss_records.txt"), "a+") as recordf: recordf.write("step-loss: %s - %s\n" % (global_step_val, best_loss_val)) return best_loss_val
def _restore_pretained_variables(self, sess, pretrained_model_path, variables_to_restore, save_model_path=None, saver_for_restore=None): if pretrained_model_path: if variables_to_restore is None and saver_for_restore: variables_to_restore = saver_for_restore._var_list logger.info('Will attempt restore from %s: %s', pretrained_model_path, variables_to_restore) if saver_for_restore is None: assert variables_to_restore saver_for_restore = tf.train.Saver(variables_to_restore) self._maybe_restore_pretrained_model( sess, saver_for_restore, osp.dirname(pretrained_model_path), osp.dirname(save_model_path))
def _initialize_process(self, sess, save_model_path): model_saver = tf.train.Saver(max_to_keep=1) summary_writer = None if save_model_path is not None: summary_writer = tf.summary.FileWriter( osp.dirname(save_model_path), sess.graph) merged_summary = tf.summary.merge_all() sess.run(tf.global_variables_initializer()) current_steps = sess.run(self.global_step) coodinator = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess, coodinator) return model_saver, summary_writer, merged_summary, coodinator, threads, current_steps
def _resotre_training_model(self, sess, save_model_path, saver_for_model_restore=None): if self.arguments[ "should_restore_if_could"] and save_model_path is not None: model_ckpt = tf.train.get_checkpoint_state( osp.dirname(save_model_path)) model_checkpoint_exists = model_ckpt and model_ckpt.model_checkpoint_path if model_checkpoint_exists: self.check_best_loss_val() logger.info("resotre model from %s" % model_ckpt.model_checkpoint_path) if saver_for_model_restore is None: # vars = [] # for var in tf.global_variables(): # if "ExponentialMovingAverage" not in var.op.name: # vars.append(var) # else: # print("----------", var) # saver_for_model_restore = tf.train.Saver(vars) saver_for_model_restore = tf.train.Saver() saver_for_model_restore.restore( sess, model_ckpt.model_checkpoint_path)