예제 #1
0
def main():
    start = time.time()

    # init args
    args = parse_args()
    assert args.config_file is not None, 'a config file must be provided!'

    # init and load retrieval pipeline settings
    cfg = get_defaults_cfg()
    cfg = setup_cfg(cfg, args.config_file, args.opts)

    # load features
    query_fea, query_info, _ = feature_loader.load(cfg.index.query_fea_dir,
                                                   cfg.index.feature_names)
    gallery_fea, gallery_info, _ = feature_loader.load(
        cfg.index.gallery_fea_dir, cfg.index.feature_names)

    print("using init_load time: {:6f}s".format(time.time() - start))
    # build helper and index features
    index_helper = build_index_helper(cfg.index)
    index_result_info, query_fea, gallery_fea = index_helper.do_index(
        query_fea, query_info, gallery_fea)

    # # build helper and evaluate results
    # evaluate_helper = build_evaluate_helper(cfg.evaluate)
    # mAP, recall_at_k = evaluate_helper.do_eval(index_result_info, gallery_info)

    # # show results
    # evaluate_helper.show_results(mAP, recall_at_k)

    save_result(args.save_file, index_result_info, query_info, gallery_info)
    save_result_n(r"./result_tmp/submission_30.csv", index_result_info,
                  query_info, gallery_info)

    print("using total time: {:6f}s".format(time.time() - start))
def main():

    # init args
    args = parse_args()
    assert args.config_file is not None, 'a config file must be provided!'

    # init and load retrieval pipeline settings
    cfg = get_defaults_cfg()
    cfg = setup_cfg(cfg, args.config_file, args.opts)

    # load features
    query_fea, query_info, _ = feature_loader.load(cfg.index.query_fea_dir,
                                                   cfg.index.feature_names)
    gallery_fea, gallery_info, _ = feature_loader.load(
        cfg.index.gallery_fea_dir, cfg.index.feature_names)

    # build helper and index features
    index_helper = build_index_helper(cfg.index)
    index_result_info, query_fea, gallery_fea = index_helper.do_index(
        query_fea, query_info, gallery_fea)

    # build helper and evaluate results
    evaluate_helper = build_evaluate_helper(cfg.evaluate)
    mAP, recall_at_k = evaluate_helper.do_eval(index_result_info, gallery_info)

    # show results
    evaluate_helper.show_results(mAP, recall_at_k)
def main():

    # init args
    args = parse_args()
    assert args.config_file is not "", 'a config file must be provided!'
    assert os.path.exists(args.config_file), 'the config file must be existed!'

    # init and load retrieval pipeline settings
    cfg = get_defaults_cfg()
    cfg = setup_cfg(cfg, args.config_file, args.opts)

    # set path for single image
    path = '/data/caltech101/query/airplanes/image_0004.jpg'

    # build transformers
    transformers = build_transformers(cfg.datasets.transformers)

    # build model
    model = build_model(cfg.model)

    # read image and convert it to tensor
    img = Image.open(path).convert("RGB")
    img_tensor = transformers(img)

    # build helper and extract feature for single image
    extract_helper = build_extract_helper(model, cfg.extract)
    img_fea_info = extract_helper.do_single_extract(img_tensor)
    stacked_feature = list()
    for name in cfg.index.feature_names:
        assert name in img_fea_info[
            0], "invalid feature name: {} not in {}!".format(
                name, img_fea_info[0].keys())
        stacked_feature.append(img_fea_info[0][name].cpu())
    img_fea = np.concatenate(stacked_feature, axis=1)

    # load gallery features
    gallery_fea, gallery_info, _ = feature_loader.load(
        cfg.index.gallery_fea_dir, cfg.index.feature_names)

    # build helper and single index feature
    index_helper = build_index_helper(cfg.index)
    index_result_info, query_fea, gallery_fea = index_helper.do_index(
        img_fea, img_fea_info, gallery_fea)

    index_helper.save_topk_retrieved_images('retrieved_images/',
                                            index_result_info[0], 5,
                                            gallery_info)

    print('single index have done!')
예제 #4
0
def main():

    # init args
    args = parse_args()
    assert args.data_json is not None, 'the dataset json must be provided!'
    assert args.save_path is not None, 'the save path must be provided!'
    assert args.config_file is not None, 'a config file must be provided!'

    # init and load retrieval pipeline settings
    cfg = get_defaults_cfg()
    cfg = setup_cfg(cfg, args.config_file, args.opts)

    # build dataset and dataloader
    dataset = build_folder(args.data_json, cfg.datasets)
    dataloader = build_loader(dataset, cfg.datasets)

    # build model
    model = build_model(cfg.model)

    # build helper and extract features
    extract_helper = build_extract_helper(model, cfg.extract)
    extract_helper.do_extract(dataloader, args.save_path, args.save_interval)