예제 #1
0
파일: test.py 프로젝트: neopenx/Dragon
# --------------------------------------------------------
# Seg-FCN for Dragon
# Copyright (c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------

""" Test a FCN-8s(PASCAL VOC) network """

import dragon.vm.caffe as caffe
import score
import numpy as np

weights = 'snapshot/train_iter_100000.caffemodel'

if __name__ == '__main__':

    # init
    caffe.set_mode_gpu()
    caffe.set_device(0)

    solver = caffe.SGDSolver('solver.prototxt')
    solver.net.copy_from(weights)

    # scoring
    val = np.loadtxt('../data/seg11valid.txt', dtype=str)
    score.seg_tests(solver, 'D:/seg', val)

예제 #2
0
# --------------------------------------------------------
# Seg-FCN for Dragon
# Copyright (c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
""" Test a FCN-8s(PASCAL VOC) network """

import dragon.vm.caffe as caffe
import score
import numpy as np

weights = 'snapshot/train_iter_100000.caffemodel'

if __name__ == '__main__':

    # init
    caffe.set_mode_gpu()
    caffe.set_device(0)

    solver = caffe.SGDSolver('solver.prototxt')
    solver.net.copy_from(weights)

    # scoring
    val = np.loadtxt('../data/seg11valid.txt', dtype=str)
    score.seg_tests(solver, 'D:/seg', val)
예제 #3
0
cfg.IMS_PER_BATCH = cfg.IMS_PER_BATCH / len(gpus)

if __name__ == '__main__':

    # fix the random seeds (numpy and caffe) for reproducibility
    np.random.seed(cfg.RNG_SEED)
    caffe.set_random_seed(cfg.RNG_SEED)

    # setup caffe
    caffe.set_mode_gpu()

    # setup mpi
    if len(gpus) != mpi.Size():
        raise ValueError('Excepted {} mpi nodes, but got {}.'.format(
            len(gpus), mpi.Size()))
    caffe.set_device(gpus[mpi.Rank()])
    mpi.Parallel([i for i in xrange(len(gpus))])
    mpi.Snapshot([0])
    if mpi.Rank() != 0:
        caffe.set_root_solver(False)

    # setup database
    cfg.DATABASE = imdb_name
    imdb = get_imdb(imdb_name)
    print 'Database({}): {} images will be used to train.'.format(
        cfg.DATABASE, imdb.db_size)
    output_dir = osp.abspath(
        osp.join(cfg.ROOT_DIR, 'output', cfg.EXP_DIR, args.imdb_name))
    print 'Output will be saved to `{:s}`'.format(output_dir)

    # train net
예제 #4
0
import dragon.config
dragon.config.LogOptimizedGraph()
import dragon.memonger as opt
opt.ShareGrads()

cfg.DATA_DIR = 'data/quad'

if __name__ == '__main__':

    # fix the random seeds (numpy and caffe) for reproducibility
    np.random.seed(cfg.RNG_SEED)
    caffe.set_random_seed(cfg.RNG_SEED)

    # setup caffe
    caffe.set_mode_gpu()
    caffe.set_device(gpu_id)

    # setup database
    cfg.DATABASE = imdb_name
    imdb = get_imdb(imdb_name)
    print 'Database({}): {} images will be used to train.'.format(
        cfg.DATABASE, imdb.db_size)
    output_dir = osp.abspath(osp.join(cfg.ROOT_DIR, 'output', imdb_name))
    print 'Output will be saved to `{:s}`'.format(output_dir)

    # train net
    train_net(solver_txt,
              output_dir,
              pretrained_model=pretrained_model,
              snapshot_model=snapshot_model,
              start_iter=start_iter,
예제 #5
0
    args = parse_args()

    prototxt = os.path.join(cfg.MODELS_DIR, NETS[args.demo_net][0],
                            'test.prototxt')
    caffemodel = os.path.join(cfg.DATA_DIR, 'ssd_models',
                              NETS[args.demo_net][1])

    if not os.path.isfile(caffemodel):
        raise IOError(('{:s} not found.\nDid you run ./data/script/'
                       'fetch_faster_rcnn_models.sh?').format(caffemodel))

    if args.cpu_mode:
        caffe.set_mode_cpu()
    else:
        caffe.set_mode_gpu()
        caffe.set_device(args.gpu_id)
        cfg.GPU_ID = args.gpu_id
    net = caffe.Net(prototxt, caffemodel, caffe.TEST)

    print '\n\nLoaded network {:s}'.format(caffemodel)

    # Warmup on a dummy image
    im = 128 * np.ones((1, 300, 500, 3), dtype=np.uint8)
    for i in xrange(2):
        _, _ = im_detect(net, im)

    im_names = [
        '000456.jpg', '000542.jpg', '001150.jpg', '001763.jpg', '004545.jpg'
    ]
    for im_name in im_names:
        print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'