示例#1
0
def get_line_confidence(line, labels, aligned_letters=None, log_probs=None):
    if log_probs is None:
        log_probs = line.get_full_logprobs()

    if aligned_letters is None:
        aligned_letters = align_text(-log_probs, labels,
                                     log_probs.shape[1] - 1)
    alignment = np.concatenate([aligned_letters, [1000]])

    probs = np.exp(log_probs)
    last_border = 0
    confidences = np.zeros(len(labels))
    for i, label in enumerate(labels):
        label_prob = probs[alignment[i], label]
        next_border = (alignment[i] + 1 + alignment[i + 1]) // 2
        pos_probs = probs[last_border:next_border]
        masked_probs = np.copy(pos_probs)
        masked_probs[:, label] = 0
        if i > 0:
            masked_probs[:, labels[i - 1]] = 0
        if i + 1 < len(labels):
            masked_probs[:, labels[i + 1]] = 0
        other_prob = masked_probs[:, :-1].max()
        confidences[i] = max(0, label_prob - other_prob)
        last_border = next_border

    #confidences = confidences / 2 + 0.5
    return confidences
示例#2
0
def process_batch(f, batch_img, batch_txt, ocr_engine, decoder, variations,
                  limit):
    chars = [i for i in range(len(ocr_engine.characters))]
    char_to_num = dict(zip(ocr_engine.characters, chars))

    _, processed_logits, _ = ocr_engine.process_lines(batch_img)
    for i, item in enumerate(processed_logits):
        logits = prepare_dense_logits(item)

        hyps = variations[batch_txt[i]]
        ground_truth = np.zeros(logits.shape)
        blank_idx = logits.shape[1] - 1
        for hyp in hyps[:limit]:
            neg_logprobs = log_softmax(logits)
            label = []
            for item in hyp[1]:
                if item in char_to_num.keys():
                    if char_to_num[item] >= blank_idx:
                        label.append(0)
                    else:
                        label.append(char_to_num[item])
                else:
                    label.append(0)
            try:
                positions = align_text(-neg_logprobs, np.array(label),
                                       blank_idx)
            except (ValueError, IndexError):
                continue
            ground_truth[positions, label] += hyp[0]

        f.write("{} {} {}\n".format(
            batch_txt[i], logits.shape, ' '.join(
                str(coo_matrix(np.around(ground_truth,
                                         decimals=2))).replace('\n',
                                                               ' ').split())))
def get_line_confidence(line, labels):
    log_probs = line.get_full_logprobs()

    alignment = align_text(-log_probs, labels, log_probs.shape[1] - 1)
    alignment = np.concatenate([alignment, [1000]])

    probs = np.exp(log_probs)
    last_border = 0
    confidences = np.zeros(len(labels))
    for i, label in enumerate(labels):
        label_prob = probs[alignment[i], label]
        next_border = (alignment[i] + 1 + alignment[i + 1]) // 2
        pos_probs = probs[last_border:next_border]
        masked_probs = np.copy(pos_probs)
        masked_probs[:, label] = 0
        other_prob = masked_probs[:, :-1].max()
        confidences[i] = label_prob - other_prob
        last_border = next_border

    confidences = confidences / 2 + 0.5
    return confidences
