def train(model_name, index_dir, train_config_fp, resume_from_checkpoint, pretrained_model_fp): check_model_name(model_name) train_transform = T.Compose([ RandomStretchAug(min_ratio=0.5, max_ratio=1.5), # RandomCrop((8, 10)), T.RandomInvert(p=0.2), T.RandomApply([T.RandomRotation(degrees=1)], p=0.4), # T.RandomAutocontrast(p=0.05), # T.RandomPosterize(bits=4, p=0.3), # T.RandomAdjustSharpness(sharpness_factor=0.5, p=0.3), # T.RandomEqualize(p=0.3), # T.RandomApply([T.GaussianBlur(kernel_size=3)], p=0.5), NormalizeAug(), # RandomPaddingAug(p=0.5, max_pad_len=72), ]) val_transform = NormalizeAug() train_config = json.load(open(train_config_fp)) data_mod = OcrDataModule( index_dir=index_dir, vocab_fp=train_config['vocab_fp'], img_folder=train_config['img_folder'], train_transforms=train_transform, val_transforms=val_transform, batch_size=train_config['batch_size'], num_workers=train_config['num_workers'], pin_memory=train_config['pin_memory'], ) # train_ds = data_mod.train # for i in range(min(100, len(train_ds))): # visualize_example(train_transform(train_ds[i][0]), 'debugs/train-1-%d' % i) # visualize_example(train_transform(train_ds[i][0]), 'debugs/train-2-%d' % i) # visualize_example(train_transform(train_ds[i][0]), 'debugs/train-3-%d' % i) # val_ds = data_mod.val # for i in range(min(10, len(val_ds))): # visualize_example(val_transform(val_ds[i][0]), 'debugs/val-1-%d' % i) # visualize_example(val_transform(val_ds[i][0]), 'debugs/val-2-%d' % i) # visualize_example(val_transform(val_ds[i][0]), 'debugs/val-2-%d' % i) # return trainer = PlTrainer(train_config, ckpt_fn=['cnocr', 'v%s' % MODEL_VERSION, model_name]) model = gen_model(model_name, data_mod.vocab) logger.info(model) if pretrained_model_fp is not None: load_model_params(model, pretrained_model_fp) trainer.fit(model, datamodule=data_mod, resume_from_checkpoint=resume_from_checkpoint)
def __init__(self, model_name='conv-lite-fc', model_epoch=None, cand_alphabet=None, root=data_dir(), gpus=0): """ :param model_name: 模型名称 :param model_epoch: 模型迭代次数 :param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围 :param root: 模型文件所在的根目录。 Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/1.1.0/conv-lite-fc-0027`。 Windows下默认值为 ``。 """ check_model_name(model_name) self._model_name = model_name self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX, model_name) self._model_epoch = model_epoch or AVAILABLE_MODELS[model_name][0] root = os.path.join(root, MODEL_VERSION) self._model_dir = os.path.join(root, self._model_name) self._assert_and_prepare_model_files() self._alphabet, inv_alph_dict = read_charset( os.path.join(self._model_dir, 'label_cn.txt')) self._cand_alph_idx = None if cand_alphabet is not None: self._cand_alph_idx = [0] + [ inv_alph_dict[word] for word in cand_alphabet ] self._cand_alph_idx.sort() self._hp = Hyperparams() self._hp._loss_type = None # infer mode # DCMMC: gpu context for mxnet if gpus > 0: self.context = [mx.context.gpu(i) for i in range(gpus)] else: self.context = [mx.context.cpu()] self._mod = self._get_module()
def __init__( self, model_name='densenet-lite-fc', model_epoch=None, cand_alphabet=None, root=data_dir(), context='cpu', name=None, ): """ :param model_name: 模型名称 :param model_epoch: 模型迭代次数 :param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围 :param root: 模型文件所在的根目录。 Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/1.1.0/conv-lite-fc-0027`。 Windows下默认值为 ``。 :param context: 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为CPU。 :param name: 正在初始化的这个实例名称。如果需要同时初始化多个实例,需要为不同的实例指定不同的名称。 """ check_model_name(model_name) self._model_name = model_name self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX, model_name) self._model_epoch = model_epoch or AVAILABLE_MODELS[model_name][0] root = os.path.join(root, MODEL_VERSION) self._model_dir = os.path.join(root, self._model_name) self._assert_and_prepare_model_files() self._alphabet, self._inv_alph_dict = read_charset( os.path.join(self._model_dir, 'label_cn.txt')) self._cand_alph_idx = None self.set_cand_alphabet(cand_alphabet) self._hp = Hyperparams() self._hp._loss_type = None # infer mode # 传入''的话,也改成传入None self._net_prefix = None if name == '' else name self._mod = self._get_module(context)
def __init__( self, model_name: str = 'densenet_lite_136-fc', *, cand_alphabet: Optional[Union[Collection, str]] = None, context: str = 'cpu', # ['cpu', 'gpu', 'cuda'] model_fp: Optional[str] = None, root: Union[str, Path] = data_dir(), **kwargs, ): """ 识别模型初始化函数。 Args: model_name (str): 模型名称。默认为 `densenet_lite_136-fc` cand_alphabet (Optional[Union[Collection, str]]): 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围 context (str): 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为 `cpu` model_fp (Optional[str]): 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件('.ckpt' 文件) root (Union[str, Path]): 模型文件所在的根目录。 Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/2.1/densenet_lite_136-fc`。 Windows下默认值为 `C:/Users/<username>/AppData/Roaming/cnocr`。 **kwargs: 目前未被使用。 Examples: 使用默认参数: >>> ocr = CnOcr() 使用指定模型: >>> ocr = CnOcr(model_name='densenet_lite_136-fc') 识别时只考虑数字: >>> ocr = CnOcr(model_name='densenet_lite_136-fc', cand_alphabet='0123456789') """ if 'name' in kwargs: logger.warning( 'param `name` is useless and deprecated since version %s' % MODEL_VERSION) check_model_name(model_name) check_context(context) self._model_name = model_name if context == 'gpu': context = 'cuda' self.context = context self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX, model_name) model_epoch = AVAILABLE_MODELS.get(model_name, [None])[0] if model_epoch is not None: self._model_file_prefix = '%s-epoch=%03d' % ( self._model_file_prefix, model_epoch, ) self._assert_and_prepare_model_files(model_fp, root) self._vocab, self._letter2id = read_charset(VOCAB_FP) self._candidates = None self.set_cand_alphabet(cand_alphabet) self._model = self._get_model(context)
def gen_model(model_name, vocab): check_model_name(model_name) model = OcrModel.from_name(model_name, vocab) return model