Exemplo n.º 1
0
def test_levenshtein_block(s1, s2):
    """
    Test blockwise Levenshtein implementation against simple implementation
    """
    reference_dist = levenshtein(s1, s2)
    reference_sim = normalize_distance(reference_dist, s1, s2)
    assert string_metric.levenshtein(s1, s2) == reference_dist
    assert isclose(string_metric.normalized_levenshtein(s1, s2), reference_sim)
    assert isclose(
        extractOne_scorer(s1, s2, string_metric.normalized_levenshtein),
        reference_sim)
    assert isclose(
        extract_scorer(s1, s2, string_metric.normalized_levenshtein),
        reference_sim)

    reference_dist = levenshtein(s1, s2, (1, 1, 2))
    reference_sim = normalize_distance(reference_dist, s1, s2, (1, 1, 2))
    assert string_metric.levenshtein(s1, s2, (1, 1, 2)) == reference_dist
    assert isclose(string_metric.normalized_levenshtein(s1, s2, (1, 1, 2)),
                   reference_sim)
    assert isclose(
        extractOne_scorer(s1,
                          s2,
                          string_metric.normalized_levenshtein,
                          weights=(1, 1, 2)), reference_sim)
    assert isclose(
        extract_scorer(s1,
                       s2,
                       string_metric.normalized_levenshtein,
                       weights=(1, 1, 2)), reference_sim)
Exemplo n.º 2
0
def test_levenshtein_random(s1, s2):
    """
    Test mixed strings to test through all implementations of Levenshtein
    """
    reference_dist = levenshtein(s1, s2)
    reference_sim = normalize_distance(reference_dist, s1, s2)
    assert string_metric.levenshtein(s1, s2) == reference_dist
    assert isclose(string_metric.normalized_levenshtein(s1, s2), reference_sim)
    assert isclose(
        extractOne_scorer(s1, s2, string_metric.normalized_levenshtein),
        reference_sim)
    assert isclose(
        extract_scorer(s1, s2, string_metric.normalized_levenshtein),
        reference_sim)

    reference_dist = levenshtein(s1, s2, (1, 1, 2))
    reference_sim = normalize_distance(reference_dist, s1, s2, (1, 1, 2))
    assert string_metric.levenshtein(s1, s2, (1, 1, 2)) == reference_dist
    assert isclose(string_metric.normalized_levenshtein(s1, s2, (1, 1, 2)),
                   reference_sim)
    assert isclose(
        extractOne_scorer(s1,
                          s2,
                          string_metric.normalized_levenshtein,
                          weights=(1, 1, 2)), reference_sim)
    assert isclose(
        extract_scorer(s1,
                       s2,
                       string_metric.normalized_levenshtein,
                       weights=(1, 1, 2)), reference_sim)
Exemplo n.º 3
0
def test_levenshtein_random(s1, s2):
    """
    Test mixed strings to test through all implementations of Levenshtein
    """
    assert string_metric.levenshtein(s1, s2) == levenshtein(s1, s2)
    assert string_metric.levenshtein(s1, s2, (1, 1, 2)) == levenshtein(
        s1, s2, (1, 1, 2))
Exemplo n.º 4
0
def test_levenshtein_word(s1, s2):
    """
    Test short Levenshtein implementation against simple implementation
    """
    assert string_metric.levenshtein(s1, s2) == levenshtein(s1, s2)
    assert string_metric.levenshtein(s1, s2, (1, 1, 2)) == levenshtein(
        s1, s2, (1, 1, 2))
Exemplo n.º 5
0
def test_levenshtein_block(s1, s2):
    """
    Test blockwise Levenshtein implementation against simple implementation
    """
    assert string_metric.levenshtein(s1, s2) == levenshtein(s1, s2)
    assert string_metric.levenshtein(s1, s2, (1, 1, 2)) == levenshtein(
        s1, s2, (1, 1, 2))
Exemplo n.º 6
0
def test_empty_string():
    """
    when both strings are empty this is a perfect match
    """
    assert string_metric.levenshtein("", "") == 0
    assert string_metric.levenshtein("", "", (1,1,0)) == 0
    assert string_metric.levenshtein("", "", (1,1,2)) == 0
    assert string_metric.levenshtein("", "", (1,1,5)) == 0
    assert string_metric.levenshtein("", "", (3,7,5)) == 0