示例#4
0
文件: layout.py 项目: DCGM/pero-ocr
    def get_quality(self, x=None, y=None, width=None, height=None, power=6):
        bbox_confidences = []
        for b, block in enumerate(self.regions):
            for l, line in enumerate(block.lines):
                if not line.transcription:
                    continue

                chars = [i for i in range(len(line.characters))]
                char_to_num = dict(zip(line.characters, chars))

                blank_idx = line.logits.shape[1] - 1

                label = []
                for item in line.transcription:
                    if item in char_to_num.keys():
                        if char_to_num[item] >= blank_idx:
                            label.append(0)
                        else:
                            label.append(char_to_num[item])
                    else:
                        label.append(0)

                logits = line.get_dense_logits()[line.logit_coords[0]:line.
                                                 logit_coords[1]]
                logprobs = line.get_full_logprobs()[line.logit_coords[0]:line.
                                                    logit_coords[1]]
                try:
                    aligned_letters = align_text(-logprobs, np.array(label),
                                                 blank_idx)
                except (ValueError, IndexError) as e:
                    pass
                else:
                    crop_engine = EngineLineCropper(poly=2)
                    line_coords = crop_engine.get_crop_inputs(
                        line.baseline, line.heights, 16)
                    space_idxs = [
                        pos for pos, char in enumerate(line.transcription)
                        if char == ' '
                    ]

                    words = []
                    space_idxs = [-1] + space_idxs + [len(aligned_letters)]

                    only_letters = dict()
                    counter = 0
                    for i, letter in enumerate(aligned_letters):
                        if i not in space_idxs:
                            words.append([letter, letter])
                            only_letters[counter] = i
                            counter += 1

                    lm_const = line_coords.shape[1] / logits.shape[0]
                    confidences = get_line_confidence(line, np.array(label),
                                                      aligned_letters,
                                                      logprobs)
                    line.transcription_confidence = np.quantile(
                        confidences, .50)
                    for w, word in enumerate(words):
                        extension = 2
                        while True:
                            all_x = line_coords[:,
                                                max(
                                                    0,
                                                    int((words[w][0] -
                                                         extension) * lm_const)
                                                ):int((words[w][1] +
                                                       extension) * lm_const),
                                                0]
                            all_y = line_coords[:,
                                                max(
                                                    0,
                                                    int((words[w][0] -
                                                         extension) * lm_const)
                                                ):int((words[w][1] +
                                                       extension) * lm_const),
                                                1]

                            if all_x.size == 0 or all_y.size == 0:
                                extension += 1
                            else:
                                break

                        vpos = int(np.min(all_y))
                        hpos = int(np.min(all_x))
                        if x and y and height and width:
                            if vpos >= y and vpos <= (
                                    y + height) and hpos >= x and hpos <= (
                                        x + width):
                                bbox_confidences.append(
                                    confidences[only_letters[w]])
                        else:
                            bbox_confidences.append(
                                confidences[only_letters[w]])

        if len(bbox_confidences) != 0:
            return (1 / len(bbox_confidences) *
                    (np.power(bbox_confidences, power).sum()))**(1 / power)
        else:
            return -1
