Пример #1
0
class TextExtractor:
    def __init__(self, image_path, seg_mode=PSM.SPARSE_TEXT):
        self.api = PyTessBaseAPI()
        self.api.SetPageSegMode(seg_mode)
        self.api.SetImageFile(image_path)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def _extract(self) -> Tuple:
        text = self.api.GetUTF8Text()
        conf = self.api.MeanTextConf()
        return text, conf

    def _extract_from_rect(self, x, y, w, h) -> Tuple:
        self.api.SetRectangle(x, y, w, h)
        return self._extract()

    #TODO: Add support of zero values
    def extract(self, x=None, y=None, w=None, h=None) -> Tuple:
        if all([x, y, w, h]):
            return self._extract_from_rect(x, y, w, h)
        else:
            return self._extract()

    def close(self):
        self.api.End()
Пример #2
0
def read_text_with_confidence(image,
                              lang='fast_ind',
                              path='/usr/share/tesseract-ocr/5/tessdata',
                              psm=4,
                              whitelist=''):
    height, width = image.shape[:2]

    if height <= 0 or width <= 0:
        return '', 0

    image_pil = Image.fromarray(image)

    api = PyTessBaseAPI(lang=lang, psm=psm, path=path, oem=OEM.LSTM_ONLY)

    try:
        api.SetImage(image_pil)

        if whitelist != '':
            api.SetVariable('tessedit_char_whitelist', whitelist)

        api.Recognize()

        text = api.GetUTF8Text()
        confidence = api.MeanTextConf()
    except Exception:
        print("[ERROR] Tesseract exception")
    finally:
        api.End()

    return text, confidence
Пример #3
0
    def extract_text_from_image(self, data):
        """Extract text from a binary string of data."""
        tessdata = '/usr/share/tesseract-ocr/'
        tessdata = self.manager.get_env('TESSDATA_PREFIX', tessdata)
        languages = self.get_languages(self.result.languages)

        key = sha1(data)
        key.update(languages)
        key = key.hexdigest()
        text = self.manager.get_cache(key)
        if text is not None:
            return text

        api = PyTessBaseAPI(lang=languages, path=tessdata)
        try:
            image = Image.open(StringIO(data))
            # TODO: play with contrast and sharpening the images.
            api.SetImage(image)
            text = api.GetUTF8Text()
        except DecompressionBombWarning as dce:
            log.warning("Image too large: %r", dce)
            return None
        except IOError as ioe:
            log.warning("Unknown image format: %r", ioe)
            return None
        finally:
            api.Clear()

        log.debug('[%s] OCR: %s, %s characters extracted', self.result,
                  languages, len(text))
        self.manager.set_cache(key, text)
        return text
Пример #4
0
def lambda_handler(event, context):
    bucket = event['Records'][0]['s3']['bucket']['name']
    key = urllib.unquote_plus(event['Records'][0]['s3']['object']['key']).decode('utf8')
    shutil.copyfile("tesseract", tmp_dir + '/tesseract')
    shutil.copyfile("test.png", tmp_dir + '/test.png')
    os.chmod(tmp_dir + "/tesseract", 0755)
    os.chmod(tmp_dir, 0755)
    os.chmod('/tmp', 0755)
    print("before image file from s3")
    image_file = download_file(bucket, key)
    print("before image file to PIL")
    print("before OCR")
    result_file = tesseract(image_file)
    print("Print files in firectory")
    for file in os.listdir('/tmp'):
        print(file)
    try:
        print("before PyTessBaseAPI set 2")
        api = PyTessBaseAPI(path=os.path.join(SCRIPT_DIR, 'tessdata'), lang='eng',psm=PSM.AUTO_OSD)
        print("After API set")
        api.SetImageFile(image_file)
        print("After API set image")
        print("TEXT from tesserocr: %s" % api.GetUTF8Text())
        print("CONFIDENCE from tesserocr: %s" % api.AllWordConfidences())
    except Exception:
        pass
Пример #5
0
def get_lines(filename):
    '''
    Args
    ::filename (str): Image file relative or absolute path.

    Return:
    ::list: List of lines as text from the image. Every line contain the stop
      times for a certain trip.
    '''
    api = PyTessBaseAPI()
    api.SetImageFile(filename)
    text = api.GetUTF8Text()

    textual_lines = []
    line = ''
    line_num = 0
    for char in text:
        line += char
        if char == "\n":
            # ignore lines with less than 5 chars (H:MM)
            if len(line) < 5:
                line = ''
                continue
            else:
                line_num += 1
                # debug
                # print('linea: "', line, '" numero: ', line_num,
                # q      'largo de linea: ', len(line))
                textual_lines.append(line)
                line = ''

    return textual_lines
Пример #6
0
def handler(event, context):
  api = PyTessBaseAPI()
  api.SetImageFile("sample.jpg")
  txt = api.GetUTF8Text()
  logging.info(txt)
  logging.info(api.AllWordConfidences())
  return txt
Пример #7
0
class Ocr:
    def __init__(self):
        self.api = None

    def __enter__(self):
        self.api = PyTessBaseAPI().__enter__()
        self.api.SetVariable('tessedit_char_whitelist',
                             OCR_CHARACTER_WHITELIST)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.api.__exit__(exc_type, exc_val, exc_tb)

    def get_characters(self, image):
        h, w = image.shape[:2]

        if h < 1 or w < 1:
            raise NoImageError()

        img_pil = Image.fromarray(image)
        self.api.SetImage(img_pil)
        cell_text = self.api.GetUTF8Text().strip()
        confidence = self.api.MeanTextConf()

        return cell_text, confidence
Пример #8
0
 async def get_ocr_text(self, thresh_img):
     # OCRにかけられるように画像変換
     pil_img = Image.fromarray(thresh_img)
     # OCR処理
     api = PyTessBaseAPI(psm=PSM.AUTO, lang='jpn')
     api.SetImage(pil_img)
     # 空白文字と改行を除去して結果を返却
     return api.GetUTF8Text().replace(' ', '').replace('\n', '')
Пример #9
0
def add_ocrinfo(tree, imgfile):
    imgpil = Image.open(imgfile)
    (orig_width, orig_height) = (imgpil.width, imgpil.height)

    #root_width = tree[min(tree)]['width']
    ratio = 1.0 * orig_width / config.width
    #imgpil = imgpil.convert("RGB").resize(
    #    (orig_width * OCR_RATIO, orig_height * OCR_RATIO))

    tesapi = PyTessBaseAPI(lang='eng')
    tesapi.SetImage(imgpil)
    tesapi.SetSourceResolution(config.ocr_resolution)

    for nodeid in tree:
        node = tree[nodeid]

        if node['children'] and node['text'] == '':
            node['ocr'] = ''
            continue

        x = max(node['x'] * ratio - 1, 0)
        y = max(node['y'] * ratio - 1, 0)
        x2 = min((node['x'] + node['width']) * ratio + 1, orig_width)
        y2 = min((node['y'] + node['height']) * ratio + 1, orig_height)
        width = int(x2 - x)
        height = int(y2 - y)

        if width > 3 and height > 3:
            #tesapi.SetRectangle(int(x * OCR_RATIO), int(y * OCR_RATIO),
            #                    int(width * OCR_RATIO), int(height * OCR_RATIO))
            #print(int(x), int(y), int(width), int(height), orig_width, orig_height)
            tesapi.SetRectangle(int(x), int(y), int(width), int(height))
            ocr_text = tesapi.GetUTF8Text().strip().replace('\n', ' ')
            if ocr_text.strip() == '':
                x = min(x + width * 0.05, orig_width)
                y = min(y + height * 0.05, orig_height)
                width *= 0.9
                height *= 0.9
                tesapi.SetRectangle(int(x), int(y), int(width), int(height))
                ocr_text = tesapi.GetUTF8Text().strip().replace('\n', ' ')

        else:
            ocr_text = ''

        node['ocr'] = ocr_text
Пример #10
0
def read_char(image, whitelist=None):
    """ OCR a single character from an image. Useful for captchas."""
    api = PyTessBaseAPI()
    api.SetPageSegMode(10)
    if whitelist is not None:
        api.SetVariable("tessedit_char_whitelist", whitelist)
    api.SetImage(image)
    api.Recognize()
    return api.GetUTF8Text().strip()
Пример #11
0
class OCR(object):
    MAX_MODELS = 5
    MIN_WIDTH = 10
    MIN_HEIGHT = 10

    def __init__(self):
        # Tesseract language types:
        _, self.supported = get_languages()
        self.reset_engine('eng')

    def language_list(self, languages):
        models = [c for c in alpha3(languages) if c in self.supported]
        if len(models) > self.MAX_MODELS:
            log.warning("Too many models, limit: %s", self.MAX_MODELS)
            models = models[:self.MAX_MODELS]
        models.append('eng')
        return '+'.join(sorted(set(models)))

    def reset_engine(self, languages):
        if hasattr(self, 'api'):
            self.api.Clear()
            self.api.End()
        self.api = PyTessBaseAPI(lang=languages, oem=OEM.LSTM_ONLY)

    def extract_text(self, data, languages=None, mode=PSM.AUTO_OSD):
        """Extract text from a binary string of data."""
        languages = self.language_list(languages)
        if languages != self.api.GetInitLanguagesAsString():
            self.reset_engine(languages)

        try:
            image = Image.open(BytesIO(data))
            # TODO: play with contrast and sharpening the images.
            if image.width <= self.MIN_WIDTH:
                return
            if image.height <= self.MIN_HEIGHT:
                return

            if mode != self.api.GetPageSegMode():
                self.api.SetPageSegMode(mode)

            self.api.SetImage(image)
            text = self.api.GetUTF8Text()
            confidence = self.api.MeanTextConf()
            log.info("%s chars (w: %s, h: %s, langs: %s, confidence: %s)",
                     len(text), image.width, image.height, languages,
                     confidence)
            return text
        except Exception as ex:
            log.exception("Failed to OCR: %s", languages)
        finally:
            self.api.Clear()
Пример #12
0
 def getText(self):
     file = open("test.png", "wb")
     file.write(self.document_text)
     file.close()
     text = ""
     out, err = Popen(
         'python ../../models/tutorials/image/imagenet/classify_image.py --image_file test.png',
         shell=True,
         stdout=PIPE).communicate()
     text += out.decode("utf-8")
     api = PyTessBaseAPI()
     api.SetImageFile("test.png")
     text += api.GetUTF8Text()
     os.remove("test.png")
     return text
Пример #13
0
def preprocess_title(filename):
    title = ''
    api = PyTessBaseAPI()
    api.SetImageFile(filename)
    boxes = api.GetComponentImages(RIL.TEXTLINE, True)
    for i, (im, box, _, _) in enumerate(boxes):
        api.SetRectangle(box['x'], box['y'], box['w'], box['h'])
        ocrResult = api.GetUTF8Text()
        text = ' '.join(alpha_re.findall(ocrResult.strip()))
        if len(text) < 5:
            continue

        title = text
        break

    if title:
        logger.info("%s: %s", filename, title)
    return title
Пример #14
0
    def getText(self):
        file = open("test.pdf", "wb")
        file.write(self.document_text)
        file.close()
        rsrcmgr = PDFResourceManager()
        retstr = StringIO()
        codec = 'utf-8'
        laparams = LAParams()
        device = TextConverter(rsrcmgr, retstr, codec=codec, laparams=laparams)
        try:
            fp = open("test.pdf", 'rb')
            interpreter = PDFPageInterpreter(rsrcmgr, device)
            password = ""
            maxpages = 0
            caching = True
            pagenos = set()

            for page in PDFPage.get_pages(fp,
                                          pagenos,
                                          maxpages=maxpages,
                                          password=password,
                                          caching=caching,
                                          check_extractable=True):
                interpreter.process_page(page)

            text = retstr.getvalue()

            fp.close()
            device.close()
            retstr.close()
            text = "".join(text.split("\n"))
            os.remove("test.pdf")
            return text
        except:
            text = ""
            with Image(filename="test.pdf") as img:
                img.save(filename="kek.png")
            for file in os.listdir(os.curdir):
                if file.endswith(".png") and file.startswith("kek"):
                    api = PyTessBaseAPI()
                    api.SetImageFile(file)
                    text += api.GetUTF8Text()
                    os.remove(file)
            return text
