示例#1
0
def get_iou_callable():
    """
    Get a pairwise box iou callable.
    """
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        A = tf.placeholder(tf.float32, shape=[None, 4])
        B = tf.placeholder(tf.float32, shape=[None, 4])
        iou = pairwise_iou(A, B)
        sess = tf.Session(config=get_default_sess_config())
        return sess.make_callable(iou, [A, B])
示例#2
0
def get_tf_nms():
    """
    Get a NMS callable.
    """
    boxes = tf.placeholder(tf.float32, shape=[None, 4])
    scores = tf.placeholder(tf.float32, shape=[None])
    indices = tf.image.non_max_suppression(boxes, scores,
                                           config.RESULTS_PER_IM,
                                           config.FASTRCNN_NMS_THRESH)
    sess = tf.Session(config=get_default_sess_config())
    return sess.make_callable(indices, [boxes, scores])
示例#3
0
def get_iou_callable():
    """
    Get a pairwise box iou callable.
    """
    # We don't want the dataflow process to touch CUDA
    # Data needs tensorflow. As a result, the training cannot run on GPUs with
    # EXCLUSIVE_PROCESS mode, unless you disable multiprocessing prefetch.
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        A = tf.placeholder(tf.float32, shape=[None, 4])
        B = tf.placeholder(tf.float32, shape=[None, 4])
        iou = pairwise_iou(A, B)
        sess = tf.Session(config=get_default_sess_config())
        return sess.make_callable(iou, [A, B])
示例#4
0
if __name__ == '__main__':
    import argparse
    from tensorpack.dataflow import TestDataSpeed
    from tensorpack.tfutils import get_default_sess_config
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', required=True)
    parser.add_argument('--batch', type=int, default=32)
    parser.add_argument('--aug', choices=['train', 'val'], default='val')
    parser.add_argument('--symbolic', action='store_true')
    args = parser.parse_args()

    if not args.symbolic:
        augs = fbresnet_augmentor(args.aug == 'train')
        df = get_imagenet_dataflow(args.data, 'train', args.batch, augs)
        # For val augmentor, Should get >100 it/s (i.e. 3k im/s) here on a decent E5 server.
        TestDataSpeed(df).start()
    else:
        assert args.aug == 'train'
        data = get_imagenet_tfdata(args.data, 'train', args.batch)

        itr = data.make_initializable_iterator()
        dp = itr.get_next()
        dpop = tf.group(*dp)
        with tf.Session(config=get_default_sess_config()) as sess:
            sess.run(itr.initializer)
            for _ in tqdm.trange(200):
                sess.run(dpop)
            for _ in tqdm.trange(5000, smoothing=0.1):
                sess.run(dpop)
示例#5
0
                                  shape=[num_train_images, 128],
                                  trainable=False)
    net = ResNetModel(num_output=(2048, 128) if args.v2 else (128, ))
    with TowerContext("", is_training=False):
        feat = net.forward(image_input)
        feat = tf.math.l2_normalize(feat, axis=1)  # Nx128
    all_feat = hvd.allgather(feat)  # GN x 128
    all_idx_input = hvd.allgather(idx_input)  # GN
    update_buffer = tf.scatter_update(feat_buffer, all_idx_input, all_feat)

    dist = tf.matmul(feat, tf.transpose(feat_buffer))  # N x #DS
    _, topk_indices = tf.math.top_k(dist, k=args.top_k)  # Nxtopk

    train_ds = build_dataflow(local_train_files)

    config = get_default_sess_config()
    config.gpu_options.visible_device_list = str(hvd.local_rank())

    def evaluate(checkpoint_file):
        result_file = get_checkpoint_path(
            checkpoint_file) + f".knn{args.top_k}.txt"
        if os.path.isfile(result_file):
            logger.info(f"Skipping evaluation of {result_file}.")
            return
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())
            SmartInit(checkpoint_file).init(sess)
            for batch_img, batch_idx in tqdm.tqdm(train_ds,
                                                  total=len(train_ds)):
                sess.run(update_buffer,
                         feed_dict={