Пример #1
0
    opt.max_label_length,
    transforms.Compose([Rescale((32, 100)),
                        Gray(),
                        ZeroMean(),
                        ToTensor()]),
)

train_loader = DataLoader(train_dataset, 64, True)

device = opt.device
net = CRNN()
net.apply(weights_init)
net = net.to(device)
net.zero_grad()

params = net.parameters()

ctc_loss = CTCLoss()
optimizer = optim.Adam(params, weight_decay=1e-5)
best_loss = 50

print("gc is enabled", gc.isenabled())

for epoch in trange(opt.epoch):
    running_loss = 0.0
    for i, train_data in tqdm(enumerate(train_loader, 0)):
        inputs, labels, labels_length = (
            train_data["image"],
            train_data["label"],
            train_data["label_length"],
        )
Пример #2
0
train_loader = DataLoader(train_dataset,
                          batch_size=128,
                          shuffle=True,
                          num_workers=0)

device = opt.device
# 是否继续训练
if opt.load_path:
    net = torch.load(opt.load_path)
else:
    net = CRNN()

net.apply(weights_init)
print(net)
print(net.parameters())
net = net.to(device=device)
net.zero_grad()

params = net.parameters()

ctc_loss = CTCLoss(blank=0)
optimizer = optim.Adam(params=params, lr=0.001, weight_decay=1e-5)
best_loss = 50
print('gc is enabel:', gc.isenabled())
for epoch in trange(opt.epoch):
    running_loss = 0.0
    for i, train_data in tqdm(enumerate(train_loader, 0)):
        inputs, labels, labels_length = train_data['image'], train_data[
            'label'], train_data['label_length']
Пример #3
0
modelWeightDict = model.state_dict()

for k, v in preWeightDict.items():
    name = k.replace('module.', '')  # remove `module.`
    if 'rnn.1.embedding' not in name:  ##不加载最后一层权重
        modelWeightDict[name] = v

model.load_state_dict(modelWeightDict)



##优化器
from crnn.util import strLabelConverter
lr = 0.1
optimizer = optim.Adadelta(model.parameters(), lr=lr)
converter = strLabelConverter(''.join(alphabetChinese))
criterion = CTCLoss()


from train.ocr.dataset import resizeNormalize
from crnn.util import loadData
image = torch.FloatTensor(batchSize, 3, imgH, imgH)
text = torch.IntTensor(batchSize * 5)
length = torch.IntTensor(batchSize)

if torch.cuda.is_available():
    model.cuda()
    model = torch.nn.DataParallel(model, device_ids=[0])##转换为多GPU训练模型
    image = image.cuda()
    criterion = criterion.cuda()
Пример #4
0
modelWeightDict = model.state_dict()

for k, v in preWeightDict.items():
    name = k.replace('module.', '')  # remove `module.`
    if 'rnn.1.embedding' not in name:  ##不加载最后一层权重
        modelWeightDict[name] = v

model.load_state_dict(modelWeightDict)
print('model has been loaded')
#print(model)

# if dense optimizer = SGD; if lstm optimizer = adadelta
# lr = 0.1

# optimizer = optim.Adadelta(model.parameters(), lr=0.001)
optimizer = optim.SGD(model.parameters(), lr=learning_reate, momentum=0.6)
converter = strLabelConverter(''.join(alphabetChinese))
criterion = CTCLoss()

image = torch.FloatTensor(batchSize, 3, imgH, imgH)
text = torch.IntTensor(batchSize * 5)
length = torch.IntTensor(batchSize)

if torch.cuda.is_available():
    model.cuda()
    model = torch.nn.DataParallel(model, device_ids=[0])  ##转换为多GPU训练模型
    image = image.cuda()
    criterion = criterion.cuda()

acc = 0
interval = len(train_loader) // display_inter  ##评估模型