def train_cnocr(args): head = '%(asctime)-15s %(message)s' logging.basicConfig(level=logging.DEBUG, format=head) args.model_name = args.emb_model_type + '-' + args.seq_model_type out_dir = os.path.join(args.out_model_dir, args.model_name) logger.info('save models to dir: %s' % out_dir) if not os.path.exists(out_dir): os.makedirs(out_dir) args.prefix = os.path.join( out_dir, 'cnocr-v{}-{}'.format(__version__, args.model_name)) hp = CnHyperparams() hp = _update_hp(hp, args) network, hp = gen_network(args.model_name, hp) metrics = CtcMetrics(hp.seq_length) data_train, data_val = _gen_iters(hp, args.train_file, args.test_file, args.use_train_image_aug, args.dataset, args.charset, args.debug) data_names = ['data'] fit( network=network, data_train=data_train, data_val=data_val, metrics=metrics, args=args, hp=hp, data_names=data_names, )
def _gen_line_pred_chars(self, line_prob, img_width, max_img_width): """ Get the predicted characters. :param line_prob: with shape of [seq_length, num_classes] :param img_width: :param max_img_width: :return: """ class_ids = np.argmax(line_prob, axis=-1) class_ids *= np.max(line_prob, axis=-1) > 0.5 # Delete low confidence result if img_width < max_img_width: comp_ratio = self._hp.seq_len_cmpr_ratio end_idx = img_width // comp_ratio if end_idx < len(class_ids): class_ids[end_idx:] = 0 prediction, start_end_idx = CtcMetrics.ctc_label(class_ids.tolist()) alphabet = self._alphabet res = [ alphabet[p] if alphabet[p] != '<space>' else ' ' for p in prediction ] return res
def _gen_line_pred_chars(self, line_prob, img_width, max_img_width): """ Get the predicted characters. :param line_prob: with shape of [seq_length, num_classes] :param img_width: :param max_img_width: :return: """ drop = [1 if l.max() > 0.8 else 0 for l in line_prob] class_ids_ = np.argmax(line_prob, axis=-1) class_ids = [] for c, d in zip(class_ids_, drop): class_ids += [c] if d == 1 else [6425] if img_width < max_img_width: comp_ratio = self._hp.seq_len_cmpr_ratio end_idx = img_width // comp_ratio if end_idx < len(class_ids): class_ids[end_idx:] = 0 #prediction, start_end_idx = CtcMetrics.ctc_label(class_ids.tolist()) prediction, start_end_idx = CtcMetrics.ctc_label(class_ids) alphabet = self._alphabet res = [ alphabet[p] if alphabet[p] != '<space>' else ' ' for p in prediction ] return res
def ocr_for_single_line(self, img_fp): """ Recognize characters from an image with characters with only one line :param img_fp: image file path; or gray image mx.nd.NDArray; or gray image np.ndarray :return: charector list, such as ['你', '好'] """ hp = deepcopy(self._hp) if isinstance(img_fp, str) and os.path.isfile(img_fp): img = read_ocr_img(img_fp) elif isinstance(img_fp, mx.nd.NDArray) or isinstance( img_fp, np.ndarray): img = img_fp else: raise TypeError('Inappropriate argument type.') img = rescale_img(img, hp) init_state_names, init_state_arrays = lstm_init_states(batch_size=1, hp=hp) sample = SimpleBatch(data_names=['data'] + init_state_names, data=[mx.nd.array([img])] + init_state_arrays) mod = self._get_module(hp, sample) mod.forward(sample) prob = mod.get_outputs()[0].asnumpy() prediction, start_end_idx = CtcMetrics.ctc_label( np.argmax(prob, axis=-1).tolist()) # print(start_end_idx) alphabet = self._alphabet res = [alphabet[p] for p in prediction] return res
def run_captcha(args): hp = Hyperparams2() network = crnn_lstm(hp) # arg_shape, out_shape, aux_shape = network.infer_shape(data=(128, 1, 32, 100), label=(128, 10), # l0_init_h=(128, 100), l1_init_h=(128, 100), l2_init_h=(128, 100), l3_init_h=(128, 100)) # print(dict(zip(network.list_arguments(), arg_shape))) # import pdb; pdb.set_trace() # Start a multiprocessor captcha image generator mp_captcha = MPDigitCaptcha(font_paths=get_fonts(args.font_path), h=hp.img_width, w=hp.img_height, num_digit_min=3, num_digit_max=4, num_processes=args.num_proc, max_queue_size=hp.batch_size * 2) mp_captcha.start() # img, num = mp_captcha.get() # print(img.shape) # import numpy as np # import cv2 # img = np.transpose(img, (1, 0)) # cv2.imwrite('captcha1.png', img * 255) # import pdb; pdb.set_trace() init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] init_states = init_c + init_h data_names = ['data'] + [x[0] for x in init_states] data_train = OCRIter(hp.train_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, num_label=hp.num_label, name='train') data_val = OCRIter(hp.eval_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, num_label=hp.num_label, name='val') head = '%(asctime)-15s %(message)s' logging.basicConfig(level=logging.DEBUG, format=head) metrics = CtcMetrics(hp.seq_length) fit(network=network, data_train=data_train, data_val=data_val, metrics=metrics, args=args, hp=hp, data_names=data_names) mp_captcha.reset()
def _gen_line_pred_chars(self, line_prob, img_width, max_img_width): """ Get the predicted characters. :param line_prob: with shape of [seq_length, num_classes] :param img_width: :param max_img_width: :return: """ # DCMMC: Greedy decoder for CTC class_ids = np.argmax(line_prob, axis=-1) if img_width < max_img_width: comp_ratio = self._hp.seq_len_cmpr_ratio end_idx = img_width // comp_ratio # DCMMC: 原来照片是 right padding 的... # 而我的数据集是 left and right padding if end_idx < len(class_ids): class_ids[end_idx:] = 0 prediction, start_end_idx = CtcMetrics.ctc_label(class_ids.tolist()) alphabet = self._alphabet res = [ alphabet[p] if alphabet[p] != '<space>' else ' ' for p in prediction ] return res
def main(): parser = argparse.ArgumentParser() parser.add_argument("--dataset", help="use which kind of dataset, captcha or cn_ocr", choices=['captcha', 'cn_ocr'], type=str, default='captcha') parser.add_argument("--file", help="Path to the CAPTCHA image file") parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='./models/model') parser.add_argument("--epoch", help="Checkpoint epoch [Default 100]", type=int, default=20) parser.add_argument('--charset_file', type=str, help='存储了每个字对应哪个id的关系.') args = parser.parse_args() if args.dataset == 'cn_ocr': hp = Hyperparams() img = read_ocr_img(args.file, hp) else: hp = Hyperparams2() img = read_captcha_img(args.file, hp) # init_state_names, init_state_arrays = lstm_init_states(batch_size=1, hp=hp) # import pdb; pdb.set_trace() sample = SimpleBatch(data_names=['data'], data=[mx.nd.array([img])]) network = crnn_lstm(hp) mod = load_module(args.prefix, args.epoch, sample.data_names, sample.provide_data, network=network) mod.forward(sample) prob = mod.get_outputs()[0].asnumpy() prediction, start_end_idx = CtcMetrics.ctc_label( np.argmax(prob, axis=-1).tolist()) if args.charset_file: alphabet, _ = read_charset(args.charset_file) res = [alphabet[p] for p in prediction] print("Predicted Chars:", res) else: # Predictions are 1 to 10 for digits 0 to 9 respectively (prediction 0 means no-digit) prediction = [p - 1 for p in prediction] print("Digits:", prediction) return
def run_cn_ocr(args): hp = Hyperparams() network = crnn_lstm(hp) mp_data_train = MPOcrImages(args.data_root, args.train_file, (hp.img_width, hp.img_height), hp.num_label, num_processes=args.num_proc, max_queue_size=hp.batch_size * 100) # img, num = mp_data_train.get() # print(img.shape) # print(mp_data_train.shape) # import pdb; pdb.set_trace() # import numpy as np # import cv2 # img = np.transpose(img, (1, 0)) # cv2.imwrite('captcha1.png', img * 255) # import pdb; pdb.set_trace() mp_data_test = MPOcrImages(args.data_root, args.test_file, (hp.img_width, hp.img_height), hp.num_label, num_processes=max(args.num_proc // 2, 1), max_queue_size=hp.batch_size * 10) mp_data_train.start() mp_data_test.start() # init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] # init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] # init_states = init_c + init_h # data_names = ['data'] + [x[0] for x in init_states] data_names = ['data'] data_train = OCRIter( hp.train_epoch_size // hp.batch_size, hp.batch_size, captcha=mp_data_train, num_label=hp.num_label, name='train') data_val = OCRIter( hp.eval_epoch_size // hp.batch_size, hp.batch_size, captcha=mp_data_test, num_label=hp.num_label, name='val') # data_train = ImageIterLstm( # args.data_root, args.train_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="train") # data_val = ImageIterLstm( # args.data_root, args.test_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="val") head = '%(asctime)-15s %(message)s' logging.basicConfig(level=logging.DEBUG, format=head) metrics = CtcMetrics(hp.seq_length) fit(network=network, data_train=data_train, data_val=data_val, metrics=metrics, args=args, hp=hp, data_names=data_names) mp_data_train.reset() mp_data_test.reset()
def _gen_line_pred_chars(self, line_prob, img_width, max_img_width): """ Get the predicted characters. :param line_prob: with shape of [seq_length, num_classes] :param img_width: :param max_img_width: :return: """ class_ids = np.argmax(line_prob, axis=-1) # idxs = list(zip(range(len(class_ids)), class_ids)) # probs = [line_prob[e[0], e[1]] for e in idxs] if img_width < max_img_width: comp_ratio = self._hp.seq_len_cmpr_ratio end_idx = img_width // comp_ratio if end_idx < len(class_ids): class_ids[end_idx:] = 0 prediction, start_end_idx = CtcMetrics.ctc_label(class_ids.tolist()) # print(start_end_idx) alphabet = self._alphabet res = [alphabet[p] for p in prediction] # res = self._insert_space_char(res, start_end_idx) return res
def test_ctc_metrics(input, expected): input = list(map(int, list(input))) expected = list(map(int, list(expected))) p, _ = CtcMetrics.ctc_label(input) assert expected == p