def main():
    Settings.init()
    init_gpus(Settings.gpu_ids)
    cfg = load_config(Settings.conf)
    model = ReidNet(cfg)
    if len(Settings.gpu_ids.split(',')) > 1:
        print('>>> Using multi-GPUs, enable DataParallel')
        model = torch.nn.DataParallel(model)
    print('>>> Using model weights: {}'.format(cfg.TEST.PRETRAINED_MODEL))
    state = torch.load(cfg.TEST.PRETRAINED_MODEL)

    # qnum = 3368
    qnum = 1678  # query num
    # qnum = 100 # query num
    model.load_state_dict(state['model_state_dict'])

    dataset = VeRi(root=cfg.DATASET.PATH,
                   mode='query')  # total 1678 query samples
    # dataset = Market1501(root=cfg.DATASET.PATH, mode='query') # 3368 (query)
    random.shuffle(dataset.query)
    dataset.query = dataset.query[:qnum]  # choose only qnum samples
    dataset = EvalDataset((256, 256), dataset, mode='query')

    gallery_dataset = VeRi(root=cfg.DATASET.PATH, mode='gallery')
    # gallery_dataset = Market1501(root=cfg.DATASET.PATH, mode='gallery')
    gallery_dataset = EvalDataset((256, 256), gallery_dataset, mode='gallery')

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=32,
                                             shuffle=False,
                                             num_workers=4)
    gallery_loader = torch.utils.data.DataLoader(gallery_dataset,
                                                 batch_size=32,
                                                 shuffle=False,
                                                 num_workers=4)

    evaluator = ReidEvaluator(cfg, num_query=qnum)

    start_time = time.time()
    res = inference_on_dataset(model, dataloader, evaluator, gallery_loader)
    end_time = time.time()

    # unpcaking
    rank_n = [res['Rank-{}'.format(n)] for n in [1, 5, 10]]
    mAP = res['mAP']
    roc = [res['TPR@FPR={:.0e}'.format(fpr)] for fpr in [1e-4, 1e-3, 1e-2]]

    summary(rank_n,
            mAP,
            roc,
            dataset='VeRi',
            num_query=qnum,
            time_cons=(end_time - start_time))
Пример #2
0
        for cfg in cfgs:
            self.predictors.append(DefaultPredictor(cfg))

    def run_on_loader(self, data_loader):
        for batch in data_loader:
            predictions = []
            for predictor in self.predictors:
                predictions.append(predictor(batch["images"]))
            yield torch.cat(predictions, dim=-1), batch


if __name__ == "__main__":
    args = get_parser().parse_args()
    logger = setup_logger()
    cfgs = []
    for config_file in args.config_file:
        cfg = setup_cfg(config_file, args.opts)
        cfgs.append(cfg)
    results = OrderedDict()
    for dataset_name in cfgs[0].DATASETS.TESTS:
        test_loader, num_query = build_reid_test_loader(cfgs[0], dataset_name)
        evaluator = ReidEvaluator(cfgs[0], num_query)
        feat_extract = FeatureExtraction(cfgs)
        for (feat, batch) in tqdm.tqdm(feat_extract.run_on_loader(test_loader),
                                       total=len(test_loader)):
            evaluator.process(batch, feat)
        result = evaluator.evaluate()
        results[dataset_name] = result
    print_csv_format(results)
Пример #3
0
 def build_evaluator(cls, cfg, num_query, output_dir=None):
     return ReidEvaluator(cfg, num_query, output_dir)
Пример #4
0
 def build_evaluator(cls, cfg, dataset_name, output_dir=None):
     data_loader, num_query = cls.build_test_loader(cfg, dataset_name)
     return data_loader, ReidEvaluator(cfg, num_query, output_dir)
Пример #5
0
 def build_evaluator(cls, cfg, num_query, output_folder=None):
     if output_folder is None:
         output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
     return ReidEvaluator(cfg, num_query)
Пример #6
0
 def build_evaluator(cls, test_loader, test_pair_file):
     return ReidEvaluator(test_loader, test_pair_file)
def get_evaluator(cfg, dataset_name, output_dir=None):
    data_loader, num_query = build_reid_test_loader(cfg,
                                                    dataset_name=dataset_name)
    return data_loader, ReidEvaluator(cfg, num_query, output_dir)