Пример #1
0
    def tensor_define(self, model_path, charset_path, label_file):

        converter = LabelConverter(chars_file=charset_path)
        dataset = ImgDataset(label_file,
                             converter,
                             batch_size=1,
                             shuffle=False)
        model = CRNN(self.cfg, num_classes=converter.num_classes)

        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        sess.run(dataset.init_op)

        sess = self.load_model(model_path=model_path, sess=sess)
        global_step = sess.run(model.global_step)

        self.sess = sess
        self.converter = converter
        self.dataset = dataset

        self.input = [
            model.inputs, model.labels, model.bat_labels, model.len_labels,
            model.char_num, model.char_pos_init, model.is_training
        ]
        self.output = [model.dense_decoded, model.char_pos, model.embedding]
        self.global_step = global_step
Пример #2
0
def infer(args):
    converter = LabelConverter(chars_file=args.chars_file)
    dataset = ImgDataset(args.infer_dir, converter, args.infer_batch_size, shuffle=False)

    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        restore_ckpt(sess, args.ckpt_dir)

        # for node in sess.graph.as_graph_def().node:
        #     print(node.name)

        # https://stackoverflow.com/questions/46912721/tensorflow-restore-model-with-sparse-placeholder
        labels_placeholder = tf.SparseTensor(
            values=sess.graph.get_tensor_by_name('labels/values:0'),
            indices=sess.graph.get_tensor_by_name('labels/indices:0'),
            dense_shape=sess.graph.get_tensor_by_name('labels/shape:0')
        )

        feeds = {
            'inputs': sess.graph.get_tensor_by_name('inputs:0'),
            'is_training': sess.graph.get_tensor_by_name('is_training:0'),
            'labels': labels_placeholder
        }

        fetches = [
            sess.graph.get_tensor_by_name('SparseToDense:0'),  # dense_decoded
            sess.graph.get_tensor_by_name('Mean_1:0'),  # mean edit distance
            sess.graph.get_tensor_by_name('edit_distance:0')  # batch edit distances
        ]

        validation(sess, feeds, fetches,
                   dataset, converter, args.result_dir, name='infer',
                   print_batch_info=True, copy_failed=args.infer_copy_failed)
Пример #3
0
    def __init__(self, args):
        self.args = args
        self.cfg = load_config(args.cfg_name)

        self.converter = LabelConverter(chars_file=args.chars_file)

        self.tr_ds = ImgDataset(args.train_dir, self.converter,
                                self.cfg.batch_size)

        self.cfg.lr_boundaries = [
            self.tr_ds.num_batches * epoch
            for epoch in self.cfg.lr_decay_epochs
        ]
        self.cfg.lr_values = [
            self.cfg.lr * (self.cfg.lr_decay_rate**i)
            for i in range(len(self.cfg.lr_boundaries) + 1)
        ]

        if args.val_dir is None:
            self.val_ds = None
        else:
            self.val_ds = ImgDataset(args.val_dir,
                                     self.converter,
                                     self.cfg.batch_size,
                                     shuffle=False)

        if args.test_dir is None:
            self.test_ds = None
        else:
            # Test images often have different size, so set batch_size to 1
            self.test_ds = ImgDataset(args.test_dir,
                                      self.converter,
                                      shuffle=False,
                                      batch_size=1)

        self.model = CRNN(self.cfg, num_classes=self.converter.num_classes)
        self.sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True))

        self.epoch_start_index = 0
        self.batch_start_index = 0