Пример #15
0
def tess_ocr(img):
    """Get text from an image.

    Args:
        img: The file path of image.

    Returns:
        A string.
    Raises:
        IOError: An error occurred accessing the img object.

    """
    with c_locale():
        from tesserocr import PyTessBaseAPI, PSM
        api = PyTessBaseAPI(lang='chi_sim', psm=PSM.AUTO_OSD)
        api.SetImageFile(img)
        text = api.GetUTF8Text()
        api.End()
    return text
Пример #16
0
def read_word(image, whitelist=None, chars=None, spaces=False):
    """ OCR a single word from an image. Useful for captchas.
        Image should be pre-processed to remove noise etc. """
    api = PyTessBaseAPI()
    api.SetPageSegMode(8)
    if whitelist is not None:
        api.SetVariable("tessedit_char_whitelist", whitelist)
    api.SetImage(image)
    api.Recognize()
    guess = api.GetUTF8Text()

    if not spaces:
        guess = ''.join([c for c in guess if c != " "])
        guess = guess.strip()

    if chars is not None and len(guess) != chars:
        return guess, None

    return guess, api.MeanTextConf()
Пример #17
0
def runTessTest(folderpath):
    images = [x for x in os.listdir(folderpath) if x[-3:] == 'png']
    images.sort()
    actualLabels = [x[x.rfind('_') + 1:-4] for x in images]
    letterCorrectCounts = [0 for x in actualLabels]
    countDict = dict(Counter(actualLabels))
    correctDict = \
        {actualLabels[i]:letterCorrectCounts[i] for i in range(len(actualLabels))}
    api = PyTessBaseAPI(lang='frk', psm=10)
    correctCount = 0
    for i in range(len(images)):
        img = images[i]
        api.SetImageFile(img)
        predictLabel = api.GetUTF8Text().rstrip()
        if actualLabels[i] == predictLabel:
            correctCount += 1
            correctDict[predictLabel] += 1
    accuracy = correctCount / len(images)
    accuracyDict = \
        {x:round(correctDict[x]/countDict[x], 3) for x in countDict}
    return (accuracy, accuracyDict)
Пример #18
0
class OcrWrapper(BaseImageToString):

    _OPTIONS = ('tessedit_char_whitelist', '0123456789ABCDEF.-')

    def __init__(self):
        if sys.platform == 'win32':
            self._ocr = PyTessBaseAPI(
                path="C:\\Program Files\\Tesseract-OCR\\tessdata")
        else:
            self._ocr = PyTessBaseAPI()

        self._ocr.SetVariable(self._OPTIONS[0], self._OPTIONS[1])
        pass

    def image_to_string(self, image: Image) -> str:
        image.format = 'PNG'
        self._ocr.SetImage(image)
        raw_data = self._ocr.GetUTF8Text()
        return raw_data

    def end(self):
        self._ocr.End()
Пример #19
0
    def read_box(self, crop, filtered, read_primes, text, table, old_read_primes, num):
        cur_time = datetime.now().strftime(self.datetime_format)
        #print("reading box")
        api = None
        try:
            api = self.api.get(block=False)
            #print("reading box")
        except queue.Empty:
            api = PyTessBaseAPI()

        api.SetImage(Image.fromarray(filtered[crop[1]:crop[3], crop[0]:crop[2]]))
        ocr_output = api.GetUTF8Text()
        self.api.put(api)
        #self.log.write("{}: Succeeded reading image x={} num_api={}\n".format(cur_time, crop[0], self.api.qsize()))
        #self.log.flush()

        sanitized = self.sanitize(ocr_output)
        ocr_text = self.title_case(sanitized)
        #self.log.write("{}: ocr text={}\n".format(cur_time, ocr_text))
        text[crop[0] + crop[1]] = ocr_text
        dict_text = self.dict_match(ocr_text)
        self.update_table(dict_text, table, read_primes, old_read_primes)
    def extract(self, page: Poppler.Page):
        from tesserocr import PyTessBaseAPI  # NOQA Stupid assert on LC_* == 'C'

        ocr = PyTessBaseAPI(lang=settings.CAMPUSONLINE_BULLETIN_OCR_LANGUAGE)
        text = page.text(QRectF()).strip()
        if len(text) > settings.CAMPUSONLINE_BULLETIN_OCR_THRESHOLD:
            self.clean = True
            self.text = text
            return
        dpi = settings.CAMPUSONLINE_BULLETIN_OCR_DPI
        buf = QBuffer()
        buf.open(QIODevice.ReadWrite)
        page.renderToImage(dpi, dpi).save(buf, "PNG")
        bio = BytesIO()
        bio.write(buf.data())
        buf.close()
        bio.seek(0)
        img = Image.open(bio)
        ocr.SetImage(img)
        scanned = ocr.GetUTF8Text().strip()
        img.close()
        bio.close()
        self.clean = False
        self.text = scanned
Пример #21
0
class VariationParser:
    def __init__(self):
        self.tesseract = PyTessBaseAPI(path='./', psm=PSM.SINGLE_LINE)
        self.item_db = json.load(open('en-us-var.json', 'rb'))

        self.items: Set[str] = set()
        self.active_section = 0
        self.section_name = None
        self.for_sale = False

        self._tesseract_cache = {}
        self._item_cache = {}

    def annotate_frame(self, frame: numpy.ndarray) -> numpy.ndarray:
        """Parses various parts of a frame for catalog items and annotates it."""

        # Detect whether we are in in Nook Shopping catalog.
        if not numpy.array_equal(frame[500, 20], (182, 249, 255)):
            text = 'Navigate to Nook catalog to start!'
            opts = {
                'org': (200, 70),
                'fontFace': cv2.FONT_HERSHEY_PLAIN,
                'lineType': cv2.LINE_AA,
                'fontScale': 3
            }
            frame = cv2.putText(frame,
                                text,
                                color=(0, 0, 0),
                                thickness=7,
                                **opts)
            return cv2.putText(frame,
                               text,
                               color=(100, 100, 255),
                               thickness=3,
                               **opts)

        # Show controls on screen.
        cv2.rectangle(frame, (0, 0), (300, 130), (106, 226, 240), -1)
        for i, line in enumerate(TUTORIAL_LINES):
            if line.startswith('F'):
                line += ' (%s)' % ('ON' if self.for_sale else 'OFF')
            frame = cv2.putText(frame, line, (30, 25 + i * 30), 0, 0.8, 0, 2)

        # Show user the item count at the bottom.
        count_text = 'Item count: %d' % len(self.items)
        if not self.section_name:
            count_text = 'Items saved to disk'
        frame = cv2.putText(frame, count_text, (500, 700), 0, 1, 0)

        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        section = numpy.nonzero(gray[20, 250:] == 156)[0]
        if section.any() and section[0] != self.active_section:
            # Grab the new section name
            x1, *_, x2 = 250 + section
            section_region = 255 - gray[8:32, x1 + 5:x2 - 5]
            self.section_name = self.image_to_text(section_region)

            # Reset item selection on section change
            self.active_section = section[0]
            self.items = set()
        elif not self.active_section:
            return frame  # Return early if not section is found.

        item_name = None
        variation_name = None

        selected = self.get_selected_item(frame)
        if not selected:  # Quit early if not item is selected
            return frame

        price_region = gray[selected.y1:selected.y2, 1070:1220]
        if self.for_sale and price_region.min() > 100:
            # Skip items not for sale
            p1, p2 = (selected.x1, selected.y1 + 20), (selected.x2,
                                                       selected.y1 + 20)
            return cv2.line(frame, p1, p2, color=(0, 0, 255), thickness=2)

        # Parse item name and display a rectangle around it.
        item_name = self.image_to_text(gray[selected.slice])
        frame = cv2.putText(frame, item_name, selected.p1, 0, 1, 0)

        # Parse variation and draw rectangle around it if there is one.
        variation = self.get_variation(gray)
        if variation:
            frame = cv2.rectangle(frame, variation.p1, variation.p2, 0)
            variation_name = self.image_to_text(gray[variation.slice])
            frame = cv2.putText(frame, variation_name, variation.p1, 1, 2, 0)

        # Match the name and optional variation against database and register it.
        full_name = self.resolve_name(item_name, variation_name)
        if full_name:
            self.items.add(full_name)

        return frame

    def get_selected_item(self, frame: numpy.ndarray) -> Optional[Rectangle]:
        """Returns the rectangle around the selected item name if there is one."""
        # Search for the yellow selected region along the item list area.
        select_region = numpy.nonzero(frame[140:640, 1052, 0] < 100)[0]
        if not select_region.any():
            return None

        rect = Rectangle(x1=635, x2=1050)

        # Find the top/bottom boundaries of the selected area
        rect.y1 = 140 + select_region[0] + 8
        rect.y2 = 140 + select_region[-1] - 4

        if rect.y2 - rect.y1 < 35:
            return None

        # Detect to width of the name by collapsing along the x axis
        # and finding the right-most dark pixel (text).
        item_region = frame[rect.y1:rect.y2, rect.x1:rect.x2, 1]
        detected_text = numpy.nonzero(item_region.min(axis=0) < 50)[0]
        if not detected_text.any():
            return None

        rect.x2 = 635 + detected_text[-1] + 10
        return rect

    def get_variation(self, gray: numpy.ndarray) -> Optional[Rectangle]:
        """Returns the rectangle around the variation text if there is one."""
        # There's a white box if the item has a variation.
        if gray[650, 25] != 250:
            return None

        variation = Rectangle(x1=30, y1=628, y2=665)
        # Find the width of tqhe variation box by horizontal floodfill.
        variation.x2 = numpy.argmax(gray[variation.y1, :] < 250) - 15
        return variation

    def resolve_name(self, item: Optional[str],
                     variation: Optional[str]) -> Optional[str]:
        """Resolves an item and optional variation name against the item database."""
        key = (item, variation)
        if key not in self._item_cache:
            item = best_match(item, self.item_db)
            variation = best_match(variation, self.item_db.get(item))
            if variation:
                self._item_cache[key] = f'{item} [{variation}]'
            elif item and not self.item_db[item]:
                self._item_cache[key] = item
            else:
                self._item_cache[key] = None
        return self._item_cache[key]

    def image_to_text(self, text_area: numpy.ndarray) -> str:
        """Runs OCR over a given image and returns the parsed text."""
        img_hash = str(cv2.img_hash.averageHash(text_area)[0])
        if img_hash not in self._tesseract_cache:
            image = Image.fromarray(text_area)
            self.tesseract.SetImage(image)
            text = self.tesseract.GetUTF8Text().strip()
            self._tesseract_cache[img_hash] = text
        return self._tesseract_cache[img_hash]

    def save_items(self) -> None:
        """"Saves the collected items to a text file on disk and clears list."""
        if not self.items or not self.section_name:
            return

        date = datetime.datetime.now().strftime('%d-%m-%Y %H-%M-%S')
        with open(f'{self.section_name} ({date}).txt', 'w') as fp:
            fp.write('\n'.join(sorted(self.items)))

        self.section_name = None
        self.items = set()
