예제 #1
0
def train_cnocr(args):
    head = '%(asctime)-15s %(message)s'
    logging.basicConfig(level=logging.DEBUG, format=head)
    args.model_name = args.emb_model_type + '-' + args.seq_model_type
    out_dir = os.path.join(args.out_model_dir, args.model_name)
    logger.info('save models to dir: %s' % out_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    args.prefix = os.path.join(
        out_dir, 'cnocr-v{}-{}'.format(__version__, args.model_name))

    hp = CnHyperparams()
    hp = _update_hp(hp, args)

    network, hp = gen_network(args.model_name, hp)
    metrics = CtcMetrics(hp.seq_length)

    data_train, data_val = _gen_iters(hp, args.train_file, args.test_file,
                                      args.use_train_image_aug, args.dataset,
                                      args.charset, args.debug)
    data_names = ['data']
    fit(
        network=network,
        data_train=data_train,
        data_val=data_val,
        metrics=metrics,
        args=args,
        hp=hp,
        data_names=data_names,
    )
예제 #2
0
    def _gen_line_pred_chars(self, line_prob, img_width, max_img_width):
        """
        Get the predicted characters.
        :param line_prob: with shape of [seq_length, num_classes]
        :param img_width:
        :param max_img_width:
        :return:
        """
        class_ids = np.argmax(line_prob, axis=-1)

        class_ids *= np.max(line_prob,
                            axis=-1) > 0.5  # Delete low confidence result

        if img_width < max_img_width:
            comp_ratio = self._hp.seq_len_cmpr_ratio
            end_idx = img_width // comp_ratio
            if end_idx < len(class_ids):
                class_ids[end_idx:] = 0
        prediction, start_end_idx = CtcMetrics.ctc_label(class_ids.tolist())
        alphabet = self._alphabet
        res = [
            alphabet[p] if alphabet[p] != '<space>' else ' '
            for p in prediction
        ]

        return res
예제 #3
0
    def _gen_line_pred_chars(self, line_prob, img_width, max_img_width):
        """
        Get the predicted characters.
        :param line_prob: with shape of [seq_length, num_classes]
        :param img_width:
        :param max_img_width:
        :return:
        """
        drop = [1 if l.max() > 0.8 else 0 for l in line_prob]
        class_ids_ = np.argmax(line_prob, axis=-1)
        class_ids = []
        for c, d in zip(class_ids_, drop):
            class_ids += [c] if d == 1 else [6425]

        if img_width < max_img_width:
            comp_ratio = self._hp.seq_len_cmpr_ratio
            end_idx = img_width // comp_ratio
            if end_idx < len(class_ids):
                class_ids[end_idx:] = 0
        #prediction, start_end_idx = CtcMetrics.ctc_label(class_ids.tolist())
        prediction, start_end_idx = CtcMetrics.ctc_label(class_ids)
        alphabet = self._alphabet
        res = [
            alphabet[p] if alphabet[p] != '<space>' else ' '
            for p in prediction
        ]

        return res
예제 #4
0
파일: cn_ocr.py 프로젝트: templeblock/cnocr
    def ocr_for_single_line(self, img_fp):
        """
        Recognize characters from an image with characters with only one line
        :param img_fp: image file path; or gray image mx.nd.NDArray; or gray image np.ndarray
        :return: charector list, such as ['你', '好']
        """
        hp = deepcopy(self._hp)
        if isinstance(img_fp, str) and os.path.isfile(img_fp):
            img = read_ocr_img(img_fp)
        elif isinstance(img_fp, mx.nd.NDArray) or isinstance(
                img_fp, np.ndarray):
            img = img_fp
        else:
            raise TypeError('Inappropriate argument type.')
        img = rescale_img(img, hp)

        init_state_names, init_state_arrays = lstm_init_states(batch_size=1,
                                                               hp=hp)

        sample = SimpleBatch(data_names=['data'] + init_state_names,
                             data=[mx.nd.array([img])] + init_state_arrays)

        mod = self._get_module(hp, sample)

        mod.forward(sample)
        prob = mod.get_outputs()[0].asnumpy()

        prediction, start_end_idx = CtcMetrics.ctc_label(
            np.argmax(prob, axis=-1).tolist())
        # print(start_end_idx)

        alphabet = self._alphabet
        res = [alphabet[p] for p in prediction]
        return res
예제 #5
0
def run_captcha(args):
    hp = Hyperparams2()

    network = crnn_lstm(hp)
    # arg_shape, out_shape, aux_shape = network.infer_shape(data=(128, 1, 32, 100), label=(128, 10),
    #                                                       l0_init_h=(128, 100), l1_init_h=(128, 100), l2_init_h=(128, 100), l3_init_h=(128, 100))
    # print(dict(zip(network.list_arguments(), arg_shape)))
    # import pdb; pdb.set_trace()

    # Start a multiprocessor captcha image generator
    mp_captcha = MPDigitCaptcha(font_paths=get_fonts(args.font_path),
                                h=hp.img_width,
                                w=hp.img_height,
                                num_digit_min=3,
                                num_digit_max=4,
                                num_processes=args.num_proc,
                                max_queue_size=hp.batch_size * 2)
    mp_captcha.start()
    # img, num = mp_captcha.get()
    # print(img.shape)
    # import numpy as np
    # import cv2
    # img = np.transpose(img, (1, 0))
    # cv2.imwrite('captcha1.png', img * 255)
    # import pdb; pdb.set_trace()

    init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden))
              for l in range(hp.num_lstm_layer * 2)]
    init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden))
              for l in range(hp.num_lstm_layer * 2)]
    init_states = init_c + init_h
    data_names = ['data'] + [x[0] for x in init_states]

    data_train = OCRIter(hp.train_epoch_size // hp.batch_size,
                         hp.batch_size,
                         init_states,
                         captcha=mp_captcha,
                         num_label=hp.num_label,
                         name='train')
    data_val = OCRIter(hp.eval_epoch_size // hp.batch_size,
                       hp.batch_size,
                       init_states,
                       captcha=mp_captcha,
                       num_label=hp.num_label,
                       name='val')

    head = '%(asctime)-15s %(message)s'
    logging.basicConfig(level=logging.DEBUG, format=head)

    metrics = CtcMetrics(hp.seq_length)

    fit(network=network,
        data_train=data_train,
        data_val=data_val,
        metrics=metrics,
        args=args,
        hp=hp,
        data_names=data_names)

    mp_captcha.reset()
