Пример #1
0
def random_retrieval(args, retrieval, n_values, write=False):
    """Random retrieval baseline for one survey."""
    mAP_l = []
    recall_l = []
    for _ in range(args.repeat):
        order_l = retrieval.random_retrieval()
        if write:
            res_dir = 'res/random/%d/' % args.trial
            retrieval.write_retrieval(order_l, args.top_k,
                                      '%s/order.txt' % res_dir,
                                      '%s/rank.txt' % res_dir)

        rank_l = retrieval.get_retrieval_rank(order_l, args.top_k)
        gt_name_d = retrieval.get_gt_rank("name")
        mAP = metrics.mAP(rank_l, gt_name_d)
        mAP_l.append(mAP)

        gt_idx_l = retrieval.get_gt_rank("idx")
        recalls = metrics.recallN(order_l, gt_idx_l, n_values)
        recall_l.append(recalls)

    mAP_avg = np.sum(mAP_l) / args.repeat
    print('mAP: %.3f' % mAP_avg)

    recall_avg = np.sum(np.vstack(recall_l), axis=0) / args.repeat
    for i, n in enumerate(n_values):
        print("- Recall@%d: %.4f" % (n, recall_avg[i]))
    return mAP_avg, recall_avg
Пример #2
0
def bench(args, kwargs, sess, extractor_fn, centroids, n_values):
    """Runs and evaluate retrieval on all specified slices and surveys.
        Load surveys, computes descriptor for all images, runs the retrieval and
        compute the metrics.

    Args:
        args: various user parameters.
        kwargs: params to init survey.
        sess: tf session.
        img_op: image placeholder.
        des_op: descriptor operation.
        n_values: list of values to compute the recall at.
    """
    res_dir = "res/delf/%d/retrieval/" % args.trial
    perf_dir = "res/delf/%d/perf/" % args.trial

    retrieval_instance_l = [
        l.split("\n")[0].split() for l in open(args.instance).readlines()
    ]

    for l in retrieval_instance_l:
        # load db traversal
        slice_id = int(l[0])
        cam_id = int(l[1])
        surveyFactory = datasets.survey.SurveyFactory()
        meta_fn = "%s/%d/c%d_db.txt" % (args.meta_dir, slice_id, cam_id)
        kwargs["meta_fn"] = meta_fn
        db_survey = surveyFactory.create(args.data, **kwargs)

        for survey_id in l[2:]:
            global_start_time = time.time()
            survey_id = int(survey_id)
            print("\nSlice %d\tCam %d\tSurvey %d" %
                  (slice_id, cam_id, survey_id))
            # check if this bench already exists
            perf_dir = 'res/delf/%d/perf' % args.trial
            mAP_fn = "%s/%d_c%d_%d_mAP.txt" % (perf_dir, slice_id, cam_id,
                                               survey_id)
            recalls_fn = "%s/%d_c%d_%d_rec.txt" % (perf_dir, slice_id, cam_id,
                                                   survey_id)
            if os.path.exists(mAP_fn):
                continue

            # load query traversal
            meta_fn = "%s/%d/c%d_%d.txt" % (args.meta_dir, slice_id, cam_id,
                                            survey_id)
            kwargs["meta_fn"] = meta_fn
            q_survey = surveyFactory.create(args.data, **kwargs)

            # retrieval instance
            retrieval = datasets.retrieval.Retrieval(db_survey, q_survey,
                                                     args.dist_pos)
            q_survey = retrieval.get_q_survey()

            # describe db img
            local_start_time = time.time()
            db_des_fn = '%s/%d_c%d_db.txt' % (res_dir, slice_id, cam_id)
            if not os.path.exists(
                    db_des_fn):  # if you did not compute the db already
                print('** Compute des for database img **')
                db_des_v = describe_survey(args, sess, extractor_fn, centroids,
                                           db_survey)
                np.savetxt(db_des_fn, db_des_v)
            else:  # if you already computed it, load it from disk
                print('** Load des for database img **')
                db_des_v = np.loadtxt(db_des_fn)
            duration = (time.time() - local_start_time)
            print('(END) run time: %d:%02d' % (duration / 60, duration % 60))

            # describe q img
            local_start_time = time.time()
            q_des_fn = '%s/%d_c%d_%d.txt' % (res_dir, slice_id, cam_id,
                                             survey_id)
            if not os.path.exists(
                    q_des_fn):  # if you did not compute the db already
                print('\n** Compute des for query img **')
                q_des_v = describe_survey(args, sess, extractor_fn, centroids,
                                          q_survey)
                np.savetxt(q_des_fn, q_des_v)
            else:  # if you already computed it, load it from disk
                print('\n** Load des for query img **')
                q_des_v = np.loadtxt(q_des_fn)
            duration = (time.time() - local_start_time)
            print('(END) run time: %d:%02d' % (duration / 60, duration % 60))

            # retrieve each query
            print('\n** Retrieve query image **')
            local_start_time = time.time()
            d = np.linalg.norm(np.expand_dims(q_des_v, 1) -
                               np.expand_dims(db_des_v, 0),
                               ord=None,
                               axis=2)
            order = np.argsort(d, axis=1)
            #np.savetxt(order_fn, order, fmt='%d')
            duration = (time.time() - local_start_time)
            print('(END) run time %d:%02d' % (duration / 60, duration % 60))

            # compute perf
            print('\n** Compute performance **')
            local_start_time = time.time()
            rank_l = retrieval.get_retrieval_rank(order, args.top_k)

            gt_name_d = retrieval.get_gt_rank("name")
            mAP = metrics.mAP(rank_l, gt_name_d)

            gt_idx_l = retrieval.get_gt_rank("idx")
            recalls = metrics.recallN(order, gt_idx_l, n_values)

            duration = (time.time() - local_start_time)
            print('(END) run time: %d:%02d' % (duration / 60, duration % 60))

            # log
            print("\nmAP: %.3f" % mAP)
            for i, n in enumerate(n_values):
                print("Recall@%d: %.3f" % (n, recalls[i]))
            duration = (time.time() - global_start_time)
            print('Global run time retrieval: %d:%02d' %
                  (duration / 60, duration % 60))

            # write retrieval
            order_fn = "%s/%d_c%d_%d_order.txt" % (res_dir, slice_id, cam_id,
                                                   survey_id)
            rank_fn = "%s/%d_c%d_%d_rank.txt" % (res_dir, slice_id, cam_id,
                                                 survey_id)
            retrieval.write_retrieval(order, args.top_k, order_fn, rank_fn)

            # write perf
            perf_dir = 'res/delf/%d/perf' % args.trial
            np.savetxt(recalls_fn, np.array(recalls))
            np.savetxt(mAP_fn, np.array([mAP]))