Exemplo n.º 7
0
def weighted_distance(s1, s2, insert_cost=1, delete_cost=1, replace_cost=1):
    logger = logging.getLogger(__name__)
    logger.warn(
        'This function is deprecated and will be removed in v2.0.0.\n'
        'Use string_metric.normalized_levenshtein(s1, s2, insert_cost=%d, delete_cost=%d, replace_cost=%d) instead'
        % (insert_cost, delete_cost, replace_cost))
    return string_metric.levenshtein(s1, s2, insert_cost, delete_cost,
                                     replace_cost)
Exemplo n.º 8
0
def distance(s1: str, s2: str):
    """Compute the Levenshtein edit distance between two Unicode strings

    Note that this is different from levenshtein() as this function knows about Unicode
    normalization and grapheme clusters. This should be the correct way to compare two
    Unicode strings.
    """
    seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", s1)))
    seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", s2)))
    return levenshtein(seq1, seq2)
Exemplo n.º 9
0
def count_matches(pred_texts, gt_texts):
    """Count the various match number for metric calculation.

    Args:
        pred_texts (list[str]): Predicted text string.
        gt_texts (list[str]): Ground truth text string.

    Returns:
        match_res: (dict[str: int]): Match number used for
            metric calculation.
    """
    match_res = {
        'gt_char_num': 0,
        'pred_char_num': 0,
        'true_positive_char_num': 0,
        'gt_word_num': 0,
        'match_word_num': 0,
        'match_word_ignore_case': 0,
        'match_word_ignore_case_symbol': 0
    }
    comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
    norm_ed_sum = 0.0
    for pred_text, gt_text in zip(pred_texts, gt_texts):
        if gt_text == pred_text:
            match_res['match_word_num'] += 1
        gt_text_lower = gt_text.lower()
        pred_text_lower = pred_text.lower()
        if gt_text_lower == pred_text_lower:
            match_res['match_word_ignore_case'] += 1
        gt_text_lower_ignore = comp.sub('', gt_text_lower)
        pred_text_lower_ignore = comp.sub('', pred_text_lower)
        if gt_text_lower_ignore == pred_text_lower_ignore:
            match_res['match_word_ignore_case_symbol'] += 1
        match_res['gt_word_num'] += 1

        # normalized edit distance
        edit_dist = string_metric.levenshtein(pred_text_lower_ignore,
                                              gt_text_lower_ignore)
        norm_ed = float(edit_dist) / max(1, len(gt_text_lower_ignore),
                                         len(pred_text_lower_ignore))
        norm_ed_sum += norm_ed

        # number to calculate char level recall & precision
        match_res['gt_char_num'] += len(gt_text_lower_ignore)
        match_res['pred_char_num'] += len(pred_text_lower_ignore)
        true_positive_char_num = cal_true_positive_char(
            pred_text_lower_ignore, gt_text_lower_ignore)
        match_res['true_positive_char_num'] += true_positive_char_num

    normalized_edit_distance = norm_ed_sum / max(1, len(gt_texts))
    match_res['ned'] = normalized_edit_distance

    return match_res
Exemplo n.º 10
0
def word_error_rate_n(reference: Iterable,
                      compared: Iterable) -> Tuple[float, int]:
    reference_seq = list(reference)
    compared_seq = list(compared)

    d = levenshtein(reference_seq, compared_seq)
    n = len(reference_seq)

    if d == 0:
        return 0, n
    if n == 0:
        return float("inf"), n
    return d / n, n
