def __init__(self, root=data_dir(), model_epoch=MODEL_EPOCE): self._model_dir = os.path.join(root, 'models') self._model_epoch = model_epoch self._assert_and_prepare_model_files(root) self._alphabet, _ = read_charset(os.path.join(self._model_dir, 'label_cn.txt')) self._hp = Hyperparams() self._mods = {}
def main(): parser = argparse.ArgumentParser() parser.add_argument("--dataset", help="use which kind of dataset, captcha or cn_ocr", choices=['captcha', 'cn_ocr'], type=str, default='captcha') parser.add_argument("--file", help="Path to the CAPTCHA image file") parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='./models/model') parser.add_argument("--epoch", help="Checkpoint epoch [Default 100]", type=int, default=20) parser.add_argument('--charset_file', type=str, help='存储了每个字对应哪个id的关系.') args = parser.parse_args() if args.dataset == 'cn_ocr': hp = Hyperparams() img = read_ocr_img(args.file, hp) else: hp = Hyperparams2() img = read_captcha_img(args.file, hp) # init_state_names, init_state_arrays = lstm_init_states(batch_size=1, hp=hp) # import pdb; pdb.set_trace() sample = SimpleBatch(data_names=['data'], data=[mx.nd.array([img])]) network = crnn_lstm(hp) mod = load_module(args.prefix, args.epoch, sample.data_names, sample.provide_data, network=network) mod.forward(sample) prob = mod.get_outputs()[0].asnumpy() prediction, start_end_idx = CtcMetrics.ctc_label( np.argmax(prob, axis=-1).tolist()) if args.charset_file: alphabet, _ = read_charset(args.charset_file) res = [alphabet[p] for p in prediction] print("Predicted Chars:", res) else: # Predictions are 1 to 10 for digits 0 to 9 respectively (prediction 0 means no-digit) prediction = [p - 1 for p in prediction] print("Digits:", prediction) return
def run_cn_ocr(args): hp = Hyperparams() network = crnn_lstm(hp) mp_data_train = MPOcrImages(args.data_root, args.train_file, (hp.img_width, hp.img_height), hp.num_label, num_processes=args.num_proc, max_queue_size=hp.batch_size * 100) # img, num = mp_data_train.get() # print(img.shape) # print(mp_data_train.shape) # import pdb; pdb.set_trace() # import numpy as np # import cv2 # img = np.transpose(img, (1, 0)) # cv2.imwrite('captcha1.png', img * 255) # import pdb; pdb.set_trace() mp_data_test = MPOcrImages(args.data_root, args.test_file, (hp.img_width, hp.img_height), hp.num_label, num_processes=max(args.num_proc // 2, 1), max_queue_size=hp.batch_size * 10) mp_data_train.start() mp_data_test.start() # init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] # init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] # init_states = init_c + init_h # data_names = ['data'] + [x[0] for x in init_states] data_names = ['data'] data_train = OCRIter( hp.train_epoch_size // hp.batch_size, hp.batch_size, captcha=mp_data_train, num_label=hp.num_label, name='train') data_val = OCRIter( hp.eval_epoch_size // hp.batch_size, hp.batch_size, captcha=mp_data_test, num_label=hp.num_label, name='val') # data_train = ImageIterLstm( # args.data_root, args.train_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="train") # data_val = ImageIterLstm( # args.data_root, args.test_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="val") head = '%(asctime)-15s %(message)s' logging.basicConfig(level=logging.DEBUG, format=head) metrics = CtcMetrics(hp.seq_length) fit(network=network, data_train=data_train, data_val=data_val, metrics=metrics, args=args, hp=hp, data_names=data_names) mp_data_train.reset() mp_data_test.reset()
def __init__(self, model_name='conv-lite-fc', model_epoch=None, cand_alphabet=None, root=data_dir(), gpus=0): """ :param model_name: 模型名称 :param model_epoch: 模型迭代次数 :param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围 :param root: 模型文件所在的根目录。 Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/1.1.0/conv-lite-fc-0027`。 Windows下默认值为 ``。 """ check_model_name(model_name) self._model_name = model_name self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX, model_name) self._model_epoch = model_epoch or AVAILABLE_MODELS[model_name][0] root = os.path.join(root, MODEL_VERSION) self._model_dir = os.path.join(root, self._model_name) self._assert_and_prepare_model_files() self._alphabet, inv_alph_dict = read_charset( os.path.join(self._model_dir, 'label_cn.txt')) self._cand_alph_idx = None if cand_alphabet is not None: self._cand_alph_idx = [0] + [ inv_alph_dict[word] for word in cand_alphabet ] self._cand_alph_idx.sort() self._hp = Hyperparams() self._hp._loss_type = None # infer mode # DCMMC: gpu context for mxnet if gpus > 0: self.context = [mx.context.gpu(i) for i in range(gpus)] else: self.context = [mx.context.cpu()] self._mod = self._get_module()
def init( self, model_name='densenet-lite-gru', model_epoch=None, cand_alphabet=None, root=data_dir(), context='cpu', name=None, ): """ :param model_name: 模型名称 :param model_epoch: 模型迭代次数 :param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围 :param root: 模型文件所在的根目录。 Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/1.1.0/conv-lite-fc-0027`。 Windows下默认值为 ``。 :param context: 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为CPU。 :param name: 正在初始化的这个实例名称。如果需要同时初始化多个实例,需要为不同的实例指定不同的名称。 """ check_model_name(model_name) self._model_name = model_name self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX, model_name) self._model_epoch = model_epoch self._model_dir = root # Change folder structure. self._assert_and_prepare_model_files() self._alphabet, self._inv_alph_dict = read_charset( os.path.join(self._model_dir, 'label_cn.txt')) self._cand_alph_idx = None # Alphabet will be set before calling ocr. # self.set_cand_alphabet(cand_alphabet) self._hp = Hyperparams() self._hp._loss_type = None # infer mode self._hp._num_classes = len(self._alphabet) # 传入''的话,也改成传入None self._net_prefix = None if name == '' else name self._mod = self._get_module(AlOcr.CNOCR_CONTEXT)
def __init__( self, model_name='densenet-lite-fc', model_epoch=None, cand_alphabet=None, root=data_dir(), context='cpu', name=None, ): """ :param model_name: 模型名称 :param model_epoch: 模型迭代次数 :param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围 :param root: 模型文件所在的根目录。 Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/1.1.0/conv-lite-fc-0027`。 Windows下默认值为 ``。 :param context: 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为CPU。 :param name: 正在初始化的这个实例名称。如果需要同时初始化多个实例,需要为不同的实例指定不同的名称。 """ check_model_name(model_name) self._model_name = model_name self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX, model_name) self._model_epoch = model_epoch or AVAILABLE_MODELS[model_name][0] root = os.path.join(root, MODEL_VERSION) self._model_dir = os.path.join(root, self._model_name) self._assert_and_prepare_model_files() self._alphabet, self._inv_alph_dict = read_charset( os.path.join(self._model_dir, 'label_cn.txt')) self._cand_alph_idx = None self.set_cand_alphabet(cand_alphabet) self._hp = Hyperparams() self._hp._loss_type = None # infer mode # 传入''的话,也改成传入None self._net_prefix = None if name == '' else name self._mod = self._get_module(context)