コード例 #1
0
ファイル: crnn.py プロジェクト: zhoul14/chineseocr
def crnnSource():
    if chinsesModel:
        alphabet = keys.alphabetChinese
    else:
        alphabet = keys.alphabetEnglish

    converter = util.strLabelConverter(alphabet)
    if torch.cuda.is_available() and GPU:
        model = crnn.CRNN(
            32, 1, len(alphabet) + 1, 256, 1,
            lstmFlag=LSTMFLAG).cuda()  ##LSTMFLAG=True crnn 否则 dense ocr
    else:
        model = crnn.CRNN(32, 1, len(alphabet) + 1, 256, 1,
                          lstmFlag=LSTMFLAG).cpu()

    state_dict = torch.load(ocrModel,
                            map_location=lambda storage, loc: storage)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace('module.', '')  # remove `module.`
        new_state_dict[name] = v
    # load params

    model.load_state_dict(new_state_dict)
    model.eval()

    return model, converter
コード例 #2
0
def crnnSource():
    if cfg.chinese_model:
        alphabet = keys.alphabetChinese
    else:
        alphabet = keys.alphabetEnglish

    converter = strLabelConverter(alphabet)
    if torch.cuda.is_available() and cfg.GPU:
        model = crnn.CRNN(32,
                          1,
                          len(alphabet) + 1,
                          256,
                          1,
                          lstmFlag=cfg.lstm_flag).cuda()
    else:
        model = crnn.CRNN(32,
                          1,
                          len(alphabet) + 1,
                          256,
                          1,
                          lstmFlag=cfg.lstm_flag).cpu()

    state_dict = torch.load(cfg.ocr_model,
                            map_location=lambda storage, loc: storage)

    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace('module.', '')  # remove `module.` torch的版本问题
        new_state_dict[name] = v
    # load params
    model.load_state_dict(new_state_dict)
    model.eval()

    return model, converter
コード例 #3
0
def predict_img(imgpath):
    converter = util.strLabelConverter(alphabet)
    model = crnn.CRNN(32, 1, len(alphabet) + 1, 256, 1,
                      lstmFlag=LSTMFLAG).cpu()
    ocrModel = './ocr-dense.pth'
    # ocrModel = './models/ocr-dense.pth'
    state_dict = torch.load(ocrModel,
                            map_location=lambda storage, loc: storage)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace('module.', '')  # remove `module.`
        new_state_dict[name] = v
    # load params

    model.load_state_dict(new_state_dict)
    model.eval()
    # imgpath = 'j8yc.png'
    image = Image.open(imgpath).convert('L')
    scale = image.size[1] * 1.0 / 32
    w = image.size[0] / scale
    w = int(w)
    # print "im size:{},{}".format(image.size,w)
    transformer = dataset.resizeNormalize((w, 32))
    image = transformer(image).cpu()
    image = image.view(1, *image.size())
    image = Variable(image)
    model.eval()
    preds = model(image)
    _, preds = preds.max(2)
    preds = preds.transpose(1, 0).contiguous().view(-1)
    preds_size = Variable(torch.IntTensor([preds.size(0)]))
    sim_pred = converter.decode(preds.data, preds_size.data, raw=False)

    # print(sim_pred)
    return sim_pred
コード例 #4
0
ファイル: network_keras.py プロジェクト: miketes/OCR_Invoice
 def predict(self,image):
     image = resizeNormalize(image,32)
     image = image.astype(np.float32)
     image = np.array([[image]])
     global graph
     with graph.as_default():
       preds       = self.model.predict(image)
     #preds = preds[0]
     preds = np.argmax(preds,axis=2).reshape((-1,))
     raw = strLabelConverter(preds,self.alphabet)
     return raw
コード例 #5
0
ファイル: web_test.py プロジェクト: lihow/chinese-ocr-win
def crnnSource():
    alphabet = keys1.alphabet
    converter = util.strLabelConverter(alphabet)
    if torch.cuda.is_available() and GPU:
       model = crnn.CRNN(32, 1, len(alphabet)+1, 256, 1).cuda()
    else:
        model = crnn.CRNN(32, 1, len(alphabet)+1, 256, 1).cpu()
    path = './crnn/samples/model_acc97.pth'
    model.eval()
    model.load_state_dict(torch.load(path))
    return model,converter
コード例 #6
0
ファイル: network_dnn.py プロジェクト: yunWJR/bd_ocr
 def predict(self, image):
     image = resizeNormalize(image, 32)
     image = image.astype(np.float32)
     image = np.array([[image]])
     self.model.setInput(image)
     preds = self.model.forward()
     preds = preds.transpose(0, 2, 3, 1)
     preds = preds[0]
     preds = np.argmax(preds, axis=2).reshape((-1, ))
     raw = strLabelConverter(preds, self.alphabet)
     return raw
