Example #1
0
def evaluate():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name",
                        help="model name",
                        type=str,
                        default='densenet-lite-lstm')
    parser.add_argument("--model-epoch",
                        type=int,
                        default=None,
                        help="model epoch")
    parser.add_argument(
        "--gpu",
        help="Number of GPUs for training [Default 0, means using cpu]"
        "目前限制gpu <= 1,因为 gpu > 1时预测结果有问题,与 gpu = 1时不同,暂未发现原因。",
        type=int,
        default=0,
    )
    parser.add_argument(
        "-i",
        "--input-fp",
        default='test.txt',
        help="the file path with image names and labels",
    )
    parser.add_argument("--image-prefix-dir",
                        default='.',
                        help="图片所在文件夹,相对于索引文件中记录的图片位置")
    parser.add_argument("--batch-size",
                        type=int,
                        default=128,
                        help="batch size")
    parser.add_argument(
        "-v",
        "--verbose",
        action='store_true',
        help="whether to print details to screen",
    )

    parser.add_argument(
        "-o",
        "--output-dir",
        default=False,
        help="the output directory which records the analysis results",
    )
    args = parser.parse_args()
    assert args.gpu <= 1
    context = gen_context(args.gpu)

    ocr = CnOcr(model_name=args.model_name,
                model_epoch=args.model_epoch,
                context=context)
    alphabet = ocr._alphabet

    fn_labels_list = read_input_file(args.input_fp)

    miss_cnt, redundant_cnt = Counter(), Counter()
    model_time_cost = 0.0
    start_idx = 0
    bad_cnt = 0
    badcases = []
    while start_idx < len(fn_labels_list):
        logger.info('start_idx: %d', start_idx)
        batch = fn_labels_list[start_idx:start_idx + args.batch_size]
        batch_img_fns = []
        batch_labels = []
        batch_imgs = []
        for fn, labels in batch:
            batch_labels.append(labels)
            img_fp = os.path.join(args.image_prefix_dir, fn)
            batch_img_fns.append(img_fp)
            img = mx.image.imread(img_fp, 1).asnumpy()
            batch_imgs.append(img)

        start_time = time.time()
        batch_preds = ocr.ocr_for_single_lines(batch_imgs)
        model_time_cost += time.time() - start_time
        for bad_info in compare_preds_to_reals(batch_preds, batch_labels,
                                               batch_img_fns, alphabet):
            if args.verbose:
                logger.info('\t'.join(bad_info))
            distance = Levenshtein.distance(bad_info[1], bad_info[2])
            bad_info.insert(0, distance)
            badcases.append(bad_info)
            miss_cnt.update(list(bad_info[-2]))
            redundant_cnt.update(list(bad_info[-1]))
            bad_cnt += 1

        start_idx += args.batch_size

    badcases.sort(key=itemgetter(0), reverse=True)

    output_dir = Path(args.output_dir)
    if not output_dir.exists():
        os.makedirs(output_dir)
    with open(output_dir / 'badcases.txt', 'w') as f:
        f.write('\t'.join([
            'distance',
            'image_fp',
            'real_words',
            'pred_words',
            'miss_words',
            'redundant_words',
        ]) + '\n')
        for bad_info in badcases:
            f.write('\t'.join(map(str, bad_info)) + '\n')
    with open(output_dir / 'miss_words_stat.txt', 'w') as f:
        for word, num in miss_cnt.most_common():
            f.write('\t'.join([word, str(num)]) + '\n')
    with open(output_dir / 'redundant_words_stat.txt', 'w') as f:
        for word, num in redundant_cnt.most_common():
            f.write('\t'.join([word, str(num)]) + '\n')

    logger.info(
        "number of total cases: %d, number of bad cases: %d, acc: %.4f, time cost per image: %f"
        % (
            len(fn_labels_list),
            bad_cnt,
            bad_cnt / len(fn_labels_list),
            model_time_cost / len(fn_labels_list),
        ))