class TesseractOCR:

    #private static   TESSERACT_ENGINE_MODE = TessAPI1.TessOcrEngineMode.OEM_DEFAULT

    #
    # bpp - bits per pixel, represents the bit depth of the image, with 1 for
    # binary bitmap, 8 for gray, and 24 for color RGB.
    #
    BBP = 8
    DEFAULT_CONFIDENT_THRESHOLD = 60.0
    MINIMUM_DESKEW_THRESHOLD = 0.05

    def __init__(self, rgbaImage, dipCalculator, language):
        self.mRgbaImage = rgbaImage
        self.mDipCalculator = dipCalculator
        self.mHandle = PyTessBaseAPI()

        self.mOcrTextWrappers = []
        self.mOcrBlockWrappers = []
        self.mOcrLineWrappers = []
        self.raWrappers = []
        #         self.mLanguage = language

        self.mBufferedImageRgbaImage = Image.fromarray(self.mRgbaImage)
        self.initOCR()

    def baseInit(self, iteratorLevel):
        width = 0
        height = 0
        channels = 1

        if len(self.mRgbaImage.shape) == 2:
            height, width = self.mRgbaImage.shape
        else:
            height, width, channels = self.mRgbaImage.shape

        return self.baseInitIter(self.mRgbaImage, Rect(0, 0, width, height),
                                 channels, iteratorLevel)

    def baseInitIter(self, imageMat, rect, channels, iteratorLevel):
        listdata = []
        parentX = rect.x
        parentY = rect.y
        #        subMat = imageMat[rect.y:rect.y+rect.height, rect.x:rect.width+rect.x]
        #
        #        if(channels != 1):
        #            subMat = imageMat[rect.y:rect.y+rect.height, rect.x:rect.width+rect.x, 0:channels]

        #tessAPI = PyTessBaseAPI()
        #Convert to PIL image
        imgPIL = Image.fromarray(imageMat)
        self.mHandle.SetImage(imgPIL)
        boxes = self.mHandle.GetComponentImages(iteratorLevel, True)

        for i, (im, box, _, _) in enumerate(boxes):

            wrapper = OCRTextWrapper.OCRTextWrapper()
            self.mHandle.SetRectangle(box['x'], box['y'], box['w'], box['h'])
            ocrResult = self.mHandle.GetUTF8Text()
            wrapper.text = ocrResult
            conf = self.mHandle.MeanTextConf()
            wrapper.confidence = conf
            self.mHandle.Recognize()
            iterator = self.mHandle.GetIterator()
            fontAttribute = iterator.WordFontAttributes()
            wrapper.x = box['x'] + parentX
            wrapper.y = box['y'] + parentY
            wrapper.width = box['w']
            wrapper.height = box['h']
            wrapper.rect = Rect(wrapper.x, wrapper.y, wrapper.width,
                                wrapper.height)
            #            print(box)
            #
            if (fontAttribute != None):
                wrapper.fontName = fontAttribute['font_name']
                wrapper.bold = fontAttribute['bold']
                wrapper.italic = fontAttribute['italic']
                wrapper.underlined = fontAttribute['underlined']
                wrapper.monospace = fontAttribute['monospace']
                wrapper.serif = fontAttribute['serif']
                wrapper.smallcaps = fontAttribute['smallcaps']
                wrapper.fontSize = fontAttribute['pointsize']
                wrapper.fontId = fontAttribute['font_id']

            listdata.append(wrapper)

        return listdata

    def getBlockWithLocation(self, rect):
        wrappers = []
        for ocrTextWrapper in self.mOcrBlockWrappers:
            bound = ocrTextWrapper.rect
            if (RectUtil.contains(rect, bound)):
                wrappers.append(OCRTextWrapper.OCRTextWrapper(ocrTextWrapper))

        return wrappers

    def getImage(self, rect):
        x2 = rect.x + rect.width
        y2 = rect.y + rect.height
        mat = self.mRgbaImage[rect.y:y2, rect.x:x2]
        return Image.fromarray(mat)

    def getText(self, rect):
        try:
            self.mHandle.SetImage(self.mBufferedImageRgbaImage)
            self.mHandle.SetRectangle(rect.x, rect.y, rect.width, rect.height)
            text = self.mHandle.GetUTF8Text()
            return text
        except Exception as error:
            print('Caught this error: ' + repr(error))

        return ""

    def getLineText(self, rect):
        try:
            self.mHandle.SetImage(self.mBufferedImageRgbaImage)
            self.mHandle.SetRectangle(rect.x, rect.y, rect.width, rect.height)
            text = self.mHandle.GetUTF8Text()
            if (TextUtils.isEmpty(text)):
                self.mHandle = PyTessBaseAPI(psm=PSM.SINGLE_LINE)
                self.mHandle.SetImage(self.mBufferedImageRgbaImage)
                self.mHandle.SetRectangle(rect.x, rect.y, rect.width,
                                          rect.height)
                text = self.mHandle.GetUTF8Text()
                if (TextUtils.isEmpty(text)):
                    self.mHandle.SetImage(self.getImage(rect))
                    text = self.mHandle.GetUTF8Text()

                self.mHandle = PyTessBaseAPI(psm=PSM.AUTO)
            return text
        except Exception as error:
            print('Caught this error: ' + repr(error))

        return ""

    def getRectWordForLowConfidence(self, ocr):
        try:
            rect = ocr.bound()
            self.mHandle = PyTessBaseAPI(psm=PSM.SINGLE_WORD)
            self.mHandle.SetImage(self.mBufferedImageRgbaImage)
            self.mHandle.SetRectangle(rect.x, rect.y, rect.width, rect.height)
            ocr.text = self.mHandle.GetUTF8Text()
            ocr.confidence = self.mHandle.MeanTextConf()
            if (ocr.confidence <= Constants.TEXT_CONFIDENT_THRESHOLD):
                self.mHandle.SetImage(self.getImage(rect))
                ocr.text = self.mHandle.GetUTF8Text()
                ocr.confidence = self.mHandle.MeanTextConf()
            if (ocr.confidence <= Constants.TEXT_CONFIDENT_THRESHOLD):
                return False
            self.mHandle.Recognize()
            iterator = self.mHandle.GetIterator()
            fontAttribute = iterator.WordFontAttributes()
            if (fontAttribute != None):
                ocr.fontName = fontAttribute['font_name']
                ocr.bold = fontAttribute['bold']
                ocr.italic = fontAttribute['italic']
                ocr.underlined = fontAttribute['underlined']
                ocr.monospace = fontAttribute['monospace']
                ocr.serif = fontAttribute['serif']
                ocr.smallcaps = fontAttribute['smallcaps']
                ocr.fontSize = fontAttribute['pointsize']
                ocr.fontId = fontAttribute['font_id']
#                ocr.fontsize = self.getPreferenceFontSize(ocr)

            self.mHandle = PyTessBaseAPI(psm=PSM.AUTO)
            return True
        except Exception as error:
            print('Caught this error: ' + repr(error))

        return False

    def getWordsIn(self, rect):
        wrappers = []
        for ocrTextWrapper in self.mOcrTextWrappers:
            bound = ocrTextWrapper.bound()
            if (RectUtil.contains(rect, bound)):
                wrappers.append(OCRTextWrapper.OCRTextWrapper(ocrTextWrapper))

        return wrappers

    def initOCR(self):

        #
        self.initText()

        self.initBlock()
        #        self.initPara()
        self.initLine()
#

    def initBlock(self):
        self.mOcrBlockWrappers = self.baseInit(RIL.BLOCK)

    def initLine(self):
        self.mOcrLineWrappers = self.baseInit(RIL.TEXTLINE)
        invalidLineWrappers = []
        # a line cannot contain another lines
        for ocrLine in self.mOcrLineWrappers:
            for otherOcrLine in self.mOcrLineWrappers:
                if (ocrLine != otherOcrLine and RectUtil.contains(
                        ocrLine.bound(), otherOcrLine.bound())):
                    invalidLineWrappers.append(ocrLine)
        self.mOcrLineWrappers = [
            x for x in self.mOcrLineWrappers if x not in invalidLineWrappers
        ]

    def initPara(self):
        self.mOcrParaWrappers = self.baseInit(RIL.PARA)

    def initText(self):
        self.mOcrTextWrappers = self.baseInit(RIL.WORD)

    def isOverlapText(self, rect, confident):
        for ocrTextWrapper in self.mOcrTextWrappers:
            bound = ocrTextWrapper.bound()
            if (ocrTextWrapper.getConfidence() >= confident
                    and RectUtil.intersects(rect, bound)):
                return True
        return False

    def reset(self):
        self.mOcrTextWrappers = []
        self.mOcrLineWrappers = []
        self.initOCR()

#    def rotateImage(bi) :
#        iden = ImageDeskew(bi)
#        imageSkewAngle = iden.getSkewAngle() # determine skew angle
#        if imageSkewAngle > MINIMUM_DESKEW_THRESHOLD or imageSkewAngle < -MINIMUM_DESKEW_THRESHOLD :
#            bi = ImageHelper.rotateImage(bi, -imageSkewAngle) # deskew
#        return bi

    def getPreferenceFontSize(self, ocrTextWrapper, parentHeight):

        #        TODO TODO
        fontName = ocrTextWrapper.fontName
        fontSize = ocrTextWrapper.fontSize

        height = ocrTextWrapper.height * Constants.TEXT_BOX_AND_TEXT_HEIGHT_RATIO

        #        height = ocrTextWrapper.height
        textHeight = int(
            self.mDipCalculator.pxToHeightDip(min(parentHeight, height)))
        #        font = QFont(fontName, fontSize)
        newFontSize = fontSize
        if (self.getTextHeightUsingFontMetrics(ocrTextWrapper, fontName,
                                               fontSize) == textHeight):
            newFontSize = fontSize

        elif (self.getTextHeightUsingFontMetrics(ocrTextWrapper, fontName,
                                                 fontSize) < textHeight):
            while (self.getTextHeightUsingFontMetrics(ocrTextWrapper, fontName,
                                                      fontSize) < textHeight):
                fontSize = fontSize + 1
            newFontSize = fontSize

        else:
            while (self.getTextHeightUsingFontMetrics(ocrTextWrapper, fontName,
                                                      fontSize) > textHeight):
                fontSize = fontSize - 1

            newFontSize = fontSize

        return newFontSize

    def getTextHeightUsingFontMetrics(self, ocrTextWrapper, fontName,
                                      fontSize):
        #        class SIZE(ctypes.Structure):
        #            _fields_ = [("cx", ctypes.c_long), ("cy", ctypes.c_long)]
        #        hdc = ctypes.windll.user32.GetDC(0)
        #        hfont = ctypes.windll.gdi32.CreateFontA(-fontSize, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, fontName)
        #        hfont_old = ctypes.windll.gdi32.SelectObject(hdc, hfont)
        #        size = SIZE(0, 0)
        #        ctypes.windll.gdi32.GetTextExtentPoint32A(hdc, text, len(text), ctypes.byref(size))
        #        ctypes.windll.gdi32.SelectObject(hdc, hfont_old)
        #        ctypes.windll.gdi32.DeleteObject(hfont)
        #        return size.cy
        file = "fonts//" + fontName + ".ttf"

        font = ImageFont.truetype(file, fontSize)
        fontSize = font.getsize(ocrTextWrapper.text)
        return fontSize[1]

    def validCharacter(self, word):
        return self.mHandle.IsValidCharacter(word)

        #Don't have this method return TessAPI1.TessBaseAPIIsValidWord(mHandle, word) != 0
#        return True

#TODO
#    def getTextHeightUsingTextLayout(self,ocrTextWrapper, font) :
#        frc = self.mGraphics.getFontRenderContext()
#        loc = Point(0, 0)
#        layout = TextLayout(ocrTextWrapper.text, font, frc)
#        layout.draw(self.mGraphics, float(loc.x, loc.y))
#        bounds = layout.getBounds()
#        height = bounds.getHeight()
#        return height