Пример #3
0
def bench(args, n_values):
    """Runs retrieval.

    Args:
        args: retrieval parameters
        retrieval: retrieval instance
        n_value: values at which you compute recalls.
    """
    global_start_time = time.time()

    # check if this bench already exists
    perf_dir = 'res/wasabi/%d/perf'%args.trial
    mAP_fn = "%s/%d_c%d_%d_mAP.txt"%(perf_dir, args.slice_id, args.cam_id,
        args.survey_id)
    recalls_fn = "%s/%d_c%d_%d_rec.txt"%(perf_dir, args.slice_id, args.cam_id,
        args.survey_id)
    if os.path.exists(mAP_fn):
        return -1, -1 

    res_dir = "res/wasabi/%d/retrieval/"%args.trial
    
    # load db traversal
    surveyFactory = datasets.survey.SurveyFactory()
    meta_fn = "%s/%d/c%d_db.txt"%(args.meta_dir, args.slice_id, args.cam_id)
    kwargs = {"meta_fn": meta_fn, "img_dir": args.img_dir, "seg_dir": args.seg_dir}
    db_survey = surveyFactory.create(args.data, **kwargs)
    
    # load query traversal
    meta_fn = "%s/%d/c%d_%d.txt"%(args.meta_dir, args.slice_id, args.cam_id, 
            args.survey_id)
    kwargs["meta_fn"] = meta_fn
    q_survey = surveyFactory.create(args.data, **kwargs)

    # retrieval instance
    retrieval = datasets.retrieval.Retrieval(db_survey, q_survey, args.dist_pos)
    q_survey = retrieval.get_q_survey()

    # describe db img
    local_start_time = time.time()
    db_des_fn = '%s/%d_c%d_db.pickle'%(res_dir, args.slice_id, args.cam_id)
    if not os.path.exists(db_des_fn): # if you did not compute the db already
        print('** Compute des for database img **')
        db_img_des_l = get_img_des_parallel(args, db_survey)
        with open(db_des_fn, 'wb') as f:
            pickle.dump(db_img_des_l, f)
    else: # if you already computed it, load it from disk
        print('** Load des for database img **')
        with open(db_des_fn, 'rb') as f:
            db_img_des_l = pickle.load(f)
    duration = (time.time() - local_start_time)
    print('(END) run time: %d:%02d'%(duration/60, duration%60))


    # describe q img
    local_start_time = time.time()
    q_des_fn = '%s/%d_c%d_%d.pickle'%(res_dir, args.slice_id, args.cam_id,
            args.survey_id)
    if not os.path.exists(q_des_fn): # if you did not compute the db already
        print('\n** Compute des for query img **')
        q_img_des_l = get_img_des_parallel(args, q_survey)
        with open(q_des_fn, 'wb') as f:
            pickle.dump(q_img_des_l, f)
    else: # if you already computed it, load it from disk
        print('\n** Load des for database img **')
        with open(q_des_fn, 'rb') as f:
            q_img_des_l = pickle.load(f)
    duration = (time.time() - local_start_time)
    print('(END) run time: %d:%02d'%(duration/60, duration%60))
    

    # retrieve each query
    print('\n** Retrieve query image **')
    local_start_time = time.time()
    order_l = retrieve_parallel(args, q_img_des_l, db_img_des_l)
    duration = (time.time() - local_start_time)
    print('(END) run time %d:%02d'%(duration/60, duration%60))
   

    # compute perf
    print('\n** Compute performance **')
    local_start_time = time.time()
    rank_l = retrieval.get_retrieval_rank(order_l, args.top_k)
    
    gt_name_d = retrieval.get_gt_rank("name")
    mAP = metrics.mAP(rank_l, gt_name_d)

    gt_idx_l = retrieval.get_gt_rank("idx")
    recalls = metrics.recallN(order_l, gt_idx_l, n_values)
    
    duration = (time.time() - local_start_time)
    print('(END) run time: %d:%02d'%(duration/60, duration%60))

  
    # log
    print("\nmAP: %.3f"%mAP)
    for i, n in enumerate(n_values):
        print("Recall@%d: %.3f"%(n, recalls[i]))
    duration = (time.time() - global_start_time)
    print('Global run time retrieval: %d:%02d'%(duration/60, duration%60))
    
    # write retrieval
    order_fn = "%s/%d_c%d_%d_order.txt"%(res_dir, args.slice_id, args.cam_id,
            args.survey_id)
    rank_fn = "%s/%d_c%d_%d_rank.txt"%(res_dir, args.slice_id, args.cam_id,
            args.survey_id)
    retrieval.write_retrieval(order_l, args.top_k, order_fn, rank_fn)

    # write perf
    perf_dir = 'res/wasabi/%d/perf'%args.trial
    np.savetxt(recalls_fn, np.array(recalls))
    np.savetxt(mAP_fn, np.array([mAP]))
    return mAP, recalls
