Exemplo n.º 1
0
class crnn(object):
    def __init__(self):
        if chinsesModel:
            alphabet = keys.alphabetChinese
        else:
            alphabet = keys.alphabetEnglish

        self.converter = util.strLabelConverter(alphabet)
        if torch.cuda.is_available() and GPU:
            # LSTMFLAG=True crnn 否则 dense ocr
            self.model = CRNN(32,
                              1,
                              len(alphabet) + 1,
                              256,
                              1,
                              lstmFlag=LSTMFLAG).cuda()
        else:
            self.model = CRNN(32,
                              1,
                              len(alphabet) + 1,
                              256,
                              1,
                              lstmFlag=LSTMFLAG).cpu()

        state_dict = torch.load(ocrModel,
                                map_location=lambda storage, loc: storage)

        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k.replace('module.', '')  # remove `module.`
            new_state_dict[name] = v
        # load params

        self.model.load_state_dict(new_state_dict)

    def crnnOcr(self, image):
        scale = image.size[1] * 1.0 / 32
        w = image.size[0] / scale
        w = int(w)
        # print "im size:{},{}".format(image.size,w)
        transformer = dataset.resizeNormalize((w, 32))
        if torch.cuda.is_available() and GPU:
            image = transformer(image).cuda()
        else:
            image = transformer(image).cpu()

        image = image.view(1, *image.size())
        image = Variable(image)
        self.model.eval()
        preds = self.model(image)
        _, preds = preds.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        preds_size = Variable(torch.IntTensor([preds.size(0)]))
        sim_pred = self.converter.decode(preds.data,
                                         preds_size.data,
                                         raw=False)

        return sim_pred
Exemplo n.º 2
0
def crnn_recognition(part_image, app):
    model = CRNN(32, 1, nclass, 256)
    # if torch.cuda.is_available():
    #     model = model.cuda()
    app.logger.info('loading pretrained model from {0}'.format(
        params.crnn_model_path))

    trainWeights = torch.load(params.crnn_model_path,
                              map_location=lambda storage, loc: storage)
    modelWeights = OrderedDict()
    for k, v in trainWeights.items():
        name = k.replace('module.', '')  # remove `module.`
        modelWeights[name] = v

    model.load_state_dict(modelWeights)
    converter = crnn.utils.strLabelConverter(alphabet)

    image = part_image.convert('L')

    w = int(image.size[0] / (280 * 1.0 / 160))
    transformer = crnn.dataset.resizeNormalize((w, 32))
    image = transformer(image)
    # if torch.cuda.is_available():
    #     image = image.cuda()
    image = image.view(1, *image.size())
    image = Variable(image)

    model.eval()
    preds = model(image)

    _, preds = preds.max(2)
    preds = preds.transpose(1, 0).contiguous().view(-1)

    preds_size = Variable(torch.IntTensor([preds.size(0)]))
    sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
    return sim_pred
Exemplo n.º 3
0
    n = len(train_loader)
    pbar = Progbar(target=n)
    train_iter = iter(train_loader)
    loss = 0
    for j in range(n):
        for p in model.named_parameters():
            p[1].requires_grad = True

        model.train()
        cpu_images, cpu_texts = train_iter.next()
        cost = trainBatch(model, criterion, optimizer, cpu_images, cpu_texts)

        loss += cost.data.numpy()

        if (j + 1) % interval == 0:
            curAcc = val(model, testdataset, max_iter=1024)
            if curAcc > acc:
                acc = curAcc
                torch.save(model.state_dict(), 'train/ocr/modellstm.pth')

        pbar.update(j + 1, values=[('loss', loss / ((j + 1) * batchSize)), ('acc', acc)])


## 预测demo
model.eval()
N  = len(testdataset)
im,label = testdataset[np.random.randint(0,N)]
pred = predict(im)
print('true:{},pred:{}'.format(label,pred))
im