コード例 #7
0
def crnnSource(net_path, alphabet):
    converter = util.strLabelConverter(alphabet)

    if config.PLATFORM == "GPU":
        model = crnn.CRNN(32, 1, len(alphabet) + 1, 256, 1).cuda()
        model.load_state_dict(torch.load(net_path))
    else:
        model = crnn.CRNN(32, 1, len(alphabet) + 1, 256, 1)
        model.load_state_dict(
            torch.load(net_path, map_location=lambda storage, loc: storage))
    #model.load_state_dict(torch.load(net_path))
    return model, converter
コード例 #8
0
ファイル: crnnport.py プロジェクト: hyb1234hi/argus
def crnnSource(net_path, alphabet):
    converter = util.strLabelConverter(alphabet)

    if cfg.PLATFORM == "GPU":
        model = crnn.CRNN(32, 1, len(alphabet) + 1, 256, 1).cuda()
        #model = torch.nn.DataParallel(model, device_ids=range(1))
        model.load_state_dict(torch.load(net_path))
    else:
        model = crnn.CRNN(32, 1, len(alphabet) + 1, 256, 1)
        model.load_state_dict(torch.load(net_path, map_location=lambda storage, loc: storage))

    return model,converter
コード例 #9
0
ファイル: network_torch.py プロジェクト: miketes/OCR_Invoice
    def predict(self, image):
        image = resizeNormalize(image, 32)
        image = image.astype(np.float32)
        image = torch.from_numpy(image)
        if torch.cuda.is_available() and self.GPU:
            image = image.cuda()
        else:
            image = image.cpu()

        image = image.view(1, 1, *image.size())
        image = Variable(image)
        preds = self(image)
        _, preds = preds.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        raw = strLabelConverter(preds, self.alphabet)
        return raw
コード例 #10
0
    def predict_batch(self, boxes, batch_size=1):
        """
        predict on batch
        """

        N = len(boxes)
        res = []
        imgW = 0
        batch = N // batch_size
        if batch * batch_size != N:
            batch += 1
        for i in range(batch):
            tmpBoxes = boxes[i * batch_size:(i + 1) * batch_size]
            imageBatch = []
            imgW = 0
            for box in tmpBoxes:
                img = box['img']
                image = resizeNormalize(img, 32)
                h, w = image.shape[:2]
                imgW = max(imgW, w)
                imageBatch.append(np.array([image]))

            imageArray = np.zeros((len(imageBatch), 1, 32, imgW),
                                  dtype=np.float32)
            n = len(imageArray)
            for j in range(n):
                _, h, w = imageBatch[j].shape
                imageArray[j][:, :, :w] = imageBatch[j]

            image = torch.from_numpy(imageArray)
            image = Variable(image)
            if torch.cuda.is_available() and self.GPU:
                image = image.cuda()
            else:
                image = image.cpu()

            preds = self(image)
            preds = preds.argmax(2)
            n = preds.shape[1]
            for j in range(n):
                res.append(strLabelConverter(preds[:, j], self.alphabet))

        for i in range(N):
            boxes[i]['text'] = res[i]
        return boxes
コード例 #11
0
    def job(self, filenames, res_save=False):
        self.filenames = filenames
        self.yolo_dataloader = textdetect.YoloImageGenerator(
            self.filenames, batch_size=yolo_batchsize)
        result = {}
        start = time.time()
        print("\n" + "#" * 30 + f"开始检测,时间{start}" + "#" * 30 + "\n")
        print("\n" + f"[INFO] 一共{len(self.filenames)}张图片" + "\n")
        for batch_img, batch_shape, batch_shape_padded, batch_filenames in tqdm(
                self.yolo_dataloader):
            # 从所有待测图片中批读取图片进行文字检测
            batch_preds = self.text_detector(batch_img)
            batch_boxes, batch_scores = net_output_process(
                batch_preds, batch_shape_padded, batch_shape_padded)
            for img, filename, boxes in zip(batch_img, batch_filenames,
                                            batch_boxes):
                # 遍历批图片逐图进行ocr
                result[filename] = []
                partImgs = cut_batch(img, filename, boxes, save=False)
                temp_partImg_loader = ocr.OcrDataGenerator(partImgs,
                                                           batch_size=16,
                                                           GPU=False)
                for batch_partImg in temp_partImg_loader:
                    # 从批截取图片中进行ocr
                    preds = self.ocr(
                        batch_partImg)  # size of [seq_len,batchsize,nclass]
                    preds = preds.argmax(axis=2)
                    preds = preds.permute(1, 0)
                    for line in preds:
                        # 逐句解码
                        result[filename].append(
                            strLabelConverter(line, self.ocr.alphabet))
        end = time.time()
        print("\n" + "#" * 30 + f"结束检测,时间{end}" + "#" * 30 + "\n")
        print(
            "\n" +
            f"[INFO]一共用时{end - start}秒,每张图片平均用时{(end - start)/len(self.filenames)}秒。"
            + "\n")
        print(result)

        if res_save:
            with open('result/result.json', 'w') as f:
                json.dump(result, f)
        return result