Example #2
0
def evaluate(
    model_name,
    pretrained_model_fp,
    context,
    eval_index_fp,
    img_folder,
    batch_size,
    output_dir,
    verbose,
):
    ocr = CnOcr(model_name=model_name,
                model_fp=pretrained_model_fp,
                context=context)

    fn_labels_list = read_input_file(eval_index_fp)

    miss_cnt, redundant_cnt = Counter(), Counter()
    total_time_cost = 0.0
    bad_cnt = 0
    badcases = []

    start_idx = 0
    while start_idx < len(fn_labels_list):
        logger.info('start_idx: %d', start_idx)
        batch = fn_labels_list[start_idx:start_idx + batch_size]
        img_fps = [os.path.join(img_folder, fn) for fn, _ in batch]
        reals = [labels for _, labels in batch]

        imgs = [read_img(img) for img in img_fps]
        start_time = time.time()
        outs = ocr.ocr_for_single_lines(imgs, batch_size=1)
        total_time_cost += time.time() - start_time

        preds = [out[0] for out in outs]
        for bad_info in compare_preds_to_reals(preds, reals, img_fps):
            if verbose:
                logger.info('\t'.join(bad_info))
            distance = Levenshtein.distance(bad_info[1], bad_info[2])
            bad_info.insert(0, distance)
            badcases.append(bad_info)
            miss_cnt.update(list(bad_info[-2]))
            redundant_cnt.update(list(bad_info[-1]))
            bad_cnt += 1

        start_idx += batch_size

    badcases.sort(key=itemgetter(0), reverse=True)

    output_dir = Path(output_dir)
    if not output_dir.exists():
        os.makedirs(output_dir)
    with open(output_dir / 'badcases.txt', 'w') as f:
        f.write('\t'.join([
            'distance',
            'image_fp',
            'real_words',
            'pred_words',
            'miss_words',
            'redundant_words',
        ]) + '\n')
        for bad_info in badcases:
            f.write('\t'.join(map(str, bad_info)) + '\n')
    with open(output_dir / 'miss_words_stat.txt', 'w') as f:
        for word, num in miss_cnt.most_common():
            f.write('\t'.join([word, str(num)]) + '\n')
    with open(output_dir / 'redundant_words_stat.txt', 'w') as f:
        for word, num in redundant_cnt.most_common():
            f.write('\t'.join([word, str(num)]) + '\n')

    logger.info(
        "number of total cases: %d, number of bad cases: %d, acc: %.4f, time cost per image: %f"
        % (
            len(fn_labels_list),
            bad_cnt,
            1.0 - bad_cnt / len(fn_labels_list),
            total_time_cost / len(fn_labels_list),
        ))
Example #3
0
def main():
    '''
    DCMMC: A very critical flaw: results predicted on multi-gpu are misplaced!
    '''
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name", help="model name", type=str, default='conv-lite-fc'
    )
    parser.add_argument("--model_epoch", type=int, default=None, help="model epoch")
    parser.add_argument(
        '--dataset',
        type=str,
        help='location of hdf5 dataset'
    )
    parser.add_argument('--train_ratio', type=float, default=0.8,
                        help='train ratio of the dataset')
    parser.add_argument('--batch_size', type=int, default=64,
                        help='batch size')
    parser.add_argument('--gpus', type=int, default=0,
                        help='number of gpus, 0 indicates cpu only')
    args = parser.parse_args()

    ocr = CnOcr(model_name=args.model_name, model_epoch=args.model_epoch,
                gpus=args.gpus)
    log_cp = 50
    dataset = h5py.File(args.dataset, 'r')
    batch_size = args.batch_size
    gold_train, gold_val = [], []
    res_train, res_val = [], []
    num_train = int(len(dataset) * args.train_ratio)
    num_train_batch = int(np.ceil(num_train / batch_size))
    num_val_batch = int(np.ceil((len(dataset) - num_train) / batch_size))
    logger.info(f'total num samples={len(dataset)}')
    logger.info(f'num_train_batch={num_train_batch}, num_val_batch={num_val_batch}')
    logger.info('start train dataset')
    s_t = time()
    for idx_batch in range(num_train_batch):
        data = []
        for idx in range(idx_batch * args.batch_size,
                         min((idx_batch+1) * args.batch_size, num_train)):
            data.append(dataset[str(idx)]['img'][...])
            gold_train.append(str(dataset[str(idx)]['y'][...]))
        res = ocr.ocr_for_single_lines(data)
        res = [''.join(r) for r in res]
        res_train += res
        if idx_batch % log_cp == 0:
            log_str = f'batch [{idx_batch + 1}/{num_train_batch}]: '
            log_str += '{:.2f}s/batch'.format((time() - s_t)/(idx_batch + 1))
            logger.info(log_str)
        # if idx_batch >= 4:
        #     break
    logger.info('start val dataset')
    s_t = time()
    for idx_batch in range(num_val_batch):
        data = []
        for idx in range(idx_batch * args.batch_size + num_train,
                         min((idx_batch+1) * args.batch_size + num_train,
                             len(dataset))):
            data.append(dataset[str(idx)]['img'][...])
            gold_val.append(str(dataset[str(idx)]['y'][...]))
        res = ocr.ocr_for_single_lines(data)
        res = [''.join(r) for r in res]
        res_val += res
        if idx_batch % log_cp == 0:
            log_str = f'batch [{idx_batch + 1}/{num_val_batch}]: '
            log_str += '{:.2f}s/batch'.format((time() - s_t)/(idx_batch + 1))
            logger.info(log_str)
        # if idx_batch >= 4:
        #     break
    assert len(res_val) == len(gold_val)
    assert len(res_train) == len(gold_train)
    acc_train = sum([r == p for r, p in zip(res_train, gold_train)]) / len(res_train)
    acc_val = sum([r == p for r, p in zip(res_val, gold_val)]) / len(res_val)
    logger.info(f'acc_train={acc_train}, acc_val={acc_val}')
    dist_fn = lambda r, p: distance(r, p) / max(len(r), len(p))
    dist_train = sum([dist_fn(r, p) for r, p in zip(res_train, gold_train)]) / len(res_train)
    dist_val = sum([dist_fn(r, p) for r, p in zip(res_val, gold_val)]) / len(res_val)
    logger.info(f'dist_train={dist_train}, dist_val={dist_val}')
    logger.info(f'write to file. #train={len(res_train)}, #val={len(res_val)}')
    with open('train_no_bpe.tok.src', 'w') as f:
        f.writelines([r + '\n' for r in res_train])
    with open('train_no_bpe.tok.trg', 'w') as f:
        f.writelines([r + '\n' for r in gold_train])
    with open('dev_no_bpe.tok.src', 'w') as f:
        f.writelines([r + '\n' for r in res_val])
    with open('dev_no_bpe.tok.trg', 'w') as f:
        f.writelines([r + '\n' for r in gold_val])
    logger.info('write done.')
