コード例 #1
0
def train(model_name, index_dir, train_config_fp, resume_from_checkpoint,
          pretrained_model_fp):
    check_model_name(model_name)
    train_transform = T.Compose([
        RandomStretchAug(min_ratio=0.5, max_ratio=1.5),
        # RandomCrop((8, 10)),
        T.RandomInvert(p=0.2),
        T.RandomApply([T.RandomRotation(degrees=1)], p=0.4),
        # T.RandomAutocontrast(p=0.05),
        # T.RandomPosterize(bits=4, p=0.3),
        # T.RandomAdjustSharpness(sharpness_factor=0.5, p=0.3),
        # T.RandomEqualize(p=0.3),
        # T.RandomApply([T.GaussianBlur(kernel_size=3)], p=0.5),
        NormalizeAug(),
        # RandomPaddingAug(p=0.5, max_pad_len=72),
    ])
    val_transform = NormalizeAug()

    train_config = json.load(open(train_config_fp))

    data_mod = OcrDataModule(
        index_dir=index_dir,
        vocab_fp=train_config['vocab_fp'],
        img_folder=train_config['img_folder'],
        train_transforms=train_transform,
        val_transforms=val_transform,
        batch_size=train_config['batch_size'],
        num_workers=train_config['num_workers'],
        pin_memory=train_config['pin_memory'],
    )

    # train_ds = data_mod.train
    # for i in range(min(100, len(train_ds))):
    #     visualize_example(train_transform(train_ds[i][0]), 'debugs/train-1-%d' % i)
    #     visualize_example(train_transform(train_ds[i][0]), 'debugs/train-2-%d' % i)
    #     visualize_example(train_transform(train_ds[i][0]), 'debugs/train-3-%d' % i)
    # val_ds = data_mod.val
    # for i in range(min(10, len(val_ds))):
    #     visualize_example(val_transform(val_ds[i][0]), 'debugs/val-1-%d' % i)
    #     visualize_example(val_transform(val_ds[i][0]), 'debugs/val-2-%d' % i)
    #     visualize_example(val_transform(val_ds[i][0]), 'debugs/val-2-%d' % i)
    # return

    trainer = PlTrainer(train_config,
                        ckpt_fn=['cnocr',
                                 'v%s' % MODEL_VERSION, model_name])
    model = gen_model(model_name, data_mod.vocab)
    logger.info(model)

    if pretrained_model_fp is not None:
        load_model_params(model, pretrained_model_fp)

    trainer.fit(model,
                datamodule=data_mod,
                resume_from_checkpoint=resume_from_checkpoint)
コード例 #2
0
    def __init__(self,
                 model_name='conv-lite-fc',
                 model_epoch=None,
                 cand_alphabet=None,
                 root=data_dir(),
                 gpus=0):
        """

        :param model_name: 模型名称
        :param model_epoch: 模型迭代次数
        :param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
        :param root: 模型文件所在的根目录。
            Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/1.1.0/conv-lite-fc-0027`。
            Windows下默认值为 ``。
        """
        check_model_name(model_name)
        self._model_name = model_name
        self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX,
                                                 model_name)
        self._model_epoch = model_epoch or AVAILABLE_MODELS[model_name][0]

        root = os.path.join(root, MODEL_VERSION)
        self._model_dir = os.path.join(root, self._model_name)
        self._assert_and_prepare_model_files()
        self._alphabet, inv_alph_dict = read_charset(
            os.path.join(self._model_dir, 'label_cn.txt'))

        self._cand_alph_idx = None
        if cand_alphabet is not None:
            self._cand_alph_idx = [0] + [
                inv_alph_dict[word] for word in cand_alphabet
            ]
            self._cand_alph_idx.sort()

        self._hp = Hyperparams()
        self._hp._loss_type = None  # infer mode

        # DCMMC: gpu context for mxnet
        if gpus > 0:
            self.context = [mx.context.gpu(i) for i in range(gpus)]
        else:
            self.context = [mx.context.cpu()]
        self._mod = self._get_module()
