def main():

    # init args
    args = parse_args()
    assert args.save_path is not None, 'the save path must be provided!'
    assert args.search_modules is not None, 'the search modules must be provided!'

    # init retrieval pipeline settings
    cfg = get_defaults_cfg()

    # load search space
    datasets = load_datasets()
    pre_processes = importlib.import_module("{}.pre_process_dict".format(
        args.search_modules)).pre_processes
    models = importlib.import_module("{}.extract_dict".format(
        args.search_modules)).models
    extracts = importlib.import_module("{}.extract_dict".format(
        args.search_modules)).extracts

    # search in an exhaustive way
    for data_name, data_args in datasets.items():
        for pre_proc_name, pre_proc_args in pre_processes.items():
            for model_name, model_args in models.items():

                feature_full_name = data_name + "_" + pre_proc_name + "_" + model_name
                print(feature_full_name)

                if os.path.exists(
                        os.path.join(args.save_path, feature_full_name)):
                    print("[Search Extract]: config exists...")
                    continue

                # load retrieval pipeline settings
                cfg.datasets.merge_from_other_cfg(pre_proc_args)
                cfg.model.merge_from_other_cfg(model_args)
                cfg.extract.merge_from_other_cfg(extracts[model_name])

                # build dataset and dataloader
                dataset = build_folder(data_args, cfg.datasets)
                dataloader = build_loader(dataset, cfg.datasets)

                # build model
                model = build_model(cfg.model)

                # build helper and extract features
                extract_helper = build_extract_helper(model, cfg.extract)
                extract_helper.do_extract(dataloader,
                                          save_path=os.path.join(
                                              args.save_path,
                                              feature_full_name))
def main():

    # init args
    args = parse_args()

    # init retrieval pipeline settings
    cfg = get_defaults_cfg()

    # load search space
    datasets = load_datasets()
    pre_processes = importlib.import_module("{}.pre_process_dict".format(args.search_modules)).pre_processes
    models = importlib.import_module("{}.extract_dict".format(args.search_modules)).models
    extracts = importlib.import_module("{}.extract_dict".format(args.search_modules)).extracts

    for data_name, data_args in datasets.items():
        for pre_proc_name, pre_proc_args in pre_processes.items():
            for model_name, model_args in models.items():

                feature_full_name = data_name + "_" + pre_proc_name + "_" + model_name
                print(feature_full_name)

                if os.path.exists(os.path.join(args.save_path, feature_full_name)):
                    print("[Search Extract]: config exists...")
                    continue

                # load retrieval pipeline settings
                cfg.datasets.merge_from_other_cfg(pre_proc_args)
                cfg.model.merge_from_other_cfg(model_args)
                cfg.extract.merge_from_other_cfg(extracts[model_name])

                # set train feature path for pwa
                pwa_train_fea_dir = os.path.join("/data/features/test_gap_gmp_gem_crow_spoc", feature_full_name)
                if "query" in pwa_train_fea_dir:
                    pwa_train_fea_dir.replace("query", "gallery")
                elif "paris" in pwa_train_fea_dir:
                    pwa_train_fea_dir.replace("paris", "oxford_gallery")
                print("[PWA Extractor]: train feature: {}".format(pwa_train_fea_dir))
                cfg.extract.aggregators.PWA.train_fea_dir = pwa_train_fea_dir

                # build dataset and dataloader
                dataset = build_folder(data_args, cfg.datasets)
                dataloader = build_loader(dataset, cfg.datasets)

                # build model
                model = build_model(cfg.model)

                # build helper and extract features
                extract_helper = build_extract_helper(model, cfg.extract)
                extract_helper.do_extract(dataloader, save_path=os.path.join(args.save_path, feature_full_name))
def main():

    # init args
    args = parse_args()
    assert args.config_file is not "", 'a config file must be provided!'
    assert os.path.exists(args.config_file), 'the config file must be existed!'

    # init and load retrieval pipeline settings
    cfg = get_defaults_cfg()
    cfg = setup_cfg(cfg, args.config_file, args.opts)

    # set path for single image
    path = '/data/caltech101/query/airplanes/image_0004.jpg'

    # build transformers
    transformers = build_transformers(cfg.datasets.transformers)

    # build model
    model = build_model(cfg.model)

    # read image and convert it to tensor
    img = Image.open(path).convert("RGB")
    img_tensor = transformers(img)

    # build helper and extract feature for single image
    extract_helper = build_extract_helper(model, cfg.extract)
    img_fea_info = extract_helper.do_single_extract(img_tensor)
    stacked_feature = list()
    for name in cfg.index.feature_names:
        assert name in img_fea_info[
            0], "invalid feature name: {} not in {}!".format(
                name, img_fea_info[0].keys())
        stacked_feature.append(img_fea_info[0][name].cpu())
    img_fea = np.concatenate(stacked_feature, axis=1)

    # load gallery features
    gallery_fea, gallery_info, _ = feature_loader.load(
        cfg.index.gallery_fea_dir, cfg.index.feature_names)

    # build helper and single index feature
    index_helper = build_index_helper(cfg.index)
    index_result_info, query_fea, gallery_fea = index_helper.do_index(
        img_fea, img_fea_info, gallery_fea)

    index_helper.save_topk_retrieved_images('retrieved_images/',
                                            index_result_info[0], 5,
                                            gallery_info)

    print('single index have done!')
Beispiel #4
0
def main():

    # init args
    args = parse_args()
    assert args.data_json is not None, 'the dataset json must be provided!'
    assert args.save_path is not None, 'the save path must be provided!'
    assert args.config_file is not None, 'a config file must be provided!'

    # init and load retrieval pipeline settings
    cfg = get_defaults_cfg()
    cfg = setup_cfg(cfg, args.config_file, args.opts)

    # build dataset and dataloader
    dataset = build_folder(args.data_json, cfg.datasets)
    dataloader = build_loader(dataset, cfg.datasets)

    # build model
    model = build_model(cfg.model)

    # build helper and extract features
    extract_helper = build_extract_helper(model, cfg.extract)
    extract_helper.do_extract(dataloader, args.save_path, args.save_interval)