#    def isValidTextUsingConfidentAndBoundaryCheck(self, ocrTextWrapper) :
#        if (ocrTextWrapper.getConfidence() > Constants.TEXT_CONFIDENT_THRESHOLD + Constants.TEXT_CONFIDENT_THRESHOLD_SECONDARY_RANGE) :
#            return True
#
#        elif (ocrTextWrapper.getConfidence() <= Constants.TEXT_CONFIDENT_THRESHOLD) :
#            return False
#
#        return self.isValidTextUsingBoundaryCheck(ocrTextWrapper)
#
#

    def getTextDimensions(self, text, fontName, fontSize):
        file = "fonts//" + fontName + ".ttf"
        try:
            font = ImageFont.truetype(file, fontSize)
            fontSize = font.getsize(text)
            return fontSize
        except OSError:
            print(file)

    def isValidTextUsingBoundaryCheck(self, ocrTextWrapper):
        # confident between TextProcessor.TEXT_CONFIDENT_THRESHOLD and
        # TextProcessor.TEXT_CONFIDENT_THRESHOLD +
        # TextProcessor.TEXT_CONFIDENT_THRESHOLD_SECONDARY_RANGE
        if (TextUtils.isEmpty(ocrTextWrapper.text)):
            # We cannot calculate width of empty text
            return True
#        return True

#        frc = mGraphics.getFontRenderContext()
#        font = QFont(ocrTextWrapper.fontName,ocrTextWrapper.fontSize)
#        loc = Point(0, 0)
#        layout = TextLayout(ocrTextWrapper.text,font, frc)
#        layout.draw(mGraphics,  loc.getX(), loc.getY())
#        bound = layout.getBounds()
        width, height = self.getTextDimensions(ocrTextWrapper.text,
                                               ocrTextWrapper.fontName,
                                               ocrTextWrapper.fontSize)

        fontRatio = float(height / width)
        boundRatio = float(ocrTextWrapper.height / ocrTextWrapper.width)
        fontArea = self.mDipCalculator.dipToHeightPx(
            height) * self.mDipCalculator.dipToWidthPx(width)
        boundArea = float(ocrTextWrapper.width * ocrTextWrapper.height)
        #
        # the different between dimensions of the text should be smaller than
        # 10% of the max dimension.
        # System.out.prln(" Ratio: " + fontRatio + ", " + boundRatio + ", "
        # + Math.abs(boundRatio - fontRatio)
        # / Math.max(boundRatio, fontRatio) + "," + fontArea + ", "
        # + boundArea + ", " + Math.min(fontArea, boundArea)
        # / Math.max(fontArea, boundArea))

        # It the bound is square, it less likely that this text is correct
        # TODO: This rule may not need it
        #        if (float(min(ocrTextWrapper.getWidth(),ocrTextWrapper.getHeight()) / max( ocrTextWrapper.getWidth(),
        #						ocrTextWrapper.getHeight())) > 0.95) :
        #			# if drawing text cannot create square, sorry -> invalid
        #            if (float(min(width, height) / max(width, height)) <= 0.95 and not validWord(ocrTextWrapper.text)) :
        #                return False
        #
        #
        #

        #        print(self.mDipCalculator.dipToWidthPx(width), self.mDipCalculator.dipToHeightPx(height))
        #        print( ocrTextWrapper.width, ocrTextWrapper.height)
        dimension = abs(boundRatio - fontRatio) / max(boundRatio, fontRatio)
        #        print(dimension)

        dimensionCheck = abs(boundRatio - fontRatio) / max(
            boundRatio, fontRatio
        ) <= Constants.TEXT_CONFIDENT_ACCEPTANCE_DIMENSION_RATIO_DIFFERENCE_THRESHOLD

        areaCheckVal = min(fontArea, boundArea) / max(fontArea, boundArea)
        #        print(areaCheckVal)
        #        print(ocrTextWrapper.text)
        areaCheck = min(fontArea, boundArea) / max(
            fontArea,
            boundArea) >= Constants.TEXT_AREA_ACCEPTANCE_DIFFERENCE_THRESHOLD

        return dimensionCheck and areaCheck

    def destroy(self):
        self.mHandle.End
Пример #23
0
class FeatureCollector(object):
    def __init__(self, tree, imgfile):
        self.tree = tree
        self.collect_texts()
        self.imgfile = imgfile
        self.tesapi = PyTessBaseAPI(lang='eng')
        self.set_tes_image()

    def set_tes_image(self):
        #imgpil = Image.fromarray(numpy.uint8(imgdata * 255))
        imgpil = Image.open(self.imgfile)
        (self.imgwidth, self.imgheight) = (imgpil.width, imgpil.height)
        imgpil = imgpil.convert("RGB").resize(
            (imgpil.width * OCR_RATIO, imgpil.height * OCR_RATIO))
        self.tesapi.SetImage(imgpil)

    def add_ctx_attr(self, ctx, data, attr_re, word_limit=1000):
        if ctx not in self.point:
            self.point[ctx] = ''
        words = attr_re.findall(data)
        if word_limit and len(words) > word_limit:
            return
        for word in words:
            if self.point[ctx]:
                self.point[ctx] += ' '
            self.point[ctx] += '%s' % word.lower()

    def add_ctx(self, ctx, node, attrs, attr_re=anything_re, word_limit=None):
        for attr in attrs:
            self.add_ctx_attr(ctx, node[attr], attr_re, word_limit)

    def collect_texts(self):
        for nodeid in self.tree:
            node = self.tree[nodeid]
            if 'fulltext' not in node:
                self.collect_text(node)

    def collect_text(self, node):
        """ Collect text from node and all its children """
        if 'fulltext' in node:
            return node['fulltext']

        cls = node['class']
        if cls == 'View':
            text = node['desc']
        else:
            text = node['text']

        for child in node['children']:
            text = text.strip() + ' ' + self.collect_text(self.tree[child])

        node['fulltext'] = text
        return text

    def prepare_neighbour(self):
        node = self.tree[self.nodeid]
        self.point['neighbour_ctx'] = ''
        self.point['adj_ctx'] = ''
        neighbour_count = 0
        for other in self.tree:
            if self.tree[other]['parent'] == node[
                    'parent'] and other != self.nodeid:
                self.add_ctx_attr('neighbour_ctx',
                                  self.collect_text(self.tree[other]), text_re)
                self.add_ctx('neighbour_ctx', self.tree[other], ['id'], id_re)
                if self.tree[other]['class'] == node['class']:
                    neighbour_count += 1

            # left sibling
            if (self.tree[other]['parent'] == node['parent']
                    and other != self.nodeid
                    and self.tree[other]['childid'] < node['childid']
                    and self.tree[other]['childid'] > node['childid'] - 2):
                self.add_ctx_attr('adj_ctx',
                                  self.collect_text(self.tree[other]), text_re)
                self.add_ctx('adj_ctx', self.tree[other], ['id'], id_re)

        self.point['neighbour_count'] = neighbour_count

    def ctx_append(self, ctx, kind, clz, detail):
        ret = ctx
        ret += ' ' + kind + clz
        regex = word_re
        for part in regex.findall(detail):
            ret += ' ' + kind + part
        return ret

    def collect_subtree_info(self, node, root):
        if ignore_node(node):
            return {'ctx': '', 'text': '', 'count': 0}
        ctx = ''
        count = 1

        clz = node['class']

        if clz == 'View':
            text = node['desc'][:30]
        else:
            text = node['text'][:30]
        desc = node['desc'][:30]

        ctx += clz + ' '
        ctx += node['id'] + ' '
        ctx += text + ' '
        ctx += desc + ' '
        ctx += gen_ngram(text) + ' '
        ctx += gen_ngram(desc) + ' '

        if root is not None:
            if node['width'] > 0.6 * config.width:
                ctx = self.ctx_append(ctx, "WIDE", clz, node['id'])

            if node['height'] > 0.6 * config.height:
                ctx = self.ctx_append(ctx, "TALL", clz, node['id'])

            if node['y'] + node['height'] < root['y'] + 0.3 * root['height']:
                ctx = self.ctx_append(ctx, "TOP", clz, node['id'])

            if node['x'] + node['width'] < root['x'] + 0.3 * root['width']:
                ctx = self.ctx_append(ctx, "LEFT", clz, node['id'])

            if node['y'] > root['y'] + 0.7 * root['height']:
                ctx = self.ctx_append(ctx, "BOTTOM", clz, node['id'])

            if node['x'] > root['x'] + 0.7 * root['width']:
                ctx = self.ctx_append(ctx, "RIGHT", clz, node['id'])

        for child in node['children']:
            child_info = self.collect_subtree_info(
                self.tree[child], root if root is not None else node)
            ctx = ctx.strip() + ' ' + child_info['ctx']
            count += child_info['count']
            text = text.strip() + ' ' + child_info['text']

        return {'ctx': ctx, 'text': text, 'count': count}

    def prepare_children(self):
        node = self.tree[self.nodeid]
        #self.add_ctx_attr('node_subtree_text', self.collect_text(node), text_re, 10)
        subtree_info = self.collect_subtree_info(node, None)
        self.point['subtree'] = subtree_info['ctx']
        self.point['node_subtree_text'] = subtree_info['text']
        self.point['node_childs'] = subtree_info['count']

    def prepare_ancestor(self):
        node = self.tree[self.nodeid]
        parent = node['parent']
        self.point['parent_ctx'] = ''
        parent_click = parent_scroll = parent_manychild = False
        parent_depth = 0
        while parent != 0 and parent != -1 and parent_depth < PARENT_DEPTH_LIMIT:
            self.add_ctx('parent_ctx', self.tree[parent], ['class', 'id'],
                         id_re)
            parent_click |= self.tree[parent]['click']
            parent_scroll |= self.tree[parent]['scroll']
            parent_manychild |= len(self.tree[parent]['children']) > 1
            parent = self.tree[parent]['parent']
            parent_depth += 1

        self.point['parent_prop'] = [
            parent_click, parent_scroll, parent_manychild
        ]

    def prepare_self(self):
        node = self.tree[self.nodeid]
        # AUX info
        self.point['id'] = self.nodeid
        self.point['str'] = util.describe_node(node, None)

        self.add_ctx('node_text', node, ['text'], text_re, 10)
        self.add_ctx('node_ctx', node, ['desc'], text_re)
        self.add_ctx('node_ctx', node, ['id'], id_re)
        self.add_ctx('node_class', node, ['class'], id_re)
        if 'Recycler' in node['class'] or 'ListView' in node['class']:
            self.point['node_class'] += " ListContainer"
        self.point['node_x'] = node['x']
        self.point['node_y'] = node['y']
        self.point['node_w'] = node['width']
        self.point['node_h'] = node['height']

    def prepare_point(self, nodeid, app, scr, caseid, imgdata, treeinfo, path):
        """Convert a node in the tree into a data point for ML"""
        self.nodeid = nodeid
        self.point = {}

        # AUX info
        self.point['app'] = app
        self.point['scr'] = scr
        self.point['case'] = caseid

        self.prepare_self()
        self.prepare_neighbour()
        self.prepare_ancestor()
        self.prepare_global(path, treeinfo)
        self.prepare_img(imgdata)
        self.prepare_ocr()
        self.prepare_children()

        return self.point

    def prepare_global(self, path, treeinfo):
        node = self.tree[self.nodeid]
        has_dupid = False
        is_itemlike = False
        is_listlike = False
        for _id in node['raw']:
            if _id in treeinfo['dupid']:
                has_dupid = True
                break
        for _id in node['raw']:
            if _id in treeinfo['itemlike']:
                is_itemlike = True
                break
        for _id in node['raw']:
            if _id in treeinfo['listlike']:
                is_listlike = True
                break
        self.point['node_prop'] = [
            node['click'], node['scroll'],
            len(node['children']) > 1, has_dupid, is_itemlike, is_listlike
        ]

        self.point['path'] = path

    def prepare_img(self, imgdata):
        node = self.tree[self.nodeid]
        # your widget should be inside the screenshot
        # not always!
        self.min_x = max(node['x'], 0)
        self.min_y = max(node['y'], 0)
        self.max_x = min(node['x'] + node['width'], self.imgwidth)
        self.max_y = min(node['y'] + node['height'], self.imgheight)
        self.empty = self.max_x <= self.min_x or self.max_y <= self.min_y
        self.point['empty'] = self.empty

    def prepare_ocr(self):
        node = self.tree[self.nodeid]
        if config.region_use_ocr and not self.empty:
            if 'ocr' in node:
                ocr_text = node['ocr']
            else:
                self.tesapi.SetRectangle(self.min_x * OCR_RATIO,
                                         self.min_y * OCR_RATIO,
                                         (self.max_x - self.min_x) * OCR_RATIO,
                                         (self.max_y - self.min_y) * OCR_RATIO)
                try:
                    ocr_text = self.tesapi.GetUTF8Text()
                except:
                    logger.warning("tessearact fail to recognize")
                    ocr_text = ''
                ocr_text = ocr_text.strip().replace('\n', ' ')
        else:
            ocr_text = 'dummy'
        self.point['node_ocr'] = ocr_text
        #if point['node_text'].strip() == '':
        #    point['node_text'] = ocr_text
        logger.debug("%s VS %s" % (ocr_text, node['text']))

        (missing, found, other) = (node['ocr_missing'], node['ocr_found'],
                                   node['ocr_other'])
        self.point['ocr_missing'] = missing
        self.point['ocr_found'] = found
        self.point['ocr_other'] = other
        self.point['ocr_ratio'] = (1.0 * missing / (missing + other)
                                   if missing + other > 0 else 0.0)
        self.point['ocr_visible'] = node['visible']
