def __init__(self, crnn_model): self.crnn_model = crnn_model # 网络常数的设置 self.batchSize = 2 workers = 1 imgH = 32 imgW = 280 keep_ratio = True self.nepochs = 10 self.acc = 0 lr = 0.1 self.image = torch.FloatTensor(self.batchSize, 3, imgH, imgH) self.text = torch.IntTensor(self.batchSize * 5) self.length = torch.IntTensor(self.batchSize) self.converter = strLabelConverter(''.join(alphabetChinese)) self.optimizer = optim.Adadelta(crnn_model.parameters(), lr=lr) roots = glob('../data/ocr/*/*.jpg') # 此处未考虑字符平衡划分 trainP, testP = train_test_split(roots, test_size=0.1) traindataset = PathDataset(trainP, alphabetChinese) self.testdataset = PathDataset(testP, alphabetChinese) self.criterion = CTCLoss() self.train_loader = torch.utils.data.DataLoader( traindataset, batch_size=self.batchSize, shuffle=False, sampler=None, num_workers=int(workers), collate_fn=alignCollate(imgH=imgH, imgW=imgW, keep_ratio=keep_ratio)) self.interval = len(self.train_loader) // 2 ##评估模型
testdataset = PathDataset(testP, alphabetChinese) batchSize = 32 workers = 1 imgH = 32 imgW = 280 keep_ratio = True cuda = True ngpu = 1 nh =256 sampler = randomSequentialSampler(traindataset, batchSize) train_loader = torch.utils.data.DataLoader( traindataset, batch_size=batchSize, shuffle=False, sampler=None, num_workers=int(workers), collate_fn=alignCollate(imgH=imgH, imgW=imgW, keep_ratio=keep_ratio)) train_iter = iter(train_loader) ## 加载预训练模型权重 def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) from crnn.models.crnn import CRNN