예제 #6
0
    def _gen_line_pred_chars(self, line_prob, img_width, max_img_width):
        """
        Get the predicted characters.
        :param line_prob: with shape of [seq_length, num_classes]
        :param img_width:
        :param max_img_width:
        :return:
        """
        # DCMMC: Greedy decoder for CTC
        class_ids = np.argmax(line_prob, axis=-1)

        if img_width < max_img_width:
            comp_ratio = self._hp.seq_len_cmpr_ratio
            end_idx = img_width // comp_ratio
            # DCMMC: 原来照片是 right padding 的...
            # 而我的数据集是 left and right padding
            if end_idx < len(class_ids):
                class_ids[end_idx:] = 0
        prediction, start_end_idx = CtcMetrics.ctc_label(class_ids.tolist())
        alphabet = self._alphabet
        res = [
            alphabet[p] if alphabet[p] != '<space>' else ' '
            for p in prediction
        ]

        return res
예제 #7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset",
                        help="use which kind of dataset, captcha or cn_ocr",
                        choices=['captcha', 'cn_ocr'],
                        type=str,
                        default='captcha')
    parser.add_argument("--file", help="Path to the CAPTCHA image file")
    parser.add_argument("--prefix",
                        help="Checkpoint prefix [Default 'ocr']",
                        default='./models/model')
    parser.add_argument("--epoch",
                        help="Checkpoint epoch [Default 100]",
                        type=int,
                        default=20)
    parser.add_argument('--charset_file', type=str, help='存储了每个字对应哪个id的关系.')
    args = parser.parse_args()
    if args.dataset == 'cn_ocr':
        hp = Hyperparams()
        img = read_ocr_img(args.file, hp)
    else:
        hp = Hyperparams2()
        img = read_captcha_img(args.file, hp)

    # init_state_names, init_state_arrays = lstm_init_states(batch_size=1, hp=hp)
    # import pdb; pdb.set_trace()

    sample = SimpleBatch(data_names=['data'], data=[mx.nd.array([img])])

    network = crnn_lstm(hp)
    mod = load_module(args.prefix,
                      args.epoch,
                      sample.data_names,
                      sample.provide_data,
                      network=network)

    mod.forward(sample)
    prob = mod.get_outputs()[0].asnumpy()

    prediction, start_end_idx = CtcMetrics.ctc_label(
        np.argmax(prob, axis=-1).tolist())

    if args.charset_file:
        alphabet, _ = read_charset(args.charset_file)
        res = [alphabet[p] for p in prediction]
        print("Predicted Chars:", res)
    else:
        # Predictions are 1 to 10 for digits 0 to 9 respectively (prediction 0 means no-digit)
        prediction = [p - 1 for p in prediction]
        print("Digits:", prediction)
    return