Пример #4
0
def validation(sess,
               feeds,
               fetches,
               dataset: ImgDataset,
               converter,
               result_dir,
               name,
               step=None,
               print_batch_info=False,
               copy_failed=False):
    """
    Save file name: {acc}_{step}.txt
    :param sess: tensorflow session
    :param model: crnn network
    :param result_dir:
    :param name: val, test, infer
    :return:
    """
    sess.run(dataset.init_op)

    img_paths = []
    predicts = []
    labels = []
    edit_distances = []
    total_batch_time = 0

    for batch in range(dataset.num_batches):
        img_batch, label_batch, batch_labels, batch_img_paths = dataset.get_next_batch(
            sess)

        batch_start_time = time.time()

        feed = {
            feeds['inputs']: img_batch,
            feeds['labels']: label_batch,
            feeds['is_training']: False
        }

        batch_predicts, edit_distance, batch_edit_distances = sess.run(
            fetches, feed)
        batch_predicts = [
            converter.decode(p, CRNN.CTC_INVALID_INDEX) for p in batch_predicts
        ]

        img_paths.extend(batch_img_paths)
        predicts.extend(batch_predicts)
        labels.extend(batch_labels)
        edit_distances.extend(batch_edit_distances)

        acc, correct_count = calculate_accuracy(batch_predicts, batch_labels)
        batch_time = time.time() - batch_start_time
        total_batch_time += batch_time
        if print_batch_info:
            print(
                "Batch [{}/{}] {:.03f}s accuracy: {:.03f} ({}/{}), edit_distance: {:.03f}"
                .format(batch, dataset.num_batches, batch_time, acc,
                        correct_count, dataset.batch_size, edit_distance))

    acc, correct_count = calculate_accuracy(predicts, labels)
    edit_distance_mean = calculate_edit_distance_mean(edit_distances)
    acc_str = "Accuracy: {:.03f} ({}/{}), Average edit distance: {:.03f}, Average batch time: {:.03f}" \
        .format(acc, correct_count, dataset.size, edit_distance_mean, total_batch_time / dataset.num_batches)

    print(acc_str)

    save_dir = os.path.join(result_dir, name)
    utils.check_dir_exist(save_dir)
    if step is not None:
        file_path = os.path.join(save_dir, '%.3f_%d.txt' % (acc, step))
    else:
        file_path = os.path.join(save_dir, '%.3f.txt' % acc)

    print("Write result to %s" % file_path)
    with open(file_path, 'w', encoding='utf-8') as f:
        for i, p_label in enumerate(predicts):
            t_label = labels[i]
            f.write("{:08d}\n".format(i))
            f.write("input:   {:17s} length: {}\n".format(
                t_label, len(t_label)))
            f.write("predict: {:17s} length: {}\n".format(
                p_label, len(p_label)))
            f.write("all match:  {}\n".format(1 if t_label == p_label else 0))
            f.write("edit distance:  {}\n".format(edit_distances[i]))
            f.write('-' * 30 + '\n')
        f.write(acc_str + "\n")

    # Copy image not all match to a dir
    if copy_failed:
        failed_infer_img_dir = file_path[:-4] + "_failed"
        if os.path.exists(failed_infer_img_dir) and os.path.isdir(
                failed_infer_img_dir):
            shutil.rmtree(failed_infer_img_dir)

        utils.check_dir_exist(failed_infer_img_dir)

        failed_image_indices = []
        for i, val in enumerate(edit_distances):
            if val != 0:
                failed_image_indices.append(i)

        for i in failed_image_indices:
            img_path = img_paths[i]
            img_name = img_path.split("/")[-1]
            dst_path = os.path.join(failed_infer_img_dir, img_name)
            shutil.copyfile(img_path, dst_path)

        failed_infer_result_file_path = os.path.join(failed_infer_img_dir,
                                                     "result.txt")
        with open(failed_infer_result_file_path, 'w', encoding='utf-8') as f:
            for i in failed_image_indices:
                p_label = predicts[i]
                t_label = labels[i]
                f.write("{:08d}\n".format(i))
                f.write("input:   {:17s} length: {}\n".format(
                    t_label, len(t_label)))
                f.write("predict: {:17s} length: {}\n".format(
                    p_label, len(p_label)))
                f.write("edit distance:  {}\n".format(edit_distances[i]))
                f.write('-' * 30 + '\n')

    return acc, edit_distance_mean
