Example #1
0
def genGT(file) :
    with open(file, 'r', encoding = 'utf-8') as f :
        gt = []
        lines = f.readlines()
        for line in lines :
            line = line.split(' ')
            points = line[0]
            text = re.sub(r'\n', '', ''.join(line[1 :]))
            box = [float(p) for p in points.split(',')]
            box = np.array(box).astype('int32')
            box = box.reshape(4, 2)
            box = order_point(box)
            box = box.reshape(-1)
            angle, w, h, cx, cy = solve(box)
            gt.append({'angle' : angle, 'w' : w, 'h' : h, 'cx' : cx, 'cy' : cy, 'text' : text})

        basic = getBasics(gt)
        buyer = getBuyer(gt)
        content = getContent(gt)
        seller = getSeller(gt)
        summation = getSummation(gt)

        groundtruth = [{'title' : r'发票基本信息', 'items' : basic},
                       {'title' : r'购买方', 'items' : buyer},
                       {'title' : r'销售方', 'items' : seller},
                       {'title' : r'货物或应税劳务、服务', 'items' : content},
                       {'title' : r'合计', 'items' : summation}]

    return groundtruth
Example #2
0
def evaluate(file, errfile, canfile, predict = None) :
    with open(file, 'r', encoding = 'utf-8') as f :
        gt = []
        precision = []
        lines = f.readlines()
        # print("file:",file)
        # print("len:",len(lines))
        # i = 1
        for line in lines :
            line = line.split(',')

            flag = len(line)

            if flag==8:# 空格
                points = line[:7]
                line = line[7].split(' ')
                points.append(line[0])
                text = ''
                for i in range(1,len(line)):
                    text += line[i]
                text = re.sub(r'\n', '', ''.join(text))
                # print("points:",points)
                box = [float(p) for p in points]
            else:#逗号
                points = line[:8]
                text = line[8:]
                text = re.sub(r'\n','',''.join(text))
                box = [float(p) for p in points]
                pass
            box = np.array(box).astype('int32')
            box = box.reshape(4, 2)
            box = order_point(box)
            box = box.reshape(-1)
            angle, w, h, cx, cy = solve(box)
            gt.append({'angle' : angle, 'w' : w, 'h' : h, 'cx' : cx, 'cy' : cy, 'text' : text})
        basic = getBasics(gt)
        buyer = getBuyer(gt)
        content = getContent(gt)
        seller = getSeller(gt)
        summation = getSummation(gt)

        gt_len = len(gt)

        groundtruth = [{'title' : r'发票基本信息', 'items' : basic},
                       {'title' : r'购买方', 'items' : buyer},
                       {'title' : r'销售方', 'items' : seller},
                       {'title' : r'货物或应税劳务、服务', 'items' : content},
                       {'title' : r'合计', 'items' : summation}]
        precision1 = calc_precision1(predict, groundtruth, errfile, canfile)
        # precision2 = calc_precision2(gt_len, errfile) # 不准
        precision.append(precision1)
        # precision.append(precision2)
        return precision
