Beispiel #1
0
class TextExtractor:
    def __init__(self, image_path, seg_mode=PSM.SPARSE_TEXT):
        self.api = PyTessBaseAPI()
        self.api.SetPageSegMode(seg_mode)
        self.api.SetImageFile(image_path)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def _extract(self) -> Tuple:
        text = self.api.GetUTF8Text()
        conf = self.api.MeanTextConf()
        return text, conf

    def _extract_from_rect(self, x, y, w, h) -> Tuple:
        self.api.SetRectangle(x, y, w, h)
        return self._extract()

    #TODO: Add support of zero values
    def extract(self, x=None, y=None, w=None, h=None) -> Tuple:
        if all([x, y, w, h]):
            return self._extract_from_rect(x, y, w, h)
        else:
            return self._extract()

    def close(self):
        self.api.End()
Beispiel #2
0
def add_ocrinfo(tree, imgfile):
    imgpil = Image.open(imgfile)
    (orig_width, orig_height) = (imgpil.width, imgpil.height)

    #root_width = tree[min(tree)]['width']
    ratio = 1.0 * orig_width / config.width
    #imgpil = imgpil.convert("RGB").resize(
    #    (orig_width * OCR_RATIO, orig_height * OCR_RATIO))

    tesapi = PyTessBaseAPI(lang='eng')
    tesapi.SetImage(imgpil)
    tesapi.SetSourceResolution(config.ocr_resolution)

    for nodeid in tree:
        node = tree[nodeid]

        if node['children'] and node['text'] == '':
            node['ocr'] = ''
            continue

        x = max(node['x'] * ratio - 1, 0)
        y = max(node['y'] * ratio - 1, 0)
        x2 = min((node['x'] + node['width']) * ratio + 1, orig_width)
        y2 = min((node['y'] + node['height']) * ratio + 1, orig_height)
        width = int(x2 - x)
        height = int(y2 - y)

        if width > 3 and height > 3:
            #tesapi.SetRectangle(int(x * OCR_RATIO), int(y * OCR_RATIO),
            #                    int(width * OCR_RATIO), int(height * OCR_RATIO))
            #print(int(x), int(y), int(width), int(height), orig_width, orig_height)
            tesapi.SetRectangle(int(x), int(y), int(width), int(height))
            ocr_text = tesapi.GetUTF8Text().strip().replace('\n', ' ')
            if ocr_text.strip() == '':
                x = min(x + width * 0.05, orig_width)
                y = min(y + height * 0.05, orig_height)
                width *= 0.9
                height *= 0.9
                tesapi.SetRectangle(int(x), int(y), int(width), int(height))
                ocr_text = tesapi.GetUTF8Text().strip().replace('\n', ' ')

        else:
            ocr_text = ''

        node['ocr'] = ocr_text
Beispiel #3
0
def preprocess_title(filename):
    title = ''
    api = PyTessBaseAPI()
    api.SetImageFile(filename)
    boxes = api.GetComponentImages(RIL.TEXTLINE, True)
    for i, (im, box, _, _) in enumerate(boxes):
        api.SetRectangle(box['x'], box['y'], box['w'], box['h'])
        ocrResult = api.GetUTF8Text()
        text = ' '.join(alpha_re.findall(ocrResult.strip()))
        if len(text) < 5:
            continue

        title = text
        break

    if title:
        logger.info("%s: %s", filename, title)
    return title
class TesseractOCR:

    #private static   TESSERACT_ENGINE_MODE = TessAPI1.TessOcrEngineMode.OEM_DEFAULT

    #
    # bpp - bits per pixel, represents the bit depth of the image, with 1 for
    # binary bitmap, 8 for gray, and 24 for color RGB.
    #
    BBP = 8
    DEFAULT_CONFIDENT_THRESHOLD = 60.0
    MINIMUM_DESKEW_THRESHOLD = 0.05

    def __init__(self, rgbaImage, dipCalculator, language):
        self.mRgbaImage = rgbaImage
        self.mDipCalculator = dipCalculator
        self.mHandle = PyTessBaseAPI()

        self.mOcrTextWrappers = []
        self.mOcrBlockWrappers = []
        self.mOcrLineWrappers = []
        self.raWrappers = []
        #         self.mLanguage = language

        self.mBufferedImageRgbaImage = Image.fromarray(self.mRgbaImage)
        self.initOCR()

    def baseInit(self, iteratorLevel):
        width = 0
        height = 0
        channels = 1

        if len(self.mRgbaImage.shape) == 2:
            height, width = self.mRgbaImage.shape
        else:
            height, width, channels = self.mRgbaImage.shape

        return self.baseInitIter(self.mRgbaImage, Rect(0, 0, width, height),
                                 channels, iteratorLevel)

    def baseInitIter(self, imageMat, rect, channels, iteratorLevel):
        listdata = []
        parentX = rect.x
        parentY = rect.y
        #        subMat = imageMat[rect.y:rect.y+rect.height, rect.x:rect.width+rect.x]
        #
        #        if(channels != 1):
        #            subMat = imageMat[rect.y:rect.y+rect.height, rect.x:rect.width+rect.x, 0:channels]

        #tessAPI = PyTessBaseAPI()
        #Convert to PIL image
        imgPIL = Image.fromarray(imageMat)
        self.mHandle.SetImage(imgPIL)
        boxes = self.mHandle.GetComponentImages(iteratorLevel, True)

        for i, (im, box, _, _) in enumerate(boxes):

            wrapper = OCRTextWrapper.OCRTextWrapper()
            self.mHandle.SetRectangle(box['x'], box['y'], box['w'], box['h'])
            ocrResult = self.mHandle.GetUTF8Text()
            wrapper.text = ocrResult
            conf = self.mHandle.MeanTextConf()
            wrapper.confidence = conf
            self.mHandle.Recognize()
            iterator = self.mHandle.GetIterator()
            fontAttribute = iterator.WordFontAttributes()
            wrapper.x = box['x'] + parentX
            wrapper.y = box['y'] + parentY
            wrapper.width = box['w']
            wrapper.height = box['h']
            wrapper.rect = Rect(wrapper.x, wrapper.y, wrapper.width,
                                wrapper.height)
            #            print(box)
            #
            if (fontAttribute != None):
                wrapper.fontName = fontAttribute['font_name']
                wrapper.bold = fontAttribute['bold']
                wrapper.italic = fontAttribute['italic']
                wrapper.underlined = fontAttribute['underlined']
                wrapper.monospace = fontAttribute['monospace']
                wrapper.serif = fontAttribute['serif']
                wrapper.smallcaps = fontAttribute['smallcaps']
                wrapper.fontSize = fontAttribute['pointsize']
                wrapper.fontId = fontAttribute['font_id']

            listdata.append(wrapper)

        return listdata

    def getBlockWithLocation(self, rect):
        wrappers = []
        for ocrTextWrapper in self.mOcrBlockWrappers:
            bound = ocrTextWrapper.rect
            if (RectUtil.contains(rect, bound)):
                wrappers.append(OCRTextWrapper.OCRTextWrapper(ocrTextWrapper))

        return wrappers

    def getImage(self, rect):
        x2 = rect.x + rect.width
        y2 = rect.y + rect.height
        mat = self.mRgbaImage[rect.y:y2, rect.x:x2]
        return Image.fromarray(mat)

    def getText(self, rect):
        try:
            self.mHandle.SetImage(self.mBufferedImageRgbaImage)
            self.mHandle.SetRectangle(rect.x, rect.y, rect.width, rect.height)
            text = self.mHandle.GetUTF8Text()
            return text
        except Exception as error:
            print('Caught this error: ' + repr(error))

        return ""

    def getLineText(self, rect):
        try:
            self.mHandle.SetImage(self.mBufferedImageRgbaImage)
            self.mHandle.SetRectangle(rect.x, rect.y, rect.width, rect.height)
            text = self.mHandle.GetUTF8Text()
            if (TextUtils.isEmpty(text)):
                self.mHandle = PyTessBaseAPI(psm=PSM.SINGLE_LINE)
                self.mHandle.SetImage(self.mBufferedImageRgbaImage)
                self.mHandle.SetRectangle(rect.x, rect.y, rect.width,
                                          rect.height)
                text = self.mHandle.GetUTF8Text()
                if (TextUtils.isEmpty(text)):
                    self.mHandle.SetImage(self.getImage(rect))
                    text = self.mHandle.GetUTF8Text()

                self.mHandle = PyTessBaseAPI(psm=PSM.AUTO)
            return text
        except Exception as error:
            print('Caught this error: ' + repr(error))

        return ""

    def getRectWordForLowConfidence(self, ocr):
        try:
            rect = ocr.bound()
            self.mHandle = PyTessBaseAPI(psm=PSM.SINGLE_WORD)
            self.mHandle.SetImage(self.mBufferedImageRgbaImage)
            self.mHandle.SetRectangle(rect.x, rect.y, rect.width, rect.height)
            ocr.text = self.mHandle.GetUTF8Text()
            ocr.confidence = self.mHandle.MeanTextConf()
            if (ocr.confidence <= Constants.TEXT_CONFIDENT_THRESHOLD):
                self.mHandle.SetImage(self.getImage(rect))
                ocr.text = self.mHandle.GetUTF8Text()
                ocr.confidence = self.mHandle.MeanTextConf()
            if (ocr.confidence <= Constants.TEXT_CONFIDENT_THRESHOLD):
                return False
            self.mHandle.Recognize()
            iterator = self.mHandle.GetIterator()
            fontAttribute = iterator.WordFontAttributes()
            if (fontAttribute != None):
                ocr.fontName = fontAttribute['font_name']
                ocr.bold = fontAttribute['bold']
                ocr.italic = fontAttribute['italic']
                ocr.underlined = fontAttribute['underlined']
                ocr.monospace = fontAttribute['monospace']
                ocr.serif = fontAttribute['serif']
                ocr.smallcaps = fontAttribute['smallcaps']
                ocr.fontSize = fontAttribute['pointsize']
                ocr.fontId = fontAttribute['font_id']
