class Infer(object): def __init__(self, model_path): os.environ['CUDA_VISIBLE_DEVICES'] = '2' self.cfg = load_config('resnet') self.label_converter = LabelConverter(lexicon) self.cfg.lr_boundaries = [10000] self.cfg.lr_values = [ self.cfg.lr * (self.cfg.lr_decay_rate**i) for i in range(len(self.cfg.lr_boundaries) + 1) ] config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True config.gpu_options.per_process_gpu_memory_fraction = 0.5 self.graph = tf.Graph() self.session = tf.Session(graph=self.graph, config=config) with self.session.as_default(): with self.graph.as_default(): self.net = CRNN(self.cfg, num_classes=self.label_converter.num_classes) saver = tf.train.Saver() saver.restore(self.session, model_path) logging.info('CRNN model initialized.') def normalize_image(self, img): """ 将图像归一化到高为32 :param img: :return: """ img = img.convert('L') w, h = img.size rio = h / 32.0 w0 = int(round(w / rio)) img = img.resize((max(w0, 1), 32), Image.BICUBIC) return img def predict(self, image, long_info=True): """ 单张预测 :param image: :return: """ start_time = time.time() image = image_to_pil(image) image_width = image.width image = self.normalize_image(image) if image.width <= 4: text = '' return text image = np.reshape(image, (1, 32, image.width, 1)) image = (image.astype(np.float32) - 128.0) / 128.0 feed = { self.net.feeds()['inputs']: image, self.net.feeds()['is_training']: False } predict_label, predict_prob, logits = self.session.run( self.net.fetches(), feed_dict=feed) p, weights, positions = ctc_label( predict_label[0], predict_prob[0], image_width, blank_index=self.label_converter.num_classes - 1) txt = self.label_converter.decode(p, invalid_index=-1) ret = dict() ret['text'] = txt ret['weights'] = [float(weight) for weight in weights[:len(txt)]] ret['positions'] = [ float(position) for position in positions[:len(txt)] ] ret['direction'] = 0 print('predict time is %.4f ms' % ((time.time() - start_time) * 1000)) if long_info: return json.dumps(ret, ensure_ascii=False) else: return ret['text'] def normalize_batch(self, image_batch): """ 将一个batch内的图像归一化到相同的尺寸 :param image_batch: :return: """ input_batch_size = len(image_batch) normalized_batch = [] chars_count = [] image_width_list = [int(img.width) for img in image_batch] batch_image_width = max(image_width_list) max_width_image_idx = np.argmax(image_width_list) if input_batch_size == BATCH_SIZE: for i in range(BATCH_SIZE): base_image = copy.deepcopy(image_batch[max_width_image_idx]) base_image.paste(image_batch[i], (0, 0)) base_image.paste(PAD_IMAGE, (image_batch[i].width, 0)) normalized_image = np.reshape(base_image, (32, batch_image_width, 1)) normalized_image = normalized_image.astype( np.float32) / 128.0 - 1.0 normalized_batch.append(normalized_image) chars_count.append(image_batch[i].width / 4) else: for i in range(input_batch_size): base_image = copy.deepcopy(image_batch[max_width_image_idx]) base_image.paste(image_batch[i], (0, 0)) base_image.paste(PAD_IMAGE, (image_batch[i].width, 0)) normalized_image = np.reshape(base_image, (32, batch_image_width, 1)) normalized_image = normalized_image.astype( np.float32) / 128.0 - 1.0 normalized_batch.append(normalized_image) chars_count.append(image_batch[i].width / 4) return normalized_batch, chars_count def predict_batch(self, batch_images, long_info=True): """ batch预测 :param batch_images: :return: """ start_time = time.time() batch_texts = [] batch_images_idx = [] invalid_images_idx = [] image_widths = [] image_heights = [] for i, image in enumerate(batch_images): image = image_to_pil(image) image_widths.append(image.width) image_heights.append(image.height) image = self.normalize_image(image) if image.width <= 4: invalid_images_idx.append(i) batch_images_idx.append(i) batch_images[i] = Image.new('L', (32, 32), color=255) image_widths[i] = 32 image_heights[i] = 32 continue batch_images[i] = image batch_images_idx.append(i) images_with_idx = zip(batch_images, image_widths, image_heights, batch_images_idx) batch_images, image_widths, image_heights, batch_images_idx = zip( *sorted(images_with_idx, key=lambda x: x[0].width)) rets = [] number_images = len(batch_images) number_batches = number_images // BATCH_SIZE number_remained = number_images % BATCH_SIZE if number_remained == 0: iters = number_batches else: iters = number_batches + 1 for step in range(iters): offset = step * BATCH_SIZE batch_array = batch_images[offset:min(offset + BATCH_SIZE, number_images)] batch_array, chars_count = self.normalize_batch(batch_array) feed = { self.net.feeds()['inputs']: batch_array, self.net.feeds()['is_training']: False } predict_label, predict_prob, logits, cnn_out = self.session.run( self.net.fetches(), feed_dict=feed) if number_remained > 0 and step == number_batches: predict_label = predict_label[:number_remained] predict_prob = predict_prob[:number_remained] chars_count = chars_count[:number_remained] for i in range(len(predict_label)): width = image_widths[step * BATCH_SIZE + i] height = image_heights[step * BATCH_SIZE + i] count = int(chars_count[i]) label = predict_label[i][:count] prob = predict_prob[i][:count] p, weights, positions = ctc_label( label, prob, width, blank_index=self.label_converter.num_classes - 1) txt = self.label_converter.decode(p, invalid_index=-1) ret = dict() ret['label'] = label ret['text'] = txt ret['weights'] = [ float(weight) for weight in weights[:len(txt)] ] ret['positions'] = [ float(position) for position in positions[:len(txt)] ] ret['direction'] = 0 if ret['text'] != '': ret['score'] = float(np.mean(ret['weights'])) else: ret['score'] = 0 rets.append(ret) for i in range(len(batch_images)): ret = rets[batch_images_idx.index(i)] batch_texts.append(ret) for i in invalid_images_idx: ret = dict() ret['label'] = [0] ret['weights'] = [0] ret['positions'] = [0] ret['direction'] = 0 ret['score'] = 0 ret['text'] = '' batch_texts[i] = ret print('predict_batch time is %.4f ms' % ((time.time() - start_time) * 1000)) if long_info: return [ json.dumps(text, ensure_ascii=False) for text in batch_texts ] else: return [ret['text'] for ret in batch_texts]
class Trainer(object): def __init__(self, args): self.args = args self.cfg = load_config(args.cfg_name) self.converter = LabelConverter(chars_file=args.chars_file) self.tr_ds = ImgDataset(args.train_dir, self.converter, self.cfg.batch_size) self.cfg.lr_boundaries = [ self.tr_ds.num_batches * epoch for epoch in self.cfg.lr_decay_epochs ] self.cfg.lr_values = [ self.cfg.lr * (self.cfg.lr_decay_rate**i) for i in range(len(self.cfg.lr_boundaries) + 1) ] if args.val_dir is None: self.val_ds = None else: self.val_ds = ImgDataset(args.val_dir, self.converter, self.cfg.batch_size, shuffle=False) if args.test_dir is None: self.test_ds = None else: # Test images often have different size, so set batch_size to 1 self.test_ds = ImgDataset(args.test_dir, self.converter, shuffle=False, batch_size=1) self.model = CRNN(self.cfg, num_classes=self.converter.num_classes) self.sess = tf.Session(config=tf.ConfigProto( allow_soft_placement=True)) self.epoch_start_index = 0 self.batch_start_index = 0 def train(self): self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=8) self.train_writer = tf.summary.FileWriter(self.args.log_dir, self.sess.graph) if self.args.restore: self._restore() print('Begin training...') for epoch in range(self.epoch_start_index, self.cfg.epochs): self.sess.run(self.tr_ds.init_op) for batch in range(self.batch_start_index, self.tr_ds.num_batches): batch_start_time = time.time() if batch != 0 and (batch % self.args.log_step == 0): batch_cost, global_step, lr = self._train_with_summary() else: batch_cost, global_step, lr = self._train() print( "epoch: {}, batch: {}/{}, step: {}, time: {:.02f}s, loss: {:.05}, lr: {:.05}" .format(epoch, batch, self.tr_ds.num_batches, global_step, time.time() - batch_start_time, batch_cost, lr)) if global_step != 0 and (global_step % self.args.val_step == 0): val_acc = self._do_val(self.val_ds, epoch, global_step, "val") test_acc = self._do_val(self.test_ds, epoch, global_step, "test") self._save_checkpoint(self.args.ckpt_dir, global_step, val_acc, test_acc) self.batch_start_index = 0 def _restore(self): utils.restore_ckpt(self.sess, self.saver, self.args.ckpt_dir) step_restored = self.sess.run(self.model.global_step) self.epoch_start_index = math.floor(step_restored / self.tr_ds.num_batches) self.batch_start_index = step_restored % self.tr_ds.num_batches print("Restored global step: %d" % step_restored) print("Restored epoch: %d" % self.epoch_start_index) print("Restored batch_start_index: %d" % self.batch_start_index) def _train(self): img_batch, label_batch, labels, _ = self.tr_ds.get_next_batch( self.sess) feed = { self.model.inputs: img_batch, self.model.labels: label_batch, self.model.is_training: True } fetches = [ self.model.total_loss, self.model.ctc_loss, self.model.regularization_loss, self.model.global_step, self.model.lr, self.model.train_op ] batch_cost, _, _, global_step, lr, _ = self.sess.run(fetches, feed) return batch_cost, global_step, lr def _train_with_summary(self): img_batch, label_batch, labels, _ = self.tr_ds.get_next_batch( self.sess) feed = { self.model.inputs: img_batch, self.model.labels: label_batch, self.model.is_training: True } fetches = [ self.model.total_loss, self.model.ctc_loss, self.model.regularization_loss, self.model.global_step, self.model.lr, self.model.merged_summay, self.model.dense_decoded, self.model.edit_distance, self.model.train_op ] batch_cost, _, _, global_step, lr, summary, predicts, edit_distance, _ = self.sess.run( fetches, feed) self.train_writer.add_summary(summary, global_step) predicts = [ self.converter.decode(p, CRNN.CTC_INVALID_INDEX) for p in predicts ] accuracy, _ = infer.calculate_accuracy(predicts, labels) tf_utils.add_scalar_summary(self.train_writer, "train_accuracy", accuracy, global_step) tf_utils.add_scalar_summary(self.train_writer, "train_edit_distance", edit_distance, global_step) return batch_cost, global_step, lr def _do_val(self, dataset, epoch, step, name): if dataset is None: return None accuracy, edit_distance = infer.validation(self.sess, self.model.feeds(), self.model.fetches(), dataset, self.converter, self.args.result_dir, name, step) tf_utils.add_scalar_summary(self.train_writer, "%s_accuracy" % name, accuracy, step) tf_utils.add_scalar_summary(self.train_writer, "%s_edit_distance" % name, edit_distance, step) print("epoch: %d/%d, %s accuracy = %.3f" % (epoch, self.cfg.epochs, name, accuracy)) return accuracy def _save_checkpoint(self, ckpt_dir, step, val_acc=None, test_acc=None): ckpt_name = "crnn_%d" % step if val_acc is not None: ckpt_name += '_val_%.03f' % val_acc if test_acc is not None: ckpt_name += '_test_%.03f' % test_acc name = os.path.join(ckpt_dir, ckpt_name) print("save checkpoint %s" % name) meta_exists, meta_file_name = self._meta_file_exist(ckpt_dir) self.saver.save(self.sess, name) # remove old meta file to save disk space if meta_exists: try: os.remove(os.path.join(ckpt_dir, meta_file_name)) except: print('Remove meta file failed: %s' % meta_file_name) def _meta_file_exist(self, ckpt_dir): fnames = os.listdir(ckpt_dir) meta_exists = False meta_file_name = '' for n in fnames: if 'meta' in n: meta_exists = True meta_file_name = n break return meta_exists, meta_file_name