Пример #24
0
def run_ocr_in_chart(chart, pad=0, psm=PSM.SINGLE_LINE):
    """
    Run OCR for all the boxes.
    :param img:
    :param boxes:
    :param pad: padding before applying ocr
    :param psm: PSM.SINGLE_WORD or PSM.SINGLE_LINE
    :return:
    """
    img = chart.image

    # add a padding to the initial figure
    fpad = 1
    img = cv2.copyMakeBorder(img.copy(), fpad, fpad, fpad, fpad, cv2.BORDER_CONSTANT, value=(255, 255, 255))
    fh, fw, _ = img.shape

    api = PyTessBaseAPI(psm=psm, lang='eng')
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(4, 4))

    for tbox in chart.texts:
        # adding a pad to original image. Some case in quartz corpus, the text touch the border.
        x, y, w, h = ru.wrap_rect(u.ttoi(tbox.rect), (fh, fw), padx=pad, pady=pad)
        x, y = x + fpad, y + fpad

        if w * h == 0:
            tbox.text = ''
            continue

        # crop region of interest
        roi = img[y:y + h, x:x + w]
        #  to gray scale
        roi_gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
        #
        roi_gray = cv2.resize(roi_gray, None, fx=3, fy=3, interpolation=cv2.INTER_CUBIC)
        # binarization
        _, roi_bw = cv2.threshold(roi_gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
        # removing noise from borders
        roi_bw = 255 - clear_border(255-roi_bw)

        # roi_gray = cv2.copyMakeBorder(roi_gray, 5, 5, 5, 5, cv2.BORDER_CONSTANT, value=255)

        # when testing boxes from csv files
        if tbox.num_comp == 0:
            # Apply Contrast Limited Adaptive Histogram Equalization
            roi_gray2 = clahe.apply(roi_gray)
            _, roi_bw2 = cv2.threshold(roi_gray2, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
            _, num_comp = morphology.label(roi_bw2, return_num=True, background=255)
            tbox.regions.extend(range(num_comp))

        pil_img = smp.toimage(roi_bw)
        if SHOW:
            pil_img.show()
        max_conf = -np.inf
        min_dist = np.inf
        correct_text = ''
        correct_angle = 0
        u.log('---------------')
        for angle in [0, -90, 90]:
            rot_img = pil_img.rotate(angle, expand=1)

            api.SetImage(rot_img)
            conf = api.MeanTextConf()
            text = api.GetUTF8Text().strip()
            dist = abs(len(text.replace(' ', '')) - tbox.num_comp)

            u.log('text: %s  conf: %f  dist: %d' % (text, conf, dist))
            if conf > max_conf and dist <= min_dist:
                max_conf = conf
                correct_text = text
                correct_angle = angle
                min_dist = dist

        tbox.text = post_process_text(lossy_unicode_to_ascii(correct_text))
        tbox.text_conf = max_conf
        tbox.text_dist = min_dist
        tbox.text_angle = correct_angle

        u.log('num comp %d' % tbox.num_comp)
        u.log(u'** text: {} conf: {} angle: {}'.format(correct_text, max_conf, correct_angle))

    api.End()
Пример #25
0
class TT2Predictor:
    """holds the several trainer predictor instances and common operations """
    def __init__(self, **kwargs):
        self.trainers_predictors_list = []
        self.text_predictors_list = [
            ("previous_level", (1212, 231, 1230, 280), "0123456789", "8"),
            ("main_level", (1203, 323, 1223, 399), "0123456789", "8"),
            ("next_level", (1212, 445, 1230, 493), "0123456789", "8"),
            ("sub_level", (1177, 625, 1203, 692), "0123456789/", "8"),
            ("gold", (1091, 283, 1126, 471),
             "0123456789.abcdefghijklmnopqrstuvwxyz", "7"),
            ("current_dps_down_no_tab", (389, 562, 423, 709),
             "0123456789.abcdefghijklmnopqrstuvwxyz", "8"),
            ("last_hero", (124, 109, 148, 430),
             "0123456789.ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
             "7")
        ]
        self.api = PyTessBaseAPI()
        self.api.Init()
        print(tesserocr.tesseract_version())
        print(tesserocr.get_languages())
        self.global_image = None
        self.status = CurrentStatus()

        boss_trainer = TrainerPredictor(
            "boss_active_predictor",
            ["boss_active", "boss_inactive", "no_boss"],
            (1224, 555, 1248, 648), 12, 46, 255.0, [200, 30])
        egg_trainer = TrainerPredictor("egg_active_predictor",
                                       ["egg_active", "egg_inactive"],
                                       (741, 31, 761, 64), 10, 16, 255.0,
                                       [200, 30])
        gold_pet_trainer = TrainerPredictor(
            "gold_pet_predictor",
            ["goldpet", "nopet", "normalpet", "partial pet"],
            (624, 364, 734, 474), 40, 40, 255.0, [200, 30])
        tab_predictor = TrainerPredictor("tab_predictor", [
            "skills_tab", "heroes_tab", "equipment_tab", "pet_tab",
            "relic_tab", "shop_tab", "no_tab"
        ], (51, 1, 59, 717), 2, 179, 255.0, [200, 30])
        self.trainers_predictors_list.append(boss_trainer)
        self.trainers_predictors_list.append(egg_trainer)
        self.trainers_predictors_list.append(gold_pet_trainer)
        self.trainers_predictors_list.append(tab_predictor)
        for trainer in self.trainers_predictors_list:
            pass
            #trainer.crop_images()
            #trainer.process_images()
            #trainer.read_and_pickle()
            #trainer.train_graph()
        saved_classes_file = glo.DATA_FOLDER + "/dataforclassifier/TrainerPredictor_list.pickle"
        save_pickle(saved_classes_file, self.trainers_predictors_list)

    def parse_raw_image(self):
        with open(glo.RAW_FULL_FILE, 'rb') as f:
            image = Image.frombytes('RGBA', (1280, 720), f.read())
        for class_predictor in self.trainers_predictors_list:
            class_predictor.predict_crop(image)
        self.global_image = image
        image.save(glo.UNCLASSIFIED_GLOBAL_CAPTURES_FOLDER + "/fullcapture" +
                   time.strftime("%Y%m%d-%H%M%S-%f") +
                   ".png")  # save original capture copy

    def parse_image_text(self, predict_map):
        return_dict = {}
        for text_predictor in self.text_predictors_list:
            if text_predictor[0] in predict_map:
                img = self.global_image.crop(text_predictor[1])

                img = img.convert('L')
                img = img.rotate(90, expand=True)
                self.api.SetImage(img)
                self.api.SetVariable("tessedit_char_whitelist",
                                     text_predictor[2])
                self.api.SetVariable("tessedit_pageseg_mode",
                                     text_predictor[3])
                self.api.SetVariable("language_model_penalty_non_dict_word",
                                     "0")
                self.api.SetVariable("doc_dict_enable", "0")
                text_capture = self.api.GetUTF8Text().encode('utf-8').strip()
                return_dict[text_predictor[0]] = text_capture
                print("raw text capture ", text_predictor[0], ":",
                      text_capture)
                self.api.Clear()
        return return_dict

    def predict_parsed_all(self):
        pred_dict = {}
        for class_predictor in self.trainers_predictors_list:
            pred_dict[class_predictor.name] = class_predictor.predict_parsed()
        return pred_dict

    def predict_parsed(self, predict_map, predict_map_text, **kwargs):
        pred_dict = {"transition_level": False}

        # check if image is level trasitioning. trivial prediction.
        if hasattr(kwargs, "empty_image") and kwargs["empty_image"] is False:
            pass
        else:
            img = self.global_image.crop((0, 0, 100, 100))  # black corner
            extrema = img.convert("L").getextrema()
            if extrema[0] == extrema[1]:  # only one color
                print("warning level transitioning")
                pred_dict["transition_level"] = True
            else:
                pass

        for class_predictor in self.trainers_predictors_list:
            if class_predictor.name in predict_map:
                pred_dict[
                    class_predictor.name] = class_predictor.predict_parsed()
        pred_dict_text = self.parse_image_text(predict_map_text)
        pred_dict.update(pred_dict_text)
        self.status.update_status(pred_dict, self.trainers_predictors_list)
        return pred_dict

    def predict(self):
        self.parse_raw_image()
        return self.predict_parsed_all()

    def check_predict(self, pred_dict, predictor, classification):
        for class_predictor in self.trainers_predictors_list:
            if class_predictor.name == predictor:
                return int(
                    pred_dict[predictor]
                ) == class_predictor.pred_classes.index(classification)
Пример #26
0
    class __impl:
        def __init__(self, vs, imgOutput):
            self.__vs = vs
            self.__imgOutput = imgOutput
            self.image = None
            self.logger = Logger()
            self.state = State()
            self.tesseract = PyTessBaseAPI(psm=PSM.SINGLE_CHAR,
                                           oem=OEM.LSTM_ONLY,
                                           lang="digits")
            self.filter = Filter()

            self.signalThresholdY = 160
            self.LAPPatternSesibility = 5

            self.recordStamp = time.strftime(self.logger.timeFormat)
            self.recordNum = 0
            self.recordFolder = None
            self.cntNum = 0

            if (self.state.RecordImage):
                root = 'record'
                if not os.path.isdir(root):
                    os.mkdir(root)
                self.recordFolder = os.path.join(root, self.recordStamp)
                if not os.path.isdir(self.recordFolder):
                    os.mkdir(self.recordFolder)

        def showImg(self, window, image):
            if self.__imgOutput:
                cv2.imshow(window, image)

        def warmup(self):
            time.sleep(2.0)
            self.tesserOCR(np.zeros((1, 1, 3), np.uint8))

        def tesserOCR(self, image):
            self.tesseract.SetImage(Image.fromarray(image))
            return self.tesseract.GetUTF8Text(
            ), self.tesseract.AllWordConfidences()

        def dominantColor(self, img, clusters=2):
            data = np.reshape(img, (-1, 3))
            data = np.float32(data)

            criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10,
                        1.0)
            flags = cv2.KMEANS_RANDOM_CENTERS
            _, _, centers = cv2.kmeans(data, 1, None, criteria, 10, flags)
            return centers[0].astype(np.int32)

        def analyzeRect(self, image, warped, box, x, y):
            # find amount of color blue in warped area, assuming over X% is the lap signal
            if (self.getAmountOfColor(warped, Colors.lower_blue_color,
                                      Colors.upper_blue_color, True) > 0.1):
                self.logger.info("Rundensignal")
                self.state.setCurrentSignal(Signal.LAP)
                return "Rundensignal"

        def analyzeSquare(self, image, warped, box, x, y):

            #dominantColor, percent, _ = self.dominantColor(warped, 3)
            # dominantColor = self.dominantColor(
            #    cv2.cvtColor(warped, cv2.COLOR_BGR2HSV), 3)
            """  color = 'k'
             # find amount of color black in warped area, assuming over X% is a numeric signal
             if ((dominantColor <= 70).all()):
                 color = 'Black'

             elif ((dominantColor >= 180).all()):
                 color = 'White'

             if (color): """
            resizedWarp = cv2.resize(warped,
                                     None,
                                     fx=2.0,
                                     fy=2.0,
                                     interpolation=cv2.INTER_CUBIC)

            # gray
            optimized = cv2.cvtColor(resizedWarp, cv2.COLOR_BGR2GRAY)

            # blur
            optimized = cv2.GaussianBlur(optimized, (5, 5), 0)

            # binary image
            optimized = cv2.threshold(optimized, 0, 255,
                                      cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]

            # binary inversion if dominant color is black
            """ if (color == 'Black'):
                optimized = cv2.bitwise_not(optimized) """

            # now check the frame (1px) of the image.. there shouldn't be any noise since its a clean signal background
            h, w = optimized.shape[0:2]
            clean = optimized[0, 0]
            for iFrame in range(0, 2):
                for iHeight in range(h):
                    if not (optimized[iHeight, iFrame] == clean) or not (
                            optimized[iHeight, w - 1 - iFrame] == clean):
                        return False
                for iWidth in range(w):
                    # or not(optimized[h - iFrame, iWidth])
                    if not (optimized[iFrame, iWidth] == clean):
                        return False

            # cv2.imwrite("records/opt/" + str(self.cntNum) + ".jpg", optimized)

            output, confidence = self.tesserOCR(optimized)

            # if the resulting text is below X% confidence threshold, we skip it
            if not output or confidence[0] < 95:
                return False

            # clean up output from tesseract
            output = output.replace('\n', '')
            output = output.replace(' ', '')

            if output.isdigit() and 0 < int(output) < 10:
                """ self.showImg("opt " + str(self.cntNum),
                                np.hstack((resizedWarp, cv2.cvtColor(optimized, cv2.COLOR_GRAY2BGR)))) """
                if y <= self.signalThresholdY:
                    self.logger.info('Stop Signal OCR: ' + output + ' X: ' +
                                     str(x) + ' Y: ' + str(y) +
                                     ' Confidence: ' + str(confidence[0]) +
                                     '%')  # + ' DC: ' + str(dominantColor))
                    self.state.setStopSignal(int(output))
                    return 'S: ' + output
                elif self.state.StopSignalNum != 0:
                    self.logger.info('Info Signal OCR: ' + output + ' X: ' +
                                     str(x) + ' Y: ' + str(y) +
                                     ' Confidence: ' + str(confidence[0]) +
                                     '%')  # + ' DC: ' + str(dominantColor))
                    self.state.setCurrentSignal(Signal.UPPER, int(output))
                    return 'I: ' + output

        def getAmountOfColor(self,
                             img,
                             lowerColor,
                             upperColor,
                             convert2hsv=True):
            if (convert2hsv):
                img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)

            # create mask from color range
            maskColor = cv2.inRange(img, lowerColor, upperColor)
            # get ratio of active pixels
            ratio_color = cv2.countNonZero(maskColor) / (img.size)
            return ratio_color

        # color picker for manual debugging
        def pick_color(self, event, x, y, flags, param):
            if event == cv2.EVENT_LBUTTONDOWN:
                pixel = self.image[y, x]
                color = np.array([pixel[0], pixel[1], pixel[2]])
                self.logger.info(pixel)

        # capture frames from the camera
        def capture(self, savedImg=None):
            if (savedImg is not None):
                image = savedImg
            else:
                image = self.__vs.read()
                if (self.state.InvertCamera):
                    image = imutils.rotate(image, angle=180)

            self.image = image

            if (self.state.RecordImage):
                self.recordNum += 1
                cv2.imwrite(
                    os.path.join(self.recordFolder,
                                 str(self.recordNum) + ".jpg"), image)
                return

            if (self.state.Approaching == Signal.UPPER
                    or self.state.Approaching == Signal.LOWER):
                self.findNumberSignal(image)
            elif (self.state.Approaching == Signal.LAP):
                self.findLapSignal(image)

        def findLapSignal(self, image):
            contourImage = image.copy()

            blur = cv2.GaussianBlur(image, (3, 3), 0)
            hsv = cv2.cvtColor(blur, cv2.COLOR_BGR2HSV)
            self.image = hsv
            mask = cv2.inRange(hsv, Colors.lower_blue_color,
                               Colors.upper_blue_color)

            cnts = imutils.grab_contours(
                cv2.findContours(mask.copy(), cv2.RETR_LIST,
                                 cv2.CHAIN_APPROX_SIMPLE))

            if len(cnts) > 0:

                # transform all contours to rects
                rects = [cv2.boundingRect(cnt) for cnt in cnts]

                # now iterate all of the rects, trying to find an approximated sibiling shifted in Y-direction
                for rect in rects:
                    (x, y, w, h) = rect
                    cv2.rectangle(contourImage, (x, y), (x + w, y + h),
                                  (0, 0, 255), 2)

                    # try to match the pattern from a given rect in all rects
                    counterPart = [
                        counterRect for counterRect in rects
                        if (counterRect != rect and x - 5 <= counterRect[0] <=
                            x + 5 and 2 * -(h + 5) <= y - counterRect[1] <= 2 *
                            (h + 5) and w - 5 <= counterRect[2] <= w + 5)
                        and h - 5 <= counterRect[3] <= h + 5
                    ]

                    if (counterPart):
                        (x, y, w, h) = counterPart[0]
                        cv2.rectangle(contourImage, (x, y), (x + w, y + h),
                                      (0, 255, 0), 2)
                        self.logger.info('LAP Signal')
                        self.state.captureLapSignal()
                        break

            self.showImg(
                'contourImage',
                np.hstack(
                    (hsv, contourImage, cv2.cvtColor(mask,
                                                     cv2.COLOR_GRAY2BGR))))
            cv2.setMouseCallback('contourImage', self.pick_color)

        def findNumberSignal(self, image):

            image_height = np.size(image, 0)
            image_width = np.size(image, 1)

            contourImage = image.copy()

            # focus only on the part of the image, where a signal could occur
            # image = image[int(image.shape[0] * 0.2):int(image.shape[0] * 0.8), 0:int(image.shape[1]*0.666)]

            mask = self.filter.autoCanny(image, 2, 3)

            # get a list of contours in the mask, chaining to just endpoints
            cnts = imutils.grab_contours(
                cv2.findContours(mask.copy(), cv2.RETR_LIST,
                                 cv2.CHAIN_APPROX_SIMPLE))

            # only proceed if at least one contour was found
            if len(cnts) > 0:
                # loop contours
                for self.cntNum, cnt in enumerate(cnts):

                    rect = cv2.minAreaRect(cnt)
                    _, _, angle = rect

                    # approximate shape
                    approx = cv2.approxPolyDP(cnt,
                                              0.02 * cv2.arcLength(cnt, True),
                                              True)

                    # the rectangle must not have a too big rotation (+/-10)
                    # and more than 3 connecting points
                    if len(approx) >= 3 and (-90 <= angle <= -80
                                             or angle >= -10):

                        box = cv2.boxPoints(rect)
                        box = np.int0(box)

                        (x, y, w, h) = cv2.boundingRect(approx)

                        # limit viewing range
                        if (y <= image_height * 0.2 or x >= image_width * 0.8):
                            continue

                        if (w <= 5 or h <= 5):
                            continue

                        # we are in approaching mode, thus we only care for the lower signals <= threshold
                        if ((self.state.Approaching == Signal.UPPER
                             and y >= self.signalThresholdY)
                                and not self.state.Standalone):
                            continue
                        elif ((self.state.Approaching == Signal.LOWER
                               and y <= self.signalThresholdY)
                              and not self.state.Standalone):
                            continue

                        sideRatio = w / float(h)

                        absoluteSizeToImageRatio = (
                            100 / (image_width * image_height)) * (w * h)

                        # calculate area of the bounding rectangle
                        rArea = w * float(h)

                        # calculate area of the contour
                        cArea = cv2.contourArea(cnt)
                        if (cArea):
                            rectContAreaRatio = (100 / rArea) * cArea
                        else:
                            continue

                        # cv2.drawContours(contourImage, [box], 0, (255, 0, 0), 1)
                        result = None

                        # is the rectangle sideways, check for lap signal
                        # if (h*2 < w and y <= self.signalThresholdY and rectContAreaRatio >= 80):
                        #result = self.analyzeRect(image, four_point_transform(image, [box][0]), box, x, y)
                        # find all contours looking like a signal with minimum area (1%)
                        if absoluteSizeToImageRatio >= 0.01:
                            # is it approx a square, or standing rect? then check for info or stop signal
                            if 0.2 <= sideRatio <= 1.1:
                                # transform ROI
                                if (sideRatio <= 0.9):
                                    coords, size, angle = rect
                                    size = size[0] + 8, size[1] + 4
                                    coords = coords[0] + 1, coords[1] + 1

                                    rect = coords, size, angle
                                    box = cv2.boxPoints(rect)
                                    box = np.int0(box)
                                """ cv2.drawContours(
                                    contourImage, [box], 0, (0, 255, 0), 1) """

                                warp = four_point_transform(image, [box][0])

                                result = self.analyzeSquare(
                                    image, warp, box, x, y)

                        if (result):
                            if (self.__imgOutput):
                                color = None
                                if (y >= self.signalThresholdY):
                                    color = (0, 0, 255)
                                else:
                                    color = (255, 0, 0)

                                cv2.drawContours(contourImage, [box], 0, color,
                                                 1)
                                cv2.drawContours(contourImage, [cnt], -1,
                                                 color, 2)
                                """ M = cv2.moments(cnt)
                                cX = int(M["m10"] / M["m00"])
                                cY = int(M["m01"] / M["m00"])
                                cv2.putText(contourImage, str(
                                    self.cntNum), (cX - 30, cY - 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) """

                                self.logger.debug(
                                    "[" + str(self.cntNum) + "] SideRatio: " +
                                    str(sideRatio) + " AreaRatio: " +
                                    str(rectContAreaRatio) + " ContArea: " +
                                    str(cArea) + " RectArea: " + str(rArea) +
                                    " AbsSize: " +
                                    str(absoluteSizeToImageRatio) +
                                    " CntPoints: " + str(len(approx)) +
                                    " Size: " + str(w) + "x" + str(h))
            """ if (self.__imgOutput):  # hsv img output
                hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
                cv2.namedWindow('contourImage')
                cv2.setMouseCallback('contourImage', self.pick_color)
                # self.showImg("hsv", hsv) """

            self.showImg(
                "contourImage",
                np.hstack((contourImage, cv2.cvtColor(mask,
                                                      cv2.COLOR_GRAY2BGR))))
