def crnnSource(): """ 加载模型 """ if chinsesModel: alphabet = keys.alphabetChinese##中英文模型 else: alphabet = keys.alphabetEnglish##英文模型 converter = strLabelConverter(alphabet) if torch.cuda.is_available() and GPU: model = CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=LSTMFLAG).cuda()##LSTMFLAG=True crnn 否则 dense ocr else: model = CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=LSTMFLAG).cpu() trainWeights = torch.load(ocrModel,map_location=lambda storage, loc: storage) modelWeights = OrderedDict() for k, v in trainWeights.items(): name = k.replace('module.','') # remove `module.` modelWeights[name] = v # load params model.load_state_dict(modelWeights) return model,converter
imgW=imgW, keep_ratio=keep_ratio)) train_iter = iter(train_loader) def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) model = CRNN(32, 1, len(alphabetEnglish) + 1, 256, 1, lstmFlag=LSTMFLAG) model.apply(weights_init) preWeightDict = torch.load( ocrModel, map_location=lambda storage, loc: storage) ##加入项目训练的权重 modelWeightDict = model.state_dict() for k, v in preWeightDict.items(): name = k.replace('module.', '') # remove `module.` if 'rnn.1.embedding' not in name: ##不加载最后一层权重 modelWeightDict[name] = v model.load_state_dict(modelWeightDict) lr = 0.1 optimizer = optim.Adadelta(model.parameters(), lr=lr)
alphabet = alphabetChinese if LSTMFLAG: ocrModel = ocrModelTorchLstm else: ocrModel = ocrModelTorchDense else: ocrModel = ocrModelTorchEng alphabet = alphabetEnglish LSTMFLAG = True nclass = len(alphabet) + 1 crnn = CRNN(32, 1, nclass, 256, leakyRelu=False, lstmFlag=LSTMFLAG, GPU=GPU, alphabet=alphabet) print('[INFO] Successfully initialize CRNN recognizer') if os.path.exists(ocrModel): crnn.load_weights(ocrModel) print("[INFO] Successfully load Torch-ocr model...") else: print("download model or tranform model with tools!") ocr = crnn.predict_job if __name__ == "__main__":