예제 #8
0
def run_cn_ocr(args):
    hp = Hyperparams()

    network = crnn_lstm(hp)

    mp_data_train = MPOcrImages(args.data_root, args.train_file, (hp.img_width, hp.img_height), hp.num_label,
                                num_processes=args.num_proc, max_queue_size=hp.batch_size * 100)
    # img, num = mp_data_train.get()
    # print(img.shape)
    # print(mp_data_train.shape)
    # import pdb; pdb.set_trace()
    # import numpy as np
    # import cv2
    # img = np.transpose(img, (1, 0))
    # cv2.imwrite('captcha1.png', img * 255)
    # import pdb; pdb.set_trace()
    mp_data_test = MPOcrImages(args.data_root, args.test_file, (hp.img_width, hp.img_height), hp.num_label,
                               num_processes=max(args.num_proc // 2, 1), max_queue_size=hp.batch_size * 10)
    mp_data_train.start()
    mp_data_test.start()

    # init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
    # init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
    # init_states = init_c + init_h
    # data_names = ['data'] + [x[0] for x in init_states]
    data_names = ['data']

    data_train = OCRIter(
        hp.train_epoch_size // hp.batch_size, hp.batch_size, captcha=mp_data_train, num_label=hp.num_label,
        name='train')
    data_val = OCRIter(
        hp.eval_epoch_size // hp.batch_size, hp.batch_size, captcha=mp_data_test, num_label=hp.num_label,
        name='val')
    # data_train = ImageIterLstm(
    #     args.data_root, args.train_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="train")
    # data_val = ImageIterLstm(
    #     args.data_root, args.test_file,  hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="val")

    head = '%(asctime)-15s %(message)s'
    logging.basicConfig(level=logging.DEBUG, format=head)

    metrics = CtcMetrics(hp.seq_length)

    fit(network=network, data_train=data_train, data_val=data_val, metrics=metrics, args=args, hp=hp, data_names=data_names)

    mp_data_train.reset()
    mp_data_test.reset()
예제 #9
0
    def _gen_line_pred_chars(self, line_prob, img_width, max_img_width):
        """
        Get the predicted characters.
        :param line_prob: with shape of [seq_length, num_classes]
        :param img_width:
        :param max_img_width:
        :return:
        """
        class_ids = np.argmax(line_prob, axis=-1)
        # idxs = list(zip(range(len(class_ids)), class_ids))
        # probs = [line_prob[e[0], e[1]] for e in idxs]

        if img_width < max_img_width:
            comp_ratio = self._hp.seq_len_cmpr_ratio
            end_idx = img_width // comp_ratio
            if end_idx < len(class_ids):
                class_ids[end_idx:] = 0
        prediction, start_end_idx = CtcMetrics.ctc_label(class_ids.tolist())
        # print(start_end_idx)
        alphabet = self._alphabet
        res = [alphabet[p] for p in prediction]

        # res = self._insert_space_char(res, start_end_idx)
        return res
예제 #10
0
def test_ctc_metrics(input, expected):
    input = list(map(int, list(input)))
    expected = list(map(int, list(expected)))
    p, _ = CtcMetrics.ctc_label(input)
    assert expected == p