コード例 #3
0
    def __init__(
            self,
            model_name='densenet-lite-fc',
            model_epoch=None,
            cand_alphabet=None,
            root=data_dir(),
            context='cpu',
            name=None,
    ):
        """

        :param model_name: 模型名称
        :param model_epoch: 模型迭代次数
        :param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
        :param root: 模型文件所在的根目录。
            Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/1.1.0/conv-lite-fc-0027`。
            Windows下默认值为 ``。
        :param context: 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为CPU。
        :param name: 正在初始化的这个实例名称。如果需要同时初始化多个实例,需要为不同的实例指定不同的名称。
        """
        check_model_name(model_name)
        self._model_name = model_name
        self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX,
                                                 model_name)
        self._model_epoch = model_epoch or AVAILABLE_MODELS[model_name][0]

        root = os.path.join(root, MODEL_VERSION)
        self._model_dir = os.path.join(root, self._model_name)
        self._assert_and_prepare_model_files()
        self._alphabet, self._inv_alph_dict = read_charset(
            os.path.join(self._model_dir, 'label_cn.txt'))

        self._cand_alph_idx = None
        self.set_cand_alphabet(cand_alphabet)

        self._hp = Hyperparams()
        self._hp._loss_type = None  # infer mode
        # 传入''的话,也改成传入None
        self._net_prefix = None if name == '' else name

        self._mod = self._get_module(context)
コード例 #4
0
ファイル: cn_ocr.py プロジェクト: showme890/cnocr
    def __init__(
            self,
            model_name: str = 'densenet_lite_136-fc',
            *,
            cand_alphabet: Optional[Union[Collection, str]] = None,
            context: str = 'cpu',  # ['cpu', 'gpu', 'cuda']
            model_fp: Optional[str] = None,
            root: Union[str, Path] = data_dir(),
            **kwargs,
    ):
        """
        识别模型初始化函数。

        Args:
            model_name (str): 模型名称。默认为 `densenet_lite_136-fc`
            cand_alphabet (Optional[Union[Collection, str]]): 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
            context (str): 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为 `cpu`
            model_fp (Optional[str]): 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件('.ckpt' 文件)
            root (Union[str, Path]): 模型文件所在的根目录。
                Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/2.1/densenet_lite_136-fc`。
                Windows下默认值为 `C:/Users/<username>/AppData/Roaming/cnocr`。
            **kwargs: 目前未被使用。

        Examples:
            使用默认参数:
            >>> ocr = CnOcr()

            使用指定模型:
            >>> ocr = CnOcr(model_name='densenet_lite_136-fc')

            识别时只考虑数字:
            >>> ocr = CnOcr(model_name='densenet_lite_136-fc', cand_alphabet='0123456789')

        """
        if 'name' in kwargs:
            logger.warning(
                'param `name` is useless and deprecated since version %s' %
                MODEL_VERSION)
        check_model_name(model_name)
        check_context(context)
        self._model_name = model_name
        if context == 'gpu':
            context = 'cuda'
        self.context = context

        self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX,
                                                 model_name)
        model_epoch = AVAILABLE_MODELS.get(model_name, [None])[0]

        if model_epoch is not None:
            self._model_file_prefix = '%s-epoch=%03d' % (
                self._model_file_prefix,
                model_epoch,
            )

        self._assert_and_prepare_model_files(model_fp, root)
        self._vocab, self._letter2id = read_charset(VOCAB_FP)

        self._candidates = None
        self.set_cand_alphabet(cand_alphabet)

        self._model = self._get_model(context)
コード例 #5
0
ファイル: cn_ocr.py プロジェクト: showme890/cnocr
def gen_model(model_name, vocab):
    check_model_name(model_name)
    model = OcrModel.from_name(model_name, vocab)
    return model