Esempio n. 1
0
def test_loader(args):
    kernel_num = 7
    min_scale = 0.4
    start_epoch = 0

    data_loader = OcrDataLoader(args, is_transform=True, img_size=args.img_size, \
                                kernel_num=kernel_num, min_scale=min_scale, debug=True)
    # data_loader = IC15Loader(is_transform=True, img_size=args.img_size, kernel_num=kernel_num, min_scale=min_scale, debug=True)

    # data_loader = CTW1500Loader(is_transform=True, img_size=args.img_size, kernel_num=kernel_num, min_scale=min_scale)
    out_dir = 'outputs/ld_%s/' % (args.dataset)
    mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
    std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
    for i in range(args.n):
        # CHW
        img, gt_text, gt_kernals, training_mask = data_loader[i]
        img = img.cpu().numpy()
        img = ((img * std + mean) * 255).astype('uint8')
        # HWC
        img = img.transpose((1, 2, 0))
        if i == 0:
            print('img.shape: ', img.shape)

        # gt
        gt_text = gt_text.cpu().numpy()
        gt_text = (gt_text * 255).astype('uint8')
        out_path = out_dir + '%d.jpg' % i
        makedirs(out_path)
        cv2.imwrite(out_path, img)
        cv2.imwrite(out_path + '.gt.jpg', gt_text)
Esempio n. 2
0
def test_dataset(args):
    train_data = get_dataset_by_name(args.dataset, filter=args.filter)
    train_data.verbose()

    out_dir = 'outputs/ds_%s/' % (args.dataset)
    for i in range(args.n):
        item = train_data.getData(i)
        img = item['img']
        img = img[:, :, [2, 1, 0]].copy()
        path = item['path']
        img_name = os.path.basename(path)
        bboxes, tags = item['bboxes'], item['tags']
        num = len(bboxes)
        for idx, box in enumerate(bboxes):
            bboxes[idx] = np.array(box).astype('int32')

        # print(bboxes)
        gt_text = np.zeros(img.shape[0:2], dtype='uint8')
        training_mask = np.ones(img.shape[0:2], dtype='uint8')
        training_mask[:] = 255
        if num > 0:
            for i in range(num):
                cv2.drawContours(gt_text, [bboxes[i]], -1, 255, -1)
                img = cv2.drawContours(img, [bboxes[i]], -1, (0, 255, 0), 2)
                if not tags[i]:
                    cv2.drawContours(img, [bboxes[i]], -1, (0, 0, 255), 2)
                    cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1)

        out_path = out_dir + img_name
        makedirs(out_path)
        cv2.imwrite(out_path, img)
Esempio n. 3
0
def save_checkpoint(ckpt_prefix, net, step, epoch):
    msg = 'Saving checkpoint @ step:{} epoch:{}'.format(step, epoch)
    print(msg)
    if isinstance(net, torch.nn.DataParallel):
        net_state_dict = net.module.state_dict()
    else:
        net_state_dict = net.state_dict()
    _fname = 'step-%05d_epoch-%03d.ckpt' % (step, epoch)
    save_path = os.path.join(ckpt_prefix, _fname)
    makedirs(save_path)
    torch.save({
        'step': step,
        'epoch': epoch,
        'net_state_dict': net_state_dict},
        save_path)
Esempio n. 4
0
def make_splits(src, splits, dst, shuffle, norm):
    files = glob.glob(src)
    if shuffle:
        random.shuffle(files)
    shuffled = files
    # parse splits
    segs = splits.split('/')
    segs = [float(s) for s in segs]
    # norm ratios
    if norm:
        sum_w = sum(segs)
        ratios = [x / sum_w for x in segs]
    else:
        ratios = segs
    # write lists
    titles = ['train.txt', 'val.txt', 'test.txt']
    start_idx = 0
    total = len(shuffled)
    print('total:{}, splits:{}'.format(total, ratios))
    for idx, r in enumerate(ratios):
        if idx >= 3:
            break
        if r < 0.001:
            continue

        n = int(total * r)
        if idx == 2:
            left = total - start_idx
            if left - n < 3:
                n = left
        end_idx = start_idx + n
        seg = shuffled[start_idx:end_idx]
        list_path = os.path.join(dst, titles[idx])
        makedirs(list_path)
        # print(seg)
        write_lines(list_path, seg)
        start_idx = end_idx