class PickStuNumber:
    def __init__(self, path: str, show_img: bool = False):
        self.__ext = {'jpg', 'jpeg'}
        self.__ocr = CnOcr(model_name='densenet-lite-gru',
                           cand_alphabet=string.digits,
                           name=path)
        self.__std = CnStd(name=path)
        self.__info_dict = {}
        self.__dup_name_dict = {}

        # 先对路径进行替换
        path = self.__format_path(path)

        # 根据传入的路径判断操作
        if os.path.isdir(path) or os.path.isfile(path):
            files = [self.__format_path(os.path.join(path, f)) for f in os.listdir(path) if
                     (os.path.isfile(os.path.join(path, f)) and self.__is_image(f))] \
                if os.path.isdir(path) \
                else [path]
            for file in tqdm(files):
                self.__handle_info(
                    file,
                    self.__ocr_number(
                        self.__std_number(self.__cutter(file, show_img))))
        else:
            print(f'获取数据错误,“{path}”既不是文件也不是文件夹')

    @staticmethod
    def __format_path(path: str):
        return os.path.abspath(path).replace('\\', '/')

    @staticmethod
    def __get_suffix(path: str) -> str:
        """
        获取后缀
        :param path: 图片路径
        :return: 是否为图片
        """
        return path.split('.')[-1]

    def __is_image(self, path: str) -> bool:
        return self.__get_suffix(path) in self.__ext

    @staticmethod
    def __cutter(path: str, show_img: bool = False) -> numpy.ndarray:
        """
        切割图片
        :param path: 图片路径
        :param show_img: 是否需要展示图片
        :return: 图片对应的 ndarray
        """
        print(path)

        # 以灰度模式读取图片
        origin_img = cv2.imread(path, 0)

        if show_img:
            # 自由拉伸窗口
            # cv2.namedWindow('bin img', 0)
            cv2.imshow('origin img', origin_img)

        # 切出一部分,取值是经验值
        origin_img = origin_img[:origin_img.shape[0] // 2]

        # 二值化
        _, origin_img = cv2.threshold(origin_img, 0, 255,
                                      cv2.THRESH_BINARY + cv2.THRESH_OTSU)

        if show_img:
            # 自由拉伸窗口
            # cv2.namedWindow('bin img', 0)
            cv2.imshow('bin img', origin_img)

        # 形态学转换,主要为了检测出那个红色的 banner
        kernel = numpy.ones((15, 15), dtype=numpy.uint8)
        # img = cv2.erode(img, kernel=kernel, iterations=1)
        img = cv2.dilate(origin_img, kernel=kernel, iterations=2)

        # 边缘检测
        contours, _ = cv2.findContours(img, 1, 2)
        # 找出第二大的,即红色的 banner
        contours = sorted(contours, key=cv2.contourArea, reverse=True)

        if len(contours) > 1:
            # 获取包围 banner 的矩形数据
            x, y, w, h = cv2.boundingRect(contours[1])

            # 目前所有的数值设定使用的是经验值
            if w * h > 250000:
                # 需要识别的学号部分
                # 左上角坐标
                left_top_x = x
                left_top_y = y + h + 20
                # 右下角坐标
                right_down_x = x + w
                right_down_y = y + h + 190

                img = origin_img[left_top_y:right_down_y,
                                 left_top_x:right_down_x]
            else:
                img = origin_img[120:]
        else:
            img = origin_img[120:]

        # 对切出的图片进行再次处理,以便图像识别
        kernel = numpy.ones((2, 2), dtype=numpy.uint8)
        # 腐蚀以加粗
        img = cv2.erode(img, kernel=kernel, iterations=1)
        # 重新映射回 rgb
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

        if show_img:
            # 自由拉伸窗口
            # cv2.namedWindow('final img', 0)
            cv2.imshow('final img', img)
            cv2.waitKey(0)
            cv2.destroyAllWindows()

        return img

    def __ocr_number(self, img_list: List[numpy.ndarray]):
        """
        识别数字
        :param img_list:
        :return:
        """
        return self.__ocr.ocr_for_single_lines(img_list)

    def __std_number(self, img: numpy.ndarray):
        """
        定位数字
        :param img:
        :return:
        """
        return [i['cropped_img'] for i in self.__std.detect(img)]

    @staticmethod
    def __handle_result_list(result_list: List[List[str]]) -> [str, bool]:
        """
        处理结果列表
        :param result_list: 结果列表
        :return: 结果,是否有效
        """
        result = result_list[0]

        if len(result) < 12 and len(result_list) > 1:
            for i in result_list:
                if len(i) >= 12:
                    result = i

        result = ''.join(result[:12] if len(result) >= 12 else result)
        print(result, re.match(r'\d{12}', result) is not None)
        return result, re.match(r'\d{12}', result) is not None

    def __handle_dup_name(self, name, path):
        dup_keys = self.__dup_name_dict.get(name)
        # 如设置过,即表明有重复的
        if dup_keys:
            # 设置重复的为 True,只要第一次重复时设置即可
            if 1 == len(dup_keys):
                self.__info_dict[dup_keys[0]]['dup'] = True
            # 将本次的 path 也添加进去
            self.__dup_name_dict[name].append(path)
            return True
        else:
            self.__dup_name_dict[name] = [path]
            return False

    def __handle_info(self, key, value):
        """
        处理每条信息
        :param key:
        :param value:
        """
        name, is_legal = self.__handle_result_list(value)
        self.__info_dict[key] = {
            'name': name,
            'suffix': self.__get_suffix(key),
            'legal': is_legal,
            'dup': self.__handle_dup_name(name, key)
        }

    def print_info(self):
        """
        打印图片信息
        :return:
        """
        beeprint.pp(self.__info_dict)
        return self

    def print_dup(self):
        """
        打印重复图片信息
        :return:
        """
        beeprint.pp(self.__dup_name_dict)
        return self

    def write_out(self,
                  path: str = '.',
                  out_path_suc: str = 'output_suc',
                  out_path_dup: str = 'output_dup',
                  out_path_fail: str = 'output_fail'):
        """
        输出重命名后的图片到文件夹
        :param path: 文件夹路径
        :param out_path_suc: 合规且不重复图片所在的文件夹
        :param out_path_dup: 合规但是重复图片所在的文件夹
        :param out_path_fail: 其它图片所在文件夹
        :return: self
        """
        # 处理路径
        path = self.__format_path(path)

        if os.path.isdir(path):
            # 拼接文件路径
            suc = os.path.join(path, out_path_suc)
            fail = os.path.join(path, out_path_fail)
            dup = os.path.join(path, out_path_dup)

            #  创建结果文件夹
            not os.path.exists(suc) and os.makedirs(suc)
            not os.path.exists(fail) and os.makedirs(fail)
            not os.path.exists(dup) and os.makedirs(dup)

            # 将图片输出到相应的文件夹
            for key, value in self.__info_dict.items():
                # 合规且不重复
                if value.get('legal') is True and value.get('dup') is False:
                    copyfile(
                        key,
                        os.path.join(
                            suc, f'{value.get("name")}.{value.get("suffix")}'))
                # 合规但是重复
                elif value.get('legal') is True and value.get('dup') is True:
                    index = self.__dup_name_dict[value.get("name")].index(key)
                    copyfile(
                        key,
                        os.path.join(
                            dup,
                            f'{value.get("name")}.{index}.{value.get("suffix")}'
                        ))
                else:
                    copyfile(
                        key,
                        os.path.join(
                            fail, f'{value.get("name")}.{value.get("suffix")}'
                            or os.path.split(key)[1]))
        else:
            print(f'“{path}” 并非一个合法的路径!')

        return self