Ejemplo n.º 1
0
class Infer(object):
    def __init__(self, model_path):
        os.environ['CUDA_VISIBLE_DEVICES'] = '2'
        self.cfg = load_config('resnet')
        self.label_converter = LabelConverter(lexicon)
        self.cfg.lr_boundaries = [10000]
        self.cfg.lr_values = [
            self.cfg.lr * (self.cfg.lr_decay_rate**i)
            for i in range(len(self.cfg.lr_boundaries) + 1)
        ]
        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = 0.5
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph, config=config)

        with self.session.as_default():
            with self.graph.as_default():
                self.net = CRNN(self.cfg,
                                num_classes=self.label_converter.num_classes)
                saver = tf.train.Saver()
                saver.restore(self.session, model_path)

        logging.info('CRNN model initialized.')

    def normalize_image(self, img):
        """
        将图像归一化到高为32
        :param img:
        :return:
        """
        img = img.convert('L')
        w, h = img.size
        rio = h / 32.0
        w0 = int(round(w / rio))
        img = img.resize((max(w0, 1), 32), Image.BICUBIC)
        return img

    def predict(self, image, long_info=True):
        """
        单张预测
        :param image:
        :return:
        """
        start_time = time.time()
        image = image_to_pil(image)
        image_width = image.width
        image = self.normalize_image(image)
        if image.width <= 4:
            text = ''
            return text

        image = np.reshape(image, (1, 32, image.width, 1))
        image = (image.astype(np.float32) - 128.0) / 128.0

        feed = {
            self.net.feeds()['inputs']: image,
            self.net.feeds()['is_training']: False
        }
        predict_label, predict_prob, logits = self.session.run(
            self.net.fetches(), feed_dict=feed)
        p, weights, positions = ctc_label(
            predict_label[0],
            predict_prob[0],
            image_width,
            blank_index=self.label_converter.num_classes - 1)
        txt = self.label_converter.decode(p, invalid_index=-1)
        ret = dict()
        ret['text'] = txt
        ret['weights'] = [float(weight) for weight in weights[:len(txt)]]
        ret['positions'] = [
            float(position) for position in positions[:len(txt)]
        ]
        ret['direction'] = 0
        print('predict time is %.4f ms' % ((time.time() - start_time) * 1000))
        if long_info:
            return json.dumps(ret, ensure_ascii=False)
        else:
            return ret['text']

    def normalize_batch(self, image_batch):
        """
        将一个batch内的图像归一化到相同的尺寸
        :param image_batch:
        :return:
        """
        input_batch_size = len(image_batch)
        normalized_batch = []
        chars_count = []
        image_width_list = [int(img.width) for img in image_batch]
        batch_image_width = max(image_width_list)
        max_width_image_idx = np.argmax(image_width_list)
        if input_batch_size == BATCH_SIZE:
            for i in range(BATCH_SIZE):
                base_image = copy.deepcopy(image_batch[max_width_image_idx])
                base_image.paste(image_batch[i], (0, 0))
                base_image.paste(PAD_IMAGE, (image_batch[i].width, 0))
                normalized_image = np.reshape(base_image,
                                              (32, batch_image_width, 1))
                normalized_image = normalized_image.astype(
                    np.float32) / 128.0 - 1.0
                normalized_batch.append(normalized_image)
                chars_count.append(image_batch[i].width / 4)
        else:
            for i in range(input_batch_size):
                base_image = copy.deepcopy(image_batch[max_width_image_idx])
                base_image.paste(image_batch[i], (0, 0))
                base_image.paste(PAD_IMAGE, (image_batch[i].width, 0))
                normalized_image = np.reshape(base_image,
                                              (32, batch_image_width, 1))
                normalized_image = normalized_image.astype(
                    np.float32) / 128.0 - 1.0
                normalized_batch.append(normalized_image)
                chars_count.append(image_batch[i].width / 4)

        return normalized_batch, chars_count

    def predict_batch(self, batch_images, long_info=True):
        """
        batch预测
        :param batch_images:
        :return:
        """
        start_time = time.time()
        batch_texts = []
        batch_images_idx = []
        invalid_images_idx = []
        image_widths = []
        image_heights = []
        for i, image in enumerate(batch_images):
            image = image_to_pil(image)
            image_widths.append(image.width)
            image_heights.append(image.height)
            image = self.normalize_image(image)
            if image.width <= 4:
                invalid_images_idx.append(i)
                batch_images_idx.append(i)
                batch_images[i] = Image.new('L', (32, 32), color=255)
                image_widths[i] = 32
                image_heights[i] = 32
                continue
            batch_images[i] = image
            batch_images_idx.append(i)

        images_with_idx = zip(batch_images, image_widths, image_heights,
                              batch_images_idx)
        batch_images, image_widths, image_heights, batch_images_idx = zip(
            *sorted(images_with_idx, key=lambda x: x[0].width))
        rets = []

        number_images = len(batch_images)
        number_batches = number_images // BATCH_SIZE
        number_remained = number_images % BATCH_SIZE
        if number_remained == 0:
            iters = number_batches
        else:
            iters = number_batches + 1

        for step in range(iters):
            offset = step * BATCH_SIZE
            batch_array = batch_images[offset:min(offset +
                                                  BATCH_SIZE, number_images)]
            batch_array, chars_count = self.normalize_batch(batch_array)

            feed = {
                self.net.feeds()['inputs']: batch_array,
                self.net.feeds()['is_training']: False
            }
            predict_label, predict_prob, logits, cnn_out = self.session.run(
                self.net.fetches(), feed_dict=feed)

            if number_remained > 0 and step == number_batches:
                predict_label = predict_label[:number_remained]
                predict_prob = predict_prob[:number_remained]
                chars_count = chars_count[:number_remained]
            for i in range(len(predict_label)):
                width = image_widths[step * BATCH_SIZE + i]
                height = image_heights[step * BATCH_SIZE + i]
                count = int(chars_count[i])
                label = predict_label[i][:count]
                prob = predict_prob[i][:count]
                p, weights, positions = ctc_label(
                    label,
                    prob,
                    width,
                    blank_index=self.label_converter.num_classes - 1)
                txt = self.label_converter.decode(p, invalid_index=-1)
                ret = dict()
                ret['label'] = label
                ret['text'] = txt
                ret['weights'] = [
                    float(weight) for weight in weights[:len(txt)]
                ]
                ret['positions'] = [
                    float(position) for position in positions[:len(txt)]
                ]
                ret['direction'] = 0

                if ret['text'] != '':
                    ret['score'] = float(np.mean(ret['weights']))
                else:
                    ret['score'] = 0

                rets.append(ret)

        for i in range(len(batch_images)):
            ret = rets[batch_images_idx.index(i)]
            batch_texts.append(ret)
        for i in invalid_images_idx:
            ret = dict()
            ret['label'] = [0]
            ret['weights'] = [0]
            ret['positions'] = [0]
            ret['direction'] = 0
            ret['score'] = 0
            ret['text'] = ''
            batch_texts[i] = ret
        print('predict_batch time is %.4f ms' %
              ((time.time() - start_time) * 1000))
        if long_info:
            return [
                json.dumps(text, ensure_ascii=False) for text in batch_texts
            ]
        else:
            return [ret['text'] for ret in batch_texts]