示例#5
0
文件: layout.py 项目: DCGM/pero-ocr
    def to_altoxml_string(self,
                          ocr_processing=None,
                          page_uuid=None,
                          min_line_confidence=0):
        arabic_helper = ArabicHelper()
        NSMAP = {
            "xlink": 'http://www.w3.org/1999/xlink',
            "xsi": 'http://www.w3.org/2001/XMLSchema-instance'
        }
        root = ET.Element("alto", nsmap=NSMAP)
        root.set("xmlns", "http://www.loc.gov/standards/alto/ns-v2#")

        description = ET.SubElement(root, "Description")
        measurement_unit = ET.SubElement(description, "MeasurementUnit")
        measurement_unit.text = "pixel"
        source_image_information = ET.SubElement(description,
                                                 "sourceImageInformation")
        file_name = ET.SubElement(source_image_information, "fileName")
        file_name.text = self.id
        if ocr_processing is not None:
            description.append(ocr_processing)
        else:
            ocr_processing = create_ocr_processing_element()
            description.append(ocr_processing)
        layout = ET.SubElement(root, "Layout")
        page = ET.SubElement(layout, "Page")
        if page_uuid is not None:
            page.set("ID", "id_" + page_uuid)
        else:
            page.set(
                "ID", "id_" +
                re.sub('[!\"#$%&\'()*+,/:;<=>?@[\\]^`{|}~ ]', '_', self.id))
        page.set("PHYSICAL_IMG_NR", str(1))
        page.set("HEIGHT", str(self.page_size[0]))
        page.set("WIDTH", str(self.page_size[1]))

        top_margin = ET.SubElement(page, "TopMargin")
        left_margin = ET.SubElement(page, "LeftMargin")
        right_margin = ET.SubElement(page, "RightMargin")
        bottom_margin = ET.SubElement(page, "BottomMargin")
        print_space = ET.SubElement(page, "PrintSpace")

        print_space_height = 0
        print_space_width = 0
        print_space_vpos = self.page_size[0]
        print_space_hpos = self.page_size[1]

        for b, block in enumerate(self.regions):
            text_block = ET.SubElement(print_space, "TextBlock")
            text_block.set("ID", 'block_{}'.format(block.id))

            text_block_height, text_block_width, text_block_vpos, text_block_hpos = get_hwvh(
                block.polygon)
            text_block.set("HEIGHT", str(int(text_block_height)))
            text_block.set("WIDTH", str(int(text_block_width)))
            text_block.set("VPOS", str(int(text_block_vpos)))
            text_block.set("HPOS", str(int(text_block_hpos)))

            print_space_height = max([
                print_space_vpos + print_space_height,
                text_block_vpos + text_block_height
            ])
            print_space_width = max([
                print_space_hpos + print_space_width,
                text_block_hpos + text_block_width
            ])
            print_space_vpos = min([print_space_vpos, text_block_vpos])
            print_space_hpos = min([print_space_hpos, text_block_hpos])
            print_space_height = print_space_height - print_space_vpos
            print_space_width = print_space_width - print_space_hpos

            for l, line in enumerate(block.lines):
                if not line.transcription:
                    continue
                arabic_line = False
                if arabic_helper.is_arabic_line(line.transcription):
                    arabic_line = True
                text_line = ET.SubElement(text_block, "TextLine")
                text_line_baseline = int(
                    np.average(np.array(line.baseline)[:, 1]))
                text_line.set("BASELINE", str(text_line_baseline))

                text_line_height, text_line_width, text_line_vpos, text_line_hpos = get_hwvh(
                    line.polygon)

                text_line.set("VPOS", str(int(text_line_vpos)))
                text_line.set("HPOS", str(int(text_line_hpos)))
                text_line.set("HEIGHT", str(int(text_line_height)))
                text_line.set("WIDTH", str(int(text_line_width)))

                chars = [i for i in range(len(line.characters))]
                char_to_num = dict(zip(line.characters, chars))

                blank_idx = line.logits.shape[1] - 1

                label = []
                for item in line.transcription:
                    if item in char_to_num.keys():
                        if char_to_num[item] >= blank_idx:
                            label.append(0)
                        else:
                            label.append(char_to_num[item])
                    else:
                        label.append(0)

                logits = line.get_dense_logits()[line.logit_coords[0]:line.
                                                 logit_coords[1]]
                logprobs = line.get_full_logprobs()[line.logit_coords[0]:line.
                                                    logit_coords[1]]
                try:
                    aligned_letters = align_text(-logprobs, np.array(label),
                                                 blank_idx)
                except (ValueError, IndexError) as e:
                    print(
                        f'Error: Alto export, unable to align line {line.id} due to exception {e}.'
                    )
                    line.transcription_confidence = 0
                    average_word_width = (text_line_hpos + text_line_width
                                          ) / len(line.transcription.split())
                    for w, word in enumerate(line.transcription.split()):
                        string = ET.SubElement(text_line, "String")
                        string.set("CONTENT", word)

                        string.set("HEIGHT", str(int(text_line_height)))
                        string.set("WIDTH", str(int(average_word_width)))
                        string.set("VPOS", str(int(text_line_vpos)))
                        string.set(
                            "HPOS",
                            str(int(text_line_hpos +
                                    (w * average_word_width))))
                else:
                    crop_engine = EngineLineCropper(poly=2)
                    line_coords = crop_engine.get_crop_inputs(
                        line.baseline, line.heights, 16)
                    space_idxs = [
                        pos for pos, char in enumerate(line.transcription)
                        if char == ' '
                    ]

                    words = []
                    space_idxs = [-1] + space_idxs + [len(aligned_letters)]
                    for i in range(len(space_idxs[1:])):
                        if space_idxs[i] != space_idxs[i + 1] - 1:
                            words.append([
                                aligned_letters[space_idxs[i] + 1],
                                aligned_letters[space_idxs[i + 1] - 1]
                            ])
                    splitted_transcription = line.transcription.split()
                    lm_const = line_coords.shape[1] / logits.shape[0]
                    letter_counter = 0
                    confidences = get_line_confidence(line, np.array(label),
                                                      aligned_letters,
                                                      logprobs)
                    if line.transcription_confidence is None:
                        line.transcription_confidence = np.quantile(
                            confidences, .50)
                    for w, word in enumerate(words):
                        extension = 2
                        while True:
                            all_x = line_coords[:,
                                                max(
                                                    0,
                                                    int((words[w][0] -
                                                         extension) * lm_const)
                                                ):int((words[w][1] +
                                                       extension) * lm_const),
                                                0]
                            all_y = line_coords[:,
                                                max(
                                                    0,
                                                    int((words[w][0] -
                                                         extension) * lm_const)
                                                ):int((words[w][1] +
                                                       extension) * lm_const),
                                                1]

                            if all_x.size == 0 or all_y.size == 0:
                                extension += 1
                            else:
                                break

                        word_confidence = None
                        if line.transcription_confidence == 1:
                            word_confidence = 1
                        else:
                            if confidences.size != 0:
                                word_confidence = np.quantile(
                                    confidences[letter_counter:letter_counter +
                                                len(splitted_transcription[w]
                                                    )], .50)

                        string = ET.SubElement(text_line, "String")

                        if arabic_line:
                            string.set(
                                "CONTENT",
                                arabic_helper.label_form_to_string(
                                    splitted_transcription[w]))
                        else:
                            string.set("CONTENT", splitted_transcription[w])

                        string.set("HEIGHT",
                                   str(int((np.max(all_y) - np.min(all_y)))))
                        string.set("WIDTH",
                                   str(int((np.max(all_x) - np.min(all_x)))))
                        string.set("VPOS", str(int(np.min(all_y))))
                        string.set("HPOS", str(int(np.min(all_x))))

                        if word_confidence is not None:
                            string.set("WC", str(round(word_confidence, 2)))

                        if w != (len(line.transcription.split()) - 1):
                            space = ET.SubElement(text_line, "SP")

                            space.set("WIDTH", str(4))
                            space.set("VPOS", str(int(np.min(all_y))))
                            space.set("HPOS", str(int(np.max(all_x))))
                        letter_counter += len(splitted_transcription[w]) + 1
                if line.transcription_confidence is not None:
                    if line.transcription_confidence < min_line_confidence:
                        text_block.remove(text_line)
        top_margin.set("HEIGHT", "{}".format(int(print_space_vpos)))
        top_margin.set("WIDTH", "{}".format(int(self.page_size[1])))
        top_margin.set("VPOS", "0")
        top_margin.set("HPOS", "0")

        left_margin.set("HEIGHT", "{}".format(int(self.page_size[0])))
        left_margin.set("WIDTH", "{}".format(int(print_space_hpos)))
        left_margin.set("VPOS", "0")
        left_margin.set("HPOS", "0")

        right_margin.set("HEIGHT", "{}".format(int(self.page_size[0])))
        right_margin.set(
            "WIDTH", "{}".format(
                int(self.page_size[1] -
                    (print_space_hpos + print_space_width))))
        right_margin.set("VPOS", "0")
        right_margin.set(
            "HPOS", "{}".format(int(print_space_hpos + print_space_width)))

        bottom_margin.set(
            "HEIGHT", "{}".format(
                int(self.page_size[0] -
                    (print_space_vpos + print_space_height))))
        bottom_margin.set("WIDTH", "{}".format(int(self.page_size[1])))
        bottom_margin.set(
            "VPOS", "{}".format(int(print_space_vpos + print_space_height)))
        bottom_margin.set("HPOS", "0")

        print_space.set("HEIGHT", str(int(print_space_height)))
        print_space.set("WIDTH", str(int(print_space_width)))
        print_space.set("VPOS", str(int(print_space_vpos)))
        print_space.set("HPOS", str(int(print_space_hpos)))

        return ET.tostring(root, pretty_print=True,
                           encoding="utf-8").decode("utf-8")