def test(args, file=None):
    result = []
    data_loader = DataLoader(long_size=args.long_size, file=file)
    test_loader = torch.utils.data.DataLoader(data_loader,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=2,
                                              drop_last=True)

    slice = 0
    # Setup Model
    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True,
                                num_classes=7,
                                scale=args.scale)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True,
                                 num_classes=7,
                                 scale=args.scale)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True,
                                 num_classes=7,
                                 scale=args.scale)
    elif args.arch == "mobilenet":
        model = models.Mobilenet(pretrained=True,
                                 num_classes=6,
                                 scale=args.scale)
        slice = -1

    for param in model.parameters():
        param.requires_grad = False

    # model = model.cuda()

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)

            # model.load_state_dict(checkpoint['state_dict'])
            d = collections.OrderedDict()
            for key, value in checkpoint['state_dict'].items():
                tmp = key[7:]
                d[tmp] = value

            try:
                model.load_state_dict(d)
            except:
                model.load_state_dict(checkpoint['state_dict'])

            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            sys.stdout.flush()
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            sys.stdout.flush()

    model.eval()

    total_frame = 0.0
    total_time = 0.0
    for idx, (org_img, img) in enumerate(test_loader):
        print('progress: %d / %d' % (idx, len(test_loader)))
        sys.stdout.flush()

        # img = Variable(img.cuda(), volatile=True)
        org_img = org_img.numpy().astype('uint8')[0]
        text_box = org_img.copy()

        # torch.cuda.synchronize()
        start = time.time()

        # angle detection
        # org_img, angle = detect_angle(org_img)
        outputs = model(img)

        score = torch.sigmoid(outputs[:, slice, :, :])
        outputs = (torch.sign(outputs - args.binary_th) + 1) / 2

        text = outputs[:, slice, :, :]
        kernels = outputs
        # kernels = outputs[:, 0:args.kernel_num, :, :] * text

        score = score.data.cpu().numpy()[0].astype(np.float32)
        text = text.data.cpu().numpy()[0].astype(np.uint8)
        kernels = kernels.data.cpu().numpy()[0].astype(np.uint8)

        if args.arch == 'mobilenet':
            pred = pse2(kernels,
                        args.min_kernel_area / (args.scale * args.scale))
        else:
            # c++ version pse
            pred = pse(kernels,
                       args.min_kernel_area / (args.scale * args.scale))
            # python version pse
            # pred = pypse(kernels, args.min_kernel_area / (args.scale * args.scale))

        # scale = (org_img.shape[0] * 1.0 / pred.shape[0], org_img.shape[1] * 1.0 / pred.shape[1])
        scale = (org_img.shape[1] * 1.0 / pred.shape[1],
                 org_img.shape[0] * 1.0 / pred.shape[0])
        label = pred
        label_num = np.max(label) + 1
        bboxes = []
        rects = []
        for i in range(1, label_num):
            points = np.array(np.where(label == i)).transpose((1, 0))[:, ::-1]

            if points.shape[0] < args.min_area / (args.scale * args.scale):
                continue

            score_i = np.mean(score[label == i])
            if score_i < args.min_score:
                continue

            rect = cv2.minAreaRect(points)
            bbox = cv2.boxPoints(rect) * scale
            bbox = bbox.astype('int32')
            bbox = order_point(bbox)
            # bbox = np.array([bbox[1], bbox[2], bbox[3], bbox[0]])
            bboxes.append(bbox.reshape(-1))

            rec = []
            rec.append(rect[-1])
            rec.append(rect[1][1] * scale[1])
            rec.append(rect[1][0] * scale[0])
            rec.append(rect[0][0] * scale[0])
            rec.append(rect[0][1] * scale[1])
            rects.append(rec)

        # torch.cuda.synchronize()
        end = time.time()
        total_frame += 1
        total_time += (end - start)
        print('fps: %.2f' % (total_frame / total_time))
        sys.stdout.flush()

        for bbox in bboxes:
            cv2.drawContours(text_box, [bbox.reshape(4, 2)], -1, (0, 255, 0),
                             2)

        image_name = data_loader.img_paths[idx].split('/')[-1].split('.')[0]
        write_result_as_txt(image_name, bboxes, 'outputs/submit_invoice/')

        text_box = cv2.resize(text_box, (text.shape[1], text.shape[0]))
        debug(idx, data_loader.img_paths, [[text_box]], 'data/images/tmp/')

        result = crnnRec(cv2.cvtColor(org_img, cv2.COLOR_BGR2RGB), rects)
        result = formatResult(result)

    # cmd = 'cd %s;zip -j %s %s/*' % ('./outputs/', 'submit_invoice.zip', 'submit_invoice')
    # print(cmd)
    # sys.stdout.flush()
    # util.cmd.Cmd(cmd)
    return result