Ejemplo n.º 1
0
def crnnSource():
    """
    加载模型
    """
    if chinsesModel:
        alphabet = keys.alphabetChinese##中英文模型
    else:
        alphabet = keys.alphabetEnglish##英文模型
        
    converter = strLabelConverter(alphabet)
    if torch.cuda.is_available() and GPU:
        model = CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=LSTMFLAG).cuda()##LSTMFLAG=True crnn 否则 dense ocr
    else:
        model = CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=LSTMFLAG).cpu()

    trainWeights = torch.load(ocrModel,map_location=lambda storage, loc: storage)
    modelWeights = OrderedDict()
    for k, v in trainWeights.items():
        name = k.replace('module.','') # remove `module.`
        modelWeights[name] = v
    # load params
  
    model.load_state_dict(modelWeights)

    return model,converter
Ejemplo n.º 2
0
        m.bias.data.fill_(0)


model = CRNN(32, 1, len(alphabetEnglish) + 1, 256, 1, lstmFlag=LSTMFLAG)
model.apply(weights_init)
preWeightDict = torch.load(
    ocrModel, map_location=lambda storage, loc: storage)  ##加入项目训练的权重

modelWeightDict = model.state_dict()

for k, v in preWeightDict.items():
    name = k.replace('module.', '')  # remove `module.`
    if 'rnn.1.embedding' not in name:  ##不加载最后一层权重
        modelWeightDict[name] = v

model.load_state_dict(modelWeightDict)

lr = 0.1
optimizer = optim.Adadelta(model.parameters(), lr=lr)
converter = strLabelConverter(''.join(alphabetEnglish))
criterion = CTCLoss()

image = torch.FloatTensor(batchSize, 3, imgH, imgH)
text = torch.IntTensor(batchSize * 5)
length = torch.IntTensor(batchSize)

if torch.cuda.is_available():
    model.cuda()
    model = torch.nn.DataParallel(model, device_ids=[0])  ##转换为多GPU训练模型
    image = image.cuda()
    criterion = criterion.cuda()