Пример #27
0
def tesser_ocr(
    image: np.ndarray,
    expected_type: Optional[Callable[[str], T]] = None,
    whitelist: Optional[str] = None,
    invert: bool = False,
    scale: float = 1,
    blur: Optional[float] = None,
    engine: tesserocr.PyTessBaseAPI = tesseract_only,
    warn_on_fail: bool = False,
) -> Optional[T]:

    with lock:

        if image.shape[0] <= 1 or image.shape[1] <= 1:
            if not expected_type or expected_type is str:
                return ""
            else:
                return None

        if whitelist is None:
            if expected_type is int:
                whitelist = string.digits
            elif expected_type is float:
                whitelist = string.digits + "."
            else:
                whitelist = string.digits + string.ascii_letters + string.punctuation + " "

        # print('>', whitelist)

        engine.SetVariable("tessedit_char_whitelist", whitelist)
        if invert:
            image = 255 - image
        if scale != 1:
            image = cv2.resize(image, (0, 0),
                               fx=scale,
                               fy=scale,
                               interpolation=cv2.INTER_LINEAR)

        if blur:
            image = cv2.GaussianBlur(image, (0, 0), blur)

        # if debug:
        #     cv2.imshow('tesser_ocr', image)
        #     cv2.waitKey(0)

        if len(image.shape) == 2:
            height, width = image.shape
            channels = 1
        else:
            height, width, channels = image.shape
        engine.SetImageBytes(image.tobytes(), width, height, channels,
                             width * channels)
        text: str = engine.GetUTF8Text()
        if " " not in whitelist:
            text = text.replace(" ", "")
        if "\n" not in whitelist:
            text = text.replace("\n", "")

        if not any(c in whitelist for c in string.ascii_lowercase):
            text = text.upper()

        if expected_type:
            try:
                return expected_type(text)
            except Exception as e:
                try:
                    caller = inspect.stack()[1]
                    logger.log(
                        logging.WARNING if warn_on_fail else logging.DEBUG,
                        f"{os.path.basename(caller.filename)}:{caller.lineno} {caller.function} | "
                        f"Got exception interpreting {text!r} as {expected_type.__name__}",
                    )
                except:
                    logger.log(
                        logging.WARNING if warn_on_fail else logging.DEBUG,
                        f"Got exception interpreting {text!r} as {expected_type.__name__}",
                    )
                return None
        else:
            return text
