Пример #1
0
    roidb = get_training_roidb(imdb)
    cfg.GPU_ID = args.gpu_id
    device_name = '/gpu:{:d}'.format(args.gpu_id)
    print device_name

    cfg.TRAIN.NUM_STEPS = 1
    cfg.TRAIN.GRID_SIZE = cfg.TEST.GRID_SIZE
    cfg.TRAIN.TRAINABLE = False

    from networks.factory import get_network
    network = get_network(args.network_name)
    print 'Use network `{:s}` in training'.format(args.network_name)

    # start a session
    saver = tf.train.Saver()
    if args.kfusion:
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                                gpu_options=gpu_options))
    else:
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    saver.restore(sess, args.model)
    print('Loading model weights from {:s}').format(args.model)
    print("                           ", args.network_name)
    if cfg.TEST.SINGLE_FRAME:
        test_net_single_frame(sess, network, imdb, weights_filename,
                              args.rig_name, args.kfusion)
    else:
        test_net(sess, network, imdb, roidb, weights_filename, args.rig_name,
                 args.kfusion)
Пример #2
0
    print('Using config:')
    pprint.pprint(cfg)

    while not os.path.exists(args.model) and args.wait:
        print('Waiting for {} to exist...'.format(args.model))
        time.sleep(10)

    weights_filename = os.path.splitext(os.path.basename(args.model))[0]

    imdb = get_imdb(args.imdb_name)
    imdb.competition_mode(args.comp_mode)

    cfg.GPU_ID = args.gpu_id
    device_name = '/gpu:{:d}'.format(args.gpu_id)
    print device_name

    network = get_network(args.network_name, args.pretrained_model)
    print 'Use network `{:s}` in training'.format(args.network_name)

    # build the network
    network.data = tf.placeholder(tf.float32, shape=[None, None, None, 3])
    network.build(network.data, train=False, num_classes=imdb.num_classes)

    # start a session
    saver = tf.train.Saver()
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    saver.restore(sess, args.model)
    print ('Loading model weights from {:s}').format(args.model)

    test_net(sess, network, imdb, weights_filename)