Ejemplo n.º 2
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.cfg = load_config(args.cfg_name)

        self.converter = LabelConverter(chars_file=args.chars_file)

        self.tr_ds = ImgDataset(args.train_dir, self.converter,
                                self.cfg.batch_size)

        self.cfg.lr_boundaries = [
            self.tr_ds.num_batches * epoch
            for epoch in self.cfg.lr_decay_epochs
        ]
        self.cfg.lr_values = [
            self.cfg.lr * (self.cfg.lr_decay_rate**i)
            for i in range(len(self.cfg.lr_boundaries) + 1)
        ]

        if args.val_dir is None:
            self.val_ds = None
        else:
            self.val_ds = ImgDataset(args.val_dir,
                                     self.converter,
                                     self.cfg.batch_size,
                                     shuffle=False)

        if args.test_dir is None:
            self.test_ds = None
        else:
            # Test images often have different size, so set batch_size to 1
            self.test_ds = ImgDataset(args.test_dir,
                                      self.converter,
                                      shuffle=False,
                                      batch_size=1)

        self.model = CRNN(self.cfg, num_classes=self.converter.num_classes)
        self.sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True))

        self.epoch_start_index = 0
        self.batch_start_index = 0

    def train(self):
        self.sess.run(tf.global_variables_initializer())

        self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=8)
        self.train_writer = tf.summary.FileWriter(self.args.log_dir,
                                                  self.sess.graph)

        if self.args.restore:
            self._restore()

        print('Begin training...')
        for epoch in range(self.epoch_start_index, self.cfg.epochs):
            self.sess.run(self.tr_ds.init_op)

            for batch in range(self.batch_start_index, self.tr_ds.num_batches):
                batch_start_time = time.time()

                if batch != 0 and (batch % self.args.log_step == 0):
                    batch_cost, global_step, lr = self._train_with_summary()
                else:
                    batch_cost, global_step, lr = self._train()

                print(
                    "epoch: {}, batch: {}/{}, step: {}, time: {:.02f}s, loss: {:.05}, lr: {:.05}"
                    .format(epoch, batch, self.tr_ds.num_batches, global_step,
                            time.time() - batch_start_time, batch_cost, lr))

                if global_step != 0 and (global_step % self.args.val_step
                                         == 0):
                    val_acc = self._do_val(self.val_ds, epoch, global_step,
                                           "val")
                    test_acc = self._do_val(self.test_ds, epoch, global_step,
                                            "test")
                    self._save_checkpoint(self.args.ckpt_dir, global_step,
                                          val_acc, test_acc)

            self.batch_start_index = 0

    def _restore(self):
        utils.restore_ckpt(self.sess, self.saver, self.args.ckpt_dir)

        step_restored = self.sess.run(self.model.global_step)

        self.epoch_start_index = math.floor(step_restored /
                                            self.tr_ds.num_batches)
        self.batch_start_index = step_restored % self.tr_ds.num_batches

        print("Restored global step: %d" % step_restored)
        print("Restored epoch: %d" % self.epoch_start_index)
        print("Restored batch_start_index: %d" % self.batch_start_index)

    def _train(self):
        img_batch, label_batch, labels, _ = self.tr_ds.get_next_batch(
            self.sess)
        feed = {
            self.model.inputs: img_batch,
            self.model.labels: label_batch,
            self.model.is_training: True
        }

        fetches = [
            self.model.total_loss, self.model.ctc_loss,
            self.model.regularization_loss, self.model.global_step,
            self.model.lr, self.model.train_op
        ]

        batch_cost, _, _, global_step, lr, _ = self.sess.run(fetches, feed)
        return batch_cost, global_step, lr

    def _train_with_summary(self):
        img_batch, label_batch, labels, _ = self.tr_ds.get_next_batch(
            self.sess)
        feed = {
            self.model.inputs: img_batch,
            self.model.labels: label_batch,
            self.model.is_training: True
        }

        fetches = [
            self.model.total_loss, self.model.ctc_loss,
            self.model.regularization_loss, self.model.global_step,
            self.model.lr, self.model.merged_summay, self.model.dense_decoded,
            self.model.edit_distance, self.model.train_op
        ]

        batch_cost, _, _, global_step, lr, summary, predicts, edit_distance, _ = self.sess.run(
            fetches, feed)
        self.train_writer.add_summary(summary, global_step)

        predicts = [
            self.converter.decode(p, CRNN.CTC_INVALID_INDEX) for p in predicts
        ]
        accuracy, _ = infer.calculate_accuracy(predicts, labels)

        tf_utils.add_scalar_summary(self.train_writer, "train_accuracy",
                                    accuracy, global_step)
        tf_utils.add_scalar_summary(self.train_writer, "train_edit_distance",
                                    edit_distance, global_step)

        return batch_cost, global_step, lr

    def _do_val(self, dataset, epoch, step, name):
        if dataset is None:
            return None

        accuracy, edit_distance = infer.validation(self.sess,
                                                   self.model.feeds(),
                                                   self.model.fetches(),
                                                   dataset, self.converter,
                                                   self.args.result_dir, name,
                                                   step)

        tf_utils.add_scalar_summary(self.train_writer, "%s_accuracy" % name,
                                    accuracy, step)
        tf_utils.add_scalar_summary(self.train_writer,
                                    "%s_edit_distance" % name, edit_distance,
                                    step)

        print("epoch: %d/%d, %s accuracy = %.3f" %
              (epoch, self.cfg.epochs, name, accuracy))
        return accuracy

    def _save_checkpoint(self, ckpt_dir, step, val_acc=None, test_acc=None):
        ckpt_name = "crnn_%d" % step
        if val_acc is not None:
            ckpt_name += '_val_%.03f' % val_acc
        if test_acc is not None:
            ckpt_name += '_test_%.03f' % test_acc

        name = os.path.join(ckpt_dir, ckpt_name)
        print("save checkpoint %s" % name)

        meta_exists, meta_file_name = self._meta_file_exist(ckpt_dir)

        self.saver.save(self.sess, name)

        # remove old meta file to save disk space
        if meta_exists:
            try:
                os.remove(os.path.join(ckpt_dir, meta_file_name))
            except:
                print('Remove meta file failed: %s' % meta_file_name)

    def _meta_file_exist(self, ckpt_dir):
        fnames = os.listdir(ckpt_dir)
        meta_exists = False
        meta_file_name = ''
        for n in fnames:
            if 'meta' in n:
                meta_exists = True
                meta_file_name = n
                break

        return meta_exists, meta_file_name