#                ocr.fontsize = self.getPreferenceFontSize(ocr)

            self.mHandle = PyTessBaseAPI(psm=PSM.AUTO)
            return True
        except Exception as error:
            print('Caught this error: ' + repr(error))

        return False

    def getWordsIn(self, rect):
        wrappers = []
        for ocrTextWrapper in self.mOcrTextWrappers:
            bound = ocrTextWrapper.bound()
            if (RectUtil.contains(rect, bound)):
                wrappers.append(OCRTextWrapper.OCRTextWrapper(ocrTextWrapper))

        return wrappers

    def initOCR(self):

        #
        self.initText()

        self.initBlock()
        #        self.initPara()
        self.initLine()
#

    def initBlock(self):
        self.mOcrBlockWrappers = self.baseInit(RIL.BLOCK)

    def initLine(self):
        self.mOcrLineWrappers = self.baseInit(RIL.TEXTLINE)
        invalidLineWrappers = []
        # a line cannot contain another lines
        for ocrLine in self.mOcrLineWrappers:
            for otherOcrLine in self.mOcrLineWrappers:
                if (ocrLine != otherOcrLine and RectUtil.contains(
                        ocrLine.bound(), otherOcrLine.bound())):
                    invalidLineWrappers.append(ocrLine)
        self.mOcrLineWrappers = [
            x for x in self.mOcrLineWrappers if x not in invalidLineWrappers
        ]

    def initPara(self):
        self.mOcrParaWrappers = self.baseInit(RIL.PARA)

    def initText(self):
        self.mOcrTextWrappers = self.baseInit(RIL.WORD)

    def isOverlapText(self, rect, confident):
        for ocrTextWrapper in self.mOcrTextWrappers:
            bound = ocrTextWrapper.bound()
            if (ocrTextWrapper.getConfidence() >= confident
                    and RectUtil.intersects(rect, bound)):
                return True
        return False

    def reset(self):
        self.mOcrTextWrappers = []
        self.mOcrLineWrappers = []
        self.initOCR()

#    def rotateImage(bi) :
#        iden = ImageDeskew(bi)
#        imageSkewAngle = iden.getSkewAngle() # determine skew angle
#        if imageSkewAngle > MINIMUM_DESKEW_THRESHOLD or imageSkewAngle < -MINIMUM_DESKEW_THRESHOLD :
#            bi = ImageHelper.rotateImage(bi, -imageSkewAngle) # deskew
#        return bi

    def getPreferenceFontSize(self, ocrTextWrapper, parentHeight):

        #        TODO TODO
        fontName = ocrTextWrapper.fontName
        fontSize = ocrTextWrapper.fontSize

        height = ocrTextWrapper.height * Constants.TEXT_BOX_AND_TEXT_HEIGHT_RATIO

        #        height = ocrTextWrapper.height
        textHeight = int(
            self.mDipCalculator.pxToHeightDip(min(parentHeight, height)))
        #        font = QFont(fontName, fontSize)
        newFontSize = fontSize
        if (self.getTextHeightUsingFontMetrics(ocrTextWrapper, fontName,
                                               fontSize) == textHeight):
            newFontSize = fontSize

        elif (self.getTextHeightUsingFontMetrics(ocrTextWrapper, fontName,
                                                 fontSize) < textHeight):
            while (self.getTextHeightUsingFontMetrics(ocrTextWrapper, fontName,
                                                      fontSize) < textHeight):
                fontSize = fontSize + 1
            newFontSize = fontSize

        else:
            while (self.getTextHeightUsingFontMetrics(ocrTextWrapper, fontName,
                                                      fontSize) > textHeight):
                fontSize = fontSize - 1

            newFontSize = fontSize

        return newFontSize

    def getTextHeightUsingFontMetrics(self, ocrTextWrapper, fontName,
                                      fontSize):
        #        class SIZE(ctypes.Structure):
        #            _fields_ = [("cx", ctypes.c_long), ("cy", ctypes.c_long)]
        #        hdc = ctypes.windll.user32.GetDC(0)
        #        hfont = ctypes.windll.gdi32.CreateFontA(-fontSize, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, fontName)
        #        hfont_old = ctypes.windll.gdi32.SelectObject(hdc, hfont)
        #        size = SIZE(0, 0)
        #        ctypes.windll.gdi32.GetTextExtentPoint32A(hdc, text, len(text), ctypes.byref(size))
        #        ctypes.windll.gdi32.SelectObject(hdc, hfont_old)
        #        ctypes.windll.gdi32.DeleteObject(hfont)
        #        return size.cy
        file = "fonts//" + fontName + ".ttf"

        font = ImageFont.truetype(file, fontSize)
        fontSize = font.getsize(ocrTextWrapper.text)
        return fontSize[1]

    def validCharacter(self, word):
        return self.mHandle.IsValidCharacter(word)

        #Don't have this method return TessAPI1.TessBaseAPIIsValidWord(mHandle, word) != 0
#        return True

#TODO
#    def getTextHeightUsingTextLayout(self,ocrTextWrapper, font) :
#        frc = self.mGraphics.getFontRenderContext()
#        loc = Point(0, 0)
#        layout = TextLayout(ocrTextWrapper.text, font, frc)
#        layout.draw(self.mGraphics, float(loc.x, loc.y))
#        bounds = layout.getBounds()
#        height = bounds.getHeight()
#        return height

#    def isValidTextUsingConfidentAndBoundaryCheck(self, ocrTextWrapper) :
#        if (ocrTextWrapper.getConfidence() > Constants.TEXT_CONFIDENT_THRESHOLD + Constants.TEXT_CONFIDENT_THRESHOLD_SECONDARY_RANGE) :
#            return True
#
#        elif (ocrTextWrapper.getConfidence() <= Constants.TEXT_CONFIDENT_THRESHOLD) :
#            return False
#
#        return self.isValidTextUsingBoundaryCheck(ocrTextWrapper)
#
#

    def getTextDimensions(self, text, fontName, fontSize):
        file = "fonts//" + fontName + ".ttf"
        try:
            font = ImageFont.truetype(file, fontSize)
            fontSize = font.getsize(text)
            return fontSize
        except OSError:
            print(file)

    def isValidTextUsingBoundaryCheck(self, ocrTextWrapper):
        # confident between TextProcessor.TEXT_CONFIDENT_THRESHOLD and
        # TextProcessor.TEXT_CONFIDENT_THRESHOLD +
        # TextProcessor.TEXT_CONFIDENT_THRESHOLD_SECONDARY_RANGE
        if (TextUtils.isEmpty(ocrTextWrapper.text)):
            # We cannot calculate width of empty text
            return True
