Пример #1
0
    def __init__(self, crnn_model):
        self.crnn_model = crnn_model

        # 网络常数的设置
        self.batchSize = 2
        workers = 1
        imgH = 32
        imgW = 280
        keep_ratio = True
        self.nepochs = 10
        self.acc = 0
        lr = 0.1

        self.image = torch.FloatTensor(self.batchSize, 3, imgH, imgH)
        self.text = torch.IntTensor(self.batchSize * 5)
        self.length = torch.IntTensor(self.batchSize)
        self.converter = strLabelConverter(''.join(alphabetChinese))
        self.optimizer = optim.Adadelta(crnn_model.parameters(), lr=lr)

        roots = glob('../data/ocr/*/*.jpg')
        # 此处未考虑字符平衡划分
        trainP, testP = train_test_split(roots, test_size=0.1)
        traindataset = PathDataset(trainP, alphabetChinese)
        self.testdataset = PathDataset(testP, alphabetChinese)
        self.criterion = CTCLoss()

        self.train_loader = torch.utils.data.DataLoader(
            traindataset,
            batch_size=self.batchSize,
            shuffle=False,
            sampler=None,
            num_workers=int(workers),
            collate_fn=alignCollate(imgH=imgH,
                                    imgW=imgW,
                                    keep_ratio=keep_ratio))
        self.interval = len(self.train_loader) // 2  ##评估模型
Пример #2
0
testdataset = PathDataset(testP, alphabetChinese)

batchSize = 32
workers = 1
imgH = 32
imgW = 280
keep_ratio = True
cuda = True
ngpu = 1
nh =256
sampler = randomSequentialSampler(traindataset, batchSize)
train_loader = torch.utils.data.DataLoader(
    traindataset, batch_size=batchSize,
    shuffle=False, sampler=None,
    num_workers=int(workers),
    collate_fn=alignCollate(imgH=imgH, imgW=imgW, keep_ratio=keep_ratio))

train_iter = iter(train_loader)


## 加载预训练模型权重
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


from crnn.models.crnn import CRNN