Пример #28
0
from tesserocr import PyTessBaseAPI
from PIL import Image

api = PyTessBaseAPI(lang='script/Thai')
api.SetImage(Image.open("sample/thaiid.jpeg"))
# api.SetImage(Image.open("sample/thaiid.jpeg"))
print(api.GetUTF8Text())

Пример #29
0
class ViewerWindow(Gtk.Window):
    def __init__(self, filenames, kind, show, ml):
        Gtk.Window.__init__(self)
        self.ptx = 0
        self.pty = 0
        self.focus_id = -1
        self.file_idx = 0
        self.kind = kind
        self.show_hidden = show
        self.ml = ml
        self.screen_hint = ''
        self.in_hint_screen = False
        self.colors = {}
        self.memory = {}
        self.elem_models = {}
        self.filenames = filenames
        self.tesapi = PyTessBaseAPI(lang='eng')
        self.tesapi.SetVariable("tessedit_char_whitelist", WHITELIST)
        self.init_ui()
        self.load()

    def init_ui(self):
        self.connect("delete-event", Gtk.main_quit)

        darea = Gtk.DrawingArea()
        darea.connect("draw", self.on_draw)
        darea.connect("motion-notify-event", self.move_over)
        darea.connect("button-release-event", self.click_evt)
        darea.connect("scroll-event", self.scroll_evt)
        darea.connect("key-release-event", self.key_evt)
        darea.set_events(Gdk.EventMask.POINTER_MOTION_MASK
                         | Gdk.EventMask.BUTTON_RELEASE_MASK
                         | Gdk.EventMask.BUTTON_PRESS_MASK
                         | Gdk.EventMask.SCROLL_MASK
                         | Gdk.EventMask.KEY_PRESS_MASK
                         | Gdk.EventMask.KEY_RELEASE_MASK)
        darea.set_can_focus(True)
        self.add(darea)

        self.show_all()

    def load(self, prev=False):
        if self.file_idx == len(self.filenames):
            Gtk.main_quit()
            return
        if prev:
            self.file_idx -= 2
        filename = self.filenames[self.file_idx]
        (self.app, self.scr) = util.get_aux_info(filename)
        if self.app not in self.memory:
            self.memory[self.app] = {}
        self.set_title(filename)
        self.file_idx += 1
        print("Loading %s" % filename)
        self.pngfile = os.path.splitext(filename)[0] + '.png'
        self.descname = os.path.splitext(filename)[0] + '.%s.txt' % self.kind

        starttime = time.time()
        self.tree = analyze.load_tree(filename)
        hidden.find_hidden_ocr(self.tree)
        hidden.mark_children_hidden_ocr(self.tree)
        util.print_tree(self.tree, show_hidden=self.show_hidden)

        if self.ml:
            self.get_ml_rets()
        else:
            self.load_desc()

        endtime = time.time()
        print("Load time: %.3fs" % (endtime - starttime))

        self.focus_id = -1
        self.colors = {}
        self.ptx = self.pty = 0

        self.img = cairo.ImageSurface.create_from_png(self.pngfile)
        print('Image:', self.img.get_width(), self.img.get_height())

        root_item_id = min(self.tree)
        root_node = self.tree[root_item_id]
        print('Root node:', root_node['width'], root_node['height'])
        self.scale = 1.0 * self.img.get_width() / config.width
        #self.scale = analyze.find_closest(self.scale, analyze.SCALE_RATIOS)
        print('Scale:', '%.3f' % self.scale, '->', '%.3f' % self.scale)

        self.resize(self.img.get_width(), self.img.get_height())

        self.mark_depth(self.tree)

        for item_id in self.tree:
            color_r = random.random() / 2
            color_g = random.random() / 2
            color_b = random.random() / 2

            self.colors[item_id] = (color_r, color_g, color_b)

        imgocr = Image.open(self.pngfile)
        self.imgwidth = imgocr.width
        self.imgheight = imgocr.height
        #imgocr2 = imgocr.convert("RGB").resize(
        #    (imgocr.width * OCR_RATIO, imgocr.height * OCR_RATIO))
        self.tesapi.SetImage(imgocr)
        self.tesapi.SetSourceResolution(config.ocr_resolution)

        self.dump_memory()

    def remember(self, node, desc):
        nodeid = node['id']
        if not node['id']:
            return

        if node['id'] in self.memory[self.app]:
            if desc != self.memory[self.app][nodeid]:
                # multiple!
                self.memory[self.app][nodeid] = 'MUL'
        else:
            self.memory[self.app][node['id']] = desc

    def forget(self, node):
        if node['id'] in self.memory[self.app]:
            del self.memory[self.app][node['id']]

    def get_elem_model(self, app):
        elem_clas = elements.getmodel("../model/", "../guis/", app,
                                      "../guis-extra/",
                                      config.extra_element_scrs)
        self.elem_models[app] = elem_clas

    def get_ml_rets(self):
        if self.app not in self.elem_models:
            self.get_elem_model(self.app)

        guess_descs = {}
        guess_items = {}  # type: Dict[str, List[int]]
        guess_score = {}
        elem_clas = self.elem_models[self.app]
        elem_clas.set_imgfile(self.pngfile)
        treeinfo = analyze.collect_treeinfo(self.tree)
        for itemid in self.tree:
            (guess_element,
             score) = elem_clas.classify(self.scr, self.tree, itemid, None,
                                         treeinfo)
            if guess_element != 'NONE':
                if tags.single(guess_element,
                               self.scr) and guess_element in guess_items:
                    old_item = guess_items[guess_element][0]
                    if guess_score[old_item] < score:
                        guess_items[guess_element] = [itemid]
                        guess_score[itemid] = score
                        del guess_descs[old_item]
                        guess_descs[itemid] = guess_element
                else:
                    guess_descs[itemid] = guess_element
                    guess_score[itemid] = score
                    guess_items[guess_element] = (
                        guess_items.get(guess_element, []) + [itemid])
        for nodeid in guess_descs:
            self.tree[nodeid]['label'] = guess_descs[nodeid]

    def load_desc(self):
        if os.path.exists(self.descname):
            with open(self.descname) as inf:
                for line in inf.read().split('\n'):
                    if not line:
                        continue
                    (item_id, desc) = line.split(' ', 1)
                    item_id = int(item_id)
                    found = False
                    for nodeid in self.tree:
                        node = self.tree[nodeid]
                        if item_id in node['raw']:
                            if 'label' in node:
                                node['label'] += ' ' + desc
                            else:
                                node['label'] = desc
                            print(nodeid, '(', item_id, ')', '->', desc)

                            self.remember(node, desc)

                            found = True
                            break
                    if not found:
                        print("WARNING: %s (%s) is missing!" % (item_id, desc))

    def mark_depth(self, tree):
        for item_id in tree:
            node = tree[item_id]
            if 'depth' in node:
                continue
            self.mark_depth_node(tree, item_id, 0)

    def mark_depth_node(self, tree, node_id, depth):
        node = tree[node_id]
        node['depth'] = depth
        node['descs'] = []
        for child in node['children']:
            descs = self.mark_depth_node(tree, child, depth + 1)
            node['descs'] += descs

        return node['descs'] + [node_id]

    def get_node_info(self, node):
        (x, y, width, height, depth) = (node['x'], node['y'], node['width'],
                                        node['height'], node['depth'])
        x *= self.scale
        y *= self.scale
        width *= self.scale
        height *= self.scale

        width = min(width, self.imgwidth)
        height = min(height, self.imgheight)

        if x < 0:
            width += x
            x = 0

        if y < 0:
            height += y
            y = 0

        return (x, y, width, height, depth)

    def find_containing_widget(self, px, py):
        max_depth = 0
        max_id = -1

        for item_id in self.tree:
            node = self.tree[item_id]
            if self.ignore_node(node):
                continue
            if self.inside(node, px, py):
                if node['depth'] > max_depth:
                    max_depth = node['depth']
                    max_id = item_id

        return max_id

    def inside(self, node, px, py):
        (x, y, width, height, depth) = self.get_node_info(node)
        return x <= px and x + width >= px and y <= py and y + height >= py

    def ignore_node(self, node):
        if node['class'].upper() == 'OPTION':
            return True
        if node.get('visible', '') == 'hidden':
            return True
        return False

    def on_draw(self, wid, ctx):
        ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL,
                             cairo.FONT_WEIGHT_BOLD)

        ctx.set_source_surface(self.img, 0, 0)
        ctx.paint()

        ctx.set_font_size(20)
        ctx.set_line_width(5)
        ctx.set_source_rgb(1.0, 0.0, 0.0)

        max_click_id = -1
        max_click_depth = 0

        max_id = self.find_containing_widget(self.ptx, self.pty)

        for item_id in self.tree:
            node = self.tree[item_id]
            depth = node['depth']
            if max_id in node['descs'] and node['click']:
                if depth > max_click_depth:
                    max_click_depth = depth
                    max_click_id = item_id

        for item_id in self.tree:
            node = self.tree[item_id]
            if self.ignore_node(node):
                continue

            if item_id == max_id:
                region_mode = False
            else:
                region_mode = True

            (x, y, width, height, depth) = self.get_node_info(node)

            if not self.inside(node, self.ptx, self.pty):
                continue

            self.show_widget(ctx, item_id, not region_mode, not region_mode)

        if max_click_id != -1 and max_click_id != max_id:
            self.show_widget(ctx, max_click_id, False, True)

        if self.focus_id >= 0:
            self.show_widget(ctx, self.focus_id, True, True, (1, 0, 0))

        for itemid in self.tree:
            node = self.tree[itemid]
            if 'label' in node:
                if itemid == self.focus_id:
                    color = (0, 1, 0)
                else:
                    color = (0, 0, 1)
                self.show_widget(ctx, itemid, True, False, (0, 0, 1))
                self.show_desc(ctx, node, color)

        #s.write_to_png('test.png')
        #os.system("%s %s" % (config.picviewer_path, 'test.png'))
        #report_time(start_time, "displayed")

    def move_sibling(self, to_next):
        leaf_list = []
        any_list = []
        for itemid in self.tree:
            node = self.tree[itemid]
            if not self.inside(node, self.clickx, self.clicky):
                continue

            if len(node['children']) == 0:
                leaf_list.append(itemid)
            any_list.append(itemid)

        for i in range(len(leaf_list)):
            if leaf_list[i] == self.focus_id:
                if to_next:
                    idx = (i + 1) % len(leaf_list)
                else:
                    idx = (i - 1) % len(leaf_list)
                self.focus_id = leaf_list[idx]
                return

        if len(leaf_list) == 0:
            for i in range(len(any_list)):
                if any_list[i] == self.focus_id:
                    if to_next:
                        idx = (i + 1) % len(any_list)
                    else:
                        idx = (i - 1) % len(any_list)
                    self.focus_id = any_list[idx]
                    return
            self.focus_id = any_list[0]
        else:
            self.focus_id = leaf_list[0]

    def show_widget(self, ctx, item_id, fill, show_text, colors=None):
        node = self.tree[item_id]

        (x, y, width, height, depth) = self.get_node_info(node)

        if colors is None:
            color_r = self.colors[item_id][0]
            color_g = self.colors[item_id][1]
            color_b = self.colors[item_id][2]
        else:
            (color_r, color_g, color_b) = colors

        ctx.rectangle(x, y, width, height)
        if fill:
            ctx.set_source_rgba(color_r, color_g, color_b, 0.3)
            ctx.fill()
        else:
            ctx.set_source_rgba(color_r, color_g, color_b, 1)
            ctx.set_line_width(5)
            ctx.stroke()

        if show_text:
            max_char = int(width / ctx.text_extents("a")[2])
            text = str(item_id)
            if node['click']:
                text = 'C' + text
            if node['text']:
                text = text + ':' + node['text'][:(max_char - 5)]
            elif node['id']:
                text += '#' + node['id'][:(max_char - 5)]

            self.show_text(ctx, x + width / 2, y + height / 2, text, color_r,
                           color_g, color_b)

    def show_desc(self, ctx, node, color=(0, 0, 1)):
        desc = node['label']
        (x, y, width, height, depth) = self.get_node_info(node)
        self.show_text(ctx, x + width / 2, y + height / 2, desc, color[0],
                       color[1], color[2])

    def show_text(self, ctx, x, y, text, color_r, color_g, color_b):
        x_bearing, y_bearing, text_width, text_height = ctx.text_extents(
            text)[:4]

        ctx.move_to(x - text_width / 2, y + text_height / 2)
        ctx.set_source_rgba(1, 1, 1, 1)
        ctx.set_line_width(5)
        ctx.text_path(text)
        ctx.stroke()

        ctx.move_to(x - text_width / 2, y + text_height / 2)
        ctx.set_source_rgba(color_r, color_g, color_b, 1)
        ctx.text_path(text)
        ctx.fill()

    def move_over(self, widget, evt):
        self.ptx = evt.x
        self.pty = evt.y
        self.queue_draw()

    def click_evt(self, widget, evt):
        if self.in_hint_screen:
            self.process_screen_hint_click(evt)
            return

        if evt.button == 3:
            self.focus_id = -1
        else:
            self.clickx = evt.x
            self.clicky = evt.y
            self.focus_id = self.find_containing_widget(evt.x, evt.y)

        self.queue_draw()

    def scroll_evt(self, widget, evt):
        if self.focus_id == -1:
            return

        scroll_up = evt.direction == Gdk.ScrollDirection.UP
        if scroll_up:
            self.focus_id = self.find_parent_widget(self.focus_id)
        else:
            self.focus_id = self.find_child_widget(self.focus_id)

        self.queue_draw()

    def find_parent_widget(self, wid):
        for itemid in self.tree:
            node = self.tree[itemid]
            if self.ignore_node(node):
                continue
            if wid in node['children']:
                return itemid
        return wid

    def find_child_widget(self, wid):
        for itemid in self.tree[wid]['children']:
            node = self.tree[itemid]
            if self.ignore_node(node):
                continue
            if self.inside(node, self.clickx, self.clicky):
                return itemid
        return wid

    def mark_direct(self):
        enter = self.get_text('Please enter id_label', 'format: <id> <label>')
        if enter is None:
            return
        if ' ' in enter:
            nodeid, label = enter.split(' ')
        else:
            nodeid = enter
            label = ''
        nodeid = int(nodeid)
        if nodeid not in self.tree:
            print('missing node', nodeid)
            return
        node = self.tree[nodeid]

        self.mark_node(node, label)

    def mark_focused(self):
        if self.focus_id < 0:
            return
        node = self.tree[self.focus_id]
        label = self.get_text(
            'Please enter label', 'label for %s: %s (%s) #%s' %
            (self.focus_id, node['text'], node['desc'], node['id']))
        if label is None:
            return

        if self.ml:
            if label == '':
                if 'label' not in self.tree[self.focus_id]:
                    return

                self.generate_negative_hint(self.tree[self.focus_id]['label'])
                del self.tree[self.focus_id]['label']
            else:
                self.generate_hint_for_widget(self.focus_id, label)
                self.add_label(node, label)
        else:
            self.mark_node(node, label)

    def generate_hint_for_widget(self, nodeid, label):
        return self.generate_hint(label,
                                  locator.get_locator(self.tree, nodeid))

    def generate_negative_hint(self, label):
        return self.generate_hint(label, 'notexist')

    def generate_hint(self, label, hint):
        print("@%s.%s %s" % (self.scr, label, hint))

    def mark_node(self, node, label):
        if label == '':
            if 'label' in node:
                del node['label']
                self.forget(node)
        else:
            self.add_label(node, label)
            self.remember(node, label)

        self.save_labels()

    def ocr_text(self):
        node = self.tree[self.focus_id]
        (x, y, width, height, _) = self.get_node_info(node)
        print(x, y, width, height)
        x = max(x - 1, 0)
        y = max(y - 1, 0)
        width = min(width + 2, self.imgwidth)
        height = min(height + 2, self.imgheight)
        #self.tesapi.SetRectangle(x * OCR_RATIO, y * OCR_RATIO,
        #                         width * OCR_RATIO, height * OCR_RATIO)
        self.tesapi.SetRectangle(x, y, width, height)
        print("OCR ret:", self.tesapi.GetUTF8Text())

        x = min(x + width * 0.05, self.imgwidth)
        y = min(y + height * 0.05, self.imgheight)
        width *= 0.9
        height *= 0.9
        self.tesapi.SetRectangle(x, y, width, height)
        print("OCR ret:", self.tesapi.GetUTF8Text())

    def save_region(self):
        if self.focus_id == -1:
            return
        node = self.tree[self.focus_id]
        (x, y, width, height, _) = self.get_node_info(node)
        x = max(x - 1, 0)
        y = max(y - 1, 0)
        width = min(width + 2, self.imgwidth)
        height = min(height + 2, self.imgheight)

        regimg = cairo.ImageSurface(cairo.FORMAT_RGB24, int(width),
                                    int(height))
        ctx = cairo.Context(regimg)
        ctx.set_source_surface(self.img, -x, -y)
        ctx.paint()

        regimg.write_to_png("/tmp/region.png")

    def dump_memory(self):
        for _id in self.memory[self.app]:
            print('MEM %s -> %s' % (_id, self.memory[self.app][_id]))

    def add_label(self, node, desc):
        print('%s -> %s' % (util.describe_node(node, short=True), desc))
        node['label'] = desc

    def auto_label(self):
        for nodeid in self.tree:
            node = self.tree[nodeid]
            if 'label' not in node and node['id'] in self.memory[self.app]:
                if self.memory[self.app][node['id']] != 'MUL':
                    self.add_label(node, self.memory[self.app][node['id']])
                else:
                    print('skip MUL id: %s' % node['id'])
        self.save_labels()

    def remove_all(self):
        for nodeid in self.tree:
            node = self.tree[nodeid]
            if 'label' in node:
                del node['label']

    def process_screen_hint_click(self, evt):
        click_id = self.find_containing_widget(evt.x, evt.y)
        if click_id == -1:
            print('Invalid widget selected')
            return

        hint = locator.get_locator(self.tree, click_id)
        if hint is None:
            print('Cannot generate hint for this widget')
            return

        hint = str(hint)
        if evt.button == 3:
            # negate
            hint = 'not ' + hint

        print('Widget hint: "%s"' % hint)
        self.add_screen_hint(hint)

    def add_screen_hint(self, hint):
        if self.screen_hint == '':
            self.screen_hint = hint
        else:
            self.screen_hint += ' && ' + hint

    def hint_screen(self):
        if not self.in_hint_screen:
            label = self.get_text('Please enter screen name',
                                  'screen name like "signin"')
            if label is None:
                return
            self.screen_hint_label = label

            self.in_hint_screen = True
            self.screen_hint = ''
        else:
            self.in_hint_screen = False
            print("%%%s %s" % (self.screen_hint_label, self.screen_hint))

    def key_evt(self, widget, evt):
        if evt.keyval == Gdk.KEY_space:
            self.mark_focused()
        elif evt.keyval == Gdk.KEY_Tab:
            self.load()
        elif evt.keyval == Gdk.KEY_Left:
            self.move_sibling(to_next=True)
        elif evt.keyval == Gdk.KEY_Right:
            self.move_sibling(to_next=False)
        elif evt.keyval == Gdk.KEY_v:
            self.ocr_text()
        elif evt.keyval == Gdk.KEY_a:
            self.auto_label()
        elif evt.keyval == Gdk.KEY_p:
            self.load(prev=True)
        elif evt.keyval == Gdk.KEY_l:
            self.mark_direct()
        elif evt.keyval == Gdk.KEY_r:
            self.remove_all()
        elif evt.keyval == Gdk.KEY_s:
            self.save_region()
        elif evt.keyval == Gdk.KEY_x:
            self.hint_screen()
        self.queue_draw()

    def save_labels(self):
        with open(self.descname, 'w') as outf:
            for itemid in sorted(self.tree):
                node = self.tree[itemid]
                if 'label' in node:
                    outf.write("%s %s\n" % (itemid, node['label']))

    def get_text(self, title, prompt):
        #base this on a message dialog
        dialog = Gtk.MessageDialog(self, 0, Gtk.MessageType.QUESTION,
                                   Gtk.ButtonsType.OK_CANCEL, title)
        dialog.format_secondary_text(prompt)
        #create the text input field
        entry = Gtk.Entry()
        #allow the user to press enter to do ok
        entry.connect("activate",
                      lambda entry: dialog.response(Gtk.ResponseType.OK))
        #create a horizontal box to pack the entry and a label
        hbox = Gtk.HBox()
        hbox.pack_start(Gtk.Label("Label:"), False, 5, 5)
        hbox.pack_end(entry, True, 0, 0)
        #add it and show it
        dialog.vbox.pack_end(hbox, True, True, 0)
        dialog.show_all()
        #go go go
        response = dialog.run()
        if response == Gtk.ResponseType.OK:
            text = entry.get_text()
        else:
            text = None
        dialog.destroy()
        return text