コード例 #12
0
def crnnSource():
    alphabet = keys.alphabet
    converter = util.strLabelConverter(alphabet)
    if torch.cuda.is_available() and GPU:
        model = crnn.CRNN(32, 1, len(alphabet) + 1, 256, 1).cuda()
    else:
        model = crnn.CRNN(32, 1, len(alphabet) + 1, 256, 1).cpu()

    state_dict = torch.load(ocrModel)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace('module.', '')  # remove `module.`
        new_state_dict[name] = v
    # load params

    model.load_state_dict(new_state_dict)
    model.eval()

    return model, converter
コード例 #13
0
def crnn_single(img):
    alphabet = keys_crnn.alphabet
    # print(len(alphabet))
    # input('\ninput:')
    converter = util.strLabelConverter(alphabet)
    # model = crnn.CRNN(32, 1, len(alphabet) + 1, 256, 1).cuda()
    model = crnn.CRNN(32, 1, len(alphabet) + 1, 256, 1)
    path = './crnn/samples/model_acc97.pth'
    model.load_state_dict(torch.load(path))
    # print(model)

    img = Image.fromarray(np.array(img))
    image = img.convert('L')
    # print(image.size)
    scale = image.size[1] * 1.0 / 32
    w = image.size[0] / scale
    w = int(w)
    # print("width:" + str(w))

    transformer = dataset.resizeNormalize((w, 32))
    # image = transformer(image).cuda()
    image = transformer(image)
    image = image.view(1, *image.size())
    image = Variable(image)

    model.eval()
    preds = model(image)
    # print(preds.shape)
    _, preds = preds.max(2)
    # print(preds.shape)

    # preds = preds.squeeze(2)
    # preds = preds.transpose(1, 0).contiguous().view(-1)
    preds = preds.squeeze(1)
    preds = preds.transpose(-1, 0).contiguous().view(-1)

    preds_size = Variable(torch.IntTensor([preds.size(0)]))
    raw_pred = converter.decode(preds.data, preds_size.data, raw=True)
    sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
    sim_pred = sim_pred.lower()
    # print('%-20s => %-20s' % (raw_pred, sim_pred))
    return deletedot(sim_pred)
コード例 #14
0
ファイル: network_keras.py プロジェクト: miketes/OCR_Invoice
    def predict_batch(self,boxes,batch_size=1):
        """
        predict on batch
        """

        N = len(boxes)
        res = []
        imgW = 0
        batch = N//batch_size
        if batch*batch_size!=N:
            batch+=1
        for i in range(batch):
            tmpBoxes = boxes[i*batch_size:(i+1)*batch_size]
            imageBatch =[]
            imgW = 0
            for box in tmpBoxes:
                img = box['img']
                image = resizeNormalize(img,32)
                h,w = image.shape[:2]
                imgW = max(imgW,w)
                imageBatch.append(np.array([image]))
                
            imageArray = np.zeros((len(imageBatch),1,32,imgW),dtype=np.float32)
            n = len(imageArray)
            for j in range(n):
                _,h,w = imageBatch[j].shape
                imageArray[j][:,:,:w] = imageBatch[j]
            
            global graph
            with graph.as_default():    
               preds       = self.model.predict(imageArray,batch_size=batch_size)
               
            preds = preds.argmax(axis=2)
            n = preds.shape[0]
            for j in range(n):
                res.append(strLabelConverter(preds[j,].tolist(),self.alphabet))

              
        for i in range(N):
            boxes[i]['text'] = res[i]
        return boxes
