コード例 #1
0
ファイル: ocr.py プロジェクト: hedinang/faster-torch
    def predict(self, model, vocab, seq, key, idx, img):

        img = process_input(img, self.config['dataset']['image_height'],
                            self.config['dataset']['image_min_width'],
                            self.config['dataset']['image_max_width'])
        img = img.to(self.config['device'])
        with torch.no_grad():
            src = model.cnn(img)
            memory = model.transformer.forward_encoder(src)
            translated_sentence = [[1] * len(img)]
            max_length = 0
            while max_length <= 128 and not all(
                    np.any(np.asarray(translated_sentence).T == 2, axis=1)):
                tgt_inp = torch.LongTensor(translated_sentence).to(self.device)
                output = model.transformer.forward_decoder(tgt_inp, memory)
                output = output.to('cpu')
                values, indices = torch.topk(output, 5)
                indices = indices[:, -1, 0]
                indices = indices.tolist()
                translated_sentence.append(indices)
                max_length += 1
                del output
            translated_sentence = np.asarray(translated_sentence).T
        s = translated_sentence[0].tolist()
        s = vocab.decode(s)
        seq[idx] = s
コード例 #2
0
ファイル: predictor.py プロジェクト: PhatDatPQ/vietocr
    def predict(self, img):
        img = process_input(img)
        img = img.to(self.config['device'])

        s = translate(img, self.model)[0].tolist()
        s = self.vocab.decode(s)

        return s
コード例 #3
0
    def predict(self, img):
        img = process_input(img, self.config['dataset']['image_height'],
                            self.config['dataset']['image_min_width'],
                            self.config['dataset']['image_max_width'])

        img = img.to(self.config['device'])
        #
        s, _ = translate(img, self.model)
        s = s[0].tolist()
        s = self.vocab.decode(s)
        return s
コード例 #4
0
ファイル: predictor.py プロジェクト: tienthienhd/vietocr
    def predict(self, img):
        img = process_input(img, self.config['dataset']['image_height'],
                            self.config['dataset']['image_min_width'],
                            self.config['dataset']['image_max_width'])
        img = img.to(self.config['device'])

        if self.config['predictor']['beamsearch']:
            sent = translate_beam_search(img, self.model)
            s = sent
        else:
            s = translate(img, self.model)[0].tolist()

        s = self.vocab.decode(s)

        return s
コード例 #5
0
    def predict(self, img, return_prob=False):
        img = process_input(img, self.config['dataset']['image_height'],
                            self.config['dataset']['image_min_width'],
                            self.config['dataset']['image_max_width'])
        img = img.to(self.config['device'])

        if self.config['predictor']['beamsearch']:
            sent = translate_beam_search(img, self.model)
            s = sent
            prob = None
        else:
            s, prob = translate(img, self.model)
            s = s[0].tolist()
            prob = prob[0]

        s = self.vocab.decode(s)

        if return_prob:
            return s, prob
        else:
            return s
コード例 #6
0
    def process(self, craft, model, seq, key, sub_img):
        img_resized, target_ratio, size_heatmap = resize_aspect_ratio(
            sub_img, 2560, interpolation=cv2.INTER_LINEAR, mag_ratio=1.)
        ratio_h = ratio_w = 1 / target_ratio

        x = normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = x.unsqueeze(0)  # [c, h, w] to [b, c, h, w]
        x = x.to(self.device)
        y, feature = craft(x)
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()
        boxes, polys = getDetBoxes(score_text,
                                   score_link,
                                   text_threshold=0.7,
                                   link_threshold=0.4,
                                   low_text=0.4,
                                   poly=False)
        boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None:
                polys[k] = boxes[k]
        result = []
        for i, box in enumerate(polys):
            poly = np.array(box).astype(np.int32).reshape((-1))
            result.append(poly)
        horizontal_list, free_list = group_text_box(result,
                                                    slope_ths=0.8,
                                                    ycenter_ths=0.5,
                                                    height_ths=1,
                                                    width_ths=1,
                                                    add_margin=0.1)
        # horizontal_list = [i for i in horizontal_list if i[0] > 0 and i[1] > 0]
        min_size = 20
        if min_size:
            horizontal_list = [
                i for i in horizontal_list
                if max(i[1] - i[0], i[3] - i[2]) > 10
            ]
            free_list = [
                i for i in free_list
                if max(diff([c[0] for c in i]), diff([c[1]
                                                      for c in i])) > min_size
            ]
        seq[:] = [None] * len(horizontal_list)

        for i, ele in enumerate(horizontal_list):
            ele = [0 if i < 0 else i for i in ele]
            img = sub_img[ele[2]:ele[3], ele[0]:ele[1], :]
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = Image.fromarray(img.astype(np.uint8))
            img = process_input(img, self.config['dataset']['image_height'],
                                self.config['dataset']['image_min_width'],
                                self.config['dataset']['image_max_width'])
            img = img.to(self.config['device'])
            with torch.no_grad():
                src = model.cnn(img)
                memory = model.transformer.forward_encoder(src)
                translated_sentence = [[1] * len(img)]
                max_length = 0
                while max_length <= 128 and not all(
                        np.any(np.asarray(translated_sentence).T == 2,
                               axis=1)):
                    tgt_inp = torch.LongTensor(translated_sentence).to(
                        self.device)
                    output = model.transformer.forward_decoder(tgt_inp, memory)
                    output = output.to('cpu')
                    values, indices = torch.topk(output, 5)
                    indices = indices[:, -1, 0]
                    indices = indices.tolist()
                    translated_sentence.append(indices)
                    max_length += 1
                    del output
                translated_sentence = np.asarray(translated_sentence).T
            s = translated_sentence[0].tolist()
            s = self.vocab.decode(s)
            seq[idx] = s
コード例 #7
0
config['predictor']['beamsearch'] = False
model, vocab = build_model(config)
weights = 'transformerocr.pth'
device = torch.device('cpu')
# if config['weights'].startswith('http'):
#     weights = download_weights(config['weights'])
# else:
#     weights = config['weights']
model.load_state_dict(torch.load(weights, map_location=torch.device('cpu')))
sub_img = cv2.imread('5.png')
# cv2.imshow('aa',sub_img)
# cv2.waitKey(0)
img = cv2.cvtColor(sub_img, cv2.COLOR_BGR2RGB)
img = Image.fromarray(img.astype(np.uint8))
img = process_input(img, config['dataset']['image_height'],
                    config['dataset']['image_min_width'],
                    config['dataset']['image_max_width'])
img = img.to(config['device'])
with torch.no_grad():
    src = model.cnn(img)
    memory = model.transformer.forward_encoder(src)
    translated_sentence = [[1] * len(img)]
    max_length = 0
    while max_length <= 128 and not all(
            np.any(np.asarray(translated_sentence).T == 2, axis=1)):
        tgt_inp = torch.LongTensor(translated_sentence).to(device)
        output = model.transformer.forward_decoder(tgt_inp, memory)
        output = output.to('cpu')
        values, indices = torch.topk(output, 5)
        indices = indices[:, -1, 0]
        indices = indices.tolist()