def tensor_define(self, model_path, charset_path, label_file):

        converter = LabelConverter(chars_file=charset_path)
        dataset = ImgDataset(label_file,
                             converter,
                             batch_size=1,
                             shuffle=False)
        model = CRNN(self.cfg, num_classes=converter.num_classes)

        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        sess.run(dataset.init_op)

        sess = self.load_model(model_path=model_path, sess=sess)
        global_step = sess.run(model.global_step)

        self.sess = sess
        self.converter = converter
        self.dataset = dataset

        self.input = [
            model.inputs, model.labels, model.bat_labels, model.len_labels,
            model.char_num, model.char_pos_init, model.is_training
        ]
        self.output = [model.dense_decoded, model.char_pos, model.embedding]
        self.global_step = global_step
示例#2
0
def infer(args):
    converter = LabelConverter(chars_file=args.chars_file)
    dataset = ImgDataset(args.infer_dir, converter, args.infer_batch_size, shuffle=False)

    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        restore_ckpt(sess, args.ckpt_dir)

        # for node in sess.graph.as_graph_def().node:
        #     print(node.name)

        # https://stackoverflow.com/questions/46912721/tensorflow-restore-model-with-sparse-placeholder
        labels_placeholder = tf.SparseTensor(
            values=sess.graph.get_tensor_by_name('labels/values:0'),
            indices=sess.graph.get_tensor_by_name('labels/indices:0'),
            dense_shape=sess.graph.get_tensor_by_name('labels/shape:0')
        )

        feeds = {
            'inputs': sess.graph.get_tensor_by_name('inputs:0'),
            'is_training': sess.graph.get_tensor_by_name('is_training:0'),
            'labels': labels_placeholder
        }

        fetches = [
            sess.graph.get_tensor_by_name('SparseToDense:0'),  # dense_decoded
            sess.graph.get_tensor_by_name('Mean_1:0'),  # mean edit distance
            sess.graph.get_tensor_by_name('edit_distance:0')  # batch edit distances
        ]

        validation(sess, feeds, fetches,
                   dataset, converter, args.result_dir, name='infer',
                   print_batch_info=True, copy_failed=args.infer_copy_failed)
    def __init__(self, model_path):
        os.environ['CUDA_VISIBLE_DEVICES'] = '2'
        self.cfg = load_config('resnet')
        self.label_converter = LabelConverter(lexicon)

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = 0.5
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph, config=config)

        with self.session.as_default():
            with self.graph.as_default():
                self.net = CRNN(self.cfg, num_classes=self.label_converter.num_classes)
                saver = tf.train.Saver()
                saver.restore(self.session, model_path)

        logging.info('CRNN model initialized.')
示例#4
0
    def __init__(self, args):
        self.args = args
        self.cfg = load_config(args.cfg_name)

        self.converter = LabelConverter(chars_file=args.chars_file)

        self.tr_ds = ImgDataset(args.train_dir, self.converter,
                                self.cfg.batch_size)

        self.cfg.lr_boundaries = [
            self.tr_ds.num_batches * epoch
            for epoch in self.cfg.lr_decay_epochs
        ]
        self.cfg.lr_values = [
            self.cfg.lr * (self.cfg.lr_decay_rate**i)
            for i in range(len(self.cfg.lr_boundaries) + 1)
        ]

        if args.val_dir is None:
            self.val_ds = None
        else:
            self.val_ds = ImgDataset(args.val_dir,
                                     self.converter,
                                     self.cfg.batch_size,
                                     shuffle=False)

        if args.test_dir is None:
            self.test_ds = None
        else:
            # Test images often have different size, so set batch_size to 1
            self.test_ds = ImgDataset(args.test_dir,
                                      self.converter,
                                      shuffle=False,
                                      batch_size=1)

        self.model = CRNN(self.cfg, num_classes=self.converter.num_classes)
        self.sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True))

        self.epoch_start_index = 0
        self.batch_start_index = 0
    def train(self, log_dir, restore, log_step, ckpt_dir, val_step, cfg_name, chars_file, train_txt, val_txt, test_txt, result_dir):

        cfg = load_config(cfg_name)

        converter = LabelConverter(chars_file=chars_file)

        tr_ds = ImgDataset(train_txt, converter, cfg.batch_size)

        cfg.lr_boundaries = [10000]
        cfg.lr_values = [cfg.lr * (cfg.lr_decay_rate ** i) for i in
                              range(len(cfg.lr_boundaries) + 1)]

        if val_txt is None:
            val_ds = None
        else:
            val_ds = ImgDataset(val_txt, converter, cfg.batch_size, shuffle=False)

        if test_txt is None:
            test_ds = None
        else:
            # Test images often have different size, so set batch_size to 1
            test_ds = ImgDataset(test_txt, converter, shuffle=False, batch_size=1)

        model = CRNN(cfg, num_classes=converter.num_classes)

        epoch_start_index = 0
        batch_start_index = 0

        config=tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = 0.8
        sess = tf.Session(config=config)
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=8)
        train_writer = tf.summary.FileWriter(log_dir, sess.graph)

        if restore:
            self._restore(sess, saver, model,tr_ds, ckpt_dir)

        print('Begin training...')
        for epoch in range(epoch_start_index, cfg.epochs):
            sess.run(tr_ds.init_op)

            for batch in range(batch_start_index, tr_ds.num_batches):
                batch_start_time = time.time()

                if batch != 0 and (batch %  log_step == 0):
                    batch_cost, global_step, lr = self._train_with_summary( model, tr_ds, sess, train_writer, converter)
                else:
                    batch_cost, global_step, lr = self._train(model, tr_ds, sess)

                print("epoch: {}, batch: {}/{}, step: {}, time: {:.02f}s, loss: {:.05}, lr: {:.05}"
                      .format(epoch, batch, tr_ds.num_batches, global_step, time.time() - batch_start_time,
                              batch_cost, lr))

                if global_step != 0 and (global_step % val_step == 0):
                    val_acc = self._do_val(val_ds, epoch, global_step, "val", sess, model, converter,  train_writer, cfg, result_dir)
                    test_acc = self._do_val(test_ds, epoch, global_step, "test", sess, model, converter, train_writer, cfg, result_dir)
                    self._save_checkpoint(ckpt_dir, global_step, saver, sess, val_acc, test_acc)

            batch_start_index = 0