コード例 #15
0
    def __init__(self, crnn_model):
        self.crnn_model = crnn_model

        # 网络常数的设置
        self.batchSize = 2
        workers = 1
        imgH = 32
        imgW = 280
        keep_ratio = True
        self.nepochs = 10
        self.acc = 0
        lr = 0.1

        self.image = torch.FloatTensor(self.batchSize, 3, imgH, imgH)
        self.text = torch.IntTensor(self.batchSize * 5)
        self.length = torch.IntTensor(self.batchSize)
        self.converter = strLabelConverter(''.join(alphabetChinese))
        self.optimizer = optim.Adadelta(crnn_model.parameters(), lr=lr)

        roots = glob('../data/ocr/*/*.jpg')
        # 此处未考虑字符平衡划分
        trainP, testP = train_test_split(roots, test_size=0.1)
        traindataset = PathDataset(trainP, alphabetChinese)
        self.testdataset = PathDataset(testP, alphabetChinese)
        self.criterion = CTCLoss()

        self.train_loader = torch.utils.data.DataLoader(
            traindataset,
            batch_size=self.batchSize,
            shuffle=False,
            sampler=None,
            num_workers=int(workers),
            collate_fn=alignCollate(imgH=imgH,
                                    imgW=imgW,
                                    keep_ratio=keep_ratio))
        self.interval = len(self.train_loader) // 2  ##评估模型
コード例 #16
0
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
from crnn import keys
from crnn import util
from crnn import dataset
from crnn.models import crnn as crnn
import torch
import torch.utils.data
from collections import OrderedDict
from PIL import Image
from torch.autograd import Variable

alphabet = keys.alphabetChinese
LSTMFLAG = False

converter = util.strLabelConverter(alphabet)
model = crnn.CRNN(32, 1, len(alphabet) + 1, 256, 1, lstmFlag=LSTMFLAG).cpu()
ocrModel = './models/epoch9_step7000_model_dense.pth'
# ocrModel = './models/ocr-dense.pth'
state_dict = torch.load(ocrModel, map_location=lambda storage, loc: storage)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k.replace('module.', '')  # remove `module.`
    new_state_dict[name] = v
# load params

model.load_state_dict(new_state_dict)
model.eval()
imgpath = 'ss_350.png'
image = Image.open(imgpath).convert('L')
scale = image.size[1] * 1.0 / 32
コード例 #17
0
model.apply(weights_init)
preWeightDict = torch.load(
    ocrModel, map_location=lambda storage, loc: storage)  ##加入项目训练的权重

modelWeightDict = model.state_dict()

for k, v in preWeightDict.items():
    name = k.replace('module.', '')  # remove `module.`
    if 'rnn.1.embedding' not in name:  ##不加载最后一层权重
        modelWeightDict[name] = v

model.load_state_dict(modelWeightDict)

lr = 0.1
optimizer = optim.Adadelta(model.parameters(), lr=lr)
converter = strLabelConverter(''.join(alphabetEnglish))
criterion = CTCLoss()

image = torch.FloatTensor(batchSize, 3, imgH, imgH)
text = torch.IntTensor(batchSize * 5)
length = torch.IntTensor(batchSize)

if torch.cuda.is_available():
    model.cuda()
    model = torch.nn.DataParallel(model, device_ids=[0])  ##转换为多GPU训练模型
    image = image.cuda()
    criterion = criterion.cuda()


def trainBatch(net, criterion, optimizer, cpu_images, cpu_texts):
    # data = train_iter.next()
コード例 #18
0
modelWeightDict = model.state_dict()

for k, v in preWeightDict.items():
    name = k.replace('module.', '')  # remove `module.`
    if 'rnn.1.embedding' not in name:  ##不加载最后一层权重
        modelWeightDict[name] = v

model.load_state_dict(modelWeightDict)



##优化器
from crnn.util import strLabelConverter
lr = 0.1
optimizer = optim.Adadelta(model.parameters(), lr=lr)
converter = strLabelConverter(''.join(alphabetChinese))
criterion = CTCLoss()


from train.ocr.dataset import resizeNormalize
from crnn.util import loadData
image = torch.FloatTensor(batchSize, 3, imgH, imgH)
text = torch.IntTensor(batchSize * 5)
length = torch.IntTensor(batchSize)

if torch.cuda.is_available():
    model.cuda()
    model = torch.nn.DataParallel(model, device_ids=[0])##转换为多GPU训练模型
    image = image.cuda()
    criterion = criterion.cuda()
コード例 #19
0
CRNN_API_URL = "http://text_recognition:8501/v1/models/crnn:predict"

# ---------------Alphabet---------------

alphabet = alphabetChinese
nclass = len(alphabet) + 1

# ---------------Process image---------------

image = cv2.imread(IMAGE_PATH)
image = Image.fromarray(image)
image = image.convert('L')
image = resizeNormalize(image, 32)
image = image.astype(np.float32)
image = np.array([image])

# ---------------Build post---------------

post_json = {"instances": [{"input_image": image.tolist()}]}

# ---------------Test---------------

t0 = time.time()
response = requests.post(CRNN_API_URL, data=json.dumps(post_json))
print("forward time : {}".format(time.time() - t0))
response.raise_for_status()
prediction = response.json()["predictions"]
print(prediction)
raw = strLabelConverter(prediction[0], alphabet)
print(raw)