#        return True

#        frc = mGraphics.getFontRenderContext()
#        font = QFont(ocrTextWrapper.fontName,ocrTextWrapper.fontSize)
#        loc = Point(0, 0)
#        layout = TextLayout(ocrTextWrapper.text,font, frc)
#        layout.draw(mGraphics,  loc.getX(), loc.getY())
#        bound = layout.getBounds()
        width, height = self.getTextDimensions(ocrTextWrapper.text,
                                               ocrTextWrapper.fontName,
                                               ocrTextWrapper.fontSize)

        fontRatio = float(height / width)
        boundRatio = float(ocrTextWrapper.height / ocrTextWrapper.width)
        fontArea = self.mDipCalculator.dipToHeightPx(
            height) * self.mDipCalculator.dipToWidthPx(width)
        boundArea = float(ocrTextWrapper.width * ocrTextWrapper.height)
        #
        # the different between dimensions of the text should be smaller than
        # 10% of the max dimension.
        # System.out.prln(" Ratio: " + fontRatio + ", " + boundRatio + ", "
        # + Math.abs(boundRatio - fontRatio)
        # / Math.max(boundRatio, fontRatio) + "," + fontArea + ", "
        # + boundArea + ", " + Math.min(fontArea, boundArea)
        # / Math.max(fontArea, boundArea))

        # It the bound is square, it less likely that this text is correct
        # TODO: This rule may not need it
        #        if (float(min(ocrTextWrapper.getWidth(),ocrTextWrapper.getHeight()) / max( ocrTextWrapper.getWidth(),
        #						ocrTextWrapper.getHeight())) > 0.95) :
        #			# if drawing text cannot create square, sorry -> invalid
        #            if (float(min(width, height) / max(width, height)) <= 0.95 and not validWord(ocrTextWrapper.text)) :
        #                return False
        #
        #
        #

        #        print(self.mDipCalculator.dipToWidthPx(width), self.mDipCalculator.dipToHeightPx(height))
        #        print( ocrTextWrapper.width, ocrTextWrapper.height)
        dimension = abs(boundRatio - fontRatio) / max(boundRatio, fontRatio)
        #        print(dimension)

        dimensionCheck = abs(boundRatio - fontRatio) / max(
            boundRatio, fontRatio
        ) <= Constants.TEXT_CONFIDENT_ACCEPTANCE_DIMENSION_RATIO_DIFFERENCE_THRESHOLD

        areaCheckVal = min(fontArea, boundArea) / max(fontArea, boundArea)
        #        print(areaCheckVal)
        #        print(ocrTextWrapper.text)
        areaCheck = min(fontArea, boundArea) / max(
            fontArea,
            boundArea) >= Constants.TEXT_AREA_ACCEPTANCE_DIFFERENCE_THRESHOLD

        return dimensionCheck and areaCheck

    def destroy(self):
        self.mHandle.End