Пример #5
0
    def train(self, log_dir, restore, log_step, ckpt_dir, val_step, cfg_name, chars_file, train_txt, val_txt, test_txt, result_dir):

        cfg = load_config(cfg_name)

        converter = LabelConverter(chars_file=chars_file)

        tr_ds = ImgDataset(train_txt, converter, cfg.batch_size)

        cfg.lr_boundaries = [10000]
        cfg.lr_values = [cfg.lr * (cfg.lr_decay_rate ** i) for i in
                              range(len(cfg.lr_boundaries) + 1)]

        if val_txt is None:
            val_ds = None
        else:
            val_ds = ImgDataset(val_txt, converter, cfg.batch_size, shuffle=False)

        if test_txt is None:
            test_ds = None
        else:
            # Test images often have different size, so set batch_size to 1
            test_ds = ImgDataset(test_txt, converter, shuffle=False, batch_size=1)

        model = CRNN(cfg, num_classes=converter.num_classes)

        epoch_start_index = 0
        batch_start_index = 0

        config=tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = 0.8
        sess = tf.Session(config=config)
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=8)
        train_writer = tf.summary.FileWriter(log_dir, sess.graph)

        if restore:
            self._restore(sess, saver, model,tr_ds, ckpt_dir)

        print('Begin training...')
        for epoch in range(epoch_start_index, cfg.epochs):
            sess.run(tr_ds.init_op)

            for batch in range(batch_start_index, tr_ds.num_batches):
                batch_start_time = time.time()

                if batch != 0 and (batch %  log_step == 0):
                    batch_cost, global_step, lr = self._train_with_summary( model, tr_ds, sess, train_writer, converter)
                else:
                    batch_cost, global_step, lr = self._train(model, tr_ds, sess)

                print("epoch: {}, batch: {}/{}, step: {}, time: {:.02f}s, loss: {:.05}, lr: {:.05}"
                      .format(epoch, batch, tr_ds.num_batches, global_step, time.time() - batch_start_time,
                              batch_cost, lr))

                if global_step != 0 and (global_step % val_step == 0):
                    val_acc = self._do_val(val_ds, epoch, global_step, "val", sess, model, converter,  train_writer, cfg, result_dir)
                    test_acc = self._do_val(test_ds, epoch, global_step, "test", sess, model, converter, train_writer, cfg, result_dir)
                    self._save_checkpoint(ckpt_dir, global_step, saver, sess, val_acc, test_acc)

            batch_start_index = 0