Esempio n. 5
0
def save_ret_json(path, ret_list):
    makedirs(path)
    with open(path, 'w') as f:
        json.dump(ret_list, f, cls=NumpyEncoder, indent=2)
Esempio n. 6
0
def do_verify(args, extractor):
    output_dir = '.'
    # parse args
    image_size = args.image_size
    model_name = args.model_name
    test_set = args.test_set
    dist_type = args.dist_type
    do_mirror = args.do_mirror

    # load images
    data_ext = os.path.splitext(args.data)[1]
    if '.np' == data_ext > 0:
        pos_img, neg_img = pickle.load(open(args.data, 'rb'))
        #pos_img, neg_img = pickle.load(open(lfw_data, 'rb'), encoding='iso-8859-1')
    elif '.txt' == data_ext:
        if args.test_set == 'ytf':
            pos_img, neg_img = load_ytf_pairs(args.data, args.prefix)
        else:
            pos_img, neg_img = load_image_paris(args.data, args.prefix)
    elif '.bin' == data_ext:
        pos_img, neg_img = load_mxnet_bin(args.data)
    else:
        if args.test_set.startswith('cfp'):
            from deeploader.dataset.dataset_cfp import CFPDataset
            pos_list_, neg_list_ = CFPDataset(args.data).get_pairs('FP')
            pos_img = load_image_list(pos_list_)
            neg_img = load_image_list(neg_list_)
    # save input images
    pos_raw = pos_img
    neg_raw = neg_img

    # abstract
    print('Dataset  \t: %s (%s,%s)' % (args.test_set, args.data, args.prefix))
    print('Pairs    \t: %d/%d' % (len(pos_img), len(neg_img)))
    print('Testing  \t: %s' % model_name)
    print('Distance \t: %s' % dist_type)
    print('Do mirror\t: {}'.format(do_mirror))
    print('Image size\t: {}'.format(image_size))
    print('Do norm  \t: {}'.format(args.do_norm))
    print('Output   \t: {}'.format(args.error_dir))
    # crop
    pos_img = crop_pair_list(pos_img, image_size)
    neg_img = crop_pair_list(neg_img, image_size)
    # norm
    if args.do_norm == True:
        print('Norm images')
        pos_img = norm_pair_list(pos_img)
        neg_img = norm_pair_list(neg_img)
    # compute feature
    print('Extracting features ...')
    pos_list = extract_feature(extractor, pos_img)
    print('  Done positive pairs')
    neg_list = extract_feature(extractor, neg_img)
    print('  Done negative pairs')

    # evaluate
    print('Evaluating ...')
    precision, std, threshold, pos, neg, _ = verification(pos_list,
                                                          neg_list,
                                                          dist_type=dist_type)
    # _, title = os.path.split(extractor.weight)
    #draw_chart(title, output_dir, {'pos': pos, 'neg': neg}, precision, threshold)
    print('------------------------------------------------------------')
    print('Precision on %s : %1.5f+-%1.5f \nBest threshold   : %f' %
          (args.test_set, precision, std, threshold))
    # save errors
    if args.error_dir:
        pos_dist, neg_dist = compute_distance(pos_list, neg_list, dist_type)
        h, w, c = pos_raw[0][0].shape

        # pos
        target_dir = os.path.join(args.error_dir, 'pos')
        false_neg = 0
        for i in range(len(pos_dist)):
            dist = pos_dist[i][0]
            if dist < threshold:
                continue
            false_neg += 1
            pair = pos_raw[i]

            # save
            canvas = draw_error_pair(pair, dist)
            #img_path = target_dir + '/%.3f_%d_%d.jpg' % (dist, i, false_neg)
            img_path = target_dir + '/%d.jpg' % (i)
            makedirs(img_path)
            cv2.imwrite(img_path, canvas)
        # neg
        target_dir = os.path.join(args.error_dir, 'neg')
        false_pos = 0
        for i in range(len(neg_dist)):
            dist = neg_dist[i][0]
            if dist > threshold:
                continue
            false_pos += 1
            pair = neg_raw[i]
            # save
            canvas = draw_error_pair(pair, dist)
            #img_path = target_dir + '/%.3f_%d_%d.jpg' % (dist, i, false_pos)
            img_path = target_dir + '/%d.jpg' % (i)
            makedirs(img_path)
            cv2.imwrite(img_path, canvas)

    return precision, std
Esempio n. 7
0
def visualize_outputs(args, model, data_loader):
    title = data_loader.name
    # collect labels
    gt_bbox_list = []
    gt_tag_list = []
    pred_list = []
    bar = tqdm(total=len(data_loader))
    for idx, item in enumerate(data_loader):
        # read in RGB
        org_img = item['img']
        gt_bboxes = item['bboxes']
        gt_tags = item['tags']
        img_path = item['path']
        gt_bbox_list.append(gt_bboxes)
        gt_tag_list.append(gt_tags)
        # output dir
        img_name = os.path.basename(img_path)
        dst_path = 'outputs/pred_%s/%s' % (title, img_name)
        makedirs(dst_path)

        # get predictions
        if model:
            img = img_preprocess(org_img, args.long_size)
            torch.cuda.synchronize()
            pred_rets, score = run_PSENet(args,
                                          model,
                                          img,
                                          org_img.shape,
                                          out_type=args.out_type,
                                          return_score=True)
            torch.cuda.synchronize()
            pred_bboxes = []
            for item in pred_rets:
                pred_bboxes.append(item['bbox'])
            pred_list.append(pred_bboxes)
            # save pred bboxes
            save_ret_json(dst_path + '.json', pred_rets)
        else:
            # load prediction
            img_name = os.path.basename(img_path)
            dst_path = 'outputs/pred_%s/%s' % (title, img_name)
            pred_rets = load_ret_json(dst_path + '.json')
            pred_bboxes = []
            for item in pred_rets:
                pred_bboxes.append(item['bbox'])
            pred_list.append(pred_bboxes)

        # matching
        rets = score_by_IC15(gt_bboxes,
                             gt_tags,
                             pred_bboxes,
                             th=args.th,
                             tr=args.tr,
                             tp=args.tp,
                             tc=args.tc,
                             wr=args.wr,
                             wp=args.wp)

        gt_ret = rets[4]
        pred_ret = rets[5]
        # back to BGR
        org_img = org_img[:, :, [2, 1, 0]]
        # scale to 1280
        org_img, scale = img_scale_max(org_img, args.long_size)
        if model:
            # alpha blend
            score = score.reshape(score.shape[0], score.shape[1], 1) * 0.6
            blend = score * np.array([0, 165, 255]) + (1.0 - score) * org_img
            np.clip(blend, 0, 255.0)
            org_img = blend

        for idx, item in enumerate(gt_ret):
            dbox = item['bbox'].astype('float32') * scale
            item['bbox'] = dbox.astype('int32')
        for idx, item in enumerate(pred_ret):
            dbox = item['bbox'].astype('float32') * scale
            item['bbox'] = dbox.astype('int32')
        # OO: iou
        # OM: cover
        # gt: recall, fn, ignore
        # pd: tp, fp, ignore
        # gt
        gt_colors = [COLOR_BLUE, COLOR_GREEN, COLOR_RED]
        pred_colors = [COLOR_YELLOW, COLOR_CYAN, COLOR_PINK]

        def _draw_match_ret(org_img, gt_ret, colors):
            for gt_id, bbox in enumerate(gt_ret):
                if not bbox['valid']:
                    color = colors[0]
                elif len(bbox['matches']) >= 1:
                    color = colors[1]
                else:
                    color = colors[2]

                text = ''
                if 'iou' in bbox:
                    text = 'IoU:%.2f' % bbox['iou']
                if 'cover' in bbox:
                    text = 'Cover:%.2f' % bbox['cover']
                # draw contour
                org_img = draw_contour(org_img, bbox['bbox'], color, 1)
                # draw text, if any
                if text:
                    bound = get_contour_rect(bbox['bbox'])
                    text_pos = (bound[0], bound[1])
                    if text_pos[1] < 20:
                        text_pos = (bound[0], bound[1] + bound[3])
                    cv2.putText(org_img, text, text_pos, 1, 1, color)
            return org_img

        org_img = _draw_match_ret(org_img, gt_ret, gt_colors)
        org_img = _draw_match_ret(org_img, pred_ret, pred_colors)
        # save image
        cv2.imwrite(dst_path, org_img)
        bar.update(1)
    bar.close()
    return eval_score(args, pred_list, gt_bbox_list, gt_tag_list)