コード例 #1
0
 def __init__(self, config, config_path=''):
     poly = config.getint('INTERP')
     line_scale = config.getfloat('LINE_SCALE')
     line_height = config.getint('LINE_HEIGHT')
     self.crop_engine = cropper.EngineLineCropper(line_height=line_height,
                                                  poly=poly,
                                                  scale=line_scale)
コード例 #2
0
ファイル: repair_engine.py プロジェクト: DCGM/pero-enhance
    def __init__(self, json_path, use_cpu=False):
        parent_folder = os.path.dirname(json_path)
        with open(json_path, 'r', encoding='utf8') as f:
            config = json.load(f)
        with open(os.path.join(parent_folder, config['chars_path']),
                  'rb') as handle:
            rb = pickle.load(handle)
        self.from_char = rb['from_char']
        self.to_char = rb['to_char']
        self.chars = rb['chars']
        self.max_labels = config['max_labels']
        self.height = config['height']
        self.max_width = config['max_width']

        repair_model = config['repair_model']
        inpainting_model = config['inpainting_model']

        self.cropper = crop_engine.EngineLineCropper(
            line_height=config['height'], poly=2, scale=1)

        if repair_model:
            tf.reset_default_graph()
            saver = tf.train.import_meta_graph(
                os.path.join(parent_folder, repair_model) + '.meta')
            if use_cpu:
                tf_config = tf.ConfigProto(device_count={'GPU': 0})
            else:
                tf_config = tf.ConfigProto(device_count={'GPU': 1})
                tf_config.gpu_options.allow_growth = True
            self.repair_session = tf.Session(config=tf_config)
            saver.restore(self.repair_session,
                          os.path.join(parent_folder, repair_model))

        if inpainting_model:
            tf.reset_default_graph()
            saver = tf.train.import_meta_graph(
                os.path.join(parent_folder, inpainting_model) + '.meta')
            if use_cpu:
                tf_config = tf.ConfigProto(device_count={'GPU': 0})
            else:
                tf_config = tf.ConfigProto(device_count={'GPU': 1})
                tf_config.gpu_options.allow_growth = True
            self.inpainting_session = tf.Session(config=tf_config)
            saver.restore(self.inpainting_session,
                          os.path.join(parent_folder, inpainting_model))
コード例 #3
0
ファイル: repair_page.py プロジェクト: kackamac/pero-ocr
def main():

    args = parseargs()

    page_img = cv2.imread(args.input_img)
    page_layout = layout.PageLayout(file=args.input_page)

    page_img_orig = page_img.copy()
    page_img_rendered = page_img.copy()

    print('\nLoading engines...')
    ocr_engine = ocr.EngineLineOCR(args.ocr_json, gpu_id=0)
    repair_engine = repair.EngineRepairCNN(args.repair_json)
    crop_engine = cropper.EngineLineCropper(line_height=repair_engine.height,
                                            poly=2,
                                            scale=1)

    cv2.namedWindow("Page Editor", cv2.WINDOW_NORMAL)
    cv2.resizeWindow('Page Editor', 1024, 1024)
    layout_clicker = LayoutClicker(page_layout)
    cv2.setMouseCallback("Page Editor", layout_clicker.callback)

    while True:
        page_img_rendered = page_img.copy()
        if layout_clicker.chosen_line:
            page_img_rendered = layout.draw_lines(
                page_img_rendered, [layout_clicker.chosen_line.polygon],
                color=(0, 255, 0),
                close=True)
        if layout_clicker.points:
            page_img_rendered = layout.draw_lines(page_img_rendered,
                                                  [layout_clicker.points],
                                                  color=(0, 0, 255))

        cv2.imshow('Page Editor', page_img_rendered)
        key = cv2.waitKey(1)

        if key == ord('q'):
            break

        elif key == ord('r'):
            text_input = TextInputRepair(
                layout_clicker.chosen_line.transcription)
            action, new_transcription = text_input.run()
            layout_clicker.chosen_line.transcription = new_transcription

            if action == 'repair':
                line_crop, line_mapping, offset = crop_engine.crop(
                    page_img,
                    layout_clicker.chosen_line.baseline,
                    layout_clicker.chosen_line.heights,
                    return_mapping=True)
                line_crop = repair_engine.repair_line(
                    line_crop, layout_clicker.chosen_line.transcription)
                page_img = crop_engine.blend_in(page_img, line_crop,
                                                line_mapping, offset)
                page_img_rendered = page_img.copy()

            elif action == 'revert':
                line_crop, line_mapping, offset = crop_engine.crop(
                    page_img_orig,
                    layout_clicker.chosen_line.baseline,
                    layout_clicker.chosen_line.heights,
                    return_mapping=True)
                page_img = crop_engine.blend_in(page_img, line_crop,
                                                line_mapping, offset)
                page_img_rendered = page_img.copy()

        elif key == ord('e') and len(layout_clicker.points) == 2:
            line_crop, line_mapping, offset = crop_engine.crop(
                page_img,
                layout_clicker.chosen_line.baseline,
                layout_clicker.chosen_line.heights,
                return_mapping=True)

            y1 = np.round(line_mapping[line_mapping.shape[0] // 2,
                                       layout_clicker.points[0][1] - offset[1],
                                       1]).astype(np.uint16)
            y2 = np.round(
                line_mapping[line_mapping.shape[0] // 2,
                             np.clip(layout_clicker.points[1][1] -
                                     offset[1], 0, line_mapping.shape[1] - 2),
                             1]).astype(np.uint16)
            if layout_clicker.points[1][1] - offset[1] > line_mapping.shape[
                    1] - 10:  # dirty fix noisy values at the end of coord map
                y2 = np.amax(line_mapping[:, :, 1].astype(np.uint16))
            print('{}/{}'.format(y2, line_crop.shape[1]))
            transcriptions, _, _ = ocr_engine.process_lines([
                line_crop[:, :np.minimum(y1, y2), :],
                line_crop[:, np.maximum(y1, y2):, :]
            ])
            line_crop[:, np.minimum(y1, y2):np.maximum(y1, y2), :] = 0
            text_input = TextInputInpaint(transcriptions[0], transcriptions[1])
            action, new_transcription = text_input.run()
            if action == 'inpaint':
                layout_clicker.chosen_line.transcription = new_transcription

                line_crop = repair_engine.inpaint_line(
                    line_crop, layout_clicker.chosen_line.transcription)
                page_img = crop_engine.blend_in(page_img, line_crop,
                                                line_mapping, offset)
                line_crop, line_mapping, offset = crop_engine.crop(
                    page_img,
                    layout_clicker.chosen_line.baseline,
                    layout_clicker.chosen_line.heights,
                    return_mapping=True)
                layout_clicker.points = []