Exemplo n.º 11
0
def merge_strings(a: str, b: str, dil_factor: float) -> str:
    """Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters.

    Args:
        a: first char seq, suffix should be similar to b's prefix.
        b: second char seq, prefix should be similar to a's suffix.
        dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
            only used when the mother sequence is splitted on a character repetition

    Returns:
        A merged character sequence.

    Example::
        >>> from doctr.model.recognition.utils import merge_sequences
        >>> merge_sequences('abcd', 'cdefgh', 1.4)
        'abcdefgh'
        >>> merge_sequences('abcdi', 'cdefgh', 1.4)
        'abcdefgh'
    """
    seq_len = min(len(a), len(b))
    if seq_len == 0:  # One sequence is empty, return the other
        return b if len(a) == 0 else b

    # Initialize merging index and corresponding score (mean Levenstein)
    min_score, index = 1.0, 0  # No overlap, just concatenate

    scores = [
        levenshtein(a[-i:], b[:i], processor=None) / i
        for i in range(1, seq_len + 1)
    ]

    # Edge case (split in the middle of char repetitions): if it starts with 2 or more 0
    if len(scores) > 1 and (scores[0], scores[1]) == (0, 0):
        # Compute n_overlap (number of overlapping chars, geometrically determined)
        n_overlap = round(len(b) * (dil_factor - 1) / dil_factor)
        # Find the number of consecutive zeros in the scores list
        # Impossible to have a zero after a non-zero score in that case
        n_zeros = sum(val == 0 for val in scores)
        # Index is bounded by the geometrical overlap to avoid collapsing repetitions
        min_score, index = 0, min(n_zeros, n_overlap)

    else:  # Common case: choose the min score index
        for i, score in enumerate(scores):
            if score < min_score:
                min_score, index = score, i + 1  # Add one because first index is an overlap of 1 char

    # Merge with correct overlap
    if index == 0:
        return a + b
    return a[:-1] + b[index - 1:]
Exemplo n.º 12
0
def do_test(args):
    """CLI method for test"""
    try:
        from rapidfuzz.string_metric import levenshtein
    except ImportError as e:
        _LOGGER.fatal("rapidfuzz library is needed for levenshtein distance")
        _LOGGER.fatal("pip install 'rapidfuzz>=1.4.1'")
        raise e

    tagger = GraphemesToPhonemes(args.model)

    if args.texts:
        lines = args.texts
    else:
        lines = sys.stdin

        if os.isatty(sys.stdin.fileno()):
            print("Reading lexicon lines from stdin...", file=sys.stderr)

    num_errors = 0
    num_missing = 0
    num_phonemes = 0

    for line in lines:
        line = line.strip()
        if not line:
            continue

        word, actual_phonemes = line.split(maxsplit=1)
        expected_phonemes = "".join(tagger(word))

        if expected_phonemes:
            distance = levenshtein(expected_phonemes, actual_phonemes)
            num_errors += distance
            num_phonemes += len(actual_phonemes)
        else:
            num_missing += 1
            _LOGGER.warning("No pronunciation for %s", word)

    # Calculate results
    per = round(num_errors / num_phonemes, 2)
    print("PER:", per, "Errors:", num_errors)

    if num_missing > 0:
        print("Total missing:", num_missing)
