def get_meta_filename(self, start_new_model, train_dir): if start_new_model: logging.info("{}: Flag 'start_new_model' is set. Building a new model.".format(task_as_string(self.task))) return None latest_checkpoint = tf.train.latest_checkpoint(train_dir) if not latest_checkpoint: logging.info("{}: No checkpoint file found. Building a new model.", task_as_string(self.task)) return None meta_filename = latest_checkpoint + ".meta" if not gfile.Exists(meta_filename): logging.info("{}: No meta graphh file found. Building a new model.", task_as_string(self.task)) return None else: return meta_filename
def remove_training_directory(self, train_dir): try: logging.info("{}: Removing Existing training directory.".format(task_as_string(self.task))) gfile.DeleteRecursively(train_dir) except: logging.error( "Failed to delete directory" )
def export_model(self, global_step_val, saver, save_path, session): if global_step_val == self.last_model_export_step: return last_checkpoint = saver.save(session, save_path, global_step_val) model_dir = "{0}/export/step_{1}".format(self.train_dir, global_step_val) logging.info("{}: Exporting the model at step {} to {}".format(task_as_string(self.task), global_step_val, model_dir)) self.model_exporter.export_model( model_dir=model_dir, global_step_val=global_step_val, last_checkpoint=last_checkpoint )
def main(unused_args): env = json.loads(os.environ.get("TF_CONFIG", "{}")) task_data = env.get("task", None) or {"type": "aster", "index": 0} task = type("TaskSpec", (object,), task_data) logging.set_verbosity(tf.logging.INFO) logging.info("{}: Tensorflow version: {}".format(task_as_string(task), tf.__version__)) model = models.MDLSTMCTCModel() reader = get_reader() model_exporter = export_model.ModelExporter( slice_features=FLAGS.slice_features, model=model, reader=reader ) Trainer(task, FLAGS.train_dir, model, reader, model_exporter, FLAGS.log_device_placement, FLAGS.max_steps, FLAGS.export_model_steps).run(start_new_model=FLAGS.start_new_model)
def recover_model(self, meta_filename): logging.info(": Restoring from meta graph file {}", task_as_string(self.task), meta_filename) return tf.train.import_meta_graph(meta_filename)
def run(self, start_new_model=False): if start_new_model: self.remove_training_directory(self.train_dir) meta_filename = self.get_meta_filename(start_new_model, self.train_dir) with tf.Graph().as_default() as graph: if meta_filename: saver = self.recover_model(meta_filename) if not meta_filename: saver = self.build_model(self.model, self.reader) global_step = tf.get_collection("global_step")[0] loss = tf.get_collection("loss")[0] predictions = tf.get_collection("predictions")[0] labels = tf.get_collection("labels")[0] train_batch = tf.get_collection("train_batch")[0] train_op = tf.get_collection("train_op")[0] decodedPrediction = [] for i in range(FLAGS.beam_size): decodedPrediction.append(tf.get_collection("decodedPrediction{}".format(i))[0]) ler = tf.get_collection("ler")[0] init_op = tf.global_variables_initializer() sv = tf.train.Supervisor( graph, logdir=self.train_dir, init_op=init_op, is_chief=True, global_step=global_step, save_model_secs=15 * 60, save_summaries_secs=120, saver=saver ) logging.info("{}: Starting managed session.".format(task_as_string(self.task))) with sv.managed_session("", config=self.config) as sess: try: logging.info("{}: Entering training loop.".format(task_as_string(self.task))) decodedPr = None; labels_val=None; global_step_val=None while (not sv.should_stop()) and (not self.max_steps_reached): batch_start_time = time.time() _, global_step_val = sess.run([train_op, global_step]) seconds_per_batch = time.time() - batch_start_time feed = {} if self.max_steps and self.max_steps <= global_step_val: self.max_steps_reached = True if global_step_val % FLAGS.display_step == 0: global_step_val, loss_val, predictions_val, labels_val, labelRateError, decodedPr = \ sess.run([global_step, loss, predictions, labels, ler, decodedPrediction], feed) feed[train_batch] = False global_step_val_te, loss_val_te, predictions_val_te, label_val_te, labelRateError_te, decodedPr_te = \ sess.run([global_step, loss, predictions, labels, ler, decodedPrediction], feed) examples_per_second = len(labels_val) / seconds_per_batch if global_step_val % FLAGS.display_step_lme == 0: lme = 0 eval_util.show_prediction(decodedPr_te, label_val_te, None, top_k=3) else: lme, lme_te = -1., -1. logging.info("{}: training step {}".format(task_as_string(self.task), global_step_val) + " ler: {}".format(labelRateError) + "ler-te: {}".format(labelRateError_te) + " Loss: {}".format(loss_val) + " Loss-te: {}".format(loss_val_te)) sv.summary_writer.add_summary( utils.makeSummary("model/labelRateError_train", labelRateError), global_step_val ) sv.summary_writer.add_summary( utils.makeSummary("model/labelRateError_test", labelRateError_te) ) sv.summary_writer.add_summary( utils.makeSummary("model/lme_train", lme), global_step_val ) sv.summary_writer.add_summary( utils.makeSummary("model/lme_test", lme_te), global_step_val ) sv.summary_writer.add_summary( utils.makeSummary("model/loss_train", loss_val), global_step_val ) sv.summary_writer.add_summary( utils.makeSummary("model/loss_test", loss_val_te) ) sv.summary_writer.add_summary( utils.makeSummary("global_step/Examples/Second", examples_per_second), global_step_val ) sv.summary_writer.flush() time_to_export = ((self.last_model_export_step == 0) or (global_step_val - self.last_model_export_step >= self.export_model_steps)) if time_to_export: eval_util.show_prediction(decodedPr, labels_val) self.export_model(global_step_val, sv.saver, sv.save_path, sess) self.last_model_export_step = global_step_val eval_util.show_prediction(decodedPr, labels_val) self.export_model(global_step_val, sv.saver, sv.save_path, sess) except tf.errors.OutOfRangeError: logging.info("{}: Done training -- epoch limit reached.", task_as_string(self.task)) logging.info("{}: Exited training loop.".format(task_as_string(self.task))) sv.stop()