Beispiel #5
0
class ViewerWindow(Gtk.Window):
    def __init__(self, filenames, kind, show, ml):
        Gtk.Window.__init__(self)
        self.ptx = 0
        self.pty = 0
        self.focus_id = -1
        self.file_idx = 0
        self.kind = kind
        self.show_hidden = show
        self.ml = ml
        self.screen_hint = ''
        self.in_hint_screen = False
        self.colors = {}
        self.memory = {}
        self.elem_models = {}
        self.filenames = filenames
        self.tesapi = PyTessBaseAPI(lang='eng')
        self.tesapi.SetVariable("tessedit_char_whitelist", WHITELIST)
        self.init_ui()
        self.load()

    def init_ui(self):
        self.connect("delete-event", Gtk.main_quit)

        darea = Gtk.DrawingArea()
        darea.connect("draw", self.on_draw)
        darea.connect("motion-notify-event", self.move_over)
        darea.connect("button-release-event", self.click_evt)
        darea.connect("scroll-event", self.scroll_evt)
        darea.connect("key-release-event", self.key_evt)
        darea.set_events(Gdk.EventMask.POINTER_MOTION_MASK
                         | Gdk.EventMask.BUTTON_RELEASE_MASK
                         | Gdk.EventMask.BUTTON_PRESS_MASK
                         | Gdk.EventMask.SCROLL_MASK
                         | Gdk.EventMask.KEY_PRESS_MASK
                         | Gdk.EventMask.KEY_RELEASE_MASK)
        darea.set_can_focus(True)
        self.add(darea)

        self.show_all()

    def load(self, prev=False):
        if self.file_idx == len(self.filenames):
            Gtk.main_quit()
            return
        if prev:
            self.file_idx -= 2
        filename = self.filenames[self.file_idx]
        (self.app, self.scr) = util.get_aux_info(filename)
        if self.app not in self.memory:
            self.memory[self.app] = {}
        self.set_title(filename)
        self.file_idx += 1
        print("Loading %s" % filename)
        self.pngfile = os.path.splitext(filename)[0] + '.png'
        self.descname = os.path.splitext(filename)[0] + '.%s.txt' % self.kind

        starttime = time.time()
        self.tree = analyze.load_tree(filename)
        hidden.find_hidden_ocr(self.tree)
        hidden.mark_children_hidden_ocr(self.tree)
        util.print_tree(self.tree, show_hidden=self.show_hidden)

        if self.ml:
            self.get_ml_rets()
        else:
            self.load_desc()

        endtime = time.time()
        print("Load time: %.3fs" % (endtime - starttime))

        self.focus_id = -1
        self.colors = {}
        self.ptx = self.pty = 0

        self.img = cairo.ImageSurface.create_from_png(self.pngfile)
        print('Image:', self.img.get_width(), self.img.get_height())

        root_item_id = min(self.tree)
        root_node = self.tree[root_item_id]
        print('Root node:', root_node['width'], root_node['height'])
        self.scale = 1.0 * self.img.get_width() / config.width
        #self.scale = analyze.find_closest(self.scale, analyze.SCALE_RATIOS)
        print('Scale:', '%.3f' % self.scale, '->', '%.3f' % self.scale)

        self.resize(self.img.get_width(), self.img.get_height())

        self.mark_depth(self.tree)

        for item_id in self.tree:
            color_r = random.random() / 2
            color_g = random.random() / 2
            color_b = random.random() / 2

            self.colors[item_id] = (color_r, color_g, color_b)

        imgocr = Image.open(self.pngfile)
        self.imgwidth = imgocr.width
        self.imgheight = imgocr.height
        #imgocr2 = imgocr.convert("RGB").resize(
        #    (imgocr.width * OCR_RATIO, imgocr.height * OCR_RATIO))
        self.tesapi.SetImage(imgocr)
        self.tesapi.SetSourceResolution(config.ocr_resolution)

        self.dump_memory()

    def remember(self, node, desc):
        nodeid = node['id']
        if not node['id']:
            return

        if node['id'] in self.memory[self.app]:
            if desc != self.memory[self.app][nodeid]:
                # multiple!
                self.memory[self.app][nodeid] = 'MUL'
        else:
            self.memory[self.app][node['id']] = desc

    def forget(self, node):
        if node['id'] in self.memory[self.app]:
            del self.memory[self.app][node['id']]

    def get_elem_model(self, app):
        elem_clas = elements.getmodel("../model/", "../guis/", app,
                                      "../guis-extra/",
                                      config.extra_element_scrs)
        self.elem_models[app] = elem_clas

    def get_ml_rets(self):
        if self.app not in self.elem_models:
            self.get_elem_model(self.app)

        guess_descs = {}
        guess_items = {}  # type: Dict[str, List[int]]
        guess_score = {}
        elem_clas = self.elem_models[self.app]
        elem_clas.set_imgfile(self.pngfile)
        treeinfo = analyze.collect_treeinfo(self.tree)
        for itemid in self.tree:
            (guess_element,
             score) = elem_clas.classify(self.scr, self.tree, itemid, None,
                                         treeinfo)
            if guess_element != 'NONE':
                if tags.single(guess_element,
                               self.scr) and guess_element in guess_items:
                    old_item = guess_items[guess_element][0]
                    if guess_score[old_item] < score:
                        guess_items[guess_element] = [itemid]
                        guess_score[itemid] = score
                        del guess_descs[old_item]
                        guess_descs[itemid] = guess_element
                else:
                    guess_descs[itemid] = guess_element
                    guess_score[itemid] = score
                    guess_items[guess_element] = (
                        guess_items.get(guess_element, []) + [itemid])
        for nodeid in guess_descs:
            self.tree[nodeid]['label'] = guess_descs[nodeid]

    def load_desc(self):
        if os.path.exists(self.descname):
            with open(self.descname) as inf:
                for line in inf.read().split('\n'):
                    if not line:
                        continue
                    (item_id, desc) = line.split(' ', 1)
                    item_id = int(item_id)
                    found = False
                    for nodeid in self.tree:
                        node = self.tree[nodeid]
                        if item_id in node['raw']:
                            if 'label' in node:
                                node['label'] += ' ' + desc
                            else:
                                node['label'] = desc
                            print(nodeid, '(', item_id, ')', '->', desc)

                            self.remember(node, desc)

                            found = True
                            break
                    if not found:
                        print("WARNING: %s (%s) is missing!" % (item_id, desc))

    def mark_depth(self, tree):
        for item_id in tree:
            node = tree[item_id]
            if 'depth' in node:
                continue
            self.mark_depth_node(tree, item_id, 0)

    def mark_depth_node(self, tree, node_id, depth):
        node = tree[node_id]
        node['depth'] = depth
        node['descs'] = []
        for child in node['children']:
            descs = self.mark_depth_node(tree, child, depth + 1)
            node['descs'] += descs

        return node['descs'] + [node_id]

    def get_node_info(self, node):
        (x, y, width, height, depth) = (node['x'], node['y'], node['width'],
                                        node['height'], node['depth'])
        x *= self.scale
        y *= self.scale
        width *= self.scale
        height *= self.scale

        width = min(width, self.imgwidth)
        height = min(height, self.imgheight)

        if x < 0:
            width += x
            x = 0

        if y < 0:
            height += y
            y = 0

        return (x, y, width, height, depth)

    def find_containing_widget(self, px, py):
        max_depth = 0
        max_id = -1

        for item_id in self.tree:
            node = self.tree[item_id]
            if self.ignore_node(node):
                continue
            if self.inside(node, px, py):
                if node['depth'] > max_depth:
                    max_depth = node['depth']
                    max_id = item_id

        return max_id

    def inside(self, node, px, py):
        (x, y, width, height, depth) = self.get_node_info(node)
        return x <= px and x + width >= px and y <= py and y + height >= py

    def ignore_node(self, node):
        if node['class'].upper() == 'OPTION':
            return True
        if node.get('visible', '') == 'hidden':
            return True
        return False

    def on_draw(self, wid, ctx):
        ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL,
                             cairo.FONT_WEIGHT_BOLD)

        ctx.set_source_surface(self.img, 0, 0)
        ctx.paint()

        ctx.set_font_size(20)
        ctx.set_line_width(5)
        ctx.set_source_rgb(1.0, 0.0, 0.0)

        max_click_id = -1
        max_click_depth = 0

        max_id = self.find_containing_widget(self.ptx, self.pty)

        for item_id in self.tree:
            node = self.tree[item_id]
            depth = node['depth']
            if max_id in node['descs'] and node['click']:
                if depth > max_click_depth:
                    max_click_depth = depth
                    max_click_id = item_id

        for item_id in self.tree:
            node = self.tree[item_id]
            if self.ignore_node(node):
                continue

            if item_id == max_id:
                region_mode = False
            else:
                region_mode = True

            (x, y, width, height, depth) = self.get_node_info(node)

            if not self.inside(node, self.ptx, self.pty):
                continue

            self.show_widget(ctx, item_id, not region_mode, not region_mode)

        if max_click_id != -1 and max_click_id != max_id:
            self.show_widget(ctx, max_click_id, False, True)

        if self.focus_id >= 0:
            self.show_widget(ctx, self.focus_id, True, True, (1, 0, 0))

        for itemid in self.tree:
            node = self.tree[itemid]
            if 'label' in node:
                if itemid == self.focus_id:
                    color = (0, 1, 0)
                else:
                    color = (0, 0, 1)
                self.show_widget(ctx, itemid, True, False, (0, 0, 1))
                self.show_desc(ctx, node, color)

        #s.write_to_png('test.png')
        #os.system("%s %s" % (config.picviewer_path, 'test.png'))
        #report_time(start_time, "displayed")

    def move_sibling(self, to_next):
        leaf_list = []
        any_list = []
        for itemid in self.tree:
            node = self.tree[itemid]
            if not self.inside(node, self.clickx, self.clicky):
                continue

            if len(node['children']) == 0:
                leaf_list.append(itemid)
            any_list.append(itemid)

        for i in range(len(leaf_list)):
            if leaf_list[i] == self.focus_id:
                if to_next:
                    idx = (i + 1) % len(leaf_list)
                else:
                    idx = (i - 1) % len(leaf_list)
                self.focus_id = leaf_list[idx]
                return

        if len(leaf_list) == 0:
            for i in range(len(any_list)):
                if any_list[i] == self.focus_id:
                    if to_next:
                        idx = (i + 1) % len(any_list)
                    else:
                        idx = (i - 1) % len(any_list)
                    self.focus_id = any_list[idx]
                    return
            self.focus_id = any_list[0]
        else:
            self.focus_id = leaf_list[0]

    def show_widget(self, ctx, item_id, fill, show_text, colors=None):
        node = self.tree[item_id]

        (x, y, width, height, depth) = self.get_node_info(node)

        if colors is None:
            color_r = self.colors[item_id][0]
            color_g = self.colors[item_id][1]
            color_b = self.colors[item_id][2]
        else:
            (color_r, color_g, color_b) = colors

        ctx.rectangle(x, y, width, height)
        if fill:
            ctx.set_source_rgba(color_r, color_g, color_b, 0.3)
            ctx.fill()
        else:
            ctx.set_source_rgba(color_r, color_g, color_b, 1)
            ctx.set_line_width(5)
            ctx.stroke()

        if show_text:
            max_char = int(width / ctx.text_extents("a")[2])
            text = str(item_id)
            if node['click']:
                text = 'C' + text
            if node['text']:
                text = text + ':' + node['text'][:(max_char - 5)]
            elif node['id']:
                text += '#' + node['id'][:(max_char - 5)]

            self.show_text(ctx, x + width / 2, y + height / 2, text, color_r,
                           color_g, color_b)

    def show_desc(self, ctx, node, color=(0, 0, 1)):
        desc = node['label']
        (x, y, width, height, depth) = self.get_node_info(node)
        self.show_text(ctx, x + width / 2, y + height / 2, desc, color[0],
                       color[1], color[2])

    def show_text(self, ctx, x, y, text, color_r, color_g, color_b):
        x_bearing, y_bearing, text_width, text_height = ctx.text_extents(
            text)[:4]

        ctx.move_to(x - text_width / 2, y + text_height / 2)
        ctx.set_source_rgba(1, 1, 1, 1)
        ctx.set_line_width(5)
        ctx.text_path(text)
        ctx.stroke()

        ctx.move_to(x - text_width / 2, y + text_height / 2)
        ctx.set_source_rgba(color_r, color_g, color_b, 1)
        ctx.text_path(text)
        ctx.fill()

    def move_over(self, widget, evt):
        self.ptx = evt.x
        self.pty = evt.y
        self.queue_draw()

    def click_evt(self, widget, evt):
        if self.in_hint_screen:
            self.process_screen_hint_click(evt)
            return

        if evt.button == 3:
            self.focus_id = -1
        else:
            self.clickx = evt.x
            self.clicky = evt.y
            self.focus_id = self.find_containing_widget(evt.x, evt.y)

        self.queue_draw()

    def scroll_evt(self, widget, evt):
        if self.focus_id == -1:
            return

        scroll_up = evt.direction == Gdk.ScrollDirection.UP
        if scroll_up:
            self.focus_id = self.find_parent_widget(self.focus_id)
        else:
            self.focus_id = self.find_child_widget(self.focus_id)

        self.queue_draw()

    def find_parent_widget(self, wid):
        for itemid in self.tree:
            node = self.tree[itemid]
            if self.ignore_node(node):
                continue
            if wid in node['children']:
                return itemid
        return wid

    def find_child_widget(self, wid):
        for itemid in self.tree[wid]['children']:
            node = self.tree[itemid]
            if self.ignore_node(node):
                continue
            if self.inside(node, self.clickx, self.clicky):
                return itemid
        return wid

    def mark_direct(self):
        enter = self.get_text('Please enter id_label', 'format: <id> <label>')
        if enter is None:
            return
        if ' ' in enter:
            nodeid, label = enter.split(' ')
        else:
            nodeid = enter
            label = ''
        nodeid = int(nodeid)
        if nodeid not in self.tree:
            print('missing node', nodeid)
            return
        node = self.tree[nodeid]

        self.mark_node(node, label)

    def mark_focused(self):
        if self.focus_id < 0:
            return
        node = self.tree[self.focus_id]
        label = self.get_text(
            'Please enter label', 'label for %s: %s (%s) #%s' %
            (self.focus_id, node['text'], node['desc'], node['id']))
        if label is None:
            return

        if self.ml:
            if label == '':
                if 'label' not in self.tree[self.focus_id]:
                    return

                self.generate_negative_hint(self.tree[self.focus_id]['label'])
                del self.tree[self.focus_id]['label']
            else:
                self.generate_hint_for_widget(self.focus_id, label)
                self.add_label(node, label)
        else:
            self.mark_node(node, label)

    def generate_hint_for_widget(self, nodeid, label):
        return self.generate_hint(label,
                                  locator.get_locator(self.tree, nodeid))

    def generate_negative_hint(self, label):
        return self.generate_hint(label, 'notexist')

    def generate_hint(self, label, hint):
        print("@%s.%s %s" % (self.scr, label, hint))

    def mark_node(self, node, label):
        if label == '':
            if 'label' in node:
                del node['label']
                self.forget(node)
        else:
            self.add_label(node, label)
            self.remember(node, label)

        self.save_labels()

    def ocr_text(self):
        node = self.tree[self.focus_id]
        (x, y, width, height, _) = self.get_node_info(node)
        print(x, y, width, height)
        x = max(x - 1, 0)
        y = max(y - 1, 0)
        width = min(width + 2, self.imgwidth)
        height = min(height + 2, self.imgheight)
        #self.tesapi.SetRectangle(x * OCR_RATIO, y * OCR_RATIO,
        #                         width * OCR_RATIO, height * OCR_RATIO)
        self.tesapi.SetRectangle(x, y, width, height)
        print("OCR ret:", self.tesapi.GetUTF8Text())

        x = min(x + width * 0.05, self.imgwidth)
        y = min(y + height * 0.05, self.imgheight)
        width *= 0.9
        height *= 0.9
        self.tesapi.SetRectangle(x, y, width, height)
        print("OCR ret:", self.tesapi.GetUTF8Text())

    def save_region(self):
        if self.focus_id == -1:
            return
        node = self.tree[self.focus_id]
        (x, y, width, height, _) = self.get_node_info(node)
        x = max(x - 1, 0)
        y = max(y - 1, 0)
        width = min(width + 2, self.imgwidth)
        height = min(height + 2, self.imgheight)

        regimg = cairo.ImageSurface(cairo.FORMAT_RGB24, int(width),
                                    int(height))
        ctx = cairo.Context(regimg)
        ctx.set_source_surface(self.img, -x, -y)
        ctx.paint()

        regimg.write_to_png("/tmp/region.png")

    def dump_memory(self):
        for _id in self.memory[self.app]:
            print('MEM %s -> %s' % (_id, self.memory[self.app][_id]))

    def add_label(self, node, desc):
        print('%s -> %s' % (util.describe_node(node, short=True), desc))
        node['label'] = desc

    def auto_label(self):
        for nodeid in self.tree:
            node = self.tree[nodeid]
            if 'label' not in node and node['id'] in self.memory[self.app]:
                if self.memory[self.app][node['id']] != 'MUL':
                    self.add_label(node, self.memory[self.app][node['id']])
                else:
                    print('skip MUL id: %s' % node['id'])
        self.save_labels()

    def remove_all(self):
        for nodeid in self.tree:
            node = self.tree[nodeid]
            if 'label' in node:
                del node['label']

    def process_screen_hint_click(self, evt):
        click_id = self.find_containing_widget(evt.x, evt.y)
        if click_id == -1:
            print('Invalid widget selected')
            return

        hint = locator.get_locator(self.tree, click_id)
        if hint is None:
            print('Cannot generate hint for this widget')
            return

        hint = str(hint)
        if evt.button == 3:
            # negate
            hint = 'not ' + hint

        print('Widget hint: "%s"' % hint)
        self.add_screen_hint(hint)

    def add_screen_hint(self, hint):
        if self.screen_hint == '':
            self.screen_hint = hint
        else:
            self.screen_hint += ' && ' + hint

    def hint_screen(self):
        if not self.in_hint_screen:
            label = self.get_text('Please enter screen name',
                                  'screen name like "signin"')
            if label is None:
                return
            self.screen_hint_label = label

            self.in_hint_screen = True
            self.screen_hint = ''
        else:
            self.in_hint_screen = False
            print("%%%s %s" % (self.screen_hint_label, self.screen_hint))

    def key_evt(self, widget, evt):
        if evt.keyval == Gdk.KEY_space:
            self.mark_focused()
        elif evt.keyval == Gdk.KEY_Tab:
            self.load()
        elif evt.keyval == Gdk.KEY_Left:
            self.move_sibling(to_next=True)
        elif evt.keyval == Gdk.KEY_Right:
            self.move_sibling(to_next=False)
        elif evt.keyval == Gdk.KEY_v:
            self.ocr_text()
        elif evt.keyval == Gdk.KEY_a:
            self.auto_label()
        elif evt.keyval == Gdk.KEY_p:
            self.load(prev=True)
        elif evt.keyval == Gdk.KEY_l:
            self.mark_direct()
        elif evt.keyval == Gdk.KEY_r:
            self.remove_all()
        elif evt.keyval == Gdk.KEY_s:
            self.save_region()
        elif evt.keyval == Gdk.KEY_x:
            self.hint_screen()
        self.queue_draw()

    def save_labels(self):
        with open(self.descname, 'w') as outf:
            for itemid in sorted(self.tree):
                node = self.tree[itemid]
                if 'label' in node:
                    outf.write("%s %s\n" % (itemid, node['label']))

    def get_text(self, title, prompt):
        #base this on a message dialog
        dialog = Gtk.MessageDialog(self, 0, Gtk.MessageType.QUESTION,
                                   Gtk.ButtonsType.OK_CANCEL, title)
        dialog.format_secondary_text(prompt)
        #create the text input field
        entry = Gtk.Entry()
        #allow the user to press enter to do ok
        entry.connect("activate",
                      lambda entry: dialog.response(Gtk.ResponseType.OK))
        #create a horizontal box to pack the entry and a label
        hbox = Gtk.HBox()
        hbox.pack_start(Gtk.Label("Label:"), False, 5, 5)
        hbox.pack_end(entry, True, 0, 0)
        #add it and show it
        dialog.vbox.pack_end(hbox, True, True, 0)
        dialog.show_all()
        #go go go
        response = dialog.run()
        if response == Gtk.ResponseType.OK:
            text = entry.get_text()
        else:
            text = None
        dialog.destroy()
        return text
