Exemplo n.º 1
0
def parse_args():
    # Parse command line arguments
    parser = argparse.ArgumentParser()
    default_model_prefix = os.path.join(data_dir(), 'models', 'model-v{}'.format(__version__))

    parser.add_argument("--dataset",
                        help="use which kind of dataset, captcha or cn_ocr",
                        choices=['captcha', 'cn_ocr'],
                        type=str, default='captcha')
    parser.add_argument("--data_root", help="Path to image files", type=str,
                        default='/Users/king/Documents/WhatIHaveDone/Test/text_renderer/output/wechat_simulator')
    parser.add_argument("--train_file", help="Path to train txt file", type=str,
                        default='/Users/king/Documents/WhatIHaveDone/Test/text_renderer/output/wechat_simulator/train.txt')
    parser.add_argument("--test_file", help="Path to test txt file", type=str,
                        default='/Users/king/Documents/WhatIHaveDone/Test/text_renderer/output/wechat_simulator/test.txt')
    parser.add_argument("--cpu",
                        help="Number of CPUs for training [Default 8]. Ignored if --gpu is specified.",
                        type=int, default=2)
    parser.add_argument("--gpu", help="Number of GPUs for training [Default 0]", type=int)
    parser.add_argument('--load_epoch', type=int,
                        help='load the model on an epoch using the model-load-prefix [Default: no trained model will be loaded]')
    parser.add_argument("--prefix", help="Checkpoint prefix [Default '{}']".format(default_model_prefix),
                        default=default_model_prefix)
    parser.add_argument("--loss", help="'ctc' or 'warpctc' loss [Default 'ctc']", default='ctc')
    parser.add_argument("--num_proc", help="Number CAPTCHA generating processes [Default 4]", type=int, default=4)
    parser.add_argument("--font_path", help="Path to ttf font file or directory containing ttf files")
    return parser.parse_args()
Exemplo n.º 2
0
    def __init__(self, root=data_dir(), model_epoch=MODEL_EPOCE):
        self._model_dir = os.path.join(root, 'models')
        self._model_epoch = model_epoch
        self._assert_and_prepare_model_files(root)
        self._alphabet, _ = read_charset(os.path.join(self._model_dir, 'label_cn.txt'))

        self._hp = Hyperparams()
        self._mods = {}
Exemplo n.º 3
0
    def __init__(self,
                 model_name='conv-lite-fc',
                 model_epoch=None,
                 cand_alphabet=None,
                 root=data_dir(),
                 gpus=0):
        """

        :param model_name: 模型名称
        :param model_epoch: 模型迭代次数
        :param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
        :param root: 模型文件所在的根目录。
            Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/1.1.0/conv-lite-fc-0027`。
            Windows下默认值为 ``。
        """
        check_model_name(model_name)
        self._model_name = model_name
        self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX,
                                                 model_name)
        self._model_epoch = model_epoch or AVAILABLE_MODELS[model_name][0]

        root = os.path.join(root, MODEL_VERSION)
        self._model_dir = os.path.join(root, self._model_name)
        self._assert_and_prepare_model_files()
        self._alphabet, inv_alph_dict = read_charset(
            os.path.join(self._model_dir, 'label_cn.txt'))

        self._cand_alph_idx = None
        if cand_alphabet is not None:
            self._cand_alph_idx = [0] + [
                inv_alph_dict[word] for word in cand_alphabet
            ]
            self._cand_alph_idx.sort()

        self._hp = Hyperparams()
        self._hp._loss_type = None  # infer mode

        # DCMMC: gpu context for mxnet
        if gpus > 0:
            self.context = [mx.context.gpu(i) for i in range(gpus)]
        else:
            self.context = [mx.context.cpu()]
        self._mod = self._get_module()
Exemplo n.º 4
0
    def __init__(
            self,
            model_name='densenet-lite-fc',
            model_epoch=None,
            cand_alphabet=None,
            root=data_dir(),
            context='cpu',
            name=None,
    ):
        """

        :param model_name: 模型名称
        :param model_epoch: 模型迭代次数
        :param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
        :param root: 模型文件所在的根目录。
            Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/1.1.0/conv-lite-fc-0027`。
            Windows下默认值为 ``。
        :param context: 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为CPU。
        :param name: 正在初始化的这个实例名称。如果需要同时初始化多个实例,需要为不同的实例指定不同的名称。
        """
        check_model_name(model_name)
        self._model_name = model_name
        self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX,
                                                 model_name)
        self._model_epoch = model_epoch or AVAILABLE_MODELS[model_name][0]

        root = os.path.join(root, MODEL_VERSION)
        self._model_dir = os.path.join(root, self._model_name)
        self._assert_and_prepare_model_files()
        self._alphabet, self._inv_alph_dict = read_charset(
            os.path.join(self._model_dir, 'label_cn.txt'))

        self._cand_alph_idx = None
        self.set_cand_alphabet(cand_alphabet)

        self._hp = Hyperparams()
        self._hp._loss_type = None  # infer mode
        # 传入''的话,也改成传入None
        self._net_prefix = None if name == '' else name

        self._mod = self._get_module(context)