Пример #6
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.cfg = load_config(args.cfg_name)

        self.converter = LabelConverter(chars_file=args.chars_file)

        self.tr_ds = ImgDataset(args.train_dir, self.converter,
                                self.cfg.batch_size)

        self.cfg.lr_boundaries = [
            self.tr_ds.num_batches * epoch
            for epoch in self.cfg.lr_decay_epochs
        ]
        self.cfg.lr_values = [
            self.cfg.lr * (self.cfg.lr_decay_rate**i)
            for i in range(len(self.cfg.lr_boundaries) + 1)
        ]

        if args.val_dir is None:
            self.val_ds = None
        else:
            self.val_ds = ImgDataset(args.val_dir,
                                     self.converter,
                                     self.cfg.batch_size,
                                     shuffle=False)

        if args.test_dir is None:
            self.test_ds = None
        else:
            # Test images often have different size, so set batch_size to 1
            self.test_ds = ImgDataset(args.test_dir,
                                      self.converter,
                                      shuffle=False,
                                      batch_size=1)

        self.model = CRNN(self.cfg, num_classes=self.converter.num_classes)
        self.sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True))

        self.epoch_start_index = 0
        self.batch_start_index = 0

    def train(self):
        self.sess.run(tf.global_variables_initializer())

        self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=8)
        self.train_writer = tf.summary.FileWriter(self.args.log_dir,
                                                  self.sess.graph)

        if self.args.restore:
            self._restore()

        print('Begin training...')
        for epoch in range(self.epoch_start_index, self.cfg.epochs):
            self.sess.run(self.tr_ds.init_op)

            for batch in range(self.batch_start_index, self.tr_ds.num_batches):
                batch_start_time = time.time()

                if batch != 0 and (batch % self.args.log_step == 0):
                    batch_cost, global_step, lr = self._train_with_summary()
                else:
                    batch_cost, global_step, lr = self._train()

                print(
                    "epoch: {}, batch: {}/{}, step: {}, time: {:.02f}s, loss: {:.05}, lr: {:.05}"
                    .format(epoch, batch, self.tr_ds.num_batches, global_step,
                            time.time() - batch_start_time, batch_cost, lr))

                if global_step != 0 and (global_step % self.args.val_step
                                         == 0):
                    val_acc = self._do_val(self.val_ds, epoch, global_step,
                                           "val")
                    test_acc = self._do_val(self.test_ds, epoch, global_step,
                                            "test")
                    self._save_checkpoint(self.args.ckpt_dir, global_step,
                                          val_acc, test_acc)

            self.batch_start_index = 0

    def _restore(self):
        utils.restore_ckpt(self.sess, self.saver, self.args.ckpt_dir)

        step_restored = self.sess.run(self.model.global_step)

        self.epoch_start_index = math.floor(step_restored /
                                            self.tr_ds.num_batches)
        self.batch_start_index = step_restored % self.tr_ds.num_batches

        print("Restored global step: %d" % step_restored)
        print("Restored epoch: %d" % self.epoch_start_index)
        print("Restored batch_start_index: %d" % self.batch_start_index)

    def _train(self):
        img_batch, label_batch, labels, _ = self.tr_ds.get_next_batch(
            self.sess)
        feed = {
            self.model.inputs: img_batch,
            self.model.labels: label_batch,
            self.model.is_training: True
        }

        fetches = [
            self.model.total_loss, self.model.ctc_loss,
            self.model.regularization_loss, self.model.global_step,
            self.model.lr, self.model.train_op
        ]

        batch_cost, _, _, global_step, lr, _ = self.sess.run(fetches, feed)
        return batch_cost, global_step, lr

    def _train_with_summary(self):
        img_batch, label_batch, labels, _ = self.tr_ds.get_next_batch(
            self.sess)
        feed = {
            self.model.inputs: img_batch,
            self.model.labels: label_batch,
            self.model.is_training: True
        }

        fetches = [
            self.model.total_loss, self.model.ctc_loss,
            self.model.regularization_loss, self.model.global_step,
            self.model.lr, self.model.merged_summay, self.model.dense_decoded,
            self.model.edit_distance, self.model.train_op
        ]

        batch_cost, _, _, global_step, lr, summary, predicts, edit_distance, _ = self.sess.run(
            fetches, feed)
        self.train_writer.add_summary(summary, global_step)

        predicts = [
            self.converter.decode(p, CRNN.CTC_INVALID_INDEX) for p in predicts
        ]
        accuracy, _ = infer.calculate_accuracy(predicts, labels)

        tf_utils.add_scalar_summary(self.train_writer, "train_accuracy",
                                    accuracy, global_step)
        tf_utils.add_scalar_summary(self.train_writer, "train_edit_distance",
                                    edit_distance, global_step)

        return batch_cost, global_step, lr

    def _do_val(self, dataset, epoch, step, name):
        if dataset is None:
            return None

        accuracy, edit_distance = infer.validation(self.sess,
                                                   self.model.feeds(),
                                                   self.model.fetches(),
                                                   dataset, self.converter,
                                                   self.args.result_dir, name,
                                                   step)

        tf_utils.add_scalar_summary(self.train_writer, "%s_accuracy" % name,
                                    accuracy, step)
        tf_utils.add_scalar_summary(self.train_writer,
                                    "%s_edit_distance" % name, edit_distance,
                                    step)

        print("epoch: %d/%d, %s accuracy = %.3f" %
              (epoch, self.cfg.epochs, name, accuracy))
        return accuracy

    def _save_checkpoint(self, ckpt_dir, step, val_acc=None, test_acc=None):
        ckpt_name = "crnn_%d" % step
        if val_acc is not None:
            ckpt_name += '_val_%.03f' % val_acc
        if test_acc is not None:
            ckpt_name += '_test_%.03f' % test_acc

        name = os.path.join(ckpt_dir, ckpt_name)
        print("save checkpoint %s" % name)

        meta_exists, meta_file_name = self._meta_file_exist(ckpt_dir)

        self.saver.save(self.sess, name)

        # remove old meta file to save disk space
        if meta_exists:
            try:
                os.remove(os.path.join(ckpt_dir, meta_file_name))
            except:
                print('Remove meta file failed: %s' % meta_file_name)

    def _meta_file_exist(self, ckpt_dir):
        fnames = os.listdir(ckpt_dir)
        meta_exists = False
        meta_file_name = ''
        for n in fnames:
            if 'meta' in n:
                meta_exists = True
                meta_file_name = n
                break

        return meta_exists, meta_file_name