コード例 #4
0
    def process_page(self, img, page_layout: PageLayout):
        if self.detect_regions or self.detect_lines:
            if self.detect_regions:
                page_layout.regions = []
            if self.detect_lines:
                for region in page_layout.regions:
                    region.lines = []

            if self.multi_orientation:
                orientations = [0, 1, 3]
            else:
                orientations = [0]

            for rot in orientations:
                regions = []
                p_list, b_list, h_list, t_list = self.engine.detect(img,
                                                                    rot=rot)
                if self.detect_regions:
                    for id, polygon in enumerate(p_list):
                        if rot > 0:
                            id = 'r{:03d}_{}'.format(id, rot)
                        else:
                            id = 'r{:03d}'.format(id)
                        region = RegionLayout(id, polygon)
                        regions.append(region)
                if self.detect_lines:
                    if not self.detect_regions:
                        regions = page_layout.regions
                    regions = helpers.assign_lines_to_regions(
                        b_list, h_list, t_list, regions)
                if self.detect_regions:
                    page_layout.regions += regions

        if self.merge_lines:
            for region in page_layout.regions:
                while True:
                    original_line_count = len(region.lines)
                    r_b_list, r_h_list = helpers.merge_lines(
                        [line.baseline for line in region.lines],
                        [line.heights for line in region.lines])
                    r_t_list = [
                        helpers.baseline_to_textline(b, h)
                        for b, h in zip(r_b_list, r_h_list)
                    ]
                    region.lines = []
                    region = helpers.assign_lines_to_regions(
                        r_b_list, r_h_list, r_t_list, [region])[0]
                    if len(region.lines) == original_line_count:
                        break

        if self.detect_straight_lines_in_regions or self.adjust_heights or self.adjust_baselines:
            maps, ds = self.engine.parsenet.get_maps_with_optimal_resolution(
                img)

        if self.detect_straight_lines_in_regions:
            for region in page_layout.regions:
                pb_list, ph_list, pt_list = detect_lines_in_region(
                    region.polygon, maps, ds)
                region.lines = []
                region = helpers.assign_lines_to_regions(
                    pb_list, ph_list, pt_list, [region])[0]

        if self.adjust_heights:
            for line in page_layout.lines_iterator():
                sample_points = helpers.resample_baselines([line.baseline],
                                                           num_points=40)[0]
                line.heights = self.engine.get_heights(maps, ds, sample_points)
                line.polygon = helpers.baseline_to_textline(
                    line.baseline, line.heights)

        if self.adjust_baselines:
            crop_engine = cropper.EngineLineCropper(line_height=32,
                                                    poly=0,
                                                    scale=1)
            for line in page_layout.lines_iterator():
                line.baseline = refine_baseline(line.baseline, line.heights,
                                                maps, ds, crop_engine)
                line.polygon = helpers.baseline_to_textline(
                    line.baseline, line.heights)
        return page_layout