Пример #4
0
def bench(args, kwargs, centroids, n_values):
    """ """
    global_start_time = time.time()

    # check if this bench already exists
    perf_dir = "res/%s/%d/perf" % (args.agg_mode, args.trial)
    mAP_fn = "%s/%d_c%d_%d_mAP.txt" % (perf_dir, args.slice_id, args.cam_id,
                                       args.survey_id)
    recalls_fn = "%s/%d_c%d_%d_rec.txt" % (perf_dir, args.slice_id,
                                           args.cam_id, args.survey_id)
    if os.path.exists(mAP_fn):
        return -1, -1

    res_dir = "res/%s/%d/retrieval/" % (args.agg_mode, args.trial)

    # load db traversal
    surveyFactory = datasets.survey.SurveyFactory()
    meta_fn = "%s/%d/c%d_db.txt" % (args.meta_dir, args.slice_id, args.cam_id)
    kwargs["meta_fn"] = meta_fn
    db_survey = surveyFactory.create(args.data, **kwargs)

    # load query traversal
    meta_fn = "%s/%d/c%d_%d.txt" % (args.meta_dir, args.slice_id, args.cam_id,
                                    args.survey_id)
    kwargs["meta_fn"] = meta_fn
    q_survey = surveyFactory.create(args.data, **kwargs)

    # retrieval instance. Filters out queries without matches.
    retrieval = datasets.retrieval.Retrieval(db_survey, q_survey,
                                             args.dist_pos)
    q_survey = retrieval.get_q_survey()

    # choose a local feature extractor
    feFactory = FeatureExtractorFactory()
    kwargs = {"max_num_feat": args.max_num_feat}
    fe = feFactory.create(args.lf_mode, kwargs)

    # describe db img
    local_start_time = time.time()
    db_des_fn = '%s/%d_c%d_db.pickle' % (res_dir, args.slice_id, args.cam_id)
    if not os.path.exists(db_des_fn):  # if you did not compute the db already
        print('** Compute des for database img **')
        db_img_des_v = describe_survey(args, fe, centroids, db_survey)
        with open(db_des_fn, 'wb') as f:
            pickle.dump(db_img_des_v, f)
    else:  # if you already computed it, load it from disk
        print('** Load des for database img **')
        with open(db_des_fn, 'rb') as f:
            db_img_des_v = pickle.load(f)
    duration = (time.time() - local_start_time)
    print('(END) run time: %d:%02d' % (duration / 60, duration % 60))

    # describe q img
    local_start_time = time.time()
    q_des_fn = '%s/%d_c%d_%d.pickle' % (res_dir, args.slice_id, args.cam_id,
                                        args.survey_id)
    if not os.path.exists(q_des_fn):  # if you did not compute the db already
        print('\n** Compute des for query img **')
        q_img_des_v = describe_survey(args, fe, centroids, q_survey)
        with open(q_des_fn, 'wb') as f:
            pickle.dump(q_img_des_v, f)
    else:  # if you already computed it, load it from disk
        print('\n** Load des for database img **')
        with open(q_des_fn, 'rb') as f:
            q_img_des_v = pickle.load(f)
    duration = (time.time() - local_start_time)
    print('(END) run time: %d:%02d' % (duration / 60, duration % 60))

    # retrieve each query
    print('\n** Retrieve query image **')
    local_start_time = time.time()
    d = np.linalg.norm(np.expand_dims(q_img_des_v, 1) -
                       np.expand_dims(db_img_des_v, 0),
                       ord=None,
                       axis=2)
    order = np.argsort(d, axis=1)
    #np.savetxt(order_fn, order, fmt='%d')
    duration = (time.time() - local_start_time)
    print('(END) run time %d:%02d' % (duration / 60, duration % 60))

    # compute perf
    print('\n** Compute performance **')
    local_start_time = time.time()
    rank_l = retrieval.get_retrieval_rank(order, args.top_k)

    gt_name_d = retrieval.get_gt_rank("name")
    mAP = metrics.mAP(rank_l, gt_name_d)

    gt_idx_l = retrieval.get_gt_rank("idx")
    recalls = metrics.recallN(order, gt_idx_l, n_values)

    duration = (time.time() - local_start_time)
    print('(END) run time: %d:%02d' % (duration / 60, duration % 60))

    # log
    print("\nmAP: %.3f" % mAP)
    for i, n in enumerate(n_values):
        print("Recall@%d: %.3f" % (n, recalls[i]))
    duration = (time.time() - global_start_time)
    print('Global run time retrieval: %d:%02d' %
          (duration / 60, duration % 60))

    # write retrieval
    order_fn = "%s/%d_c%d_%d_order.txt" % (res_dir, args.slice_id, args.cam_id,
                                           args.survey_id)
    rank_fn = "%s/%d_c%d_%d_rank.txt" % (res_dir, args.slice_id, args.cam_id,
                                         args.survey_id)
    retrieval.write_retrieval(order, args.top_k, order_fn, rank_fn)

    # write perf
    perf_dir = 'res/vlad/%d/perf' % args.trial
    np.savetxt(recalls_fn, np.array(recalls))
    np.savetxt(mAP_fn, np.array([mAP]))
    return mAP, recalls