Exemple #1
0
    def __init__(self, root=data_dir(), model_epoch=MODEL_EPOCE):
        self._model_dir = os.path.join(root, 'models')
        self._model_epoch = model_epoch
        self._assert_and_prepare_model_files(root)
        self._alphabet, _ = read_charset(os.path.join(self._model_dir, 'label_cn.txt'))

        self._hp = Hyperparams()
        self._mods = {}
Exemple #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset",
                        help="use which kind of dataset, captcha or cn_ocr",
                        choices=['captcha', 'cn_ocr'],
                        type=str,
                        default='captcha')
    parser.add_argument("--file", help="Path to the CAPTCHA image file")
    parser.add_argument("--prefix",
                        help="Checkpoint prefix [Default 'ocr']",
                        default='./models/model')
    parser.add_argument("--epoch",
                        help="Checkpoint epoch [Default 100]",
                        type=int,
                        default=20)
    parser.add_argument('--charset_file', type=str, help='存储了每个字对应哪个id的关系.')
    args = parser.parse_args()
    if args.dataset == 'cn_ocr':
        hp = Hyperparams()
        img = read_ocr_img(args.file, hp)
    else:
        hp = Hyperparams2()
        img = read_captcha_img(args.file, hp)

    # init_state_names, init_state_arrays = lstm_init_states(batch_size=1, hp=hp)
    # import pdb; pdb.set_trace()

    sample = SimpleBatch(data_names=['data'], data=[mx.nd.array([img])])

    network = crnn_lstm(hp)
    mod = load_module(args.prefix,
                      args.epoch,
                      sample.data_names,
                      sample.provide_data,
                      network=network)

    mod.forward(sample)
    prob = mod.get_outputs()[0].asnumpy()

    prediction, start_end_idx = CtcMetrics.ctc_label(
        np.argmax(prob, axis=-1).tolist())

    if args.charset_file:
        alphabet, _ = read_charset(args.charset_file)
        res = [alphabet[p] for p in prediction]
        print("Predicted Chars:", res)
    else:
        # Predictions are 1 to 10 for digits 0 to 9 respectively (prediction 0 means no-digit)
        prediction = [p - 1 for p in prediction]
        print("Digits:", prediction)
    return
