示例#1
0
文件: cn_ocr.py 项目: xinqiyang/cnocr
def load_module(prefix,
                epoch,
                data_names,
                data_shapes,
                network=None,
                context='cpu'):
    """
    Loads the model from checkpoint specified by prefix and epoch, binds it
    to an executor, and sets its parameters and returns a mx.mod.Module
    """
    sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
    if network is not None:
        sym = network

    # We don't need CTC loss for prediction, just a simple softmax will suffice.
    # We get the output of the layer just before the loss layer ('pred_fc') and add softmax on top
    pred_fc = sym.get_internals()['pred_fc_output']
    sym = mx.sym.softmax(data=pred_fc)

    if not check_context(context):
        raise NotImplementedError('illegal value %s for parameter context' %
                                  context)
    if isinstance(context, str):
        context = mx.gpu() if context.lower() == 'gpu' else mx.cpu()

    mod = mx.mod.Module(symbol=sym,
                        context=context,
                        data_names=data_names,
                        label_names=None)
    mod.bind(for_training=False, data_shapes=data_shapes)
    mod.set_params(arg_params, aux_params, allow_missing=False)
    return mod
示例#2
0
def test_check_context(context, expected):
    assert check_context(context) == expected
示例#3
0
文件: cn_ocr.py 项目: showme890/cnocr
    def __init__(
            self,
            model_name: str = 'densenet_lite_136-fc',
            *,
            cand_alphabet: Optional[Union[Collection, str]] = None,
            context: str = 'cpu',  # ['cpu', 'gpu', 'cuda']
            model_fp: Optional[str] = None,
            root: Union[str, Path] = data_dir(),
            **kwargs,
    ):
        """
        识别模型初始化函数。

        Args:
            model_name (str): 模型名称。默认为 `densenet_lite_136-fc`
            cand_alphabet (Optional[Union[Collection, str]]): 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
            context (str): 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为 `cpu`
            model_fp (Optional[str]): 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件('.ckpt' 文件)
            root (Union[str, Path]): 模型文件所在的根目录。
                Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/2.1/densenet_lite_136-fc`。
                Windows下默认值为 `C:/Users/<username>/AppData/Roaming/cnocr`。
            **kwargs: 目前未被使用。

        Examples:
            使用默认参数:
            >>> ocr = CnOcr()

            使用指定模型:
            >>> ocr = CnOcr(model_name='densenet_lite_136-fc')

            识别时只考虑数字:
            >>> ocr = CnOcr(model_name='densenet_lite_136-fc', cand_alphabet='0123456789')

        """
        if 'name' in kwargs:
            logger.warning(
                'param `name` is useless and deprecated since version %s' %
                MODEL_VERSION)
        check_model_name(model_name)
        check_context(context)
        self._model_name = model_name
        if context == 'gpu':
            context = 'cuda'
        self.context = context

        self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX,
                                                 model_name)
        model_epoch = AVAILABLE_MODELS.get(model_name, [None])[0]

        if model_epoch is not None:
            self._model_file_prefix = '%s-epoch=%03d' % (
                self._model_file_prefix,
                model_epoch,
            )

        self._assert_and_prepare_model_files(model_fp, root)
        self._vocab, self._letter2id = read_charset(VOCAB_FP)

        self._candidates = None
        self.set_cand_alphabet(cand_alphabet)

        self._model = self._get_model(context)