Beispiel #6
0
class FeatureCollector(object):
    def __init__(self, tree, imgfile):
        self.tree = tree
        self.collect_texts()
        self.imgfile = imgfile
        self.tesapi = PyTessBaseAPI(lang='eng')
        self.set_tes_image()

    def set_tes_image(self):
        #imgpil = Image.fromarray(numpy.uint8(imgdata * 255))
        imgpil = Image.open(self.imgfile)
        (self.imgwidth, self.imgheight) = (imgpil.width, imgpil.height)
        imgpil = imgpil.convert("RGB").resize(
            (imgpil.width * OCR_RATIO, imgpil.height * OCR_RATIO))
        self.tesapi.SetImage(imgpil)

    def add_ctx_attr(self, ctx, data, attr_re, word_limit=1000):
        if ctx not in self.point:
            self.point[ctx] = ''
        words = attr_re.findall(data)
        if word_limit and len(words) > word_limit:
            return
        for word in words:
            if self.point[ctx]:
                self.point[ctx] += ' '
            self.point[ctx] += '%s' % word.lower()

    def add_ctx(self, ctx, node, attrs, attr_re=anything_re, word_limit=None):
        for attr in attrs:
            self.add_ctx_attr(ctx, node[attr], attr_re, word_limit)

    def collect_texts(self):
        for nodeid in self.tree:
            node = self.tree[nodeid]
            if 'fulltext' not in node:
                self.collect_text(node)

    def collect_text(self, node):
        """ Collect text from node and all its children """
        if 'fulltext' in node:
            return node['fulltext']

        cls = node['class']
        if cls == 'View':
            text = node['desc']
        else:
            text = node['text']

        for child in node['children']:
            text = text.strip() + ' ' + self.collect_text(self.tree[child])

        node['fulltext'] = text
        return text

    def prepare_neighbour(self):
        node = self.tree[self.nodeid]
        self.point['neighbour_ctx'] = ''
        self.point['adj_ctx'] = ''
        neighbour_count = 0
        for other in self.tree:
            if self.tree[other]['parent'] == node[
                    'parent'] and other != self.nodeid:
                self.add_ctx_attr('neighbour_ctx',
                                  self.collect_text(self.tree[other]), text_re)
                self.add_ctx('neighbour_ctx', self.tree[other], ['id'], id_re)
                if self.tree[other]['class'] == node['class']:
                    neighbour_count += 1

            # left sibling
            if (self.tree[other]['parent'] == node['parent']
                    and other != self.nodeid
                    and self.tree[other]['childid'] < node['childid']
                    and self.tree[other]['childid'] > node['childid'] - 2):
                self.add_ctx_attr('adj_ctx',
                                  self.collect_text(self.tree[other]), text_re)
                self.add_ctx('adj_ctx', self.tree[other], ['id'], id_re)

        self.point['neighbour_count'] = neighbour_count

    def ctx_append(self, ctx, kind, clz, detail):
        ret = ctx
        ret += ' ' + kind + clz
        regex = word_re
        for part in regex.findall(detail):
            ret += ' ' + kind + part
        return ret

    def collect_subtree_info(self, node, root):
        if ignore_node(node):
            return {'ctx': '', 'text': '', 'count': 0}
        ctx = ''
        count = 1

        clz = node['class']

        if clz == 'View':
            text = node['desc'][:30]
        else:
            text = node['text'][:30]
        desc = node['desc'][:30]

        ctx += clz + ' '
        ctx += node['id'] + ' '
        ctx += text + ' '
        ctx += desc + ' '
        ctx += gen_ngram(text) + ' '
        ctx += gen_ngram(desc) + ' '

        if root is not None:
            if node['width'] > 0.6 * config.width:
                ctx = self.ctx_append(ctx, "WIDE", clz, node['id'])

            if node['height'] > 0.6 * config.height:
                ctx = self.ctx_append(ctx, "TALL", clz, node['id'])

            if node['y'] + node['height'] < root['y'] + 0.3 * root['height']:
                ctx = self.ctx_append(ctx, "TOP", clz, node['id'])

            if node['x'] + node['width'] < root['x'] + 0.3 * root['width']:
                ctx = self.ctx_append(ctx, "LEFT", clz, node['id'])

            if node['y'] > root['y'] + 0.7 * root['height']:
                ctx = self.ctx_append(ctx, "BOTTOM", clz, node['id'])

            if node['x'] > root['x'] + 0.7 * root['width']:
                ctx = self.ctx_append(ctx, "RIGHT", clz, node['id'])

        for child in node['children']:
            child_info = self.collect_subtree_info(
                self.tree[child], root if root is not None else node)
            ctx = ctx.strip() + ' ' + child_info['ctx']
            count += child_info['count']
            text = text.strip() + ' ' + child_info['text']

        return {'ctx': ctx, 'text': text, 'count': count}

    def prepare_children(self):
        node = self.tree[self.nodeid]
        #self.add_ctx_attr('node_subtree_text', self.collect_text(node), text_re, 10)
        subtree_info = self.collect_subtree_info(node, None)
        self.point['subtree'] = subtree_info['ctx']
        self.point['node_subtree_text'] = subtree_info['text']
        self.point['node_childs'] = subtree_info['count']

    def prepare_ancestor(self):
        node = self.tree[self.nodeid]
        parent = node['parent']
        self.point['parent_ctx'] = ''
        parent_click = parent_scroll = parent_manychild = False
        parent_depth = 0
        while parent != 0 and parent != -1 and parent_depth < PARENT_DEPTH_LIMIT:
            self.add_ctx('parent_ctx', self.tree[parent], ['class', 'id'],
                         id_re)
            parent_click |= self.tree[parent]['click']
            parent_scroll |= self.tree[parent]['scroll']
            parent_manychild |= len(self.tree[parent]['children']) > 1
            parent = self.tree[parent]['parent']
            parent_depth += 1

        self.point['parent_prop'] = [
            parent_click, parent_scroll, parent_manychild
        ]

    def prepare_self(self):
        node = self.tree[self.nodeid]
        # AUX info
        self.point['id'] = self.nodeid
        self.point['str'] = util.describe_node(node, None)

        self.add_ctx('node_text', node, ['text'], text_re, 10)
        self.add_ctx('node_ctx', node, ['desc'], text_re)
        self.add_ctx('node_ctx', node, ['id'], id_re)
        self.add_ctx('node_class', node, ['class'], id_re)
        if 'Recycler' in node['class'] or 'ListView' in node['class']:
            self.point['node_class'] += " ListContainer"
        self.point['node_x'] = node['x']
        self.point['node_y'] = node['y']
        self.point['node_w'] = node['width']
        self.point['node_h'] = node['height']

    def prepare_point(self, nodeid, app, scr, caseid, imgdata, treeinfo, path):
        """Convert a node in the tree into a data point for ML"""
        self.nodeid = nodeid
        self.point = {}

        # AUX info
        self.point['app'] = app
        self.point['scr'] = scr
        self.point['case'] = caseid

        self.prepare_self()
        self.prepare_neighbour()
        self.prepare_ancestor()
        self.prepare_global(path, treeinfo)
        self.prepare_img(imgdata)
        self.prepare_ocr()
        self.prepare_children()

        return self.point

    def prepare_global(self, path, treeinfo):
        node = self.tree[self.nodeid]
        has_dupid = False
        is_itemlike = False
        is_listlike = False
        for _id in node['raw']:
            if _id in treeinfo['dupid']:
                has_dupid = True
                break
        for _id in node['raw']:
            if _id in treeinfo['itemlike']:
                is_itemlike = True
                break
        for _id in node['raw']:
            if _id in treeinfo['listlike']:
                is_listlike = True
                break
        self.point['node_prop'] = [
            node['click'], node['scroll'],
            len(node['children']) > 1, has_dupid, is_itemlike, is_listlike
        ]

        self.point['path'] = path

    def prepare_img(self, imgdata):
        node = self.tree[self.nodeid]
        # your widget should be inside the screenshot
        # not always!
        self.min_x = max(node['x'], 0)
        self.min_y = max(node['y'], 0)
        self.max_x = min(node['x'] + node['width'], self.imgwidth)
        self.max_y = min(node['y'] + node['height'], self.imgheight)
        self.empty = self.max_x <= self.min_x or self.max_y <= self.min_y
        self.point['empty'] = self.empty

    def prepare_ocr(self):
        node = self.tree[self.nodeid]
        if config.region_use_ocr and not self.empty:
            if 'ocr' in node:
                ocr_text = node['ocr']
            else:
                self.tesapi.SetRectangle(self.min_x * OCR_RATIO,
                                         self.min_y * OCR_RATIO,
                                         (self.max_x - self.min_x) * OCR_RATIO,
                                         (self.max_y - self.min_y) * OCR_RATIO)
                try:
                    ocr_text = self.tesapi.GetUTF8Text()
                except:
                    logger.warning("tessearact fail to recognize")
                    ocr_text = ''
                ocr_text = ocr_text.strip().replace('\n', ' ')
        else:
            ocr_text = 'dummy'
        self.point['node_ocr'] = ocr_text
        #if point['node_text'].strip() == '':
        #    point['node_text'] = ocr_text
        logger.debug("%s VS %s" % (ocr_text, node['text']))

        (missing, found, other) = (node['ocr_missing'], node['ocr_found'],
                                   node['ocr_other'])
        self.point['ocr_missing'] = missing
        self.point['ocr_found'] = found
        self.point['ocr_other'] = other
        self.point['ocr_ratio'] = (1.0 * missing / (missing + other)
                                   if missing + other > 0 else 0.0)
        self.point['ocr_visible'] = node['visible']
