def load_module(prefix, epoch, data_names, data_shapes, network=None, context='cpu'): """ Loads the model from checkpoint specified by prefix and epoch, binds it to an executor, and sets its parameters and returns a mx.mod.Module """ sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) if network is not None: sym = network # We don't need CTC loss for prediction, just a simple softmax will suffice. # We get the output of the layer just before the loss layer ('pred_fc') and add softmax on top pred_fc = sym.get_internals()['pred_fc_output'] sym = mx.sym.softmax(data=pred_fc) if not check_context(context): raise NotImplementedError('illegal value %s for parameter context' % context) if isinstance(context, str): context = mx.gpu() if context.lower() == 'gpu' else mx.cpu() mod = mx.mod.Module(symbol=sym, context=context, data_names=data_names, label_names=None) mod.bind(for_training=False, data_shapes=data_shapes) mod.set_params(arg_params, aux_params, allow_missing=False) return mod
def test_check_context(context, expected): assert check_context(context) == expected
def __init__( self, model_name: str = 'densenet_lite_136-fc', *, cand_alphabet: Optional[Union[Collection, str]] = None, context: str = 'cpu', # ['cpu', 'gpu', 'cuda'] model_fp: Optional[str] = None, root: Union[str, Path] = data_dir(), **kwargs, ): """ 识别模型初始化函数。 Args: model_name (str): 模型名称。默认为 `densenet_lite_136-fc` cand_alphabet (Optional[Union[Collection, str]]): 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围 context (str): 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为 `cpu` model_fp (Optional[str]): 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件('.ckpt' 文件) root (Union[str, Path]): 模型文件所在的根目录。 Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/2.1/densenet_lite_136-fc`。 Windows下默认值为 `C:/Users/<username>/AppData/Roaming/cnocr`。 **kwargs: 目前未被使用。 Examples: 使用默认参数: >>> ocr = CnOcr() 使用指定模型: >>> ocr = CnOcr(model_name='densenet_lite_136-fc') 识别时只考虑数字: >>> ocr = CnOcr(model_name='densenet_lite_136-fc', cand_alphabet='0123456789') """ if 'name' in kwargs: logger.warning( 'param `name` is useless and deprecated since version %s' % MODEL_VERSION) check_model_name(model_name) check_context(context) self._model_name = model_name if context == 'gpu': context = 'cuda' self.context = context self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX, model_name) model_epoch = AVAILABLE_MODELS.get(model_name, [None])[0] if model_epoch is not None: self._model_file_prefix = '%s-epoch=%03d' % ( self._model_file_prefix, model_epoch, ) self._assert_and_prepare_model_files(model_fp, root) self._vocab, self._letter2id = read_charset(VOCAB_FP) self._candidates = None self.set_cand_alphabet(cand_alphabet) self._model = self._get_model(context)