Ejemplo n.º 1
0
def crnnSource():
    """
    加载模型
    """
    if chinsesModel:
        alphabet = keys.alphabetChinese##中英文模型
    else:
        alphabet = keys.alphabetEnglish##英文模型
        
    converter = strLabelConverter(alphabet)
    if torch.cuda.is_available() and GPU:
        model = CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=LSTMFLAG).cuda()##LSTMFLAG=True crnn 否则 dense ocr
    else:
        model = CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=LSTMFLAG).cpu()

    trainWeights = torch.load(ocrModel,map_location=lambda storage, loc: storage)
    modelWeights = OrderedDict()
    for k, v in trainWeights.items():
        name = k.replace('module.','') # remove `module.`
        modelWeights[name] = v
    # load params
  
    model.load_state_dict(modelWeights)

    return model,converter
Ejemplo n.º 2
0
                                               imgW=imgW,
                                               keep_ratio=keep_ratio))

train_iter = iter(train_loader)


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


model = CRNN(32, 1, len(alphabetEnglish) + 1, 256, 1, lstmFlag=LSTMFLAG)
model.apply(weights_init)
preWeightDict = torch.load(
    ocrModel, map_location=lambda storage, loc: storage)  ##加入项目训练的权重

modelWeightDict = model.state_dict()

for k, v in preWeightDict.items():
    name = k.replace('module.', '')  # remove `module.`
    if 'rnn.1.embedding' not in name:  ##不加载最后一层权重
        modelWeightDict[name] = v

model.load_state_dict(modelWeightDict)

lr = 0.1
optimizer = optim.Adadelta(model.parameters(), lr=lr)
Ejemplo n.º 3
0
        alphabet = alphabetChinese
        if LSTMFLAG:
            ocrModel = ocrModelTorchLstm
        else:
            ocrModel = ocrModelTorchDense

    else:
        ocrModel = ocrModelTorchEng
        alphabet = alphabetEnglish
        LSTMFLAG = True

    nclass = len(alphabet) + 1
    crnn = CRNN(32,
                1,
                nclass,
                256,
                leakyRelu=False,
                lstmFlag=LSTMFLAG,
                GPU=GPU,
                alphabet=alphabet)
    print('[INFO] Successfully initialize CRNN recognizer')

    if os.path.exists(ocrModel):
        crnn.load_weights(ocrModel)
        print("[INFO] Successfully load Torch-ocr model...")
    else:
        print("download model or tranform model with tools!")

    ocr = crnn.predict_job

if __name__ == "__main__":