Пример #30
0
    action_number = int(key) - 1

    # get the tick number
    (death_left, death_top, death_width, death_height) = death_location
    tick_shot_pil = pyautogui.screenshot(region=(death_left + death_width,
                                                 death_top, tick_width,
                                                 tick_height))
    tick_shot_cv = np.array(tick_shot_pil)

    tick_shot_cv_gray = cv2.cvtColor(tick_shot_cv, cv2.COLOR_RGB2GRAY)
    (_, tick_shot_cv_black) = cv2.threshold(tick_shot_cv_gray, 190, 255,
                                            cv2.THRESH_BINARY_INV)

    pil_cap_black_text_demoUI = Image.fromarray(tick_shot_cv_black)
    tessocr_api.SetImage(pil_cap_black_text_demoUI)
    text = tessocr_api.GetUTF8Text()

    cur_tick_string_match = re.search('k[^\d]*(\d+)[^\d]*\/[^\d]*(\d+)', text)
    if cur_tick_string_match:
        cur_tick = int(cur_tick_string_match.group(1))
    else:
        print("skipping as didn't find tick")
        continue

    addActionDict(action_number, cur_tick)
    processing_end_time = time.time()
    time_budget = 0.4 - (processing_end_time - processing_start_time)
    if time_budget > 0:
        print(f'''hit budget, sleeping for {time_budget}''')
        time.sleep(time_budget)
    else: