Exemplo n.º 1
0
def demo(args):
    hyps = hyp_parse(args.hyp)
    ds = DATASETS[args.dataset](level=1)
    model = RetinaNet(backbone=args.backbone, hyps=hyps)
    if args.weight.endswith('.pth'):
        chkpt = torch.load(args.weight)
        # load model
        if 'model' in chkpt.keys():
            model.load_state_dict(chkpt['model'])
        else:
            model.load_state_dict(chkpt)
        print('load weight from: {}'.format(args.weight))
    model.eval()

    #     if  os.path.exists('outputs'):
    #         shutil.rmtree('outputs')
    #     os.mkdir('outputs')

    t0 = time.time()
    if not args.dataset == 'DOTA':
        ims_list = [x for x in os.listdir(args.ims_dir) if is_image(x)]
        for idx, im_name in enumerate(ims_list):
            s = ''
            t = time.time()
            im_path = os.path.join(args.ims_dir, im_name)
            s += 'image %g/%g %s: ' % (idx, len(ims_list), im_path)
            src = cv2.imread(im_path, cv2.IMREAD_COLOR)
            im = cv2.cvtColor(src, cv2.COLOR_BGR2RGB)
            cls_dets = im_detect(model, im, target_sizes=args.target_size)
            for j in range(len(cls_dets)):
                cls, scores = cls_dets[j, 0], cls_dets[j, 1]
                bbox = cls_dets[j, 2:]
                if len(bbox) == 4:
                    draw_caption(src, bbox, '{:1.3f}'.format(scores))
                    cv2.rectangle(src, (int(bbox[0]), int(bbox[1])),
                                  (int(bbox[2]), int(bbox[3])),
                                  color=(0, 0, 255),
                                  thickness=2)
                else:
                    pts = np.array([rbox_2_quad(bbox[:5]).reshape((4, 2))],
                                   dtype=np.int32)
                    cv2.drawContours(src,
                                     pts,
                                     0,
                                     color=(0, 255, 0),
                                     thickness=2)
                    put_label = True
                    if put_label:
                        label = ds.return_class(cls) + str(' %.2f' % scores)
                        fontScale = 0.7
                        font = cv2.FONT_HERSHEY_COMPLEX
                        thickness = 1
                        t_size = cv2.getTextSize(label,
                                                 font,
                                                 fontScale=fontScale,
                                                 thickness=thickness)[0]
                        c1 = tuple(bbox[:2].astype('int'))
                        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 5
                        cv2.rectangle(src, c1, c2, [0, 255, 0], -1)  # filled
                        cv2.putText(src,
                                    label, (c1[0], c1[1] - 5),
                                    font,
                                    fontScale, [0, 0, 0],
                                    thickness=thickness,
                                    lineType=cv2.LINE_AA)
                    # display original anchors
                    # if len(bbox) > 5:
                    #     pts = np.array([rbox_2_quad(bbox[5:]).reshape((4, 2))], dtype=np.int32)
                    #     cv2.drawContours(src, pts, 0, color=(0, 0, 255), thickness=2)

            # resize for better shown
            # im = cv2.resize(src, (800, 800), interpolation=cv2.INTER_LINEAR)
            # cv2.imshow('Detection Results', im)
            # cv2.waitKey(0)

            print('%sDone. (%.3fs) %d objs' %
                  (s, time.time() - t, len(cls_dets)))
            # save image

            out_path = os.path.join('outputs', os.path.split(im_path)[1])
            cv2.imwrite(out_path, src)
    ## DOTA detct on large image
    else:
        evaluate(args.target_size,
                 args.ims_dir,
                 'DOTA',
                 args.backbone,
                 args.weight,
                 hyps=hyps,
                 conf=0.05)
        if os.path.exists('outputs/dota_out'):
            shutil.rmtree('outputs/dota_out')
        os.mkdir('outputs/dota_out')
        exec(
            'cd outputs &&  rm -rf detections && rm -rf integrated  && rm -rf merged'
        )
        ResultMerge('outputs/detections', 'outputs/integrated',
                    'outputs/merged', 'outputs/dota_out')
        img_path = os.path.join(args.ims_dir, 'images')
        label_path = 'outputs/dota_out'
        save_imgs = False
        if save_imgs:
            show_dota_results(img_path, label_path)
    print('Done. (%.3fs)' % (time.time() - t0))
Exemplo n.º 2
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Hyperparams')
    parser.add_argument('--backbone',
                        dest='backbone',
                        default='res50',
                        type=str)
    parser.add_argument('--weight', type=str, default='weights/best.pth')
    parser.add_argument('--target_size',
                        dest='target_size',
                        default=[800],
                        type=int)
    parser.add_argument('--hyp',
                        type=str,
                        default='hyp.py',
                        help='hyper-parameter path')

    # parser.add_argument('--dataset', nargs='?', type=str, default='NWPU_VHR')
    # parser.add_argument('--test_path', type=str, default='NWPU_VHR/test.txt')
    parser.add_argument('--dataset', nargs='?', type=str, default='HRSC2016')
    parser.add_argument('--test_path', type=str, default='./HRSC2016/test.txt')

    arg = parser.parse_args()
    hyps = hyp_parse(arg.hyp)
    evaluate(arg.target_size,
             arg.test_path,
             arg.dataset,
             arg.backbone,
             arg.weight,
             hyps=hyps)