示例#1
0
文件: train.py 项目: rixingw/DocuScan
 def get_meta_filename(self, start_new_model, train_dir):
     if start_new_model:
         logging.info("{}: Flag 'start_new_model' is set. Building a new model.".format(task_as_string(self.task)))
         return None
     latest_checkpoint = tf.train.latest_checkpoint(train_dir)
     if not latest_checkpoint:
         logging.info("{}: No checkpoint file found. Building a new model.",
                      task_as_string(self.task))
         return None
     meta_filename = latest_checkpoint + ".meta"
     if not gfile.Exists(meta_filename):
         logging.info("{}: No meta graphh file found. Building a new model.",
                      task_as_string(self.task))
         return None
     else:
         return meta_filename
示例#2
0
文件: train.py 项目: rixingw/DocuScan
 def remove_training_directory(self, train_dir):
     try:
         logging.info("{}: Removing Existing training directory.".format(task_as_string(self.task)))
         gfile.DeleteRecursively(train_dir)
     except:
         logging.error(
             "Failed to delete directory"
         )
示例#3
0
文件: train.py 项目: rixingw/DocuScan
    def export_model(self, global_step_val, saver, save_path, session):
        if global_step_val == self.last_model_export_step:
            return

        last_checkpoint = saver.save(session, save_path, global_step_val)

        model_dir = "{0}/export/step_{1}".format(self.train_dir, global_step_val)
        logging.info("{}: Exporting the model at step {} to {}".format(task_as_string(self.task), global_step_val, model_dir))

        self.model_exporter.export_model(
            model_dir=model_dir,
            global_step_val=global_step_val,
            last_checkpoint=last_checkpoint
        )
示例#4
0
文件: train.py 项目: rixingw/DocuScan
def main(unused_args):
    env = json.loads(os.environ.get("TF_CONFIG", "{}"))

    task_data = env.get("task", None) or {"type": "aster", "index": 0}
    task = type("TaskSpec", (object,), task_data)

    logging.set_verbosity(tf.logging.INFO)
    logging.info("{}: Tensorflow version: {}".format(task_as_string(task), tf.__version__))

    model = models.MDLSTMCTCModel()
    reader = get_reader()

    model_exporter = export_model.ModelExporter(
        slice_features=FLAGS.slice_features,
        model=model,
        reader=reader
    )

    Trainer(task, FLAGS.train_dir, model, reader, model_exporter,
            FLAGS.log_device_placement, FLAGS.max_steps,
            FLAGS.export_model_steps).run(start_new_model=FLAGS.start_new_model)
示例#5
0
文件: train.py 项目: rixingw/DocuScan
 def recover_model(self, meta_filename):
     logging.info(": Restoring from meta graph file {}",
                  task_as_string(self.task), meta_filename)
     return tf.train.import_meta_graph(meta_filename)
示例#6
0
文件: train.py 项目: rixingw/DocuScan
    def run(self, start_new_model=False):
        if start_new_model:
            self.remove_training_directory(self.train_dir)

        meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

        with tf.Graph().as_default() as graph:
            if meta_filename:
                saver = self.recover_model(meta_filename)

            if not meta_filename:
                saver = self.build_model(self.model, self.reader)

            global_step = tf.get_collection("global_step")[0]
            loss = tf.get_collection("loss")[0]
            predictions = tf.get_collection("predictions")[0]
            labels = tf.get_collection("labels")[0]
            train_batch = tf.get_collection("train_batch")[0]
            train_op = tf.get_collection("train_op")[0]

            decodedPrediction = []
            for i in range(FLAGS.beam_size):
                decodedPrediction.append(tf.get_collection("decodedPrediction{}".format(i))[0])
            ler = tf.get_collection("ler")[0]
            init_op = tf.global_variables_initializer()

            sv = tf.train.Supervisor(
                graph,
                logdir=self.train_dir,
                init_op=init_op,
                is_chief=True,
                global_step=global_step,
                save_model_secs=15 * 60,
                save_summaries_secs=120,
                saver=saver
            )

            logging.info("{}: Starting managed session.".format(task_as_string(self.task)))
            with sv.managed_session("", config=self.config) as sess:
                try:
                    logging.info("{}: Entering training loop.".format(task_as_string(self.task)))
                    decodedPr = None; labels_val=None; global_step_val=None
                    while (not sv.should_stop()) and (not self.max_steps_reached):
                        batch_start_time = time.time()
                        _, global_step_val = sess.run([train_op, global_step])
                        seconds_per_batch = time.time() - batch_start_time

                        feed = {}

                        if self.max_steps and self.max_steps <= global_step_val:
                            self.max_steps_reached = True

                        if global_step_val % FLAGS.display_step == 0:
                            global_step_val, loss_val, predictions_val, labels_val, labelRateError, decodedPr = \
                                sess.run([global_step, loss, predictions, labels, ler, decodedPrediction], feed)

                            feed[train_batch] = False
                            global_step_val_te, loss_val_te, predictions_val_te, label_val_te, labelRateError_te, decodedPr_te = \
                                sess.run([global_step, loss, predictions, labels, ler, decodedPrediction], feed)

                            examples_per_second = len(labels_val) / seconds_per_batch

                            if global_step_val % FLAGS.display_step_lme == 0:
                                lme = 0
                                eval_util.show_prediction(decodedPr_te, label_val_te, None, top_k=3)
                            else:
                                lme, lme_te = -1., -1.

                            logging.info("{}: training step {}".format(task_as_string(self.task), global_step_val)
                                         + " ler: {}".format(labelRateError) + "ler-te: {}".format(labelRateError_te)
                                         + " Loss: {}".format(loss_val) + " Loss-te: {}".format(loss_val_te))
                            sv.summary_writer.add_summary(
                                utils.makeSummary("model/labelRateError_train", labelRateError),
                                global_step_val
                            )
                            sv.summary_writer.add_summary(
                                utils.makeSummary("model/labelRateError_test", labelRateError_te)
                            )
                            sv.summary_writer.add_summary(
                                utils.makeSummary("model/lme_train", lme), global_step_val
                            )
                            sv.summary_writer.add_summary(
                                utils.makeSummary("model/lme_test", lme_te), global_step_val
                            )
                            sv.summary_writer.add_summary(
                                utils.makeSummary("model/loss_train", loss_val), global_step_val
                            )
                            sv.summary_writer.add_summary(
                                utils.makeSummary("model/loss_test", loss_val_te)
                            )
                            sv.summary_writer.add_summary(
                                utils.makeSummary("global_step/Examples/Second",
                                                  examples_per_second), global_step_val
                            )
                            sv.summary_writer.flush()

                            time_to_export = ((self.last_model_export_step == 0) or
                                              (global_step_val - self.last_model_export_step >= self.export_model_steps))

                            if time_to_export:
                                eval_util.show_prediction(decodedPr, labels_val)
                                self.export_model(global_step_val, sv.saver, sv.save_path, sess)
                                self.last_model_export_step = global_step_val

                    eval_util.show_prediction(decodedPr, labels_val)
                    self.export_model(global_step_val, sv.saver, sv.save_path, sess)

                except tf.errors.OutOfRangeError:
                    logging.info("{}: Done training -- epoch limit reached.",
                                task_as_string(self.task))
                logging.info("{}: Exited training loop.".format(task_as_string(self.task)))
                sv.stop()