Exemplo n.º 13
0
def evaluate_method(gtFilePath, submFilePath, evaluationParams):
    """
    Method evaluate_method: evaluate method and returns the results
        Results. Dictionary with the following values:
        - method (required)  Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 }
        - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 }
    """
    for module, alias in evaluation_imports().items():
        globals()[alias] = importlib.import_module(module)

    def polygon_from_points(points):
        """
        Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
        """
        num_points = len(points)
        # resBoxes=np.empty([1,num_points],dtype='int32')
        resBoxes = np.empty([1, num_points], dtype='float32')
        for inp in range(0, num_points, 2):
            resBoxes[0, int(inp / 2)] = float(points[int(inp)])
            resBoxes[0, int(inp / 2 + num_points / 2)] = float(points[int(inp +
                                                                          1)])
        pointMat = resBoxes[0].reshape([2, int(num_points / 2)]).T
        return plg.Polygon(pointMat)

    def rectangle_to_polygon(rect):
        resBoxes = np.empty([1, 8], dtype='int32')
        resBoxes[0, 0] = int(rect.xmin)
        resBoxes[0, 4] = int(rect.ymax)
        resBoxes[0, 1] = int(rect.xmin)
        resBoxes[0, 5] = int(rect.ymin)
        resBoxes[0, 2] = int(rect.xmax)
        resBoxes[0, 6] = int(rect.ymin)
        resBoxes[0, 3] = int(rect.xmax)
        resBoxes[0, 7] = int(rect.ymax)

        pointMat = resBoxes[0].reshape([2, 4]).T

        return plg.Polygon(pointMat)

    def rectangle_to_points(rect):
        points = [
            int(rect.xmin),
            int(rect.ymax),
            int(rect.xmax),
            int(rect.ymax),
            int(rect.xmax),
            int(rect.ymin),
            int(rect.xmin),
            int(rect.ymin)
        ]
        return points

    def get_union(pD, pG):
        areaA = pD.area()
        areaB = pG.area()
        return areaA + areaB - get_intersection(pD, pG)

    def get_intersection_over_union(pD, pG):
        try:
            return get_intersection(pD, pG) / get_union(pD, pG)
        except:
            return 0

    def get_intersection(pD, pG):
        pInt = pD & pG
        if len(pInt) == 0:
            return 0
        return pInt.area()

    def compute_ap(confList, matchList, numGtCare):
        correct = 0
        AP = 0
        if len(confList) > 0:
            confList = np.array(confList)
            matchList = np.array(matchList)
            sorted_ind = np.argsort(-confList)
            confList = confList[sorted_ind]
            matchList = matchList[sorted_ind]
            for n in range(len(confList)):
                match = matchList[n]
                if match:
                    correct += 1
                    AP += float(correct) / (n + 1)

            if numGtCare > 0:
                AP /= numGtCare

        return AP

    def transcription_match(transGt,
                            transDet,
                            specialCharacters=str(r'!?.:,*"()·[]/\''),
                            onlyRemoveFirstLastCharacterGT=True):

        if onlyRemoveFirstLastCharacterGT:
            #special characters in GT are allowed only at initial or final position
            if (transGt == transDet):
                return True

            if specialCharacters.find(transGt[0]) > -1:
                if transGt[1:] == transDet:
                    return True

            if specialCharacters.find(transGt[-1]) > -1:
                if transGt[0:len(transGt) - 1] == transDet:
                    return True

            if specialCharacters.find(
                    transGt[0]) > -1 and specialCharacters.find(
                        transGt[-1]) > -1:
                if transGt[1:len(transGt) - 1] == transDet:
                    return True
            return False
        else:
            #Special characters are removed from the begining and the end of both Detection and GroundTruth
            while len(transGt) > 0 and specialCharacters.find(transGt[0]) > -1:
                transGt = transGt[1:]

            while len(transDet) > 0 and specialCharacters.find(
                    transDet[0]) > -1:
                transDet = transDet[1:]

            while len(transGt) > 0 and specialCharacters.find(
                    transGt[-1]) > -1:
                transGt = transGt[0:len(transGt) - 1]

            while len(transDet) > 0 and specialCharacters.find(
                    transDet[-1]) > -1:
                transDet = transDet[0:len(transDet) - 1]

            return transGt == transDet

    def include_in_dictionary(transcription):
        """
        Function used in Word Spotting that finds if the Ground Truth transcription meets the rules to enter into the dictionary. If not, the transcription will be cared as don't care
        """
        #special case 's at final
        if transcription[len(transcription) -
                         2:] == "'s" or transcription[len(transcription) -
                                                      2:] == "'S":
            transcription = transcription[0:len(transcription) - 2]

        #hypens at init or final of the word
        transcription = transcription.strip('-')

        specialCharacters = str("'!?.:,*\"()·[]/")
        for character in specialCharacters:
            transcription = transcription.replace(character, ' ')

        transcription = transcription.strip()

        if len(transcription) != len(transcription.replace(" ", "")):
            return False

        if len(transcription) < evaluationParams['MIN_LENGTH_CARE_WORD']:
            return False

        notAllowed = str("×÷·")

        range1 = [ord(u'a'), ord(u'z')]
        range2 = [ord(u'A'), ord(u'Z')]
        range3 = [ord(u'À'), ord(u'ƿ')]
        range4 = [ord(u'DŽ'), ord(u'ɿ')]
        range5 = [ord(u'Ά'), ord(u'Ͽ')]
        range6 = [ord(u'-'), ord(u'-')]

        for char in transcription:
            charCode = ord(char)
            if (notAllowed.find(char) != -1):
                return False

            valid = (charCode >= range1[0] and charCode <= range1[1]) or (
                charCode >= range2[0] and charCode <= range2[1]
            ) or (charCode >= range3[0] and charCode <= range3[1]) or (
                charCode >= range4[0] and charCode <= range4[1]) or (
                    charCode >= range5[0]
                    and charCode <= range5[1]) or (charCode >= range6[0]
                                                   and charCode <= range6[1])
            if valid == False:
                return False

        return True

    def include_in_dictionary_transcription(transcription):
        """
        Function applied to the Ground Truth transcriptions used in Word Spotting. It removes special characters or terminations
        """
        #special case 's at final
        if transcription[len(transcription) -
                         2:] == "'s" or transcription[len(transcription) -
                                                      2:] == "'S":
            transcription = transcription[0:len(transcription) - 2]

        #hypens at init or final of the word
        transcription = transcription.strip('-')

        specialCharacters = str("'!?.:,*\"()·[]/")
        for character in specialCharacters:
            transcription = transcription.replace(character, ' ')

        transcription = transcription.strip()

        return transcription

    perSampleMetrics = {}

    matchedSum = 0
    det_only_matchedSum = 0

    Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')

    gt = rrc_evaluation_funcs.load_zip_file(
        gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
    subm = rrc_evaluation_funcs.load_zip_file(
        submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)

    numGlobalCareGt = 0
    numGlobalCareDet = 0
    det_only_numGlobalCareGt = 0
    det_only_numGlobalCareDet = 0

    arrGlobalConfidences = []
    arrGlobalMatches = []

    for resFile in gt:
        # print('resgt', resFile)
        gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile])
        if (gtFile is None):
            raise Exception("The file %s is not UTF-8" % resFile)

        recall = 0
        precision = 0
        hmean = 0
        detCorrect = 0
        detOnlyCorrect = 0
        iouMat = np.empty([1, 1])
        gtPols = []
        detPols = []
        gtTrans = []
        detTrans = []
        gtPolPoints = []
        detPolPoints = []
        gtDontCarePolsNum = [
        ]  #Array of Ground Truth Polygons' keys marked as don't Care
        det_only_gtDontCarePolsNum = []
        detDontCarePolsNum = [
        ]  #Array of Detected Polygons' matched with a don't Care GT
        det_only_detDontCarePolsNum = []
        detMatchedNums = []
        pairs = []

        arrSampleConfidences = []
        arrSampleMatch = []
        sampleAP = 0

        pointsList, _, transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(
            gtFile, evaluationParams['CRLF'], evaluationParams['LTRB'], True,
            False)

        for n in range(len(pointsList)):
            points = pointsList[n]
            transcription = transcriptionsList[n]
            det_only_dontCare = dontCare = transcription == "###"  # ctw1500 and total_text gt have been modified to the same format.
            if evaluationParams['LTRB']:
                gtRect = Rectangle(*points)
                gtPol = rectangle_to_polygon(gtRect)
            else:
                gtPol = polygon_from_points(points)
            gtPols.append(gtPol)
            gtPolPoints.append(points)

            #On word spotting we will filter some transcriptions with special characters
            if evaluationParams['WORD_SPOTTING']:
                if dontCare == False:
                    if include_in_dictionary(transcription) == False:
                        dontCare = True
                    else:
                        transcription = include_in_dictionary_transcription(
                            transcription)

            gtTrans.append(transcription)
            if dontCare:
                gtDontCarePolsNum.append(len(gtPols) - 1)
            if det_only_dontCare:
                det_only_gtDontCarePolsNum.append(len(gtPols) - 1)

        if resFile in subm:

            detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile])

            pointsList, confidencesList, transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents_det(
                detFile, evaluationParams['CRLF'], evaluationParams['LTRB'],
                True, evaluationParams['CONFIDENCES'])

            for n in range(len(pointsList)):
                points = pointsList[n]
                transcription = transcriptionsList[n]

                if evaluationParams['LTRB']:
                    detRect = Rectangle(*points)
                    detPol = rectangle_to_polygon(detRect)
                else:
                    detPol = polygon_from_points(points)
                detPols.append(detPol)
                detPolPoints.append(points)
                detTrans.append(transcription)

                if len(gtDontCarePolsNum) > 0:
                    for dontCarePol in gtDontCarePolsNum:
                        dontCarePol = gtPols[dontCarePol]
                        intersected_area = get_intersection(
                            dontCarePol, detPol)
                        pdDimensions = detPol.area()
                        precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
                        if (precision >
                                evaluationParams['AREA_PRECISION_CONSTRAINT']):
                            detDontCarePolsNum.append(len(detPols) - 1)
                            break

                if len(det_only_gtDontCarePolsNum) > 0:
                    for dontCarePol in det_only_gtDontCarePolsNum:
                        dontCarePol = gtPols[dontCarePol]
                        intersected_area = get_intersection(
                            dontCarePol, detPol)
                        pdDimensions = detPol.area()
                        precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
                        if (precision >
                                evaluationParams['AREA_PRECISION_CONSTRAINT']):
                            det_only_detDontCarePolsNum.append(
                                len(detPols) - 1)
                            break

            if len(gtPols) > 0 and len(detPols) > 0:
                #Calculate IoU and precision matrixs
                outputShape = [len(gtPols), len(detPols)]
                iouMat = np.empty(outputShape)
                gtRectMat = np.zeros(len(gtPols), np.int8)
                detRectMat = np.zeros(len(detPols), np.int8)
                det_only_gtRectMat = np.zeros(len(gtPols), np.int8)
                det_only_detRectMat = np.zeros(len(detPols), np.int8)
                for gtNum in range(len(gtPols)):
                    for detNum in range(len(detPols)):
                        pG = gtPols[gtNum]
                        pD = detPols[detNum]
                        iouMat[gtNum,
                               detNum] = get_intersection_over_union(pD, pG)

                for gtNum in range(len(gtPols)):
                    for detNum in range(len(detPols)):
                        if gtRectMat[gtNum] == 0 and detRectMat[
                                detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
                            if iouMat[gtNum, detNum] > evaluationParams[
                                    'IOU_CONSTRAINT']:
                                gtRectMat[gtNum] = 1
                                detRectMat[detNum] = 1
                                #detection matched only if transcription is equal
                                # det_only_correct = True
                                # detOnlyCorrect += 1
                                if evaluationParams['WORD_SPOTTING']:
                                    edd = string_metric.levenshtein(
                                        gtTrans[gtNum].upper(),
                                        detTrans[detNum].upper())
                                    if edd <= 0:
                                        correct = True
                                    else:
                                        correct = False
                                    # correct = gtTrans[gtNum].upper() == detTrans[detNum].upper()
                                else:
                                    try:
                                        correct = transcription_match(
                                            gtTrans[gtNum].upper(),
                                            detTrans[detNum].upper(),
                                            evaluationParams[
                                                'SPECIAL_CHARACTERS'],
                                            evaluationParams[
                                                'ONLY_REMOVE_FIRST_LAST_CHARACTER']
                                        ) == True
                                    except:  # empty
                                        correct = False
                                detCorrect += (1 if correct else 0)
                                if correct:
                                    detMatchedNums.append(detNum)

                for gtNum in range(len(gtPols)):
                    for detNum in range(len(detPols)):
                        if det_only_gtRectMat[gtNum] == 0 and det_only_detRectMat[
                                detNum] == 0 and gtNum not in det_only_gtDontCarePolsNum and detNum not in det_only_detDontCarePolsNum:
                            if iouMat[gtNum, detNum] > evaluationParams[
                                    'IOU_CONSTRAINT']:
                                det_only_gtRectMat[gtNum] = 1
                                det_only_detRectMat[detNum] = 1
                                #detection matched only if transcription is equal
                                det_only_correct = True
                                detOnlyCorrect += 1

        numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
        numDetCare = (len(detPols) - len(detDontCarePolsNum))
        det_only_numGtCare = (len(gtPols) - len(det_only_gtDontCarePolsNum))
        det_only_numDetCare = (len(detPols) - len(det_only_detDontCarePolsNum))
        if numGtCare == 0:
            recall = float(1)
            precision = float(0) if numDetCare > 0 else float(1)
        else:
            recall = float(detCorrect) / numGtCare
            precision = 0 if numDetCare == 0 else float(
                detCorrect) / numDetCare

        if det_only_numGtCare == 0:
            det_only_recall = float(1)
            det_only_precision = float(
                0) if det_only_numDetCare > 0 else float(1)
        else:
            det_only_recall = float(detOnlyCorrect) / det_only_numGtCare
            det_only_precision = 0 if det_only_numDetCare == 0 else float(
                detOnlyCorrect) / det_only_numDetCare

        hmean = 0 if (
            precision +
            recall) == 0 else 2.0 * precision * recall / (precision + recall)
        det_only_hmean = 0 if (
            det_only_precision + det_only_recall
        ) == 0 else 2.0 * det_only_precision * det_only_recall / (
            det_only_precision + det_only_recall)

        matchedSum += detCorrect
        det_only_matchedSum += detOnlyCorrect
        numGlobalCareGt += numGtCare
        numGlobalCareDet += numDetCare
        det_only_numGlobalCareGt += det_only_numGtCare
        det_only_numGlobalCareDet += det_only_numDetCare

        perSampleMetrics[resFile] = {
            'precision': precision,
            'recall': recall,
            'hmean': hmean,
            'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
            'gtPolPoints': gtPolPoints,
            'detPolPoints': detPolPoints,
            'gtTrans': gtTrans,
            'detTrans': detTrans,
            'gtDontCare': gtDontCarePolsNum,
            'detDontCare': detDontCarePolsNum,
            'evaluationParams': evaluationParams,
        }

    methodRecall = 0 if numGlobalCareGt == 0 else float(
        matchedSum) / numGlobalCareGt
    methodPrecision = 0 if numGlobalCareDet == 0 else float(
        matchedSum) / numGlobalCareDet
    methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / (
        methodRecall + methodPrecision)

    det_only_methodRecall = 0 if det_only_numGlobalCareGt == 0 else float(
        det_only_matchedSum) / det_only_numGlobalCareGt
    det_only_methodPrecision = 0 if det_only_numGlobalCareDet == 0 else float(
        det_only_matchedSum) / det_only_numGlobalCareDet
    det_only_methodHmean = 0 if det_only_methodRecall + det_only_methodPrecision == 0 else 2 * det_only_methodRecall * det_only_methodPrecision / (
        det_only_methodRecall + det_only_methodPrecision)

    methodMetrics = r"E2E_RESULTS: precision: {}, recall: {}, hmean: {}".format(
        methodPrecision, methodRecall, methodHmean)
    det_only_methodMetrics = r"DETECTION_ONLY_RESULTS: precision: {}, recall: {}, hmean: {}".format(
        det_only_methodPrecision, det_only_methodRecall, det_only_methodHmean)

    resDict = {
        'calculated': True,
        'Message': '',
        'e2e_method': methodMetrics,
        'det_only_method': det_only_methodMetrics,
        'per_sample': perSampleMetrics
    }

    return resDict
Exemplo n.º 14
0
def distance(s1, s2):
    logger = logging.getLogger(__name__)
    logger.warn('This function is deprecated and will be removed in v2.0.0.\n'
                'Use string_metric.levenshtein(s1, s2) instead')
    return string_metric.levenshtein(s1, s2)
Exemplo n.º 15
0
def test_simple_unicode_tests():
    """
    some very simple tests using unicode with scorers
    to catch relatively obvious implementation errors
    """
    s1 = u"ÁÄ"
    s2 = "ABCD"
    assert string_metric.levenshtein(s1, s2) == 4           # 2 sub + 2 ins
    assert string_metric.levenshtein(s1, s2, (1,1,0)) == 2  # 2 sub + 2 ins
    assert string_metric.levenshtein(s1, s2, (1,1,2)) == 6  # 2 del + 4 ins / 2 sub + 2 ins
    assert string_metric.levenshtein(s1, s2, (1,1,5)) == 6  # 2 del + 4 ins
    assert string_metric.levenshtein(s1, s2, (1,7,5)) == 12 # 2 sub + 2 ins
    assert string_metric.levenshtein(s2, s1, (1,7,5)) == 24 # 2 sub + 2 del

    assert string_metric.levenshtein(s1, s1) == 0
    assert string_metric.levenshtein(s1, s1, (1,1,0)) == 0
    assert string_metric.levenshtein(s1, s1, (1,1,2)) == 0
    assert string_metric.levenshtein(s1, s1, (1,1,5)) == 0
    assert string_metric.levenshtein(s1, s1, (3,7,5)) == 0