def run_cn_ocr(args):
    hp = Hyperparams()

    network = crnn_lstm(hp)

    mp_data_train = MPOcrImages(args.data_root, args.train_file, (hp.img_width, hp.img_height), hp.num_label,
                                num_processes=args.num_proc, max_queue_size=hp.batch_size * 100)
    # img, num = mp_data_train.get()
    # print(img.shape)
    # print(mp_data_train.shape)
    # import pdb; pdb.set_trace()
    # import numpy as np
    # import cv2
    # img = np.transpose(img, (1, 0))
    # cv2.imwrite('captcha1.png', img * 255)
    # import pdb; pdb.set_trace()
    mp_data_test = MPOcrImages(args.data_root, args.test_file, (hp.img_width, hp.img_height), hp.num_label,
                               num_processes=max(args.num_proc // 2, 1), max_queue_size=hp.batch_size * 10)
    mp_data_train.start()
    mp_data_test.start()

    # init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
    # init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
    # init_states = init_c + init_h
    # data_names = ['data'] + [x[0] for x in init_states]
    data_names = ['data']

    data_train = OCRIter(
        hp.train_epoch_size // hp.batch_size, hp.batch_size, captcha=mp_data_train, num_label=hp.num_label,
        name='train')
    data_val = OCRIter(
        hp.eval_epoch_size // hp.batch_size, hp.batch_size, captcha=mp_data_test, num_label=hp.num_label,
        name='val')
    # data_train = ImageIterLstm(
    #     args.data_root, args.train_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="train")
    # data_val = ImageIterLstm(
    #     args.data_root, args.test_file,  hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="val")

    head = '%(asctime)-15s %(message)s'
    logging.basicConfig(level=logging.DEBUG, format=head)

    metrics = CtcMetrics(hp.seq_length)

    fit(network=network, data_train=data_train, data_val=data_val, metrics=metrics, args=args, hp=hp, data_names=data_names)

    mp_data_train.reset()
    mp_data_test.reset()
Exemple #4
0
    def __init__(self,
                 model_name='conv-lite-fc',
                 model_epoch=None,
                 cand_alphabet=None,
                 root=data_dir(),
                 gpus=0):
        """

        :param model_name: 模型名称
        :param model_epoch: 模型迭代次数
        :param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
        :param root: 模型文件所在的根目录。
            Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/1.1.0/conv-lite-fc-0027`。
            Windows下默认值为 ``。
        """
        check_model_name(model_name)
        self._model_name = model_name
        self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX,
                                                 model_name)
        self._model_epoch = model_epoch or AVAILABLE_MODELS[model_name][0]

        root = os.path.join(root, MODEL_VERSION)
        self._model_dir = os.path.join(root, self._model_name)
        self._assert_and_prepare_model_files()
        self._alphabet, inv_alph_dict = read_charset(
            os.path.join(self._model_dir, 'label_cn.txt'))

        self._cand_alph_idx = None
        if cand_alphabet is not None:
            self._cand_alph_idx = [0] + [
                inv_alph_dict[word] for word in cand_alphabet
            ]
            self._cand_alph_idx.sort()

        self._hp = Hyperparams()
        self._hp._loss_type = None  # infer mode

        # DCMMC: gpu context for mxnet
        if gpus > 0:
            self.context = [mx.context.gpu(i) for i in range(gpus)]
        else:
            self.context = [mx.context.cpu()]
        self._mod = self._get_module()
    def init(
            self,
            model_name='densenet-lite-gru',
            model_epoch=None,
            cand_alphabet=None,
            root=data_dir(),
            context='cpu',
            name=None,
    ):
        """

        :param model_name: 模型名称
        :param model_epoch: 模型迭代次数
        :param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
        :param root: 模型文件所在的根目录。
            Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/1.1.0/conv-lite-fc-0027`。
            Windows下默认值为 ``。
        :param context: 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为CPU。
        :param name: 正在初始化的这个实例名称。如果需要同时初始化多个实例,需要为不同的实例指定不同的名称。
        """
        check_model_name(model_name)
        self._model_name = model_name
        self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX,
                                                 model_name)
        self._model_epoch = model_epoch

        self._model_dir = root  # Change folder structure.
        self._assert_and_prepare_model_files()
        self._alphabet, self._inv_alph_dict = read_charset(
            os.path.join(self._model_dir, 'label_cn.txt'))

        self._cand_alph_idx = None
        # Alphabet will be set before calling ocr.
        # self.set_cand_alphabet(cand_alphabet)

        self._hp = Hyperparams()
        self._hp._loss_type = None  # infer mode
        self._hp._num_classes = len(self._alphabet)
        # 传入''的话,也改成传入None
        self._net_prefix = None if name == '' else name

        self._mod = self._get_module(AlOcr.CNOCR_CONTEXT)
Exemple #6
0
    def __init__(
            self,
            model_name='densenet-lite-fc',
            model_epoch=None,
            cand_alphabet=None,
            root=data_dir(),
            context='cpu',
            name=None,
    ):
        """

        :param model_name: 模型名称
        :param model_epoch: 模型迭代次数
        :param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
        :param root: 模型文件所在的根目录。
            Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/1.1.0/conv-lite-fc-0027`。
            Windows下默认值为 ``。
        :param context: 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为CPU。
        :param name: 正在初始化的这个实例名称。如果需要同时初始化多个实例,需要为不同的实例指定不同的名称。
        """
        check_model_name(model_name)
        self._model_name = model_name
        self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX,
                                                 model_name)
        self._model_epoch = model_epoch or AVAILABLE_MODELS[model_name][0]

        root = os.path.join(root, MODEL_VERSION)
        self._model_dir = os.path.join(root, self._model_name)
        self._assert_and_prepare_model_files()
        self._alphabet, self._inv_alph_dict = read_charset(
            os.path.join(self._model_dir, 'label_cn.txt'))

        self._cand_alph_idx = None
        self.set_cand_alphabet(cand_alphabet)

        self._hp = Hyperparams()
        self._hp._loss_type = None  # infer mode
        # 传入''的话,也改成传入None
        self._net_prefix = None if name == '' else name

        self._mod = self._get_module(context)