if not opt.random_sample:
    sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
else:
    sampler = None
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=opt.batchSize,
                                           shuffle=False,
                                           sampler=sampler,
                                           num_workers=int(opt.workers),
                                           collate_fn=dataset.alignCollate(
                                               height=opt.height,
                                               width=opt.width,
                                               keep_ratio=opt.keep_ratio))

val_dataset = dataset.listDataset(list_file=opt.valList,
                                  transform=dataset.resizeNormalize(
                                      (opt.width, opt.height)))

nclass = len(alphabet) + 3  # decoder的时候,需要的类别数,3 for SOS,EOS和blank
nc = 1

converter = utils.strLabelConverterForAttention(alphabet)
image = torch.FloatTensor(opt.batchSize, 3, opt.width, opt.height)
criterion = torch.nn.NLLLoss()  # 最后的输出要为log_softmax

encoder = model.encoder(opt.height, nc=nc, nh=256)
decoder = model.decoder(nh=256, nclass=nclass, dropout_p=0.1)

# continue training or use the pretrained model to initial the parameters of the encoder and decoder
encoder.apply(weights_init)
decoder.apply(weights_init)
if opt.encoder:
Ejemplo n.º 2
0
model.load_state_dict(
    torch.load(model_path, map_location=lambda storage, loc: storage))

model.eval()

# -------------------预处理-------------------

converter = utils.strLabelConverter(alphabet)
image = Image.open(img_path).convert("RGB")
image = image.convert('L')

scale = image.size[1] * 1.0 / 32
w = image.size[0] / scale
w = int(w)
transformer = dataset.resizeNormalize((w, 32))

image = transformer(image)
if torch.cuda.is_available():
    image = image.cuda()
image = image.view(1, *image.size())
image = Variable(image)

# -------------------获取结果-------------------

preds = model(image)

_, preds = preds.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
preds_size = Variable(torch.IntTensor([preds.size(0)]))
Ejemplo n.º 3
0
transform = None
train_dataset = dataset.listDataset(list_file =opt.trainlist, transform=transform)
assert train_dataset
if not opt.random_sample:
    sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
else:
    sampler = None
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=opt.batchSize,
    shuffle=False, sampler=sampler,
    num_workers=int(opt.workers),
    collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio))

if opt.mode == '1D':
    test_dataset = dataset.listDataset(list_file =opt.vallist, transform=dataset.resizeNormalize((opt.imgW, opt.imgH)))
else:
    test_dataset = dataset.listDataset(list_file =opt.vallist, transform=dataset.paddingNormalize(opt.imgH, opt.imgW))

nclass = len(alphabet) + 3          # decoder的时候,需要的类别数,3 for SOS,EOS和blank 
print('nclass:',nclass)
cfg.SEQUENCE.NUM_CHAR = nclass
nc = 1

converter = utils.strLabelConverterForAttention(alphabet)
# criterion = torch.nn.CrossEntropyLoss()
criterion = torch.nn.NLLLoss()              # 最后的输出要为log_softmax

if opt.mode == '1D':
    encoder = crnn.CNN(opt.imgH, nc, opt.nh, cfg)
    decoder = crnn.decoderV2(opt.nh, nclass, dropout_p=0.1)
Ejemplo n.º 4
0
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opt.batchSize,
                                               shuffle=False,
                                               sampler=sampler,
                                               num_workers=int(opt.workers),
                                               collate_fn=dataset.alignCollate(
                                                   imgH=opt.imgH,
                                                   imgW=opt.imgW,
                                                   keep_ratio=opt.keep_ratio))

    if opt.vallist.endswith(".h5"):
        #h5file, datasetImage='/train/image', datasetProf='/train/prof'
        test_dataset = h5dataset.H5Dataset(opt.vallist,
                                           datasetImage='/test/image',
                                           datasetProf='/test/prof',
                                           transform=dataset.resizeNormalize(
                                               (opt.imgW, opt.imgH)))
    else:
        test_dataset = dataset.listDataset(list_file=opt.vallist,
                                           transform=dataset.resizeNormalize(
                                               (opt.imgW, opt.imgH)))

    alphabet = utils.getAlphabetStr(opt.alphabet)

    nclass = len(alphabet) + 3  # decoder的时候,需要的类别数,3 for SOS,EOS和blank
    print(" -- Number of classes:", nclass)
    nc = 1

    converter = utils.strLabelConverterForAttention(alphabet)
    alphabet = utils.getAlphabetStr(opt.alphabet)

    # criterion = torch.nn.CrossEntropyLoss()
Ejemplo n.º 5
0
assert train_dataset
if not opt.random_sample:
    sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
else:
    sampler = None
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=opt.batchSize,
                                           shuffle=True,
                                           sampler=sampler,
                                           num_workers=int(opt.workers),
                                           collate_fn=dataset.alignCollate(
                                               imgH=opt.imgH,
                                               imgW=opt.imgW,
                                               keep_ratio=opt.keep_ratio))
test_dataset = dataset.lmdbDataset(root=opt.valroot,
                                   transform=dataset.resizeNormalize(
                                       (100, 32)))

nclass = len(opt.alphabet) + 1
nc = 1

converter = utils.strLabelConverter(opt.alphabet)
criterion = CTCLoss()


# custom weights initialization called on crnn
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)