示例#6
0
            self.embedding: embedding features
            self.char_label: 和 embedding features 对应的标签
            self.char_pos: 和 embedding features 对应的字符位置

        """
        with tf.variable_scope('pos'):
            # 判断是否为预测的字符
            is_char = tf.less(raw_pred, self.num_classes - 1)

            # 错位比较法,找到重复字符
            char_rep = tf.equal(raw_pred[:, :-1], raw_pred[:, 1:])
            tail = tf.greater(raw_pred[:, :1], self.num_classes - 1)
            char_rep = tf.concat([char_rep, tail], axis=1)

            # 去掉重复字符之后的字符位置,重复字符取其 最后一次 出现的位置
            char_no_rep = tf.math.logical_and(is_char,
                                              tf.math.logical_not(char_rep))

            # 得到字符位置 和 相应的标签,如果某张图片 预测出来的字符数量 和gt不一致则跳过
            self.char_pos, self.char_label = self.get_char_pos_and_label(
                preds=char_no_rep, label=label, char_num=char_num, poses=poses)
            # 根据字符位置得到字符的 embedding
            self.embedding = self.get_features(self.char_pos, embedding)


if __name__ == '__main__':
    from libs.label_converter import LabelConverter

    cfg = load_config('raw')
    converter = LabelConverter(chars_file='./data/chars/lexicon.txt')
    model = CRNN(cfg, num_classes=converter.num_classes)
示例#7
0
class Infer(object):
    def __init__(self, model_path):
        os.environ['CUDA_VISIBLE_DEVICES'] = '2'
        self.cfg = load_config('resnet')
        self.label_converter = LabelConverter(lexicon)
        self.cfg.lr_boundaries = [10000]
        self.cfg.lr_values = [
            self.cfg.lr * (self.cfg.lr_decay_rate**i)
            for i in range(len(self.cfg.lr_boundaries) + 1)
        ]
        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = 0.5
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph, config=config)

        with self.session.as_default():
            with self.graph.as_default():
                self.net = CRNN(self.cfg,
                                num_classes=self.label_converter.num_classes)
                saver = tf.train.Saver()
                saver.restore(self.session, model_path)

        logging.info('CRNN model initialized.')

    def normalize_image(self, img):
        """
        将图像归一化到高为32
        :param img:
        :return:
        """
        img = img.convert('L')
        w, h = img.size
        rio = h / 32.0
        w0 = int(round(w / rio))
        img = img.resize((max(w0, 1), 32), Image.BICUBIC)
        return img

    def predict(self, image, long_info=True):
        """
        单张预测
        :param image:
        :return:
        """
        start_time = time.time()
        image = image_to_pil(image)
        image_width = image.width
        image = self.normalize_image(image)
        if image.width <= 4:
            text = ''
            return text

        image = np.reshape(image, (1, 32, image.width, 1))
        image = (image.astype(np.float32) - 128.0) / 128.0

        feed = {
            self.net.feeds()['inputs']: image,
            self.net.feeds()['is_training']: False
        }
        predict_label, predict_prob, logits = self.session.run(
            self.net.fetches(), feed_dict=feed)
        p, weights, positions = ctc_label(
            predict_label[0],
            predict_prob[0],
            image_width,
            blank_index=self.label_converter.num_classes - 1)
        txt = self.label_converter.decode(p, invalid_index=-1)
        ret = dict()
        ret['text'] = txt
        ret['weights'] = [float(weight) for weight in weights[:len(txt)]]
        ret['positions'] = [
            float(position) for position in positions[:len(txt)]
        ]
        ret['direction'] = 0
        print('predict time is %.4f ms' % ((time.time() - start_time) * 1000))
        if long_info:
            return json.dumps(ret, ensure_ascii=False)
        else:
            return ret['text']

    def normalize_batch(self, image_batch):
        """
        将一个batch内的图像归一化到相同的尺寸
        :param image_batch:
        :return:
        """
        input_batch_size = len(image_batch)
        normalized_batch = []
        chars_count = []
        image_width_list = [int(img.width) for img in image_batch]
        batch_image_width = max(image_width_list)
        max_width_image_idx = np.argmax(image_width_list)
        if input_batch_size == BATCH_SIZE:
            for i in range(BATCH_SIZE):
                base_image = copy.deepcopy(image_batch[max_width_image_idx])
                base_image.paste(image_batch[i], (0, 0))
                base_image.paste(PAD_IMAGE, (image_batch[i].width, 0))
                normalized_image = np.reshape(base_image,
                                              (32, batch_image_width, 1))
                normalized_image = normalized_image.astype(
                    np.float32) / 128.0 - 1.0
                normalized_batch.append(normalized_image)
                chars_count.append(image_batch[i].width / 4)
        else:
            for i in range(input_batch_size):
                base_image = copy.deepcopy(image_batch[max_width_image_idx])
                base_image.paste(image_batch[i], (0, 0))
                base_image.paste(PAD_IMAGE, (image_batch[i].width, 0))
                normalized_image = np.reshape(base_image,
                                              (32, batch_image_width, 1))
                normalized_image = normalized_image.astype(
                    np.float32) / 128.0 - 1.0
                normalized_batch.append(normalized_image)
                chars_count.append(image_batch[i].width / 4)

        return normalized_batch, chars_count

    def predict_batch(self, batch_images, long_info=True):
        """
        batch预测
        :param batch_images:
        :return:
        """
        start_time = time.time()
        batch_texts = []
        batch_images_idx = []
        invalid_images_idx = []
        image_widths = []
        image_heights = []
        for i, image in enumerate(batch_images):
            image = image_to_pil(image)
            image_widths.append(image.width)
            image_heights.append(image.height)
            image = self.normalize_image(image)
            if image.width <= 4:
                invalid_images_idx.append(i)
                batch_images_idx.append(i)
                batch_images[i] = Image.new('L', (32, 32), color=255)
                image_widths[i] = 32
                image_heights[i] = 32
                continue
            batch_images[i] = image
            batch_images_idx.append(i)

        images_with_idx = zip(batch_images, image_widths, image_heights,
                              batch_images_idx)
        batch_images, image_widths, image_heights, batch_images_idx = zip(
            *sorted(images_with_idx, key=lambda x: x[0].width))
        rets = []

        number_images = len(batch_images)
        number_batches = number_images // BATCH_SIZE
        number_remained = number_images % BATCH_SIZE
        if number_remained == 0:
            iters = number_batches
        else:
            iters = number_batches + 1

        for step in range(iters):
            offset = step * BATCH_SIZE
            batch_array = batch_images[offset:min(offset +
                                                  BATCH_SIZE, number_images)]
            batch_array, chars_count = self.normalize_batch(batch_array)

            feed = {
                self.net.feeds()['inputs']: batch_array,
                self.net.feeds()['is_training']: False
            }
            predict_label, predict_prob, logits, cnn_out = self.session.run(
                self.net.fetches(), feed_dict=feed)

            if number_remained > 0 and step == number_batches:
                predict_label = predict_label[:number_remained]
                predict_prob = predict_prob[:number_remained]
                chars_count = chars_count[:number_remained]
            for i in range(len(predict_label)):
                width = image_widths[step * BATCH_SIZE + i]
                height = image_heights[step * BATCH_SIZE + i]
                count = int(chars_count[i])
                label = predict_label[i][:count]
                prob = predict_prob[i][:count]
                p, weights, positions = ctc_label(
                    label,
                    prob,
                    width,
                    blank_index=self.label_converter.num_classes - 1)
                txt = self.label_converter.decode(p, invalid_index=-1)
                ret = dict()
                ret['label'] = label
                ret['text'] = txt
                ret['weights'] = [
                    float(weight) for weight in weights[:len(txt)]
                ]
                ret['positions'] = [
                    float(position) for position in positions[:len(txt)]
                ]
                ret['direction'] = 0

                if ret['text'] != '':
                    ret['score'] = float(np.mean(ret['weights']))
                else:
                    ret['score'] = 0

                rets.append(ret)

        for i in range(len(batch_images)):
            ret = rets[batch_images_idx.index(i)]
            batch_texts.append(ret)
        for i in invalid_images_idx:
            ret = dict()
            ret['label'] = [0]
            ret['weights'] = [0]
            ret['positions'] = [0]
            ret['direction'] = 0
            ret['score'] = 0
            ret['text'] = ''
            batch_texts[i] = ret
        print('predict_batch time is %.4f ms' %
              ((time.time() - start_time) * 1000))
        if long_info:
            return [
                json.dumps(text, ensure_ascii=False) for text in batch_texts
            ]
        else:
            return [ret['text'] for ret in batch_texts]
示例#8
0
        if self.img_channels == 3:
            img_decoded = tf.image.rgb_to_grayscale(img_decoded)

        img_decoded = tf.cast(img_decoded, tf.float32)
        img_decoded = (img_decoded - 128.0) / 128.0

        return img_decoded, label, img_path


if __name__ == '__main__':
    from libs.label_converter import LabelConverter

    demo_path = '/home/cwq/ssd_data/more_bg_corpus/val'
    chars_file = '/home/cwq/code/tf_crnn/data/chars/chn.txt'
    epochs = 5
    batch_size = 128

    converter = LabelConverter(chars_file=chars_file)
    ds = ImgDataset(demo_path, converter, batch_size=batch_size)

    num_batches = int(np.floor(ds.size / batch_size))

    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        for epoch in range(epochs):
            sess.run(ds.init_op)
            print('------------Epoch(%d)------------' % epoch)
            for batch in range(num_batches):
                _, _, labels, _ = ds.get_next_batch(sess)
                print(labels[0])
示例#9
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.cfg = load_config(args.cfg_name)

        self.converter = LabelConverter(chars_file=args.chars_file)

        self.tr_ds = ImgDataset(args.train_dir, self.converter,
                                self.cfg.batch_size)

        self.cfg.lr_boundaries = [
            self.tr_ds.num_batches * epoch
            for epoch in self.cfg.lr_decay_epochs
        ]
        self.cfg.lr_values = [
            self.cfg.lr * (self.cfg.lr_decay_rate**i)
            for i in range(len(self.cfg.lr_boundaries) + 1)
        ]

        if args.val_dir is None:
            self.val_ds = None
        else:
            self.val_ds = ImgDataset(args.val_dir,
                                     self.converter,
                                     self.cfg.batch_size,
                                     shuffle=False)

        if args.test_dir is None:
            self.test_ds = None
        else:
            # Test images often have different size, so set batch_size to 1
            self.test_ds = ImgDataset(args.test_dir,
                                      self.converter,
                                      shuffle=False,
                                      batch_size=1)

        self.model = CRNN(self.cfg, num_classes=self.converter.num_classes)
        self.sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True))

        self.epoch_start_index = 0
        self.batch_start_index = 0

    def train(self):
        self.sess.run(tf.global_variables_initializer())

        self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=8)
        self.train_writer = tf.summary.FileWriter(self.args.log_dir,
                                                  self.sess.graph)

        if self.args.restore:
            self._restore()

        print('Begin training...')
        for epoch in range(self.epoch_start_index, self.cfg.epochs):
            self.sess.run(self.tr_ds.init_op)

            for batch in range(self.batch_start_index, self.tr_ds.num_batches):
                batch_start_time = time.time()

                if batch != 0 and (batch % self.args.log_step == 0):
                    batch_cost, global_step, lr = self._train_with_summary()
                else:
                    batch_cost, global_step, lr = self._train()

                print(
                    "epoch: {}, batch: {}/{}, step: {}, time: {:.02f}s, loss: {:.05}, lr: {:.05}"
                    .format(epoch, batch, self.tr_ds.num_batches, global_step,
                            time.time() - batch_start_time, batch_cost, lr))

                if global_step != 0 and (global_step % self.args.val_step
                                         == 0):
                    val_acc = self._do_val(self.val_ds, epoch, global_step,
                                           "val")
                    test_acc = self._do_val(self.test_ds, epoch, global_step,
                                            "test")
                    self._save_checkpoint(self.args.ckpt_dir, global_step,
                                          val_acc, test_acc)

            self.batch_start_index = 0

    def _restore(self):
        utils.restore_ckpt(self.sess, self.saver, self.args.ckpt_dir)

        step_restored = self.sess.run(self.model.global_step)

        self.epoch_start_index = math.floor(step_restored /
                                            self.tr_ds.num_batches)
        self.batch_start_index = step_restored % self.tr_ds.num_batches

        print("Restored global step: %d" % step_restored)
        print("Restored epoch: %d" % self.epoch_start_index)
        print("Restored batch_start_index: %d" % self.batch_start_index)

    def _train(self):
        img_batch, label_batch, labels, _ = self.tr_ds.get_next_batch(
            self.sess)
        feed = {
            self.model.inputs: img_batch,
            self.model.labels: label_batch,
            self.model.is_training: True
        }

        fetches = [
            self.model.total_loss, self.model.ctc_loss,
            self.model.regularization_loss, self.model.global_step,
            self.model.lr, self.model.train_op
        ]

        batch_cost, _, _, global_step, lr, _ = self.sess.run(fetches, feed)
        return batch_cost, global_step, lr

    def _train_with_summary(self):
        img_batch, label_batch, labels, _ = self.tr_ds.get_next_batch(
            self.sess)
        feed = {
            self.model.inputs: img_batch,
            self.model.labels: label_batch,
            self.model.is_training: True
        }

        fetches = [
            self.model.total_loss, self.model.ctc_loss,
            self.model.regularization_loss, self.model.global_step,
            self.model.lr, self.model.merged_summay, self.model.dense_decoded,
            self.model.edit_distance, self.model.train_op
        ]

        batch_cost, _, _, global_step, lr, summary, predicts, edit_distance, _ = self.sess.run(
            fetches, feed)
        self.train_writer.add_summary(summary, global_step)

        predicts = [
            self.converter.decode(p, CRNN.CTC_INVALID_INDEX) for p in predicts
        ]
        accuracy, _ = infer.calculate_accuracy(predicts, labels)

        tf_utils.add_scalar_summary(self.train_writer, "train_accuracy",
                                    accuracy, global_step)
        tf_utils.add_scalar_summary(self.train_writer, "train_edit_distance",
                                    edit_distance, global_step)

        return batch_cost, global_step, lr

    def _do_val(self, dataset, epoch, step, name):
        if dataset is None:
            return None

        accuracy, edit_distance = infer.validation(self.sess,
                                                   self.model.feeds(),
                                                   self.model.fetches(),
                                                   dataset, self.converter,
                                                   self.args.result_dir, name,
                                                   step)

        tf_utils.add_scalar_summary(self.train_writer, "%s_accuracy" % name,
                                    accuracy, step)
        tf_utils.add_scalar_summary(self.train_writer,
                                    "%s_edit_distance" % name, edit_distance,
                                    step)

        print("epoch: %d/%d, %s accuracy = %.3f" %
              (epoch, self.cfg.epochs, name, accuracy))
        return accuracy

    def _save_checkpoint(self, ckpt_dir, step, val_acc=None, test_acc=None):
        ckpt_name = "crnn_%d" % step
        if val_acc is not None:
            ckpt_name += '_val_%.03f' % val_acc
        if test_acc is not None:
            ckpt_name += '_test_%.03f' % test_acc

        name = os.path.join(ckpt_dir, ckpt_name)
        print("save checkpoint %s" % name)

        meta_exists, meta_file_name = self._meta_file_exist(ckpt_dir)

        self.saver.save(self.sess, name)

        # remove old meta file to save disk space
        if meta_exists:
            try:
                os.remove(os.path.join(ckpt_dir, meta_file_name))
            except:
                print('Remove meta file failed: %s' % meta_file_name)

    def _meta_file_exist(self, ckpt_dir):
        fnames = os.listdir(ckpt_dir)
        meta_exists = False
        meta_file_name = ''
        for n in fnames:
            if 'meta' in n:
                meta_exists = True
                meta_file_name = n
                break

        return meta_exists, meta_file_name