Beispiel #7
0
    image = Image.fromarray(np.uint8(skimg * 255))
    #image = Image.
    #image.save('page0y.png')
    #image = Image.open('page0.png')
    title = ''
    max_h = 0
    min_y = 10000
    api.SetImage(image)
    #api.SetImageFile(filename)
    boxes = api.GetComponentImages(RIL.TEXTLINE, True)
    #print('Found {} textline image components.'.format(len(boxes)))
    for i, (im, box, _, _) in enumerate(boxes):
        # im is a PIL image object
        # box is a dict with x, y, w and h keys
        api.SetRectangle(box['x'], box['y'], box['w'], box['h'])
        ocrResult = api.GetUTF8Text().replace('\n', ' ').strip()
        conf = api.MeanTextConf()
        print("Box[{0}]: x={x}, y={y}, w={w}, h={h}, "
              "confidence: {1}, text: {2}".format(i, conf, ocrResult, **box))

        text = ' '.join(alpha_re.findall(ocrResult.strip()))
        if len(text) < 5:
            continue

        if box['y'] <= 50:
            continue

        if box['y'] < min_y:
            min_y = box['y']
            title = text
Beispiel #8
0
class GameAdaptor:
    def __init__(self, window_name):
        self._window_name = window_name
        self._hwnd = win32gui.FindWindow(None, window_name)
        _options = dict(psm=PSM.SINGLE_LINE, oem=OEM.LSTM_ONLY)
        self._api = PyTessBaseAPI('tessdata', 'eng', **_options)
        self._image = None
        self._lock = Lock()
        self._work = 0
        if self._hwnd == 0:
            raise Exception('Window Handle Not Found! xD')

    def _get_window_region(self):
        bl, bt, br, bb = 12, 31, 12, 20
        l, t, r, b = win32gui.GetWindowRect(self._hwnd)
        w = r - l - br - bl
        h = b - t - bt - bb
        return l, t, w, h, bl, bt

    @contextmanager
    def _window_device_context(self):
        wdc = win32gui.GetWindowDC(self._hwnd)
        dc_obj = win32ui.CreateDCFromHandle(wdc)
        c_dc = dc_obj.CreateCompatibleDC()
        yield dc_obj, c_dc
        dc_obj.DeleteDC()
        c_dc.DeleteDC()
        win32gui.ReleaseDC(self._hwnd, wdc)

    def _capture(self):
        x, y, w, h, bx, by = self._get_window_region()
        with self._window_device_context() as (dc_obj, cdc):
            bmp = win32ui.CreateBitmap()
            bmp.CreateCompatibleBitmap(dc_obj, w, h)
            cdc.SelectObject(bmp)
            cdc.BitBlt((0, 0), (w, h), dc_obj, (bx, by), win32con.SRCCOPY)
            bmp_info = bmp.GetInfo()
            img = np.frombuffer(bmp.GetBitmapBits(True), dtype=np.uint8)
            win32gui.DeleteObject(bmp.GetHandle())
        return img.reshape(bmp_info['bmHeight'], bmp_info['bmWidth'],
                           4)[:, :, :-1]

    def _do_capture(self):
        while self._work == 1:
            temp_image = self._capture()
            self._lock.acquire()
            self._image = temp_image
            self._lock.release()
            sleep(0.001)
        self._work = -1

    def start_capture(self):
        self._work = 1
        Thread(target=self._do_capture).start()
        while self._image is None:
            sleep(0.001)

    def stop_capture(self):
        self._work = 0
        while self._work != -1:
            sleep(0.001)
        self._image = None

    def get_image(self):
        self._lock.acquire()
        res = self._image
        self._lock.release()
        return res

    def send_keys(self, *keys):
        for k in keys:
            win32gui.PostMessage(self._hwnd, win32con.WM_KEYDOWN, k, 0)

    def get_text(self, region):
        temp_pil_image = Image.fromarray(self.get_image())
        self._api.SetImage(temp_pil_image)
        while region is not None:
            x, y, w, h = region
            self._api.SetRectangle(x, y, w, h)
            self._api.Recognize(0)
            region = yield self._api.GetUTF8Text()
