예제 #1
0
def test_rcnn(imageset,
              year,
              root_path,
              devkit_path,
              prefix,
              epoch,
              ctx,
              vis=False,
              has_rpn=True,
              proposal='rpn'):
    # load symbol and testing data
    if has_rpn:
        sym = get_vgg_test()
        config.TEST.HAS_RPN = True
        config.TEST.RPN_PRE_NMS_TOP_N = 6000
        config.TEST.RPN_POST_NMS_TOP_N = 300
        voc, roidb = load_gt_roidb(imageset, year, root_path, devkit_path)
    else:
        sym = get_vgg_rcnn_test()
        voc, roidb = eval('load_test_' + proposal + '_roidb')(imageset, year,
                                                              root_path,
                                                              devkit_path)

    # get test data iter
    test_data = ROIIter(roidb, batch_size=1, shuffle=False, mode='test')

    # load model
    args, auxs, _ = load_param(prefix, epoch, convert=True, ctx=ctx)

    # detect
    detector = Detector(sym, ctx, args, auxs)
    pred_eval(detector, test_data, voc, vis=vis)
예제 #2
0
def test_rcnn(imageset, year, root_path, devkit_path, prefix, epoch, ctx, vis=False, has_rpn=True, proposal='rpn',
              end2end=False):
    # load symbol and testing data
    if has_rpn:
        # sym = get_vgg_test()
        config.TRAIN.AGNOSTIC = True
        config.END2END = 1
        config.PIXEL_MEANS = np.array([[[0,0,0]]])
        sym = resnext_101(num_class=21)
        config.TEST.HAS_RPN = True
        config.TEST.RPN_PRE_NMS_TOP_N = 6000
        config.TEST.RPN_POST_NMS_TOP_N = 300
        voc, roidb = load_gt_roidb(imageset, year, root_path, devkit_path)
    else:
        sym = get_vgg_rcnn_test()
        voc, roidb = eval('load_test_' + proposal + '_roidb')(imageset, year, root_path, devkit_path)

    # get test data iter
    test_data = ROIIter(roidb, batch_size=1, shuffle=False, mode='test')

    # load model
    args, auxs, _ = load_param(prefix, epoch, convert=True, ctx=ctx)

    # detect
    detector = Detector(sym, ctx, args, auxs)
    pred_eval(detector, test_data, voc, vis=vis)
예제 #3
0
파일: test_rcnn.py 프로젝트: Orange15/mxnet
def test_net(imageset, year, root_path, devkit_path, prefix, epoch, ctx, vis):
    # set up logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # load testing data
    voc, roidb = load_test_rpn_roidb(imageset, year, root_path, devkit_path)
    test_data = ROIIter(roidb, batch_size=1, shuffle=False, mode="test")

    # load model
    args, auxs = load_param(prefix, epoch, convert=True, ctx=ctx)

    # load symbol
    sym = get_vgg_rcnn_test()

    # detect
    detector = Detector(sym, ctx, args, auxs)
    pred_eval(detector, test_data, voc, vis=vis)
예제 #4
0
def test_net(imageset, year, root_path, devkit_path, prefix, epoch, ctx, vis):
    # set up logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # load testing data
    voc, roidb = load_test_rpn_roidb(imageset, year, root_path, devkit_path)
    test_data = ROIIter(roidb, batch_size=1, shuffle=False, mode='test')

    # load model
    args, auxs = load_param(prefix, epoch, convert=True, ctx=ctx)

    # load symbol
    sym = get_vgg_rcnn_test()

    # detect
    detector = Detector(sym, ctx, args, auxs)
    pred_eval(detector, test_data, voc, vis=vis)
예제 #5
0
파일: test_rcnn.py 프로젝트: 4ker/mxnet
def test_rcnn(imageset, year, root_path, devkit_path, prefix, epoch, ctx, vis=False, has_rpn=True, proposal='rpn'):
    # load symbol and testing data
    if has_rpn:
        sym = get_vgg_test()
        config.TEST.HAS_RPN = True
        config.TEST.RPN_PRE_NMS_TOP_N = 6000
        config.TEST.RPN_POST_NMS_TOP_N = 300
        voc, roidb = load_gt_roidb(imageset, year, root_path, devkit_path)
    else:
        sym = get_vgg_rcnn_test()
        voc, roidb = eval('load_test_' + proposal + '_roidb')(imageset, year, root_path, devkit_path)

    # get test data iter
    test_data = ROIIter(roidb, batch_size=1, shuffle=False, mode='test')

    # load model
    args, auxs, _ = load_param(prefix, epoch, convert=True, ctx=ctx)

    # detect
    detector = Detector(sym, ctx, args, auxs)
    pred_eval(detector, test_data, voc, vis=vis)
예제 #6
0
파일: demo.py 프로젝트: Alexbert1/mxnet
def get_net(prefix, epoch, ctx):
    args, auxs = load_param(prefix, epoch, convert=True, ctx=ctx)
    sym = get_vgg_rcnn_test()
    detector = Detector(sym, ctx, args, auxs)
    return detector
예제 #7
0
파일: demo.py 프로젝트: izhaolei/mx-rcnn
def get_net(prefix, epoch, ctx):
    args, auxs = load_param(prefix, epoch, convert=True, ctx=ctx)
    sym = get_vgg_rcnn_test()
    detector = Detector(sym, ctx, args, auxs)
    return detector