Exemplo n.º 5
0
def main():
    charset_fp = os.path.join(data_dir(), 'models', 'label_cn.txt')
    alphabet, inv_alph_dict = read_charset(charset_fp)
    for idx in BAD_CHARS:
        print('idx: {}, char: {}'.format(idx, alphabet[idx]))
Exemplo n.º 6
0
def parse_args():
    # Parse command line arguments
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--emb_model_type",
        help="which embedding model to use",
        choices=EMB_MODEL_TYPES,
        type=str,
        default='conv-lite',
    )
    parser.add_argument(
        "--seq_model_type",
        help='which sequence model to use',
        default='fc',
        type=str,
        choices=SEQ_MODEL_TYPES,
    )
    parser.add_argument(
        "--train_file",
        help="Path to train txt file",
        type=str,
        default='data/sample-data-lst/train.txt',
    )
    parser.add_argument(
        "--test_file",
        help="Path to test txt file",
        type=str,
        default='data/sample-data-lst/test.txt',
    )
    parser.add_argument('--dataset',
                        help='file path for dataset hdf5',
                        type=str,
                        required=True)
    parser.add_argument('--charset',
                        help='file path for chat set of labels',
                        type=str,
                        required=True)
    parser.add_argument(
        '--debug',
        help='debug mode',
        action='store_true',
    )
    parser.add_argument(
        "--use_train_image_aug",
        action='store_true',
        help="Whether to use image augmentation for training",
    )
    parser.add_argument(
        "--gpu",
        help="Number of GPUs for training [Default 0, means using cpu]",
        type=int,
        default=0,
    )
    parser.add_argument(
        "--optimizer",
        help="optimizer for training [Default: Adam]",
        type=str,
        default='Adam',
    )
    parser.add_argument('--epoch',
                        type=int,
                        default=20,
                        help='train epochs [Default: 20]')
    parser.add_argument(
        '--load_epoch',
        type=int,
        help=
        'load the model on an epoch using the model-load-prefix [Default: no trained model will be loaded]',
    )
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate')
    parser.add_argument('--wd',
                        type=float,
                        default=0.0,
                        help='weight decay factor [Default: 0.0]')
    parser.add_argument(
        '--clip_gradient',
        type=float,
        default=None,
        help=
        'value for clip gradient [Default: None, means no gradient will be clip]',
    )
    parser.add_argument(
        "--out_model_dir",
        help='output model directory',
        default=os.path.join(data_dir(), __version__),
    )
    return parser.parse_args()
Exemplo n.º 7
0
    def __init__(
            self,
            model_name: str = 'densenet_lite_136-fc',
            *,
            cand_alphabet: Optional[Union[Collection, str]] = None,
            context: str = 'cpu',  # ['cpu', 'gpu', 'cuda']
            model_fp: Optional[str] = None,
            root: Union[str, Path] = data_dir(),
            **kwargs,
    ):
        """
        识别模型初始化函数。

        Args:
            model_name (str): 模型名称。默认为 `densenet_lite_136-fc`
            cand_alphabet (Optional[Union[Collection, str]]): 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
            context (str): 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为 `cpu`
            model_fp (Optional[str]): 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件('.ckpt' 文件)
            root (Union[str, Path]): 模型文件所在的根目录。
                Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/2.1/densenet_lite_136-fc`。
                Windows下默认值为 `C:/Users/<username>/AppData/Roaming/cnocr`。
            **kwargs: 目前未被使用。

        Examples:
            使用默认参数:
            >>> ocr = CnOcr()

            使用指定模型:
            >>> ocr = CnOcr(model_name='densenet_lite_136-fc')

            识别时只考虑数字:
            >>> ocr = CnOcr(model_name='densenet_lite_136-fc', cand_alphabet='0123456789')

        """
        if 'name' in kwargs:
            logger.warning(
                'param `name` is useless and deprecated since version %s' %
                MODEL_VERSION)
        check_model_name(model_name)
        check_context(context)
        self._model_name = model_name
        if context == 'gpu':
            context = 'cuda'
        self.context = context

        self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX,
                                                 model_name)
        model_epoch = AVAILABLE_MODELS.get(model_name, [None])[0]

        if model_epoch is not None:
            self._model_file_prefix = '%s-epoch=%03d' % (
                self._model_file_prefix,
                model_epoch,
            )

        self._assert_and_prepare_model_files(model_fp, root)
        self._vocab, self._letter2id = read_charset(VOCAB_FP)

        self._candidates = None
        self.set_cand_alphabet(cand_alphabet)

        self._model = self._get_model(context)