Beispiel #9
0
class OCREngine():
    def __init__(self, extra_whitelist='', all_unicode=False, lang='eng'):
        """
        Args:
          extra_whitelist: string of extra chars for Tesseract to consider
              only takes effect when all_unicode is False
          all_unicode: if True, Tess will consider all possible unicode characters
          lang: OCR language
        """
        self.tess = PyTessBaseAPI(psm=PSM_MODE, lang=lang)
        self.is_closed = False
        if all_unicode:
            self.whitelist_chars = None
        else:
            self.whitelist_chars = ("abcdefghijklmnopqrstuvwxyz"
                                    "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
                                    "1234567890"
                                    r"~!@#$%^&*()_+-={}|[]\:;'<>?,./"
                                    '"'
                                    "©") + extra_whitelist
            self.tess.SetVariable('tessedit_char_whitelist',
                                  self.whitelist_chars)

    def check_engine(self):
        if self.is_closed:
            raise RuntimeError('OCREngine has been closed.')

    def recognize(self,
                  image,
                  min_text_size=MIN_TEXT_SIZE,
                  max_text_size=MAX_TEXT_SIZE,
                  uniformity_thresh=UNIFORMITY_THRESH,
                  thin_line_thresh=THIN_LINE_THRESH,
                  conf_thresh=CONF_THRESH,
                  box_expand_factor=BOX_EXPAND_FACTOR,
                  horizontal_pooling=HORIZONTAL_POOLING):
        """
        Generator: Blob
        http://stackoverflow.com/questions/23506105/extracting-text-opencv

        Args:
          input_image: can be one of the following types:
            - string: image file path
            - ndarray: numpy image
            - PIL.Image.Image: PIL image
          min_text_size:
            min text height/width in pixels, below which will be ignored
          max_text_size:
            max text height/width in pixels, above which will be ignored
          uniformity_thresh (0.0 < _ < 1.0):
            remove all black or all white regions
            ignore a region if the number of pixels neither black nor white < [thresh]
          thin_line_thresh (must be odd int):
            remove all lines thinner than [thresh] pixels.
            can be used to remove the thin borders of web page textboxes.
          conf_thresh (0 < _ < 100):
            ignore regions with OCR confidence < thresh.
          box_expand_factor (0.0 < _ < 1.0):
            expand the bounding box outwards in case certain chars are cutoff.
          horizontal_pooling:
            result bounding boxes will be more connected with more pooling,
            but large pooling might lower accuracy.
        """
        self.check_engine()
        # param sanity check
        assert max_text_size > min_text_size > 0
        assert 0.0 <= uniformity_thresh < 1.0
        assert thin_line_thresh % 2 == 1
        assert 0 <= conf_thresh < 100
        assert 0.0 <= box_expand_factor < 1.0
        assert horizontal_pooling > 0

        image = get_np_img(image)
        img_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        img_bw = cv2.adaptiveThreshold(img_gray, 255,
                                       cv2.ADAPTIVE_THRESH_MEAN_C,
                                       cv2.THRESH_BINARY, 11, 5)
        img = img_gray
        # http://docs.opencv.org/3.0-beta/doc/py_tutorials/py_imgproc/py_morphological_ops/py_morphological_ops.html
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
        img = cv2.morphologyEx(img, cv2.MORPH_GRADIENT, kernel)
        # cut off all gray pixels < 30.
        # `cv2.THRESH_BINARY | cv2.THRESH_OTSU` is also good, but might overlook certain light gray areas
        _, img = cv2.threshold(img, 30, 255, cv2.THRESH_BINARY)
        # connect horizontally oriented regions
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT,
                                           (horizontal_pooling, 1))
        img = cv2.morphologyEx(img, cv2.MORPH_CLOSE, kernel)
        # remove all thin textbox borders (e.g. web page textbox)
        if thin_line_thresh > 0:
            kernel = cv2.getStructuringElement(
                cv2.MORPH_RECT, (thin_line_thresh, thin_line_thresh))
            img = cv2.morphologyEx(img, cv2.MORPH_OPEN, kernel)

        # http://docs.opencv.org/trunk/d9/d8b/tutorial_py_contours_hierarchy.html
        _, contours, hierarchy = cv2.findContours(img, cv2.RETR_CCOMP,
                                                  cv2.CHAIN_APPROX_SIMPLE)
        for contour in contours:
            x, y, w, h = box = Box(*cv2.boundingRect(contour))
            # remove regions that are beyond size limits
            if (w < min_text_size or h < min_text_size or h > max_text_size):
                continue
            # remove regions that are almost uniformly white or black
            binary_region = crop(img_bw, box)
            uniformity = np.count_nonzero(binary_region) / float(w * h)
            if (uniformity > 1 - uniformity_thresh
                    or uniformity < uniformity_thresh):
                continue
            # expand the borders a little bit to include cutoff chars
            expansion = int(min(h, w) * box_expand_factor)
            x = max(0, x - expansion)
            y = max(0, y - expansion)
            h, w = h + 2 * expansion, w + 2 * expansion
            if h > w:  # further extend the long axis
                h += 2 * expansion
            elif w > h:
                w += 2 * expansion
            # image passed to Tess should be grayscale.
            # http://stackoverflow.com/questions/15606379/python-tesseract-segmentation-fault-11
            box = Box(x, y, w, h)
            img_crop = crop(img_gray, box)
            # make sure that crops passed in tesseract have minimum x-height
            # http://github.com/tesseract-ocr/tesseract/wiki/FAQ#is-there-a-minimum-text-size-it-wont-read-screen-text
            img_crop = cv2.resize(img_crop,
                                  (int(img_crop.shape[1] * CROP_RESIZE_HEIGHT /
                                       img_crop.shape[0]), CROP_RESIZE_HEIGHT))
            ocr_text, conf = self.run_tess(img_crop)
            if conf > conf_thresh:
                yield Blob(ocr_text, box, conf)

    def _experiment_segment(self,
                            img,
                            min_text_size=MIN_TEXT_SIZE,
                            max_text_size=MAX_TEXT_SIZE,
                            uniformity_thresh=UNIFORMITY_THRESH,
                            horizontal_pooling=HORIZONTAL_POOLING):
        """
        PRIVATE: experiment only
        """
        img_init = img  # preserve initial image
        img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        img_bw = cv2.adaptiveThreshold(img_gray, 255,
                                       cv2.ADAPTIVE_THRESH_MEAN_C,
                                       cv2.THRESH_BINARY, 11, 5)

        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        # http://docs.opencv.org/3.0-beta/doc/py_tutorials/py_imgproc/py_morphological_ops/py_morphological_ops.html
        morph_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
        img = cv2.morphologyEx(img, cv2.MORPH_GRADIENT, morph_kernel)
        disp(img)
        #         morph_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
        #         img = cv2.dilate(img, morph_kernel)
        # OTSU thresholding
        #         _, img = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
        _, img = cv2.threshold(img, 30, 255, cv2.THRESH_BINARY)
        #         img = cv2.adaptiveThreshold(img,255,cv2.ADAPTIVE_THRESH_MEAN_C,cv2.THRESH_BINARY_INV,9,2)
        disp(img)
        # connect horizontally oriented regions
        morph_kernel = cv2.getStructuringElement(cv2.MORPH_RECT,
                                                 (horizontal_pooling, 1))
        img = cv2.morphologyEx(img, cv2.MORPH_CLOSE, morph_kernel)
        disp(img)

        if 0:
            morph_kernel = cv2.getStructuringElement(cv2.MORPH_CROSS,
                                                     (horizontal_pooling, 3))
            img = cv2.erode(img, morph_kernel, iterations=1)
            disp(img)
            morph_kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (6, 6))
            img = cv2.dilate(img, morph_kernel, iterations=1)
        elif 1:
            morph_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 7))
            img = cv2.morphologyEx(img, cv2.MORPH_OPEN, morph_kernel)
        disp(img)

        # http://docs.opencv.org/trunk/d9/d8b/tutorial_py_contours_hierarchy.html
        _, contours, hierarchy = cv2.findContours(img, cv2.RETR_CCOMP,
                                                  cv2.CHAIN_APPROX_SIMPLE)
        img_copy = np.copy(img_init)
        for contour in contours:
            x, y, w, h = cv2.boundingRect(contour)
            draw_rect(img_copy, x, y, w, h)

            if (w < min_text_size or h < min_text_size or h > max_text_size):
                continue

            binary_region = img_bw[y:y + h, x:x + w]
            uniformity = np.count_nonzero(binary_region) / float(w * h)
            if (uniformity > 1 - uniformity_thresh
                    or uniformity < uniformity_thresh):
                # ignore mostly white or black regions
                #                 print(w, h)
                #                 disp(binary_region)
                continue
            # the image must be grayscale, otherwise Tesseract will SegFault
            # http://stackoverflow.com/questions/15606379/python-tesseract-segmentation-fault-11
            draw_rect(img_init, x, y, w, h)
        disp(img_copy)
        disp(img_init, 0)

    def run_tess(self, img):
        """
        Tesseract python API source code:
        https://github.com/sirfz/tesserocr/blob/master/tesserocr.pyx

        Returns:
          (ocr_text, confidence)
        """
        if isinstance(img, np.ndarray):
            img = np2PIL(img)
        self.tess.SetImage(img)
        ocr_text = self.tess.GetUTF8Text().strip()
        conf = self.tess.MeanTextConf()
        return ocr_text, conf

    def _deprec_run_tess(self, img):
        "GetComponentImages throws SegFault randomly. No way to fix. :("
        if isinstance(img, np.ndarray):
            img = np2PIL(img)

        components = self.tess.GetComponentImages(RIL.TEXTLINE, True)
        for _, inner_box, block_id, paragraph_id in components:
            # box is a dict with x, y, w and h keys
            inner_box = Box(**inner_box)
            if inner_box.w < MIN_TEXT_SIZE or inner_box.h < MIN_TEXT_SIZE:
                continue
            self.tess.SetRectangle(*inner_box)
            ocr_text = self.tess.GetUTF8Text().strip()
            conf = self.tess.MeanTextConf()
            yield ocr_text, inner_box, conf

    def close(self):
        self.tess.End()
        self.is_closed = True

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.close()